Skip to content

Commit

Permalink
[Feature] Support YOLO-Pose (#2020)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Mar 22, 2023
1 parent 45486ea commit de58466
Show file tree
Hide file tree
Showing 21 changed files with 1,765 additions and 13 deletions.
9 changes: 7 additions & 2 deletions demo/inferencer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def parse_args():
nargs='+',
default=None,
help='Category id for detection model.')
parser.add_argument(
'--scope',
type=str,
default='mmpose',
help='Scope where modules are defined.')
parser.add_argument(
'--device',
type=str,
Expand Down Expand Up @@ -83,8 +88,8 @@ def parse_args():
call_args = vars(parser.parse_args())

init_kws = [
'pose2d', 'pose2d_weights', 'device', 'det_model', 'det_weights',
'det_cat_ids'
'pose2d', 'pose2d_weights', 'scope', 'device', 'det_model',
'det_weights', 'det_cat_ids'
]
init_args = {}
for init_kw in init_kws:
Expand Down
24 changes: 13 additions & 11 deletions projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,22 @@ We also provide some documentation listed below to help you get started:

## Project List

- [MMPose4AIGC](./mmpose4aigc)
- **[:zap:RTMPose](./rtmpose)**: Real-Time Multi-Person Pose Estimation toolkit based on MMPose

This project will demonstrate how to use MMPose to generate skeleton images for pose guided AI image generation.
<div align="center">
<img src="https://user-images.githubusercontent.com/15977946/225229448-36ff568d-a723-4248-bb19-2df4044ff8e8.png" width=800 height=200 />
</div><br/>

<div align=center>
<img src="https://user-images.githubusercontent.com/13503330/222403836-c65ba905-4bdd-4a44-834c-ff8d5959649d.png" width=1000 height=200/>
</div>
- **[:art:MMPose4AIGC](./mmpose4aigc)**: Guide AI image generation with MMPose

- [RTMPose](./rtmpose)
<div align=center>
<img src="https://user-images.githubusercontent.com/13503330/222403836-c65ba905-4bdd-4a44-834c-ff8d5959649d.png" width="800"/>
</div><br/>

Real-Time Multi-Person Pose Estimation toolkit based on MMPose
- **[:bulb:YOLOX-Pose](./yolox-pose)**: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss

<div align="center">
<img width=1000 height=200 src="https://user-images.githubusercontent.com/15977946/225229448-36ff568d-a723-4248-bb19-2df4044ff8e8.png"/>
</div>
<div align=center>
<img src="https://user-images.githubusercontent.com/26127467/226655503-3cee746e-6e42-40be-82ae-6e7cae2a4c7e.jpg" width="800" style="width: 800px; height: 200px; object-fit: cover"/>
</div><br/>

- **And we can't wait to see what you contribute next!**
- **What's next? Join the rank of <span style="color:blue"> *MMPose contributors* </span> by creating a new project**!
124 changes: 124 additions & 0 deletions projects/yolox-pose/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# YOLOX-Pose

This project implements a YOLOX-based human pose estimator, utilizing the approach outlined in **YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss** (CVPRW 2022). This pose estimator is lightweight and quick, making it well-suited for crowded scenes.

<img src="https://user-images.githubusercontent.com/26127467/226655503-3cee746e-6e42-40be-82ae-6e7cae2a4c7e.jpg" alt><br>

## Usage

### Prerequisites

- Python 3.7 or higher
- PyTorch 1.6 or higher
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.6.0 or higher
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 or higher
- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0rc6 or higher
- [MMYOLO](https://github.com/open-mmlab/mmyolo) v0.5.0 or higher
- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0rc1 or higher

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `yolox-pose/` root directory, run the following line to add the current directory to `PYTHONPATH`:

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

### Inference

Users can apply YOLOX-Pose models to estimate human poses using the inferencer found in the MMPose core package. Use the command below:

```shell
python demo/inferencer_demo.py $INPUTS \
--pose2d $CONFIG --pose2d-weights $CHECKPOINT --scope mmyolo \
[--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR]
```

For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/1.x/user_guides/inference.html#out-of-the-box-inferencer).

Here's an example code:

```shell
python demo/inferencer_demo.py ../../tests/data/coco/000000000785.jpg \
--pose2d configs/yolox-pose_s_8xb32-300e_coco.py \
--pose2d-weights https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco-9f5e3924_20230321.pth \
--scope mmyolo --vis-out-dir vis_results
```

This will create an output image `vis_results/000000000785.jpg`, which appears like:

<img src="https://user-images.githubusercontent.com/26127467/226552585-19b91294-9751-4599-98e7-5dae071a1761.jpg" height="360px" alt><br>

### Training & Testing

#### Data Preparation

Prepare the COCO dataset according to the [instruction](https://mmpose.readthedocs.io/en/1.x/dataset_zoo/2d_body_keypoint.html#coco).

#### Commands

**To train with multiple GPUs:**

```shell
bash tools/dist_train.sh $CONFIG 8 --amp
```

**To train with slurm:**

```shell
bash tools/slurm_train.sh $PARTITION $JOBNAME $CONFIG $WORKDIR --amp
```

**To test with single GPU:**

```shell
python tools/test.py $CONFIG $CHECKPOINT
```

**To test with multiple GPUs:**

```shell
bash tools/dist_test.sh $CONFIG $CHECKPOINT 8
```

**To test with multiple GPUs by slurm:**

```shell
bash tools/slurm_test.sh $PARTITION $JOBNAME $CONFIG $CHECKPOINT
```

### Results

Results on COCO val2017

| Model | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | Download |
| :-------------------------------------------------------------: | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :----------------------------------------------------------------------: |
| [YOLOX-tiny-Pose](./configs/yolox-pose_tiny_4xb64-300e_coco.py) | 640 | 0.477 | 0.756 | 0.506 | 0.547 | 0.802 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco-c47dd83b_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco_20230321.json) |
| [YOLOX-s-Pose](./configs/yolox-pose_s_8xb32-300e_coco.py) | 640 | 0.595 | 0.836 | 0.653 | 0.658 | 0.878 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco-9f5e3924_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco_20230321.json) |
| [YOLOX-m-Pose](./configs/yolox-pose_m_4xb64-300e_coco.py) | 640 | 0.659 | 0.870 | 0.729 | 0.713 | 0.903 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco-cbd11d30_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco_20230321.json) |
| [YOLOX-l-Pose](./configs/yolox-pose_l_4xb64-300e_coco.py) | 640 | 0.679 | 0.882 | 0.749 | 0.733 | 0.911 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco-122e4cf8_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco_20230321.json) |

We have only trained models with an input size of 640, as we couldn't replicate the performance enhancement mentioned in the paper when increasing the input size from 640 to 960. We warmly welcome any contributions if you can successfully reproduce the results from the paper!

## Citation

If this project benefits your work, please kindly consider citing the original paper:

```bibtex
@inproceedings{maji2022yolo,
title={YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object Keypoint Similarity Loss},
author={Maji, Debapriya and Nagori, Soyeb and Mathew, Manu and Poddar, Deepak},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={2637--2646},
year={2022}
}
```

Additionally, please cite our work as well:

```bibtex
@misc{mmpose2020,
title={OpenMMLab Pose Estimation Toolbox and Benchmark},
author={MMPose Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmpose}},
year={2020}
}
```
1 change: 1 addition & 0 deletions projects/yolox-pose/configs/_base_/datasets
41 changes: 41 additions & 0 deletions projects/yolox-pose/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
default_scope = 'mmyolo'
custom_imports = dict(imports=['models', 'datasets'])

# hooks
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='mmpose.PoseVisualizationHook', enable=False),
)

# multi-processing backend
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)

# visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='mmpose.PoseLocalVisualizer',
vis_backends=vis_backends,
name='visualizer')

# logger
log_processor = dict(
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
log_level = 'INFO'
load_from = None
resume = False

# file I/O backend
file_client_args = dict(backend='disk')

# training/validation/testing progress
train_cfg = dict()
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
18 changes: 18 additions & 0 deletions projects/yolox-pose/configs/yolox-pose_l_4xb64-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./yolox-pose_s_8xb32-300e_coco.py']

# model settings
model = dict(
init_cfg=dict(checkpoint='https://download.openmmlab.com/mmyolo/v0/yolox/'
'yolox_l_fast_8xb8-300e_coco/yolox_l_fast_8xb8-300e_'
'coco_20230213_160715-c731eb1c.pth'),
backbone=dict(
deepen_factor=1.0,
widen_factor=1.0,
),
neck=dict(
deepen_factor=1.0,
widen_factor=1.0,
),
bbox_head=dict(head_module=dict(widen_factor=1.0)))

train_dataloader = dict(batch_size=64)
18 changes: 18 additions & 0 deletions projects/yolox-pose/configs/yolox-pose_m_4xb64-300e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./yolox-pose_s_8xb32-300e_coco.py']

# model settings
model = dict(
init_cfg=dict(checkpoint='https://download.openmmlab.com/mmyolo/v0/yolox/'
'yolox_m_fast_8xb32-300e-rtmdet-hyp_coco/yolox_m_fast_8xb32'
'-300e-rtmdet-hyp_coco_20230210_144328-e657e182.pth'),
backbone=dict(
deepen_factor=0.67,
widen_factor=0.75,
),
neck=dict(
deepen_factor=0.67,
widen_factor=0.75,
),
bbox_head=dict(head_module=dict(widen_factor=0.75)))

train_dataloader = dict(batch_size=64)
Loading

0 comments on commit de58466

Please sign in to comment.