Skip to content

Commit

Permalink
support edpose
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuYi-Up committed Sep 13, 2023
1 parent 85f0e78 commit 537f084
Show file tree
Hide file tree
Showing 21 changed files with 3,668 additions and 9 deletions.
59 changes: 59 additions & 0 deletions configs/body_2d_keypoint/edpose/coco/edpose_coco.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/pdf/2302.01593.pdf">ED-Pose (ICLR'2023)</a></summary>

```bibtex
@inproceedings{
yang2023explicit,
title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation},
author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=s4WVupnJjmX}
}
```

</details>

<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html">ResNet (CVPR'2016)</a></summary>

```bibtex
@inproceedings{he2016deep,
title={Deep residual learning for image recognition},
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={770--778},
year={2016}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>

```bibtex
@inproceedings{lin2014microsoft,
title={Microsoft coco: Common objects in context},
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
booktitle={European conference on computer vision},
pages={740--755},
year={2014},
organization={Springer}
}
```

</details>

Results on COCO val2017

| Arch | BackBone | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :-------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------------------------------------------: | :-------------------------------------------: |
| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py) | ResNet-50 | 0.716 | 0.897 | 0.783 | 0.793 | 0.943 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) |
| | | | | | | | | |
25 changes: 25 additions & 0 deletions configs/body_2d_keypoint/edpose/coco/edpose_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Collections:
- Name: ED-Pose
Paper:
Title: Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation
URL: https://arxiv.org/pdf/2302.01593.pdf
README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/edpose.md
Models:
- Config: configs/body_2d_keypoint/edpose/coco/edpose_resnet50_coco.py
In Collection: ED-Pose
Metadata:
Architecture: &id001
- ED-Pose
- ResNet
Training Data: COCO
Name: edpose_resnet50_coco
Results:
- Dataset: COCO
Metrics:
AP: 0.716
[email protected]: 0.897
[email protected]: 0.783
AR: 0.793
[email protected]: 0.943
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth
217 changes: 217 additions & 0 deletions configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=140, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=1e-3,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=140,
milestones=[90, 120],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=80)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
type='EDPoseLabel', num_select=50, num_body_points=17, not_to_xyxy=False)

# model settings
model = dict(
type='BottomupPoseEstimator',
data_preprocessor=dict(
type='BatchShapeDataPreprocessor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
bgr_to_rgb=True,
pad_size_divisor=1,
normalize_bakend='pillow'),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='FrozenBatchNorm2d', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
head=dict(
type='EDPoseHead',
num_queries=900,
num_feature_levels=4,
num_body_points=17,
as_two_stage=True,
encoder=dict(
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.0))),
decoder=dict(
num_layers=6,
embed_dims=256,
layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
batch_first=True),
cross_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=2048, ffn_drop=0.1)),
query_dim=4,
num_feature_levels=4,
num_group=100,
num_dn=100,
num_box_decoder_layers=2,
return_intermediate=True),
out_head=dict(num_classes=2),
positional_encoding=dict(
num_pos_feats=128,
temperatureH=20,
temperatureW=20,
normalize=True),
denosing_cfg=dict(
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
dn_attn_mask_type_list=['match2dn', 'dn2dn', 'group2group']),
data_decoder=codec),
test_cfg=dict(Pmultiscale_test=False, flip_test=False, num_select=50),
train_cfg=dict())

# enable DDP training when rescore net is used
find_unused_parameters = True

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'bottomup'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='RandomFlip', direction='horizontal'),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='BottomupRandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='BottomupRandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='BottomupRandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='PackPoseInputs'),
]

val_pipeline = [
dict(type='LoadImage', imdecode_backend='pillow'),
dict(
type='BottomupRandomChoiceResize',
scales=[(800, 1333)],
keep_ratio=True,
backend='pillow'),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

# data loaders
train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_val2017.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
nms_mode='none',
score_mode='keypoint',
)
test_evaluator = val_evaluator
31 changes: 31 additions & 0 deletions docs/src/papers/algorithms/edpose.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/pdf/2302.01593.pdf">ED-Pose (ICLR'2023)</a></summary>

```bibtex
@inproceedings{
yang2023explicit,
title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation},
author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=s4WVupnJjmX}
}
```

</details>

## Abstract

<!-- [ABSTRACT] -->

This paper presents a novel end-to-end framework with Explicit box Detection for multi-person Pose estimation, called ED-Pose, where it unifies the contextual learning between human-level (global) and keypoint-level (local) information. Different from previous one-stage methods, ED-Pose re-considers this task as two explicit box detection processes with a unified representation and regression supervision. First, we introduce a human detection decoder from encoded tokens to extract global features. It can provide a good initialization for the latter keypoint detection, making the training process converge fast. Second, to bring in contextual information near keypoints, we regard pose estimation as a keypoint box detection problem to learn both box positions and contents for each keypoint. A human-to-keypoint detection decoder adopts an interactive learning strategy between human and keypoint features to further enhance global and local feature aggregation. In general, ED-Pose is conceptually simple without post-processing and dense heatmap supervision. It demonstrates its effectiveness and efficiency compared with both two-stage and one-stage methods. Notably, explicit box detection boosts the pose estimation performance by 4.5 AP on COCO and 9.9 AP on CrowdPose. For the first time, as a fully end-to-end framework with a L1 regression loss, ED-Pose surpasses heatmap-based Top-down methods under the same backbone by 1.2 AP on COCO and achieves the state-of-the-art with 76.6 AP on CrowdPose without bells and whistles. Code is available at https://github.com/IDEA-Research/ED-Pose.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/IDEA-Research/ED-Pose/raw/master/figs/edpose_git.jpg">
</div>
3 changes: 2 additions & 1 deletion mmpose/codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .associative_embedding import AssociativeEmbedding
from .decoupled_heatmap import DecoupledHeatmap
from .edpose_label import EDPoseLabel
from .image_pose_lifting import ImagePoseLifting
from .integral_regression_label import IntegralRegressionLabel
from .megvii_heatmap import MegviiHeatmap
Expand All @@ -16,5 +17,5 @@
'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR',
'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting',
'MotionBERTLabel'
'MotionBERTLabel', 'EDPoseLabel'
]
Loading

0 comments on commit 537f084

Please sign in to comment.