-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add DyHead * move and update DYReLU * update * replace stack with sum to reduce memory * clean and update * update to align inference accuracy (incomplete) * fix pad * update to align training accuracy and pick #6867 * add README and metafile * update docs * resolve comments * revert picking 6867 * update README.md * update metafile.yml * resolve comments and update urls
- Loading branch information
Showing
14 changed files
with
604 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# DyHead | ||
|
||
> [Dynamic Head: Unifying Object Detection Heads with Attentions](https://arxiv.org/abs/2106.08322) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
The complex nature of combining localization and classification in object detection has resulted in the flourished development of methods. Previous works tried to improve the performance in various object detection heads but failed to present a unified view. In this paper, we present a novel dynamic head framework to unify object detection heads with attentions. By coherently combining multiple self-attention mechanisms between feature levels for scale-awareness, among spatial locations for spatial-awareness, and within output channels for task-awareness, the proposed approach significantly improves the representation ability of object detection heads without any computational overhead. Further experiments demonstrate that the effectiveness and efficiency of the proposed dynamic head on the COCO benchmark. With a standard ResNeXt-101-DCN backbone, we largely improve the performance over popular object detectors and achieve a new state-of-the-art at 54.0 AP. Furthermore, with latest transformer backbone and extra data, we can push current best COCO result to a new record at 60.6 AP. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/42844407/149169448-fcafb6d0-b866-41cc-9422-94de9f1e1761.png" height="300"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
| Method | Backbone | Style | Setting | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download | | ||
|:------:|:--------:|:-------:|:------------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:| | ||
| ATSS | R-50 | caffe | reproduction | 1x | 5.4 | 13.2 | 42.5 | [config](./atss_r50_caffe_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939.log.json) | | ||
| ATSS | R-50 | pytorch | simple | 1x | 4.9 | 13.7 | 43.3 | [config](./atss_r50_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314.log.json) | | ||
|
||
- We trained the above models with 4 GPUs and 4 `samples_per_gpu`. | ||
- The `reproduction` setting aims to reproduce the official implementation based on Detectron2. | ||
- The `simple` setting serves as a minimum example to use DyHead in MMDetection. Specifically, | ||
- it adds `DyHead` to `neck` after `FPN` | ||
- it sets `stacked_convs=0` to `bbox_head` | ||
- The `simple` setting achieves higher AP than the original implementation. | ||
We have not conduct ablation study between the two settings. | ||
`dict(type='Pad', size_divisor=128)` may further improve AP by prefer spatial alignment across pyramid levels, although large padding reduces efficiency. | ||
|
||
## Relation to Other Methods | ||
|
||
- DyHead can be regarded as an improved [SEPC](https://arxiv.org/abs/2005.03101) with [DyReLU modules](https://arxiv.org/abs/2003.10027) and simplified [SE blocks](https://arxiv.org/abs/1709.01507). | ||
- Xiyang Dai et al., the author team of DyHead, adopt it for [Dynamic DETR](https://openaccess.thecvf.com/content/ICCV2021/html/Dai_Dynamic_DETR_End-to-End_Object_Detection_With_Dynamic_Attention_ICCV_2021_paper.html). | ||
The description of Dynamic Encoder in Sec. 3.2 will help you understand DyHead. | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{DyHead_CVPR2021, | ||
author = {Dai, Xiyang and Chen, Yinpeng and Xiao, Bin and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Zhang, Lei}, | ||
title = {Dynamic Head: Unifying Object Detection Heads With Attentions}, | ||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
year = {2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
model = dict( | ||
type='ATSS', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='caffe', | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint='open-mmlab://detectron2/resnet50_caffe')), | ||
neck=[ | ||
dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
dict( | ||
type='DyHead', | ||
in_channels=256, | ||
out_channels=256, | ||
num_blocks=6, | ||
# disable zero_init_offset to follow official implementation | ||
zero_init_offset=False) | ||
], | ||
bbox_head=dict( | ||
type='ATSSHead', | ||
num_classes=80, | ||
in_channels=256, | ||
pred_kernel_size=1, # follow DyHead official implementation | ||
stacked_convs=0, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128], | ||
center_offset=0.5), # follow DyHead official implementation | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=2.0), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict(type='ATSSAssigner', topk=9), | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=100)) | ||
# optimizer | ||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) | ||
|
||
# use caffe img_norm, size_divisor=128, pillow resize | ||
img_norm_cfg = dict( | ||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='Resize', | ||
img_scale=(1333, 800), | ||
keep_ratio=True, | ||
backend='pillow'), | ||
dict(type='RandomFlip', flip_ratio=0.5), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=128), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(1333, 800), | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True, backend='pillow'), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=128), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
model = dict( | ||
type='ATSS', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=True), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
neck=[ | ||
dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=256, | ||
start_level=1, | ||
add_extra_convs='on_output', | ||
num_outs=5), | ||
dict(type='DyHead', in_channels=256, out_channels=256, num_blocks=6) | ||
], | ||
bbox_head=dict( | ||
type='ATSSHead', | ||
num_classes=80, | ||
in_channels=256, | ||
stacked_convs=0, | ||
feat_channels=256, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
ratios=[1.0], | ||
octave_base_scale=8, | ||
scales_per_octave=1, | ||
strides=[8, 16, 32, 64, 128]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='GIoULoss', loss_weight=2.0), | ||
loss_centerness=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict(type='ATSSAssigner', topk=9), | ||
allowed_border=-1, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
nms_pre=1000, | ||
min_bbox_size=0, | ||
score_thr=0.05, | ||
nms=dict(type='nms', iou_threshold=0.6), | ||
max_per_img=100)) | ||
# optimizer | ||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
Collections: | ||
- Name: DyHead | ||
Metadata: | ||
Training Data: COCO | ||
Training Techniques: | ||
- SGD with Momentum | ||
- Weight Decay | ||
Training Resources: 4x T4 GPUs | ||
Architecture: | ||
- ATSS | ||
- DyHead | ||
- FPN | ||
- ResNet | ||
- Deformable Convolution | ||
- Pyramid Convolution | ||
Paper: | ||
URL: https://arxiv.org/abs/2106.08322 | ||
Title: 'Dynamic Head: Unifying Object Detection Heads with Attentions' | ||
README: configs/dyhead/README.md | ||
Code: | ||
URL: https://github.com/open-mmlab/mmdetection/blob/v2.22.0/mmdet/models/necks/dyhead.py#L130 | ||
Version: v2.22.0 | ||
|
||
Models: | ||
- Name: atss_r50_caffe_fpn_dyhead_1x_coco | ||
In Collection: DyHead | ||
Config: configs/dyhead/atss_r50_caffe_fpn_dyhead_1x_coco.py | ||
Metadata: | ||
Training Memory (GB): 5.4 | ||
inference time (ms/im): | ||
- value: 75.7 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (800, 1333) | ||
Epochs: 12 | ||
Results: | ||
- Task: Object Detection | ||
Dataset: COCO | ||
Metrics: | ||
box AP: 42.5 | ||
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth | ||
|
||
- Name: atss_r50_fpn_dyhead_1x_coco | ||
In Collection: DyHead | ||
Config: configs/dyhead/atss_r50_fpn_dyhead_1x_coco.py | ||
Metadata: | ||
Training Memory (GB): 4.9 | ||
inference time (ms/im): | ||
- value: 73.1 | ||
hardware: V100 | ||
backend: PyTorch | ||
batch size: 1 | ||
mode: FP32 | ||
resolution: (800, 1333) | ||
Epochs: 12 | ||
Results: | ||
- Task: Object Detection | ||
Dataset: COCO | ||
Metrics: | ||
box AP: 43.3 | ||
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.