Skip to content

Commit

Permalink
Enable DINO to OTX - Step 2. Upgrade Deformable DETR to DINO (#2266)
Browse files Browse the repository at this point in the history
* Add DINO

* Modify docstrings

* Add mmengine to detection requirements

* Add unit tests

* Add intg test

* Update CHANGELOG.md

* Change description of config files for DINO

* Modify unit tests

* Reflect reviews

* Reflect Reviews

* Update unit tests
  • Loading branch information
jaegukhyun authored Jun 27, 2023
1 parent bf30d6e commit f974b41
Show file tree
Hide file tree
Showing 20 changed files with 2,554 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file.
- Add custom max iou assigner to prevent CPU OOM when large annotations are used (<https://github.com/openvinotoolkit/training_extensions/pull/2228>)
- Auto train type detection for Semi-SL, Self-SL and Incremental: "--train-type" now is optional (https://github.com/openvinotoolkit/training_extensions/pull/2195)
- Add new object detector Deformable DETR (<https://github.com/openvinotoolkit/training_extensions/pull/2249>)
- Add new object detecotr DINO(<https://github.com/openvinotoolkit/training_extensions/pull/2266>)

### Enhancements

Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#

from . import assigners, backbones, dense_heads, detectors, heads, losses, necks, roi_heads
from . import assigners, backbones, dense_heads, detectors, heads, layers, losses, necks, roi_heads

__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "losses", "necks", "roi_heads"]
__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "layers", "losses", "necks", "roi_heads"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .custom_atss_detector import CustomATSS
from .custom_deformable_detr_detector import CustomDeformableDETR
from .custom_dino_detector import CustomDINO
from .custom_maskrcnn_detector import CustomMaskRCNN
from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized
from .custom_single_stage_detector import CustomSingleStageDetector
Expand All @@ -18,6 +19,7 @@
__all__ = [
"CustomATSS",
"CustomDeformableDETR",
"CustomDINO",
"CustomMaskRCNN",
"CustomSingleStageDetector",
"CustomTwoStageDetector",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""OTX DINO Class for mmdetection detectors."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmdet.models.builder import DETECTORS

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomDeformableDETR

logger = get_logger()


@DETECTORS.register_module()
class CustomDINO(CustomDeformableDETR):
"""Custom DINO detector."""

def __init__(self, *args, task_adapt=None, **kwargs):
super().__init__(*args, task_adapt=task_adapt, **kwargs)
self._register_load_state_dict_pre_hook(
self.load_state_dict_pre_hook,
)

@staticmethod
def load_state_dict_pre_hook(ckpt_dict, *args, **kwargs):
"""Modify mmdet3.x version's weights before weight loading."""

if list(ckpt_dict.keys())[0] == "level_embed":
logger.info("----------------- CustomDINO.load_state_dict_pre_hook() called")
# This ckpt_dict comes from mmdet3.x
ckpt_dict["bbox_head.transformer.level_embeds"] = ckpt_dict.pop("level_embed")
replaced_params = {}
for param in ckpt_dict:
new_param = None
if "encoder" in param or "decoder" in param:
new_param = "bbox_head.transformer." + param
new_param = new_param.replace("self_attn", "attentions.0")
new_param = new_param.replace("cross_attn", "attentions.1")
new_param = new_param.replace("ffn", "ffns.0")
elif param == "query_embedding.weight":
new_param = "bbox_head." + param
elif param == "dn_query_generator.label_embedding.weight":
new_param = "bbox_head.transformer." + param
elif "memory_trans" in param:
new_param = "bbox_head.transformer." + param
new_param = new_param.replace("memory_trans_fc", "enc_output")
new_param = new_param.replace("memory_trans_norm", "enc_output_norm")
if new_param is not None:
replaced_params[param] = new_param

for origin, new in replaced_params.items():
ckpt_dict[new] = ckpt_dict.pop(origin)


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.detection.adapters.mmdet.models.detectors.custom_dino_detector.CustomDINO.simple_test"
)
def custom_dino__simple_test(ctx, self, img, img_metas, **kwargs):
"""Function for custom_dino__simple_test."""
height = int(img_metas[0]["img_shape"][0])
width = int(img_metas[0]["img_shape"][1])
img_metas[0]["batch_input_shape"] = (height, width)
img_metas[0]["img_shape"] = (height, width, 3)
feats = self.extract_feat(img)
gt_bboxes = [None] * len(feats)
gt_labels = [None] * len(feats)
hidden_states, references, enc_output_class, enc_output_coord, _ = self.bbox_head.forward_transformer(
feats, gt_bboxes, gt_labels, img_metas
)
cls_scores, bbox_preds = self.bbox_head(hidden_states, references)
bbox_results = self.bbox_head.get_bboxes(
cls_scores, bbox_preds, enc_output_class, enc_output_coord, img_metas=img_metas, **kwargs
)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feats)
saliency_map = ActivationMapHook.func(cls_scores)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@
from .cross_dataset_detector_head import CrossDatasetDetectorHead
from .custom_anchor_generator import SSDAnchorGeneratorClustered
from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics
from .custom_dino_head import CustomDINOHead
from .custom_fcn_mask_head import CustomFCNMaskHead
from .custom_retina_head import CustomRetinaHead
from .custom_roi_head import CustomRoIHead
from .custom_ssd_head import CustomSSDHead
from .custom_vfnet_head import CustomVFNetHead
from .custom_yolox_head import CustomYOLOXHead
from .detr_head import DETRHeadExtension

__all__ = [
"CrossDatasetDetectorHead",
"SSDAnchorGeneratorClustered",
"CustomATSSHead",
"CustomDINOHead",
"CustomFCNMaskHead",
"CustomRetinaHead",
"CustomSSDHead",
"CustomRoIHead",
"CustomVFNetHead",
"CustomYOLOXHead",
"DETRHeadExtension",
# Loss dynamics tracking
"CustomATSSHeadTrackingLossDynamics",
]
Loading

0 comments on commit f974b41

Please sign in to comment.