Skip to content

Commit

Permalink
Merge pull request #2299 from open-mmlab/dev-1.x
Browse files Browse the repository at this point in the history
Dev 1.x
  • Loading branch information
Tau-J authored Apr 24, 2023
2 parents 5fe1a0f + aa89a7c commit 9dda210
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 9 deletions.
35 changes: 35 additions & 0 deletions .github/ISSUE_TEMPLATE/3-documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: 📚 Documentation
description: Report an issue related to the documentation.
labels: "docs"
title: "[Docs] "

body:
- type: markdown
attributes:
value: |
## Note
For general usage questions or idea discussions, please post it to our [**Forum**](https://github.com/open-mmlab/mmpose/discussions)
Please fill in as **much** of the following form as you're able to. **The clearer the description, the shorter it will take to solve it.**
- type: textarea
attributes:
label: 📚 The doc issue
description: >
A clear and concise description the issue.
validations:
required: true

- type: textarea
attributes:
label: Suggest a potential alternative/fix
description: >
Tell us how we could improve the documentation in this regard.
- type: markdown
attributes:
value: |
## Acknowledgement
Thanks for taking the time to fill out this report.
If you have already identified the reason, we strongly appreciate you creating a new PR to fix it [**here**](https://github.com/open-mmlab/mmpose/pulls)!
Please refer to [**Contribution Guide**](https://mmpose.readthedocs.io/en/latest/contribution_guide.html) for contributing.
9 changes: 0 additions & 9 deletions .github/ISSUE_TEMPLATE/3-general_questions.yml

This file was deleted.

1 change: 1 addition & 0 deletions configs/face_2d_keypoint/topdown_regression/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Result on WFLW test set
| Model | Input Size | NME | ckpt | log |
| :-------------------------------------------------------------- | :--------: | :--: | :------------------------------------------------------------: | :-----------------------------------------------------------: |
| [ResNet-50](/configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_8xb64-210e_wflw-256x256.py) | 256x256 | 4.88 | [ckpt](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256-92d0ba7f_20210303.pth) | [log](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_20210303.log.json) |
| [ResNet-50+SoftWingLoss](/configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_softwingloss_8xb64-210e_wflw-256x256.py) | 256x256 | 4.67 | [ckpt](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_softwingloss-4d34f22a_20211212.pth) | [log](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_softwingloss_20211212.log.json) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2014/html/Toshev_DeepPose_Human_Pose_2014_CVPR_paper.html">DeepPose (CVPR'2014)</a></summary>

```bibtex
@inproceedings{toshev2014deeppose,
title={Deeppose: Human pose estimation via deep neural networks},
author={Toshev, Alexander and Szegedy, Christian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={1653--1660},
year={2014}
}
```

</details>

<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html">ResNet (CVPR'2016)</a></summary>

```bibtex
@inproceedings{he2016deep,
title={Deep residual learning for image recognition},
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={770--778},
year={2016}
}
```

</details>

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://ieeexplore.ieee.org/document/9442331/">SoftWingloss (TIP'2021)</a></summary>

```bibtex
@article{lin2021structure,
title={Structure-Coherent Deep Feature Learning for Robust Face Alignment},
author={Lin, Chunze and Zhu, Beier and Wang, Quan and Liao, Renjie and Qian, Chen and Lu, Jiwen and Zhou, Jie},
journal={IEEE Transactions on Image Processing},
year={2021},
publisher={IEEE}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_cvpr_2018/html/Wu_Look_at_Boundary_CVPR_2018_paper.html">WFLW (CVPR'2018)</a></summary>

```bibtex
@inproceedings{wu2018look,
title={Look at boundary: A boundary-aware face alignment algorithm},
author={Wu, Wayne and Qian, Chen and Yang, Shuo and Wang, Quan and Cai, Yici and Zhou, Qiang},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={2129--2138},
year={2018}
}
```

</details>

Results on WFLW dataset

The model is trained on WFLW train set.

| Model | Input Size | NME | ckpt | log |
| :-------------------------------------------------------------- | :--------: | :--: | :------------------------------------------------------------: | :-----------------------------------------------------------: |
| [ResNet-50+SoftWingLoss](/configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_softwingloss_8xb64-210e_wflw-256x256.py) | 256x256 | 4.44 | [ckpt](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_softwingloss-4d34f22a_20211212.pth) | [log](https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_softwingloss_20211212.log.json) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Models:
- Config: configs/face_2d_keypoint/topdown_regression/wflw/td-reg_res50_softwingloss_8xb64-210e_wflw-256x256.py
In Collection: ResNet
Metadata:
Architecture:
- DeepPose
- ResNet
- SoftWingloss
Training Data: WFLW
Name: td-reg_res50_softwingloss_8xb64-210e_wflw-256x256
Results:
- Dataset: WFLW
Metrics:
NME: 4.44
Task: Face 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/face/deeppose/deeppose_res50_wflw_256x256_softwingloss-4d34f22a_20211212.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# codec settings
codec = dict(type='RegressionLabel', input_size=(256, 256))

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='RegressionHead',
in_channels=2048,
num_joints=98,
loss=dict(type='SoftWingLoss', use_target_weight=True),
decoder=codec),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
shift_coords=True,
))

# base dataset settings
dataset_type = 'WFLWDataset'
data_mode = 'topdown'
data_root = 'data/wflw/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(
type='RandomBBoxTransform',
scale_factor=[0.75, 1.25],
rotate_factor=60),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# dataloaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/face_landmarks_wflw_train.json',
data_prefix=dict(img='images/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/face_landmarks_wflw_test.json',
data_prefix=dict(img='images/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# hooks
default_hooks = dict(checkpoint=dict(save_best='NME', rule='less'))

# evaluators
val_evaluator = dict(
type='NME',
norm_mode='keypoint_distance',
)
test_evaluator = val_evaluator

0 comments on commit 9dda210

Please sign in to comment.