Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable DINO to OTX - Step 2. Upgrade Deformable DETR to DINO #2266

Merged
merged 11 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file.
- Add custom max iou assigner to prevent CPU OOM when large annotations are used (<https://github.com/openvinotoolkit/training_extensions/pull/2228>)
- Auto train type detection for Semi-SL, Self-SL and Incremental: "--train-type" now is optional (https://github.com/openvinotoolkit/training_extensions/pull/2195)
- Add new object detector Deformable DETR (<https://github.com/openvinotoolkit/training_extensions/pull/2249>)
- Add new object detecotr DINO(<https://github.com/openvinotoolkit/training_extensions/pull/2266>)

### Enhancements

Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#

from . import assigners, backbones, dense_heads, detectors, heads, losses, necks, roi_heads
from . import assigners, backbones, dense_heads, detectors, heads, layers, losses, necks, roi_heads

__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "losses", "necks", "roi_heads"]
__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "layers", "losses", "necks", "roi_heads"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .custom_atss_detector import CustomATSS
from .custom_deformable_detr_detector import CustomDeformableDETR
from .custom_dino_detector import CustomDINO
from .custom_maskrcnn_detector import CustomMaskRCNN
from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized
from .custom_single_stage_detector import CustomSingleStageDetector
Expand All @@ -18,6 +19,7 @@
__all__ = [
"CustomATSS",
"CustomDeformableDETR",
"CustomDINO",
"CustomMaskRCNN",
"CustomSingleStageDetector",
"CustomTwoStageDetector",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""OTX DINO Class for mmdetection detectors."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import functools

from mmdet.models.builder import DETECTORS

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomDeformableDETR

logger = get_logger()


@DETECTORS.register_module()
class CustomDINO(CustomDeformableDETR):
"""Custom DINO detector."""

def __init__(self, *args, task_adapt=None, **kwargs):
super().__init__(*args, task_adapt=task_adapt, **kwargs)
self._register_load_state_dict_pre_hook(
functools.partial(
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
self.load_state_dict_pre_hook,
self,
)
)

@staticmethod
def load_state_dict_pre_hook(model, ckpt_dict, *args, **kwargs):
"""Modify mmdet3.x version's weights before weight loading."""

if list(ckpt_dict.keys())[0] == "level_embed":
logger.info("----------------- CustomDINO.load_state_dict_pre_hook() called")
# This ckpt_dict is come from mmdet3.x
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
ckpt_dict["bbox_head.transformer.level_embeds"] = ckpt_dict.pop("level_embed")
replaced_params = {}
for param in ckpt_dict:
new_param = None
if "encoder" in param or "decoder" in param:
new_param = "bbox_head.transformer." + param
new_param = new_param.replace("self_attn", "attentions.0")
new_param = new_param.replace("cross_attn", "attentions.1")
new_param = new_param.replace("ffn", "ffns.0")
elif param == "query_embedding.weight":
new_param = "bbox_head." + param
elif param == "dn_query_generator.label_embedding.weight":
new_param = "bbox_head.transformer." + param
elif "memory_trans" in param:
new_param = "bbox_head.transformer." + param
new_param = new_param.replace("memory_trans_fc", "enc_output")
new_param = new_param.replace("memory_trans_norm", "enc_output_norm")
if new_param is not None:
replaced_params[param] = new_param

for origin, new in replaced_params.items():
ckpt_dict[new] = ckpt_dict.pop(origin)


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.detection.adapters.mmdet.models.detectors.custom_dino_detector.CustomDINO.simple_test"
)
def custom_dino__simple_test(ctx, self, img, img_metas, **kwargs):
"""Function for custom_dino__simple_test."""
height = int(img_metas[0]["img_shape"][0])
width = int(img_metas[0]["img_shape"][1])
img_metas[0]["batch_input_shape"] = (height, width)
img_metas[0]["img_shape"] = (height, width, 3)
feats = self.extract_feat(img)
gt_bboxes = [None] * len(feats)
gt_labels = [None] * len(feats)
hidden_states, references, enc_output_class, enc_output_coord, _ = self.bbox_head.forward_transformer(
feats, gt_bboxes, gt_labels, img_metas
)
cls_scores, bbox_preds = self.bbox_head(hidden_states, references)
bbox_results = self.bbox_head.get_bboxes(
cls_scores, bbox_preds, enc_output_class, enc_output_coord, img_metas=img_metas, **kwargs
)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feats)
saliency_map = ActivationMapHook.func(cls_scores)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@
from .cross_dataset_detector_head import CrossDatasetDetectorHead
from .custom_anchor_generator import SSDAnchorGeneratorClustered
from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics
from .custom_dino_head import CustomDINOHead
from .custom_fcn_mask_head import CustomFCNMaskHead
from .custom_retina_head import CustomRetinaHead
from .custom_roi_head import CustomRoIHead
from .custom_ssd_head import CustomSSDHead
from .custom_vfnet_head import CustomVFNetHead
from .custom_yolox_head import CustomYOLOXHead
from .detr_head import DETRHeadExtension

__all__ = [
"CrossDatasetDetectorHead",
"SSDAnchorGeneratorClustered",
"CustomATSSHead",
"CustomDINOHead",
"CustomFCNMaskHead",
"CustomRetinaHead",
"CustomSSDHead",
"CustomRoIHead",
"CustomVFNetHead",
"CustomYOLOXHead",
"DETRHeadExtension",
# Loss dynamics tracking
"CustomATSSHeadTrackingLossDynamics",
]
Loading