Skip to content

Commit

Permalink
[Feature] Support DyHead (#6823)
Browse files Browse the repository at this point in the history
* add DyHead

* move and update DYReLU

* update

* replace stack with sum to reduce memory

* clean and update

* update to align inference accuracy (incomplete)

* fix pad

* update to align training accuracy and pick #6867

* add README and metafile

* update docs

* resolve comments

* revert picking 6867

* update README.md

* update metafile.yml

* resolve comments and update urls
  • Loading branch information
shinya7y authored Feb 22, 2022
1 parent d556a86 commit f3b1647
Show file tree
Hide file tree
Showing 14 changed files with 604 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="configs/carafe">CARAFE (ICCV'2019)</a></li>
<li><a href="configs/fpg">FPG (ArXiv'2020)</a></li>
<li><a href="configs/groie">GRoIE (ICPR'2020)</a></li>
<li><a href="configs/dyhead">DyHead (CVPR'2021)</a></li>
</ul>
</td>
<td>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
<li><a href="configs/carafe">CARAFE (ICCV'2019)</a></li>
<li><a href="configs/fpg">FPG (ArXiv'2020)</a></li>
<li><a href="configs/groie">GRoIE (ICPR'2020)</a></li>
<li><a href="configs/dyhead">DyHead (CVPR'2021)</a></li>
</ul>
</td>
<td>
Expand Down
46 changes: 46 additions & 0 deletions configs/dyhead/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# DyHead

> [Dynamic Head: Unifying Object Detection Heads with Attentions](https://arxiv.org/abs/2106.08322)
<!-- [ALGORITHM] -->

## Abstract

The complex nature of combining localization and classification in object detection has resulted in the flourished development of methods. Previous works tried to improve the performance in various object detection heads but failed to present a unified view. In this paper, we present a novel dynamic head framework to unify object detection heads with attentions. By coherently combining multiple self-attention mechanisms between feature levels for scale-awareness, among spatial locations for spatial-awareness, and within output channels for task-awareness, the proposed approach significantly improves the representation ability of object detection heads without any computational overhead. Further experiments demonstrate that the effectiveness and efficiency of the proposed dynamic head on the COCO benchmark. With a standard ResNeXt-101-DCN backbone, we largely improve the performance over popular object detectors and achieve a new state-of-the-art at 54.0 AP. Furthermore, with latest transformer backbone and extra data, we can push current best COCO result to a new record at 60.6 AP.

<div align=center>
<img src="https://user-images.githubusercontent.com/42844407/149169448-fcafb6d0-b866-41cc-9422-94de9f1e1761.png" height="300"/>
</div>

## Results and Models

| Method | Backbone | Style | Setting | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
|:------:|:--------:|:-------:|:------------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:|
| ATSS | R-50 | caffe | reproduction | 1x | 5.4 | 13.2 | 42.5 | [config](./atss_r50_caffe_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939.log.json) |
| ATSS | R-50 | pytorch | simple | 1x | 4.9 | 13.7 | 43.3 | [config](./atss_r50_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314.log.json) |

- We trained the above models with 4 GPUs and 4 `samples_per_gpu`.
- The `reproduction` setting aims to reproduce the official implementation based on Detectron2.
- The `simple` setting serves as a minimum example to use DyHead in MMDetection. Specifically,
- it adds `DyHead` to `neck` after `FPN`
- it sets `stacked_convs=0` to `bbox_head`
- The `simple` setting achieves higher AP than the original implementation.
We have not conduct ablation study between the two settings.
`dict(type='Pad', size_divisor=128)` may further improve AP by prefer spatial alignment across pyramid levels, although large padding reduces efficiency.

## Relation to Other Methods

- DyHead can be regarded as an improved [SEPC](https://arxiv.org/abs/2005.03101) with [DyReLU modules](https://arxiv.org/abs/2003.10027) and simplified [SE blocks](https://arxiv.org/abs/1709.01507).
- Xiyang Dai et al., the author team of DyHead, adopt it for [Dynamic DETR](https://openaccess.thecvf.com/content/ICCV2021/html/Dai_Dynamic_DETR_End-to-End_Object_Detection_With_Dynamic_Attention_ICCV_2021_paper.html).
The description of Dynamic Encoder in Sec. 3.2 will help you understand DyHead.

## Citation

```latex
@inproceedings{DyHead_CVPR2021,
author = {Dai, Xiyang and Chen, Yinpeng and Xiao, Bin and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Zhang, Lei},
title = {Dynamic Head: Unifying Object Detection Heads With Attentions},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
112 changes: 112 additions & 0 deletions configs/dyhead/atss_r50_caffe_fpn_dyhead_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
dict(
type='DyHead',
in_channels=256,
out_channels=256,
num_blocks=6,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
],
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
pred_kernel_size=1, # follow DyHead official implementation
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128],
center_offset=0.5), # follow DyHead official implementation
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

# use caffe img_norm, size_divisor=128, pillow resize
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=(1333, 800),
keep_ratio=True,
backend='pillow'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True, backend='pillow'),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
65 changes: 65 additions & 0 deletions configs/dyhead/atss_r50_fpn_dyhead_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
dict(type='DyHead', in_channels=256, out_channels=256, num_blocks=6)
],
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
63 changes: 63 additions & 0 deletions configs/dyhead/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
Collections:
- Name: DyHead
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 4x T4 GPUs
Architecture:
- ATSS
- DyHead
- FPN
- ResNet
- Deformable Convolution
- Pyramid Convolution
Paper:
URL: https://arxiv.org/abs/2106.08322
Title: 'Dynamic Head: Unifying Object Detection Heads with Attentions'
README: configs/dyhead/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.22.0/mmdet/models/necks/dyhead.py#L130
Version: v2.22.0

Models:
- Name: atss_r50_caffe_fpn_dyhead_1x_coco
In Collection: DyHead
Config: configs/dyhead/atss_r50_caffe_fpn_dyhead_1x_coco.py
Metadata:
Training Memory (GB): 5.4
inference time (ms/im):
- value: 75.7
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (800, 1333)
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 42.5
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth

- Name: atss_r50_fpn_dyhead_1x_coco
In Collection: DyHead
Config: configs/dyhead/atss_r50_fpn_dyhead_1x_coco.py
Metadata:
Training Memory (GB): 4.9
inference time (ms/im):
- value: 73.1
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (800, 1333)
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 43.3
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth
17 changes: 13 additions & 4 deletions mmdet/models/dense_heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ATSSHead(AnchorHead):
def __init__(self,
num_classes,
in_channels,
pred_kernel_size=3,
stacked_convs=4,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
Expand All @@ -42,6 +43,7 @@ def __init__(self,
std=0.01,
bias_prob=0.01)),
**kwargs):
self.pred_kernel_size = pred_kernel_size
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand Down Expand Up @@ -85,15 +87,22 @@ def _init_layers(self):
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
pred_pad_size = self.pred_kernel_size // 2
self.atss_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.pred_kernel_size,
padding=pred_pad_size)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.feat_channels,
self.num_base_priors * 4,
self.pred_kernel_size,
padding=pred_pad_size)
self.atss_centerness = nn.Conv2d(
self.feat_channels, self.num_base_priors * 1, 3, padding=1)
self.feat_channels,
self.num_base_priors * 1,
self.pred_kernel_size,
padding=pred_pad_size)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])

Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/necks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .channel_mapper import ChannelMapper
from .ct_resnet_neck import CTResNetNeck
from .dilated_encoder import DilatedEncoder
from .dyhead import DyHead
from .fpg import FPG
from .fpn import FPN
from .fpn_carafe import FPN_CARAFE
Expand All @@ -18,5 +19,5 @@
__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN'
'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead'
]
Loading

0 comments on commit f3b1647

Please sign in to comment.