Skip to content

Commit

Permalink
Merge b163af0 into c6cd195
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-li-67 authored Apr 21, 2023
2 parents c6cd195 + b163af0 commit 9ef24e7
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 0 deletions.
17 changes: 17 additions & 0 deletions configs/face_2d_keypoint/topdown_regression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Top-down regression-based pose estimation

Top-down methods divide the task into two stages: object detection, followed by single-object pose estimation given object bounding boxes. At the 2nd stage, regression based methods directly regress the keypoint coordinates given the features extracted from the bounding box area, following the paradigm introduced in [Deeppose: Human pose estimation via deep neural networks](http://openaccess.thecvf.com/content_cvpr_2014/html/Toshev_DeepPose_Human_Pose_2014_CVPR_paper.html).

<div align=center>
<img src="https://user-images.githubusercontent.com/15977946/146515040-a82a8a29-d6bc-42f1-a2ab-7dfa610ce363.png">
</div>

## Results and Models

### WFLW Dataset

Result on WFLW test set

| Model | Input Size | NME<sub>*test*</sub> | NME<sub>*pose*</sub> | NME<sub>*illumination*</sub> | NME<sub>*occlusion*</sub> | NME<sub>*blur*</sub> | NME<sub>*makeup*</sub> | NME<sub>*expression*</sub> | ckpt | log |
| :--------- | :--------: | :------------------: | :------------------: | :--------------------------: | :-----------------------: | :------------------: | :--------------------: | :------------------------: | :--------: | :-------: |
| [ResNet-50](/configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_8x64e-210e_wflw-256x256.py) | 256x256 | 4.85 | 8.50 | 4.81 | 5.69 | 5.45 | 4.82 | 5.20 | [ckpt](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256-92d0ba7f_20210303.pth) | [log](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_20210303.log.json) |
58 changes: 58 additions & 0 deletions configs/face_2d_keypoint/topdown_regression/wflw/resnet_wflw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2014/html/Toshev_DeepPose_Human_Pose_2014_CVPR_paper.html">DeepPose (CVPR'2014)</a></summary>

```bibtex
@inproceedings{toshev2014deeppose,
title={Deeppose: Human pose estimation via deep neural networks},
author={Toshev, Alexander and Szegedy, Christian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={1653--1660},
year={2014}
}
```

</details>

<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html">ResNet (CVPR'2016)</a></summary>

```bibtex
@inproceedings{he2016deep,
title={Deep residual learning for image recognition},
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={770--778},
year={2016}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Wu_Look_at_Boundary_CVPR_2018_paper.html">WFLW (CVPR'2018)</a></summary>

```bibtex
@inproceedings{wu2018look,
title={Look at boundary: A boundary-aware face alignment algorithm},
author={Wu, Wayne and Qian, Chen and Yang, Shuo and Wang, Quan and Cai, Yici and Zhou, Qiang},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={2129--2138},
year={2018}
}
```

</details>

Results on WFLW dataset

The model is trained on WFLW train.

| Arch | Input Size | NME<sub>*test*</sub> | NME<sub>*pose*</sub> | NME<sub>*illumination*</sub> | NME<sub>*occlusion*</sub> | NME<sub>*blur*</sub> | NME<sub>*makeup*</sub> | NME<sub>*expression*</sub> | ckpt | log |
| :--------- | :--------: | :------------------: | :------------------: | :--------------------------: | :-----------------------: | :------------------: | :--------------------: | :------------------------: | :--------: | :-------: |
| [deeppose_res50](/configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_8x64e-210e_wflw-256x256.py) | 256x256 | 4.85 | 8.50 | 4.81 | 5.69 | 5.45 | 4.82 | 5.20 | [ckpt](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256-92d0ba7f_20210303.pth) | [log](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_20210303.log.json) |
21 changes: 21 additions & 0 deletions configs/face_2d_keypoint/topdown_regression/wflw/resnet_wflw.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Models:
- Config: configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_8x64e-210e_wflw-256x256.py
In Collection: ResNet
Metadata:
Architecture:
- DeepPose
- ResNet
Training Data: WFLW
Name: td-reg_res50_8x64e-210e_wflw-256x256
Results:
- Dataset: WFLW
Metrics:
NME blur: 5.45
NME expression: 5.2
NME illumination: 4.81
NME makeup: 4.82
NME occlusion: 5.69
NME pose: 8.5
NME test: 4.85
Task: Face 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256-92d0ba7f_20210303.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=1)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# codec settings
codec = dict(type='RegressionLabel', input_size=(256, 256))

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='RegressionHead',
in_channels=2048,
num_joints=98,
loss=dict(type='SmoothL1Loss', use_target_weight=True),
decoder=codec),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
shift_coords=True,
))

# base dataset settings
dataset_type = 'WFLWDataset'
data_mode = 'topdown'
data_root = 'data/wflw/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(
type='RandomBBoxTransform',
scale_factor=[0.75, 1.25],
rotate_factor=60),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# dataloaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/face_landmarks_wflw_train.json',
data_prefix=dict(img='images/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/face_landmarks_wflw_test.json',
data_prefix=dict(img='images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# hooks
default_hooks = dict(checkpoint=dict(save_best='NME', rule='greater'))

# evaluators
val_evaluator = dict(
type='NME',
norm_mode='keypoint_distance',
)
test_evaluator = val_evaluator

0 comments on commit 9ef24e7

Please sign in to comment.