Skip to content

Commit

Permalink
[Enhance] Support saving predictions in topdown demos (open-mmlab#1814)
Browse files Browse the repository at this point in the history
* support saving predictions in topdown demos

* update docs
  • Loading branch information
Ben-Louis authored and ly015 committed Feb 21, 2023
1 parent fe3c7c0 commit 5a6e726
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 27 deletions.
14 changes: 13 additions & 1 deletion demo/docs/2d_animal_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ python demo/topdown_demo_with_mmdet.py \
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} --det-cat-id ${DET_CAT_ID} \
[--show] [--output-root ${OUTPUT_DIR}] \
[--show] [--output-root ${OUTPUT_DIR}] [--save-predictions] \
[--draw-heatmap ${DRAW_HEATMAP}] [--radius ${KPT_RADIUS}] \
[--kpt-thr ${KPT_SCORE_THR}] [--bbox-thr ${BBOX_SCORE_THR}] \
[--device ${GPU_ID or CPU}]
Expand Down Expand Up @@ -53,6 +53,18 @@ python demo/topdown_demo_with_mmdet.py \
--output-root vis_results --draw-heatmap --det-cat-id=15
```

To save predicted results on disk:

```shell
python demo/topdown_demo_with_mmdet.py \
demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py \
https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
configs/animal_2d_keypoint/topdown_heatmap/animalpose/td-hm_hrnet-w32_8xb64-210e_animalpose-256x256.py \
https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth \
--input tests/data/animalpose/ca110.jpeg \
--output-root vis_results --save-predictions --draw-heatmap --det-cat-id=15
```

To run demos on CPU:

```shell
Expand Down
4 changes: 3 additions & 1 deletion demo/docs/2d_face_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ For more details, please refer to [face_recognition](https://github.com/ageitgey
python demo/topdown_face_demo.py \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} [--output-root ${OUTPUT_DIR}] \
[--show] [--device ${GPU_ID or CPU}] \
[--show] [--device ${GPU_ID or CPU}] [--save-predictions] \
[--draw-heatmap ${DRAW_HEATMAP}] [--radius ${KPT_RADIUS}] \
[--kpt-thr ${KPT_SCORE_THR}]
```
Expand Down Expand Up @@ -46,6 +46,8 @@ python demo/topdown_face_demo.py \
--draw-heatmap --output-root vis_results
```

To save the predicted results on disk, please specify `--save-predictions`.

To run demos on CPU:

```shell
Expand Down
4 changes: 3 additions & 1 deletion demo/docs/2d_hand_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ python demo/topdown_demo_with_mmdet.py \
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} [--output-root ${OUTPUT_DIR}] \
[--show] [--device ${GPU_ID or CPU}] \
[--show] [--device ${GPU_ID or CPU}] [--save-predictions] \
[--draw-heatmap ${DRAW_HEATMAP}] [--radius ${KPT_RADIUS}] \
[--kpt-thr ${KPT_SCORE_THR}] [--bbox-thr ${BBOX_SCORE_THR}]

Expand Down Expand Up @@ -48,6 +48,8 @@ python demo/topdown_demo_with_mmdet.py \
--output-root vis_results --show --draw-heatmap
```

To save the predicted results on disk, please specify `--save-predictions`.

To run demos on CPU:

```shell
Expand Down
8 changes: 5 additions & 3 deletions demo/docs/2d_human_pose_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ python demo/topdown_demo_with_mmdet.py \
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} \
--output-root ${OUTPUT_DIR} \
[--show --draw-heatmap --device ${GPU_ID or CPU}] \
[--bbox-thr ${BBOX_SCORE_THR} --kpt-thr ${KPT_SCORE_THR}]
[--output-root ${OUTPUT_DIR}] [--save-predictions] \
[--show] [--draw-heatmap] [--device ${GPU_ID or CPU}] \
[--bbox-thr ${BBOX_SCORE_THR}] [--kpt-thr ${KPT_SCORE_THR}]
```

Example:
Expand All @@ -78,6 +78,8 @@ Visualization result:

<img src="https://user-images.githubusercontent.com/87690686/187824368-1f1631c3-52bf-4b45-bf9a-a70cd6551e1a.jpg" height="500px" alt><br>

To save the predicted results on disk, please specify `--save-predictions`.

### 2D Human Pose Top-Down Video Demo

The above demo script can also take video as input, and run mmdet for human detection, and mmpose for pose estimation. The difference is, the `${INPUT_PATH}` for videos can be the local path or **URL** link to video file.
Expand Down
8 changes: 5 additions & 3 deletions demo/docs/2d_wholebody_pose_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ python demo/topdown_demo_with_mmdet.py \
${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} \
--output-root ${OUTPUT_DIR} \
[--show --draw-heatmap --device ${GPU_ID or CPU}] \
[--bbox-thr ${BBOX_SCORE_THR} --kpt-thr ${KPT_SCORE_THR}]
[--output-root ${OUTPUT_DIR}] [--save-predictions] \
[--show] [--draw-heatmap] [--device ${GPU_ID or CPU}] \
[--bbox-thr ${BBOX_SCORE_THR}] [--kpt-thr ${KPT_SCORE_THR}]
```

Examples:
Expand All @@ -64,6 +64,8 @@ python demo/topdown_demo_with_mmdet.py \
--output-root vis_results/ --show
```

To save the predicted results on disk, please specify `--save-predictions`.

### 2D Human Whole-Body Pose Top-Down Video Demo

The above demo script can also take video as input, and run mmdet for human detection, and mmpose for pose estimation.
Expand Down
44 changes: 37 additions & 7 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from argparse import ArgumentParser

import json_tricks as json
import mmcv
import mmengine
import numpy as np
Expand All @@ -12,7 +13,7 @@
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import register_all_modules as register_mmpose_modules

try:
Expand All @@ -23,9 +24,9 @@
has_mmdet = False


def visualize_img(args, img_path, detector, pose_estimator, visualizer,
show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""
def infer_and_visualize_image(args, img_path, detector, pose_estimator,
visualizer, show_interval):
"""Predict the keypoints of one image, and visualize the results."""

# predict bbox
register_mmdet_modules()
Expand Down Expand Up @@ -61,6 +62,8 @@ def visualize_img(args, img_path, detector, pose_estimator, visualizer,
out_file=out_file,
kpt_score_thr=args.kpt_thr)

return data_samples.pred_instances


def main():
"""Visualize the demo images.
Expand All @@ -85,6 +88,11 @@ def main():
default='',
help='root of the output img file. '
'Default not saving the visualization images.')
parser.add_argument(
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
Expand Down Expand Up @@ -132,6 +140,10 @@ def main():
assert args.det_checkpoint is not None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# build detector
register_mmdet_modules()
Expand All @@ -157,29 +169,41 @@ def main():

input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
if input_type == 'image':
visualize_img(
pred_instances = infer_and_visualize_image(
args,
args.input,
detector,
pose_estimator,
visualizer,
show_interval=0)
if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(split_instances(pred_instances), f, indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
for img_fname in os.listdir(tmp_folder.name):
visualize_img(
pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
pred_instances = infer_and_visualize_image(
args,
f'{tmp_folder.name}/{img_fname}',
detector,
pose_estimator,
visualizer,
show_interval=1)
progressbar.update()
pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
Expand All @@ -188,6 +212,12 @@ def main():
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()

if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(pred_instances_list, f, indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')

else:
raise ValueError(
f'file {os.path.basename(args.input)} has invalid format.')
Expand Down
43 changes: 37 additions & 6 deletions demo/topdown_face_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from argparse import ArgumentParser

import json_tricks as json
import mmcv
import mmengine
import numpy as np
Expand All @@ -12,7 +13,7 @@
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import register_all_modules as register_mmpose_modules

try:
Expand All @@ -38,8 +39,9 @@ def process_face_det_results(face_det_results):
return person_results


def visualize_img(args, img_path, pose_estimator, visualizer, show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""
def infer_and_visualize_image(args, img_path, pose_estimator, visualizer,
show_interval):
"""Predict the keypoints of one image, and visualize the results."""

# predict bbox
image = face_recognition.load_image_file(img_path)
Expand Down Expand Up @@ -73,6 +75,8 @@ def visualize_img(args, img_path, pose_estimator, visualizer, show_interval):
out_file=out_file,
kpt_score_thr=args.kpt_thr)

return data_samples.pred_instances


def main():
"""Visualize the demo images.
Expand All @@ -95,6 +99,11 @@ def main():
default='',
help='root of the output img file. '
'Default not saving the visualization images.')
parser.add_argument(
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
Expand Down Expand Up @@ -132,6 +141,10 @@ def main():
assert args.input != ''
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
f'{os.path.splitext(os.path.basename(args.input))[0]}.json'

# build pose estimator
register_mmpose_modules()
Expand All @@ -153,23 +166,35 @@ def main():

input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
if input_type == 'image':
visualize_img(
pred_instances = infer_and_visualize_image(
args, args.input, pose_estimator, visualizer, show_interval=0)
if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(split_instances(pred_instances), f, indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
for img_fname in os.listdir(tmp_folder.name):
visualize_img(
pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
pred_instances = infer_and_visualize_image(
args,
f'{tmp_folder.name}/{img_fname}',
pose_estimator,
visualizer,
show_interval=1)
progressbar.update()
pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
Expand All @@ -178,6 +203,12 @@ def main():
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()

if args.save_predictions:
with open(args.pred_save_path, 'w') as f:
json.dump(pred_instances_list, f, indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')

else:
raise ValueError(
f'file {os.path.basename(args.input)} has invalid format.')
Expand Down
6 changes: 3 additions & 3 deletions demo/webcam_cfg/pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@
enable=False,
input_buffer='vis',
output_buffer='vis_sunglasses'),
# # 'BigeyeEffectNode':
# # This node draw the big-eye effetc in the frame image.
# # Pose results is needed.
# 'BigeyeEffectNode':
# This node draw the big-eye effetc in the frame image.
# Pose results is needed.
dict(
type='BigeyeEffectNode',
name='big-eye',
Expand Down
4 changes: 2 additions & 2 deletions mmpose/structures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from .keypoint import flip_keypoints
from .multilevel_pixel_data import MultilevelPixelData
from .pose_data_sample import PoseDataSample
from .utils import merge_data_samples, revert_heatmap
from .utils import merge_data_samples, revert_heatmap, split_instances

__all__ = [
'PoseDataSample', 'MultilevelPixelData', 'bbox_cs2xywh', 'bbox_cs2xyxy',
'bbox_xywh2cs', 'bbox_xywh2xyxy', 'bbox_xyxy2cs', 'bbox_xyxy2xywh',
'flip_bbox', 'get_udp_warp_matrix', 'get_warp_matrix', 'flip_keypoints',
'merge_data_samples', 'revert_heatmap'
'merge_data_samples', 'revert_heatmap', 'split_instances'
]
19 changes: 19 additions & 0 deletions mmpose/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,22 @@ def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape):
heatmap = heatmap.transpose(2, 0, 1)

return heatmap


def split_instances(instances: InstanceData):
"""Convert instances into a list where each element is a dict that contains
information about one instance."""
results = []

for i in range(len(instances.keypoints)):
result = dict(
keypoints=instances.keypoints[i].tolist(),
keypoint_scores=instances.keypoint_scores[i].tolist(),
)
if 'bboxes' in instances:
result['bbox'] = instances.bboxes[i].tolist(),
if 'bbox_scores' in instances:
result['bbox_score'] = instances.bbox_scores[i]
results.append(result)

return results

0 comments on commit 5a6e726

Please sign in to comment.