-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_symnet_multi.py
98 lines (62 loc) · 2.5 KB
/
test_symnet_multi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from utils.base_solver import BaseSolver
import os, logging, importlib, re, copy, random, tqdm, argparse
import os.path as osp
import cPickle as pickle
import numpy as np
from collections import defaultdict
import tensorflow as tf
import tensorflow.contrib.slim as slim
import torch
from utils import config as cfg
from utils import dataset, utils
from utils.evaluator import Multi_Evaluator
from run_symnet_multi import make_parser
from test_symnet import SolverWrapper as BaseSolverWrapper
def main():
logger = logging.getLogger('MAIN')
parser = make_parser()
args = parser.parse_args()
args.network = 'symnet_multi'
utils.display_args(args, logger)
logger.info("Loading dataset")
test_dataloader = dataset.get_dataloader(args.data, 'test',
batchsize=args.test_bz, args=args)
logger.info("Loading network and solver")
network = importlib.import_module('network.'+args.network)
net = network.Network(test_dataloader, args)
with utils.create_session() as sess:
sw = SolverWrapper(net, test_dataloader, args)
sw.trainval_model(sess, args.epoch)
class SolverWrapper(BaseSolverWrapper):
def trainval_model(self, sess, max_epoch):
logger = self.logger('test_model')
logger.info('Begin testing')
score_op, train_summary_op = self.construct_graph(sess)
self.initialize(sess)
sess.graph.finalize()
evaluator = Multi_Evaluator()
all_attr = []
all_pred = defaultdict(list)
for image_ind, batch in tqdm.tqdm(enumerate(self.test_dataloader), total=len(self.test_dataloader), postfix='test'):
predictions = self.network.test_step(sess, batch, score_op)
all_attr.append(torch.from_numpy(batch[1]))
for key in score_op.keys():
all_pred[key].append(predictions[key])
for name, pred in all_pred.items():
mAP, mAUC = evaluator(
np.concatenate(pred,0),
np.concatenate(all_attr,0)
)
report_dict = {
'mAP': mAP,
'mAUC': mAUC,
'name': self.args.name,
'epoch': self.args.epoch,
}
print("%s: "%(name) + utils.formated_multi_result(report_dict))
logger.info('Finished.')
if __name__=="__main__":
main()