forked from NirAharon/BoT-SORT
-
Notifications
You must be signed in to change notification settings - Fork 249
/
reid_evaluation.py
143 lines (113 loc) · 4.84 KB
/
reid_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# encoding: utf-8
"""
@author: liaoxingyu
@contact: [email protected]
"""
import copy
import logging
import time
import itertools
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from fast_reid.fastreid.utils import comm
from fast_reid.fastreid.utils.compute_dist import build_dist
from .evaluator import DatasetEvaluator
from .query_expansion import aqe
from .rank_cylib import compile_helper
logger = logging.getLogger(__name__)
class ReidEvaluator(DatasetEvaluator):
def __init__(self, cfg, num_query, output_dir=None):
self.cfg = cfg
self._num_query = num_query
self._output_dir = output_dir
self._cpu_device = torch.device('cpu')
self._predictions = []
self._compile_dependencies()
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
prediction = {
'feats': outputs.to(self._cpu_device, torch.float32),
'pids': inputs['targets'].to(self._cpu_device),
'camids': inputs['camids'].to(self._cpu_device)
}
self._predictions.append(prediction)
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
predictions = comm.gather(self._predictions, dst=0)
predictions = list(itertools.chain(*predictions))
if not comm.is_main_process():
return {}
else:
predictions = self._predictions
features = []
pids = []
camids = []
for prediction in predictions:
features.append(prediction['feats'])
pids.append(prediction['pids'])
camids.append(prediction['camids'])
features = torch.cat(features, dim=0)
pids = torch.cat(pids, dim=0).numpy()
camids = torch.cat(camids, dim=0).numpy()
# query feature, person ids and camera ids
query_features = features[:self._num_query]
query_pids = pids[:self._num_query]
query_camids = camids[:self._num_query]
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
gallery_pids = pids[self._num_query:]
gallery_camids = camids[self._num_query:]
self._results = OrderedDict()
if self.cfg.TEST.AQE.ENABLED:
logger.info("Test with AQE setting")
qe_time = self.cfg.TEST.AQE.QE_TIME
qe_k = self.cfg.TEST.AQE.QE_K
alpha = self.cfg.TEST.AQE.ALPHA
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
dist = build_dist(query_features, gallery_features, self.cfg.TEST.METRIC)
if self.cfg.TEST.RERANK.ENABLED:
logger.info("Test with rerank setting")
k1 = self.cfg.TEST.RERANK.K1
k2 = self.cfg.TEST.RERANK.K2
lambda_value = self.cfg.TEST.RERANK.LAMBDA
if self.cfg.TEST.METRIC == "cosine":
query_features = F.normalize(query_features, dim=1)
gallery_features = F.normalize(gallery_features, dim=1)
rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2)
dist = rerank_dist * (1 - lambda_value) + dist * lambda_value
from .rank import evaluate_rank
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1] * 100
self._results['mAP'] = mAP * 100
self._results['mINP'] = mINP * 100
self._results["metric"] = (mAP + cmc[0]) / 2 * 100
if self.cfg.TEST.ROC.ENABLED:
from .roc import evaluate_roc
scores, labels = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
fprs, tprs, thres = metrics.roc_curve(labels, scores)
for fpr in [1e-4, 1e-3, 1e-2]:
ind = np.argmin(np.abs(fprs - fpr))
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
return copy.deepcopy(self._results)
def _compile_dependencies(self):
# Since we only evaluate results in rank(0), so we just need to compile
# cython evaluation tool on rank(0)
if comm.is_main_process():
try:
from .rank_cylib.rank_cy import evaluate_cy
except ImportError:
start_time = time.time()
logger.info("> compiling reid evaluation cython tool")
compile_helper()
logger.info(
">>> done with reid evaluation cython tool. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time))
comm.synchronize()