forked from cfchen-duke/ProtoPNet
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_pruning.py
82 lines (63 loc) · 3.07 KB
/
run_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import argh
import gin
import torch
import torch.utils.data
import prune
from preprocess import preprocess
from segmentation.data_module import PatchClassificationDataModule
from log import create_logger
from segmentation.dataset import PatchClassificationDataset
def run_pruning(config_name: str, experiment_name: str, k: int = 6, prune_threshold: int = 3):
gin.parse_config_file(f'segmentation/configs/{config_name}.gin', skip_unknown=True)
gin.parse_config_file(os.path.join(os.environ['RESULTS_DIR'], experiment_name, 'config.gin'),
skip_unknown=True)
model_path = os.path.join(os.environ['RESULTS_DIR'], experiment_name, 'checkpoints/push_last.pth')
output_dir = os.path.join(os.environ['RESULTS_DIR'], experiment_name, 'pruned')
os.makedirs(output_dir, exist_ok=True)
log, logclose = create_logger(log_filename=os.path.join(output_dir, 'prune.log'))
ppnet = torch.load(model_path)
ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
# load the data
# TODO use configurable value for model_image_size here
data_module = PatchClassificationDataModule(batch_size=1)
# TODO: implement test here for segmentation
# test_loader = data_module.val_dataloader(batch_size=1)
# push set: needed for pruning because it is unnormalized
train_push_loader = data_module.train_push_dataloader(batch_size=1)
push_dataset = PatchClassificationDataset(
split_key='train',
is_eval=True,
push_prototypes=True
)
def preprocess_push_input(x):
return preprocess(x, mean=train_push_loader.dataset.mean, std=train_push_loader.dataset.std)
# log('test set size: {0}'.format(len(test_loader.dataset)))
log('push set size: {0}'.format(len(train_push_loader.dataset)))
# prune prototypes
log('prune')
with torch.no_grad():
# accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
# class_specific=class_specific, log=log)
# log(f"Accuracy before pruning: {accu}")
prune.prune_prototypes(dataset=push_dataset,
prototype_network_parallel=ppnet_multi,
k=k,
prune_threshold=prune_threshold,
preprocess_input_function=preprocess_push_input, # normalize
original_model_dir=output_dir,
epoch_number=0,
# model_name=None,
log=log,
copy_prototype_imgs=False, )
# accu = tnt.test(model=ppnet_multi, dataloader=test_loader,
# class_specific=class_specific, log=log)
# log(f"Accuracy after pruning: {accu}")
save.save_model_w_condition(model=ppnet, model_dir=output_dir,
model_name='pruned',
accu=1.0,
target_accu=0.0, log=log)
logclose()
if __name__ == '__main__':
argh.dispatch_command(run_pruning)