Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-503] Add the DDQ algorithm to mmdetection #10772

Merged
merged 13 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions configs/ddq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# DDQ

[Dense Distinct Query for End-to-End Object Detection](https://arxiv.org/abs/2303.12776)

## Abstract

One-to-one label assignment in object detection has successfully obviated the need for non-maximum suppression (NMS) as postprocessing and makes the pipeline end-to-end. However, it triggers a new dilemma as the widely used sparse queries cannot guarantee a high recall, while dense queries inevitably bring more similar queries and encounter optimization difficulties. As both sparse and dense queries are problematic, then what are the expected queries in end-to-end object detection? This paper shows that the solution should be Dense Distinct Queries (DDQ). Concretely, we first lay dense queries like traditional detectors and then select distinct ones for one-to-one assignments. DDQ blends the advantages of traditional and recent end-to-end detectors and significantly improves the performance of various detectors including FCN, R-CNN, and DETRs. Most impressively, DDQ-DETR achieves 52.1 AP on MS-COCO dataset within 12 epochs using a ResNet-50 backbone, outperforming all existing detectors in the same setting. DDQ also shares the benefit of end-to-end detectors in crowded scenes and achieves 93.8 AP on CrowdHuman. We hope DDQ can inspire researchers to consider the complementarity between traditional methods and end-to-end detectors.

![ddq_arch](https://github.com/open-mmlab/mmdetection/assets/33146359/5ca9f11b-b6f3-454f-a2d1-3009ee337bbc)

## Results and Models

| Model | Backbone | Lr schd | Augmentation | box AP(val) | Config | Download |
| :-------------: | :------: | :-----: | :----------: | :---------: | :------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| DDQ DETR-4scale | R-50 | 12e | DETR | 51.4 | [config](./ddq-detr-4scale_r50_8xb2-12e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq-detr-4scale_r50_8xb2-12e_coco/ddq-detr-4scale_r50_8xb2-12e_coco_20230809_170711-42528127.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq-detr-4scale_r50_8xb2-12e_coco/ddq-detr-4scale_r50_8xb2-12e_coco_20230809_170711.log.json) |
| DDQ DETR-5scale | R-50 | 12e | DETR | 52.1 | [config](./ddq-detr-5scale_r50_8xb2-12e_coco.py) | [model\*](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_5scale_coco_1x.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_5scale_coco_1x_20230319_103307.log) |
| DDQ DETR-4scale | Swin-L | 30e | DETR | 58.7 | [config](./ddq-detr-4scale_swinl_8xb2-30e_coco.py) | [model\*](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_swinl_30e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/ddq/ddq_detr_swinl_30e_20230316_221721_20230318_143554.log) |

**Note:** Models labeled "\*" are not trained by us, but from [DDQ official website](https://github.com/jshilong/DDQ).

## Citation

We provide the config files for DDQ: [Dense Distinct Query for End-to-End Object Detection](https://arxiv.org/abs/2303.12776).

```latex
@InProceedings{Zhang_2023_CVPR,
author = {Zhang, Shilong and Wang, Xinjiang and Wang, Jiaqi and Pang, Jiangmiao and Lyu, Chengqi and Zhang, Wenwei and Luo, Ping and Chen, Kai},
title = {Dense Distinct Query for End-to-End Object Detection},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {7329-7338}
}
```
170 changes: 170 additions & 0 deletions configs/ddq/ddq-detr-4scale_r50_8xb2-12e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DDQDETR',
num_queries=900, # num_matching_queries
# ratio of num_dense queries to num_queries
dense_topk_ratio=1.5,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
# encoder class name: DeformableDetrTransformerEncoder
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0))), # 0.1 for DeformDETR
# decoder class name: DDQTransformerDecoder
decoder=dict(
# `num_layers` >= 2, because attention masks of the last
# `num_layers` - 1 layers are used for distinct query selection
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='DDQDETRHead',
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict(
label_noise_scale=0.5,
box_noise_scale=1.0,
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
dqs_cfg=dict(type='nms', iou_threshold=0.8),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

train_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackDetInputs')
]

train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.05),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)}))

# learning policy
max_epochs = 12
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=False,
begin=0,
end=2000),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
177 changes: 177 additions & 0 deletions configs/ddq/ddq-detr-4scale_swinl_8xb2-30e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa: E501
model = dict(
type='DDQDETR',
num_queries=900, # num_matching_queries
# ratio of num_dense queries to num_queries
dense_topk_ratio=1.5,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='ChannelMapper',
in_channels=[384, 768, 1536],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
# encoder class name: DeformableDetrTransformerEncoder
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0))), # 0.1 for DeformDETR
# decoder class name: DDQTransformerDecoder
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0)), # 0.1 for DeformDETR
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0, # -0.5 for DeformDETR
temperature=20), # 10000 for DeformDETR
bbox_head=dict(
type='DDQDETRHead',
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict(
label_noise_scale=0.5,
box_noise_scale=1.0,
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
dqs_cfg=dict(type='nms', iou_threshold=0.8),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

train_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackDetInputs')
]

train_dataloader = dict(
dataset=dict(
filter_cfg=dict(filter_empty_gt=False), pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.05),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.05)}))

# learning policy
max_epochs = 30
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=False,
begin=0,
end=2000),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[20, 26],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)
Loading