Skip to content

Commit

Permalink
Support H-DINO (#10893)
Browse files Browse the repository at this point in the history
Co-authored-by: huanghaian <[email protected]>
  • Loading branch information
PhoenixZ810 and hhaAndroid authored Sep 19, 2023
1 parent dfe7a57 commit 9915a5e
Show file tree
Hide file tree
Showing 6 changed files with 469 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mmdet/models/dense_heads/deformable_detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def init_weights(self) -> None:
nn.init.constant_(m[-1].bias.data[2:], 0.0)

def forward(self, hidden_states: Tensor,
references: List[Tensor]) -> Tuple[Tensor]:
references: List[Tensor]) -> Tuple[Tensor, Tensor]:
"""Forward function.
Args:
Expand Down
35 changes: 35 additions & 0 deletions projects/HDINO/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# H-DETR

> [DETRs with Hybrid Matching](https://arxiv.org/abs/2207.13080)
<!-- [ALGORITHM] -->

## Abstract

One-to-one set matching is a key design for DETR to establish its end-to-end capability, so that object detection does not require a hand-crafted NMS (non-maximum suppression) to remove duplicate detections. This end-to-end signature is important for the versatility of DETR, and it has been generalized to broader vision tasks. However, we note that there are few queries assigned as positive samples and the one-to-one set matching significantly reduces the training efficacy of positive samples. We propose a simple yet effective method based on a hybrid matching scheme that combines the original one-to-one matching branch with an auxiliary one-to-many matching branch during training. Our hybrid strategy has been shown to significantly improve accuracy. In inference, only the original one-to-one match branch is used, thus maintaining the end-to-end merit and the same inference efficiency of DETR. The method is named H-DETR, and it shows that a wide range of representative DETR methods can be consistently improved across a wide range of visual tasks, including DeformableDETR, PETRv2, PETR, and TransTrack, among others.

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/254f3037-1ca8-4d0c-8f3e-45d8ec3f9abc"/>
</div>

## Results and Models

| Backbone | Model | Lr schd | box AP | Config | Download |
| :------: | :-----------: | :-----: | :----: | :--------------------------------------------: | :------: |
| R-50 | H-DINO-4scale | 12e | 48.0 | [config](./h-dino-4scale_r50_8xb2-12e_coco.py) | |

### NOTE

1. We are based on `DINO` rather than `Deformable DETR` to support the `Hybrid Matching` algorithm.
2. We found that directly applying Hybrid Matching to the DINO algorithm results in a significant decrease in performance. If you have any other insights or suggestions, please feel free to comment or submit a pull request (PR).

## Citation

```latex
@article{jia2022detrs,
title={DETRs with Hybrid Matching},
author={Jia, Ding and Yuan, Yuhui and He, Haodi and Wu, Xiaopei and Yu, Haojun and Lin, Weihong and Sun, Lei and Zhang, Chao and Hu, Han},
journal={arXiv preprint arXiv:2207.13080},
year={2022}
}
```
4 changes: 4 additions & 0 deletions projects/HDINO/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .h_dino import HDINO
from .h_dino_head import HybridDINOHead

__all__ = ['HDINO', 'HybridDINOHead']
168 changes: 168 additions & 0 deletions projects/HDINO/h-dino-4scale_r50_8xb2-12e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
_base_ = [
'../../configs/_base_/datasets/coco_detection.py',
'../../configs/_base_/default_runtime.py'
]

custom_imports = dict(imports=['projects.HDINO'], allow_failed_imports=False)

model = dict(
type='HDINO',
num_queries=1800, # num_total_queries: 900+900
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0))), # 0.1 for DeformDETR
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='HybridDINOHead',
num_classes=80,
sync_cls_avg_factor=True,
num_query_one2one=900,
k_one2many=2,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict(
label_noise_scale=0.5,
box_noise_scale=1.0, # 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackDetInputs')
]
train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa

# learning policy
max_epochs = 12
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
149 changes: 149 additions & 0 deletions projects/HDINO/h_dino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from torch import Tensor, nn
from torch.nn.init import normal_

from mmdet.models.detectors import DINO, DeformableDETR
from mmdet.models.detectors.deformable_detr import \
MultiScaleDeformableAttention
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList
from mmdet.utils import OptConfigType


@MODELS.register_module()
class HDINO(DINO):

def __init__(self,
*args,
bbox_head: OptConfigType = None,
**kwargs) -> None:
self.method = 0
self.num_query_one2one = bbox_head['num_query_one2one']
super(HDINO, self).__init__(*args, bbox_head=bbox_head, **kwargs)

def _init_layers(self) -> None:
super(HDINO, self)._init_layers()
self.query_embedding = None
if self.method == 1:
self.query_map = nn.Linear(self.embed_dims, self.embed_dims)
else:
self.pos_trans_fc = nn.Linear(self.embed_dims * 2, self.embed_dims)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims)

def init_weights(self) -> None:
super(DeformableDETR, self).init_weights()
"""Initialize weights for Transformer and other components."""
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
normal_(self.level_embed)

if self.method == 1:
nn.init.xavier_uniform_(self.query_map.weight)
else:
nn.init.xavier_uniform_(self.pos_trans_fc.weight)

def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[dict, dict]:

bs, _, c = memory.shape
cls_out_features = self.bbox_head.cls_branches[
self.decoder.num_layers].out_features

output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals

# NOTE The DINO selects top-k proposals according to scores of
# multi-class classification, while DeformDETR, where the input
# is `enc_outputs_class[..., 0]` selects according to scores of
# binary classification.
topk_indices = torch.topk(
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
topk_score = torch.gather(
enc_outputs_class, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, 4))
topk_coords = topk_coords_unact.sigmoid()
topk_coords_unact = topk_coords_unact.detach()

# We only made changes here.
# -------------------------------------
if self.method == 1:
map_memory = self.query_map(memory.detach())
query = torch.gather(
map_memory, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, self.embed_dims))
else:
pos_trans_out = self.pos_trans_fc(
self.get_proposal_pos_embed(topk_coords_unact))
query = self.pos_trans_norm(pos_trans_out)
# -------------------------------------

if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
else:
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()

decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
# NOTE DINO calculates encoder losses on scores and coordinates
# of selected top-k encoder queries, while DeformDETR is of all
# encoder queries.
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
dn_meta=dn_meta) if self.training else dict()

# We only made changes here.
# -------------------------------------
if self.training:
# train: num_denoising_queries + num_query_one2one
# + num_query_one2many
dn_mask = decoder_inputs_dict['dn_mask']
num_denoising_queries = head_inputs_dict['dn_meta'][
'num_denoising_queries']
num_query_one2one = num_denoising_queries + self.num_query_one2one
# dn_mask[num_query_one2one:, :num_query_one2one] = True
dn_mask[num_denoising_queries:num_query_one2one,
num_query_one2one:] = True
decoder_inputs_dict['dn_mask'] = dn_mask
else:
# test: num_query_one2one
# + num_query_one2many
query = decoder_inputs_dict['query']
reference_points = decoder_inputs_dict['reference_points']
num_query_one2many = self.num_queries - self.num_query_one2one
decoder_inputs_dict['query'] = query[:num_query_one2many]
decoder_inputs_dict[
'reference_points'] = reference_points[:num_query_one2many]
# -------------------------------------
return decoder_inputs_dict, head_inputs_dict
Loading

0 comments on commit 9915a5e

Please sign in to comment.