forked from cfchen-duke/ProtoPNet
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy patheval_test.py
119 lines (85 loc) · 4.11 KB
/
eval_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import argh
import gin
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
# noinspection PyUnresolvedReferences
from segmentation import train
from segmentation.constants import CITYSCAPES_19_EVAL_CATEGORIES, PASCAL_ID_MAPPING, CITYSCAPES_CATEGORIES, \
CITYSCAPES_ID_2_LABEL
from settings import data_path, log
def run_evaluation(model_name: str, training_phase: str, batch_size: int = 2, pascal: bool = False,
margin: int = 0):
model_path = os.path.join(os.environ['RESULTS_DIR'], model_name)
config_path = os.path.join(model_path, 'config.gin')
gin.parse_config_file(config_path)
if training_phase == 'pruned':
checkpoint_path = os.path.join(model_path, 'pruned/checkpoints/push_last.pth')
else:
checkpoint_path = os.path.join(model_path, f'checkpoints/{training_phase}_last.pth')
log(f'Loading model from {checkpoint_path}')
ppnet = torch.load(checkpoint_path) # , map_location=torch.device('cpu'))
ppnet = ppnet.cuda()
ppnet.eval()
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=NORM_MEAN, std=NORM_STD)
])
img_dir = os.path.join(data_path, f'img_with_margin_{margin}/test')
all_img_files = [p for p in os.listdir(img_dir) if p.endswith('.npy')]
# TODO pascal
ID_MAPPING = PASCAL_ID_MAPPING if pascal else CITYSCAPES_19_EVAL_CATEGORIES
OUR_ID_2_SOURCE_ID = {v: k for k, v in ID_MAPPING.items()}
if not pascal:
OUR_ID_2_SOURCE_ID[0] = 0
rev_origin = {v: k for k, v in CITYSCAPES_ID_2_LABEL.items()}
OUR_ID_2_SOURCE_ID = {k: rev_origin[CITYSCAPES_CATEGORIES[v]] for k, v in OUR_ID_2_SOURCE_ID.items()}
OUR_ID_2_SOURCE_ID = np.vectorize(OUR_ID_2_SOURCE_ID.get)
RESULTS_DIR = os.path.join(model_path, f'evaluation/test/{training_phase}')
os.makedirs(RESULTS_DIR, exist_ok=True)
np.random.shuffle(all_img_files)
n_batches = int(np.ceil(len(all_img_files) / batch_size))
batched_img_files = np.array_split(all_img_files, n_batches)
# batched_img_files = batched_img_files[:50]
with torch.no_grad():
for batch_img_files in tqdm(batched_img_files, desc='evaluating'):
img_tensors = []
img_arrays = []
for img_file in batch_img_files:
img = np.load(os.path.join(img_dir, img_file)).astype(np.uint8)
if margin != 0:
img = img[margin:-margin, margin:-margin]
img_arrays.append(img)
if pascal:
img_shape = (513, 513)
else:
img_shape = img.shape
img_tensor = transform(img)
if pascal:
img_tensor = torch.nn.functional.interpolate(img_tensor.unsqueeze(0),
size=img_shape, mode='bilinear', align_corners=False)[0]
img_tensors.append(img_tensor)
img_tensors = torch.stack(img_tensors, dim=0).cuda()
batch_logits, batch_distances = ppnet.forward(img_tensors)
del batch_distances, img_tensor
batch_logits = batch_logits.permute(0, 3, 1, 2)
for sample_i in range(len(batch_img_files)):
img = img_arrays[sample_i]
logits = torch.unsqueeze(batch_logits[sample_i], 0)
logits = F.interpolate(logits, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)[0]
pred = torch.argmax(logits, dim=0).cpu().detach().numpy()
# 0 is 'void'
pred = pred + 1
pred = OUR_ID_2_SOURCE_ID(pred)
pred_img = Image.fromarray(np.uint8(pred))
img_id = batch_img_files[sample_i].split('.')[0]
pred_img.convert("L").save(os.path.join(RESULTS_DIR, f'{img_id}.png'))
if __name__ == '__main__':
argh.dispatch_command(run_evaluation)