forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: facebookresearch#3870 Add a Hungarian tracker base class Reviewed By: zhanghang1989 Differential Revision: D32690379 fbshipit-source-id: 91266fce6660f0413935b5556f76ed7fe2eaca0c
- Loading branch information
1 parent
932f25a
commit 38628d1
Showing
2 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2004-present Facebook. All Rights Reserved. | ||
import copy | ||
|
||
import numpy as np | ||
import torch | ||
from detectron2.structures import Boxes, Instances | ||
|
||
from .base_tracker import BaseTracker | ||
from scipy.optimize import linear_sum_assignment | ||
from ..config.config import CfgNode as CfgNode_ | ||
from typing import Dict | ||
from detectron2.config import configurable | ||
|
||
|
||
class BaseHungarianTracker(BaseTracker): | ||
""" | ||
A base class for all Hungarian trackers | ||
""" | ||
|
||
@configurable | ||
def __init__( | ||
self, | ||
video_height: int, | ||
video_width: int, | ||
max_num_instances: int = 200, | ||
max_lost_frame_count: int = 0, | ||
min_box_rel_dim: float = 0.02, | ||
min_instance_period: int = 1, | ||
**kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self._video_height = video_height | ||
self._video_width = video_width | ||
self._max_num_instances = max_num_instances | ||
self._max_lost_frame_count = max_lost_frame_count | ||
self._min_box_rel_dim = min_box_rel_dim | ||
self._min_instance_period = min_instance_period | ||
|
||
@classmethod | ||
def from_config(cls, cfg: CfgNode_) -> Dict: | ||
raise NotImplementedError("Calling HungarianTracker::from_config") | ||
|
||
def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: | ||
raise NotImplementedError("Calling HungarianTracker::build_matrix") | ||
|
||
def update(self, instances: Instances) -> Instances: | ||
if instances.has("pred_keypoints"): | ||
raise NotImplementedError("Need to add support for keypoints") | ||
instances = self._initialize_extra_fields(instances) | ||
if self._prev_instances is not None: | ||
self._untracked_prev_idx = set(range(len(self._prev_instances))) | ||
cost_matrix = self.build_cost_matrix(instances, self._prev_instances) | ||
matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) | ||
instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) | ||
instances = self._process_unmatched_idx(instances, matched_idx) | ||
instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) | ||
self._prev_instances = copy.deepcopy(instances) | ||
return instances | ||
|
||
def _initialize_extra_fields(self, instances: Instances) -> Instances: | ||
""" | ||
If input instances don't have ID, ID_period, lost_frame_count fields, | ||
this method is used to initialize these fields. | ||
Args: | ||
instances: D2 Instances, for predictions of the current frame | ||
Return: | ||
D2 Instances with extra fields added | ||
""" | ||
if not instances.has("ID"): | ||
instances.set("ID", [None] * len(instances)) | ||
if not instances.has("ID_period"): | ||
instances.set("ID_period", [None] * len(instances)) | ||
if not instances.has("lost_frame_count"): | ||
instances.set("lost_frame_count", [None] * len(instances)) | ||
if self._prev_instances is None: | ||
instances.ID = list(range(len(instances))) | ||
self._id_count += len(instances) | ||
instances.ID_period = [1] * len(instances) | ||
instances.lost_frame_count = [0] * len(instances) | ||
return instances | ||
|
||
def _process_matched_idx( | ||
self, | ||
instances: Instances, | ||
matched_idx: np.ndarray, | ||
matched_prev_idx: np.ndarray | ||
) -> Instances: | ||
assert matched_idx.size == matched_prev_idx.size | ||
for i in range(matched_idx.size): | ||
instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] | ||
instances.ID_period[matched_idx[i]] = \ | ||
self._prev_instances.ID_period[matched_prev_idx[i]] + 1 | ||
instances.lost_frame_count[matched_idx[i]] = 0 | ||
return instances | ||
|
||
def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: | ||
untracked_idx = set(range(len(instances))).difference(set(matched_idx)) | ||
for idx in untracked_idx: | ||
instances.ID[idx] = self._id_count | ||
self._id_count += 1 | ||
instances.ID_period[idx] = 1 | ||
instances.lost_frame_count[idx] = 0 | ||
return instances | ||
|
||
def _process_unmatched_prev_idx( | ||
self, | ||
instances: Instances, | ||
matched_prev_idx: | ||
np.ndarray | ||
) -> Instances: | ||
untracked_instances = Instances( | ||
image_size=instances.image_size, | ||
pred_boxes=[], | ||
pred_masks=[], | ||
pred_classes=[], | ||
scores=[], | ||
ID=[], | ||
ID_period=[], | ||
lost_frame_count=[], | ||
) | ||
prev_bboxes = list(self._prev_instances.pred_boxes) | ||
prev_classes = list(self._prev_instances.pred_classes) | ||
prev_scores = list(self._prev_instances.scores) | ||
prev_ID_period = self._prev_instances.ID_period | ||
if instances.has("pred_masks"): | ||
prev_masks = list(self._prev_instances.pred_masks) | ||
untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) | ||
for idx in untracked_prev_idx: | ||
x_left, y_top, x_right, y_bot = prev_bboxes[idx] | ||
if ( | ||
(1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) | ||
or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) | ||
or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count | ||
or prev_ID_period[idx] <= self._min_instance_period | ||
): | ||
continue | ||
untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) | ||
untracked_instances.pred_classes.append(int(prev_classes[idx])) | ||
untracked_instances.scores.append(float(prev_scores[idx])) | ||
untracked_instances.ID.append(self._prev_instances.ID[idx]) | ||
untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) | ||
untracked_instances.lost_frame_count.append( | ||
self._prev_instances.lost_frame_count[idx] + 1 | ||
) | ||
if instances.has("pred_masks"): | ||
untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) | ||
|
||
untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) | ||
untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) | ||
untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) | ||
if instances.has("pred_masks"): | ||
untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) | ||
else: | ||
untracked_instances.remove("pred_masks") | ||
|
||
return Instances.cat( | ||
[ | ||
instances, | ||
untracked_instances, | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import unittest | ||
from typing import Dict | ||
|
||
import numpy as np | ||
import torch | ||
from detectron2.config import instantiate | ||
from detectron2.structures import Boxes, Instances | ||
|
||
|
||
class TestBaseHungarianTracker(unittest.TestCase): | ||
def setUp(self): | ||
self._img_size = np.array([600, 800]) | ||
self._prev_boxes = np.array( | ||
[ | ||
[101, 101, 200, 200], | ||
[301, 301, 450, 450], | ||
] | ||
).astype(np.float32) | ||
self._prev_scores = np.array([0.9, 0.9]) | ||
self._prev_classes = np.array([1, 1]) | ||
self._prev_masks = np.ones((2, 600, 800)).astype("uint8") | ||
self._curr_boxes = np.array( | ||
[ | ||
[302, 303, 451, 452], | ||
[101, 102, 201, 203], | ||
] | ||
).astype(np.float32) | ||
self._curr_scores = np.array([0.95, 0.85]) | ||
self._curr_classes = np.array([1, 1]) | ||
self._curr_masks = np.ones((2, 600, 800)).astype("uint8") | ||
|
||
self._prev_instances = { | ||
"image_size": self._img_size, | ||
"pred_boxes": self._prev_boxes, | ||
"scores": self._prev_scores, | ||
"pred_classes": self._prev_classes, | ||
"pred_masks": self._prev_masks, | ||
} | ||
self._prev_instances = self._convertDictPredictionToInstance(self._prev_instances) | ||
self._curr_instances = { | ||
"image_size": self._img_size, | ||
"pred_boxes": self._curr_boxes, | ||
"scores": self._curr_scores, | ||
"pred_classes": self._curr_classes, | ||
"pred_masks": self._curr_masks, | ||
} | ||
self._curr_instances = self._convertDictPredictionToInstance(self._curr_instances) | ||
|
||
self._max_num_instances = 200 | ||
self._max_lost_frame_count = 0 | ||
self._min_box_rel_dim = 0.02 | ||
self._min_instance_period = 1 | ||
self._track_iou_threshold = 0.5 | ||
|
||
def _convertDictPredictionToInstance(self, prediction: Dict) -> Instances: | ||
""" | ||
convert prediction from Dict to D2 Instances format | ||
""" | ||
res = Instances( | ||
image_size=torch.IntTensor(prediction["image_size"]), | ||
pred_boxes=Boxes(torch.FloatTensor(prediction["pred_boxes"])), | ||
pred_masks=torch.IntTensor(prediction["pred_masks"]), | ||
pred_classes=torch.IntTensor(prediction["pred_classes"]), | ||
scores=torch.FloatTensor(prediction["scores"]), | ||
) | ||
return res | ||
|
||
def test_init(self): | ||
cfg = { | ||
"_target_": "detectron2.tracking.hungarian_tracker.BaseHungarianTracker", | ||
"video_height": self._img_size[0], | ||
"video_width": self._img_size[1], | ||
"max_num_instances": self._max_num_instances, | ||
"max_lost_frame_count": self._max_lost_frame_count, | ||
"min_box_rel_dim": self._min_box_rel_dim, | ||
"min_instance_period": self._min_instance_period, | ||
"track_iou_threshold": self._track_iou_threshold | ||
} | ||
tracker = instantiate(cfg) | ||
self.assertTrue(tracker._video_height == self._img_size[0]) | ||
|
||
def test_initialize_extra_fields(self): | ||
cfg = { | ||
"_target_": "detectron2.tracking.hungarian_tracker.BaseHungarianTracker", | ||
"video_height": self._img_size[0], | ||
"video_width": self._img_size[1], | ||
"max_num_instances": self._max_num_instances, | ||
"max_lost_frame_count": self._max_lost_frame_count, | ||
"min_box_rel_dim": self._min_box_rel_dim, | ||
"min_instance_period": self._min_instance_period, | ||
"track_iou_threshold": self._track_iou_threshold | ||
} | ||
tracker = instantiate(cfg) | ||
instances = tracker._initialize_extra_fields(self._curr_instances) | ||
self.assertTrue(instances.has("ID")) | ||
self.assertTrue(instances.has("ID_period")) | ||
self.assertTrue(instances.has("lost_frame_count")) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |