forked from facebookresearch/detectron2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: facebookresearch#3835 As base class for D2 (facebookresearch@11528ce) tracking module Reviewed By: zhanghang1989 Differential Revision: D32671844 fbshipit-source-id: a83c3629b7cc326333be7cccf14c1ff95bcfe433
- Loading branch information
1 parent
e3053c1
commit f7bc78e
Showing
1 changed file
with
57 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2004-present Facebook. All Rights Reserved. | ||
from ..structures import Instances | ||
from detectron2.utils.registry import Registry | ||
from ..config.config import CfgNode as CfgNode_ | ||
from detectron2.config import configurable | ||
|
||
|
||
TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") | ||
TRACKER_HEADS_REGISTRY.__doc__ = """ | ||
Registry for tracking classes. | ||
""" | ||
|
||
|
||
class BaseTracker(object): | ||
""" | ||
A parent class for all trackers | ||
""" | ||
@configurable | ||
def __init__(self, **kwargs): | ||
self._prev_instances = None # (D2)instances for previous frame | ||
self._matched_idx = set() # indices in prev_instances found matching | ||
self._matched_ID = set() # idendities in prev_instances found matching | ||
self._untracked_prev_idx = set() # indices in prev_instances not found matching | ||
self._id_count = 0 # used to assign new id | ||
|
||
@classmethod | ||
def from_config(cls, cfg: CfgNode_): | ||
raise NotImplementedError("Calling BaseTracker::from_config") | ||
|
||
def update(self, predictions: Instances) -> Instances: | ||
""" | ||
Args: | ||
predictions: D2 Instances for predictions of the current frame | ||
Return: | ||
D2 Instances for predictions of the current frame with ID assigned | ||
_prev_instances and instances will have the following fields: | ||
.pred_boxes (shape=[N, 4]) | ||
.scores (shape=[N,]) | ||
.pred_classes (shape=[N,]) | ||
.pred_keypoints (shape=[N, M, 3], Optional) | ||
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] | ||
.ID (shape=[N,]) | ||
N: # of detected bboxes | ||
H and W: height and width of 2D mask | ||
""" | ||
raise NotImplementedError("Calling BaseTracker::update") | ||
|
||
|
||
def build_tracker_head(cfg: CfgNode_) -> BaseTracker: | ||
""" | ||
Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`. | ||
""" | ||
name = cfg.TRACKING.TRACKER_NAME | ||
return TRACKER_HEADS_REGISTRY.get(name)(cfg) |