Skip to content

Commit

Permalink
Add base class
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3835

As base class for D2 (facebookresearch@11528ce) tracking module

Reviewed By: zhanghang1989

Differential Revision: D32671844

fbshipit-source-id: a83c3629b7cc326333be7cccf14c1ff95bcfe433
  • Loading branch information
wenliangzhao2018 authored and facebook-github-bot committed Dec 30, 2021
1 parent e3053c1 commit f7bc78e
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions detectron2/tracking/base_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
from ..structures import Instances
from detectron2.utils.registry import Registry
from ..config.config import CfgNode as CfgNode_
from detectron2.config import configurable


TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
TRACKER_HEADS_REGISTRY.__doc__ = """
Registry for tracking classes.
"""


class BaseTracker(object):
"""
A parent class for all trackers
"""
@configurable
def __init__(self, **kwargs):
self._prev_instances = None # (D2)instances for previous frame
self._matched_idx = set() # indices in prev_instances found matching
self._matched_ID = set() # idendities in prev_instances found matching
self._untracked_prev_idx = set() # indices in prev_instances not found matching
self._id_count = 0 # used to assign new id

@classmethod
def from_config(cls, cfg: CfgNode_):
raise NotImplementedError("Calling BaseTracker::from_config")

def update(self, predictions: Instances) -> Instances:
"""
Args:
predictions: D2 Instances for predictions of the current frame
Return:
D2 Instances for predictions of the current frame with ID assigned
_prev_instances and instances will have the following fields:
.pred_boxes (shape=[N, 4])
.scores (shape=[N,])
.pred_classes (shape=[N,])
.pred_keypoints (shape=[N, M, 3], Optional)
.pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W]
.ID (shape=[N,])
N: # of detected bboxes
H and W: height and width of 2D mask
"""
raise NotImplementedError("Calling BaseTracker::update")


def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
"""
Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`.
"""
name = cfg.TRACKING.TRACKER_NAME
return TRACKER_HEADS_REGISTRY.get(name)(cfg)

0 comments on commit f7bc78e

Please sign in to comment.