Skip to content

Commit

Permalink
Add build tracker based on registry
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#3867

as titled

Reviewed By: zhanghang1989

Differential Revision: D32685315

fbshipit-source-id: b21a662a5538dd6d8a260d154a3f4b88fcf179c8
  • Loading branch information
wenliangzhao2018 authored and facebook-github-bot committed Jan 10, 2022
1 parent e1166a1 commit 932f25a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
19 changes: 13 additions & 6 deletions detectron2/tracking/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
from detectron2.config import CfgNode as CfgNode_
from detectron2.config import configurable
from detectron2.structures import Instances
from ..structures import Instances
from detectron2.utils.registry import Registry
from ..config.config import CfgNode as CfgNode_
from detectron2.config import configurable


TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS")
TRACKER_HEADS_REGISTRY.__doc__ = """
Expand Down Expand Up @@ -51,7 +52,13 @@ def update(self, predictions: Instances) -> Instances:

def build_tracker_head(cfg: CfgNode_) -> BaseTracker:
"""
Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`.
Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`.
Args:
cfg: D2 CfgNode, config file with tracker information
Return:
tracker object
"""
name = cfg.TRACKING.TRACKER_NAME
return TRACKER_HEADS_REGISTRY.get(name)(cfg)
name = cfg.TRACKER_HEADS.TRACKER_NAME
tracker_class = TRACKER_HEADS_REGISTRY.get(name)
return tracker_class(cfg)
7 changes: 4 additions & 3 deletions tests/tracking/test_bbox_iou_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
import numpy as np

from detectron2.tracking.bbox_iou_tracker import BBoxIOUTracker
from typing import Dict
from detectron2.structures import Boxes, Instances
from detectron2.config import instantiate, CfgNode as CfgNode_
from detectron2.tracking.base_tracker import build_tracker_head
from detectron2.tracking.bbox_iou_tracker import BBoxIOUTracker # noqa
from copy import deepcopy


Expand Down Expand Up @@ -85,15 +86,15 @@ def test_init(self):
def test_from_config(self):
cfg = CfgNode_()
cfg.TRACKER_HEADS = CfgNode_()
cfg.TRACKER_HEADS.TRACKER_NAME = "BBoxIOUTracker"
cfg.TRACKER_HEADS.VIDEO_HEIGHT = int(self._img_size[0])
cfg.TRACKER_HEADS.VIDEO_WIDTH = int(self._img_size[1])
cfg.TRACKER_HEADS.MAX_NUM_INSTANCES = self._max_num_instances
cfg.TRACKER_HEADS.MAX_LOST_FRAME_COUNT = self._max_lost_frame_count
cfg.TRACKER_HEADS.MIN_BOX_REL_DIM = self._min_box_rel_dim
cfg.TRACKER_HEADS.MIN_INSTANCE_PERIOD = self._min_instance_period
cfg.TRACKER_HEADS.TRACK_IOU_THRESHOLD = self._track_iou_threshold
input_arg = BBoxIOUTracker.from_config(cfg)
tracker = instantiate(input_arg)
tracker = build_tracker_head(cfg)
self.assertTrue(tracker._video_height == self._img_size[0])

def test_initialize_extra_fields(self):
Expand Down

0 comments on commit 932f25a

Please sign in to comment.