-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
strong_sort.py
99 lines (85 loc) · 3.07 KB
/
strong_sort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import numpy as np
from boxmot.appearance.reid_multibackend import ReIDDetectMultiBackend
from boxmot.motion.cmc import get_cmc_method
from boxmot.trackers.strongsort.sort.detection import Detection
from boxmot.trackers.strongsort.sort.tracker import Tracker
from boxmot.utils.matching import NearestNeighborDistanceMetric
from boxmot.utils.ops import xyxy2tlwh
class StrongSORT(object):
def __init__(
self,
model_weights,
device,
fp16,
max_dist=0.2,
max_iou_dist=0.7,
max_age=30,
n_init=1,
nn_budget=100,
mc_lambda=0.995,
ema_alpha=0.9,
):
self.model = ReIDDetectMultiBackend(
weights=model_weights,
device=device,
fp16=fp16
)
self.tracker = Tracker(
metric=NearestNeighborDistanceMetric("cosine", max_dist, nn_budget),
max_iou_dist=max_iou_dist,
max_age=max_age,
n_init=n_init,
mc_lambda=mc_lambda,
ema_alpha=ema_alpha,
)
self.cmc = get_cmc_method('ecc')()
def update(self, dets, img):
assert isinstance(
dets, np.ndarray
), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray"
assert isinstance(
img, np.ndarray
), f"Unsupported 'img' input format '{type(img)}', valid format is np.ndarray"
assert (
len(dets.shape) == 2
), "Unsupported 'dets' dimensions, valid number of dimensions is two"
assert (
dets.shape[1] == 6
), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6"
dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)])
xyxy = dets[:, 0:4]
confs = dets[:, 4]
clss = dets[:, 5]
det_ind = dets[:, 6]
if len(self.tracker.tracks) >= 1:
warp_matrix = self.cmc.apply(img, xyxy)
for track in self.tracker.tracks:
track.camera_update(warp_matrix)
# extract appearance information for each detection
features = self.model.get_features(xyxy, img)
tlwh = xyxy2tlwh(xyxy)
detections = [
Detection(box, conf, cls, det_ind, feat) for
box, conf, cls, det_ind, feat in
zip(tlwh, confs, clss, det_ind, features)
]
# update tracker
self.tracker.predict()
self.tracker.update(detections)
# output bbox identities
outputs = []
for track in self.tracker.tracks:
if not track.is_confirmed() or track.time_since_update >= 1:
continue
x1, y1, x2, y2 = track.to_tlbr()
id = track.id
conf = track.conf
cls = track.cls
det_ind = track.det_ind
outputs.append(
np.concatenate(([x1, y1, x2, y2], [id], [conf], [cls], [det_ind])).reshape(1, -1)
)
if len(outputs) > 0:
return np.concatenate(outputs)
return np.array([])