Skip to content

Commit

Permalink
[Enhance] Update video read/write process in demos (#2192)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Apr 12, 2023
1 parent 1a9f3aa commit 21181f6
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 218 deletions.
70 changes: 38 additions & 32 deletions demo/bottomup_demo.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import tempfile
from argparse import ArgumentParser

import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.utils import track_iter_progress

from mmpose.apis import inference_bottomup, init_model
from mmpose.registry import VISUALIZERS
from mmpose.structures import split_instances


def process_one_image(args, img_path, pose_estimator, visualizer,
show_interval):
def process_one_image(args, img, pose_estimator, visualizer, show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""

# inference a single image
batch_results = inference_bottomup(pose_estimator, img_path)
batch_results = inference_bottomup(pose_estimator, img)
results = batch_results[0]

# show the results
img = mmcv.imread(img_path, channel_order='rgb')

out_file = None
if args.output_root:
out_file = f'{args.output_root}/{os.path.basename(img_path)}'
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)

visualizer.add_datasample(
'result',
Expand All @@ -38,8 +38,7 @@ def process_one_image(args, img_path, pose_estimator, visualizer,
show_kpt_idx=args.show_kpt_idx,
show=args.show,
wait_time=show_interval,
out_file=out_file,
kpt_score_thr=args.kpt_thr)
kpt_thr=args.kpt_thr)

return results.pred_instances

Expand Down Expand Up @@ -97,8 +96,11 @@ def main():
args = parse_args()
assert args.show or (args.output_root != '')
assert args.input != ''
output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
Expand Down Expand Up @@ -128,36 +130,40 @@ def main():
args, args.input, model, visualizer, show_interval=0)
pred_instances_list = split_instances(pred_instances)

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
video_reader = mmcv.VideoReader(args.input)
video_writer = None

pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
for frame_id, frame in enumerate(track_iter_progress(video_reader)):
pred_instances = process_one_image(
args,
f'{tmp_folder.name}/{img_fname}',
model,
visualizer,
show_interval=1)
progressbar.update()
args, frame, model, visualizer, show_interval=0.001)

pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
f'{output_root}/{os.path.basename(args.input)}',
fps=video.fps,
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()
if output_file:
frame_vis = visualizer.get_image()
if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(output_file, fourcc,
video_reader.fps,
(frame_vis.shape[1],
frame_vis.shape[0]))

video_writer.write(mmcv.rgb2bgr(frame_vis))

if video_writer:
video_writer.release()

else:
args.save_predictions = False
Expand Down
6 changes: 6 additions & 0 deletions demo/inferencer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def parse_args():
'--draw-bbox',
action='store_true',
help='Whether to draw the bounding boxes.')
parser.add_argument(
'--draw-heatmap',
action='store_true',
default=False,
help='Whether to draw the predicted heatmaps.')
parser.add_argument(
'--bbox-thr',
type=float,
Expand Down Expand Up @@ -104,6 +109,7 @@ def parse_args():
'det_weights', 'det_cat_ids'
]
init_args = {}
init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)

Expand Down
65 changes: 37 additions & 28 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mimetypes
import os
import tempfile
from argparse import ArgumentParser

import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmengine.utils import track_iter_progress

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
Expand All @@ -23,12 +24,12 @@
has_mmdet = False


def process_one_image(args, img_path, detector, pose_estimator, visualizer,
def process_one_image(args, img, detector, pose_estimator, visualizer,
show_interval):
"""Visualize predicted keypoints (and heatmaps) of one image."""

# predict bbox
det_result = inference_detector(detector, img_path)
det_result = inference_detector(detector, img)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
Expand All @@ -37,15 +38,14 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
bboxes = bboxes[nms(bboxes, args.nms_thr), :4]

# predict keypoints
pose_results = inference_topdown(pose_estimator, img_path, bboxes)
pose_results = inference_topdown(pose_estimator, img, bboxes)
data_samples = merge_data_samples(pose_results)

# show the results
img = mmcv.imread(img_path, channel_order='rgb')

out_file = None
if args.output_root:
out_file = f'{args.output_root}/{os.path.basename(img_path)}'
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
elif isinstance(img, np.ndarray):
img = mmcv.bgr2rgb(img)

visualizer.add_datasample(
'result',
Expand All @@ -58,7 +58,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
skeleton_style=args.skeleton_style,
show=args.show,
wait_time=show_interval,
out_file=out_file,
kpt_thr=args.kpt_thr)

# if there is no instance detected, return None
Expand Down Expand Up @@ -154,8 +153,11 @@ def main():
assert args.input != ''
assert args.det_config is not None
assert args.det_checkpoint is not None
output_file = None
if args.output_root:
mmengine.mkdir_or_exist(args.output_root)
output_file = os.path.join(args.output_root,
os.path.basename(args.input))
if args.save_predictions:
assert args.output_root != ''
args.pred_save_path = f'{args.output_root}/results_' \
Expand Down Expand Up @@ -196,38 +198,45 @@ def main():
show_interval=0)
pred_instances_list = split_instances(pred_instances)

if output_file:
img_vis = visualizer.get_image()
mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file)

elif input_type == 'video':
tmp_folder = tempfile.TemporaryDirectory()
video = mmcv.VideoReader(args.input)
progressbar = mmengine.ProgressBar(len(video))
video.cvt2frames(tmp_folder.name, show_progress=False)
output_root = args.output_root
args.output_root = tmp_folder.name
video_reader = mmcv.VideoReader(args.input)
video_writer = None

pred_instances_list = []

for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
for frame_id, frame in enumerate(track_iter_progress(video_reader)):
pred_instances = process_one_image(
args,
f'{tmp_folder.name}/{img_fname}',
frame,
detector,
pose_estimator,
visualizer,
show_interval=1)
show_interval=0.001)

progressbar.update()
pred_instances_list.append(
dict(
frame_id=frame_id,
instances=split_instances(pred_instances)))

if output_root:
mmcv.frames2video(
tmp_folder.name,
f'{output_root}/{os.path.basename(args.input)}',
fps=video.fps,
fourcc='mp4v',
show_progress=False)
tmp_folder.cleanup()
if output_file:
frame_vis = visualizer.get_image()
if video_writer is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# the size of the image with visualization may vary
# depending on the presence of heatmaps
video_writer = cv2.VideoWriter(output_file, fourcc,
video_reader.fps,
(frame_vis.shape[1],
frame_vis.shape[0]))

video_writer.write(mmcv.rgb2bgr(frame_vis))

if video_writer:
video_writer.release()

else:
args.save_predictions = False
Expand Down
3 changes: 1 addition & 2 deletions demo/webcam_api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
def parse_args():
parser = ArgumentParser('Webcam executor configs')
parser.add_argument(
'--config', type=str, default='demo/webcam_cfg/pose_estimation.py')

'--config', type=str, default='demo/webcam_cfg/human_pose.py')
parser.add_argument(
'--cfg-options',
nargs='+',
Expand Down
File renamed without changes.
102 changes: 102 additions & 0 deletions demo/webcam_cfg/human_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) OpenMMLab. All rights reserved.
executor_cfg = dict(
# Basic configurations of the executor
name='Pose Estimation',
camera_id=0,
# Define nodes.
# The configuration of a node usually includes:
# 1. 'type': Node class name
# 2. 'name': Node name
# 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the
# input and output buffer names. This may depend on the node class.
# 4. 'enable_key': assign a hot-key to toggle enable/disable this node.
# This may depend on the node class.
# 5. Other class-specific arguments
nodes=[
# 'DetectorNode':
# This node performs object detection from the frame image using an
# MMDetection model.
dict(
type='DetectorNode',
name='detector',
model_config='projects/rtmpose/rtmdet/person/'
'rtmdet_nano_320-8xb32_coco-person.py',
model_checkpoint='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth', # noqa
input_buffer='_input_', # `_input_` is an executor-reserved buffer
output_buffer='det_result'),
# 'TopdownPoseEstimatorNode':
# This node performs keypoint detection from the frame image using an
# MMPose top-down model. Detection results is needed.
dict(
type='TopdownPoseEstimatorNode',
name='human pose estimator',
model_config='projects/rtmpose/rtmpose/body_2d_keypoint/'
'rtmpose-t_8xb256-420e_coco-256x192.py',
model_checkpoint='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmpose/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth', # noqa
labels=['person'],
input_buffer='det_result',
output_buffer='human_pose'),
# 'ObjectAssignerNode':
# This node binds the latest model inference result with the current
# frame. (This means the frame image and inference result may be
# asynchronous).
dict(
type='ObjectAssignerNode',
name='object assigner',
frame_buffer='_frame_', # `_frame_` is an executor-reserved buffer
object_buffer='human_pose',
output_buffer='frame'),
# 'ObjectVisualizerNode':
# This node draw the pose visualization result in the frame image.
# Pose results is needed.
dict(
type='ObjectVisualizerNode',
name='object visualizer',
enable_key='v',
enable=True,
show_bbox=True,
must_have_keypoint=False,
show_keypoint=True,
input_buffer='frame',
output_buffer='vis'),
# 'NoticeBoardNode':
# This node show a notice board with given content, e.g. help
# information.
dict(
type='NoticeBoardNode',
name='instruction',
enable_key='h',
enable=True,
input_buffer='vis',
output_buffer='vis_notice',
content_lines=[
'This is a demo for pose visualization and simple image '
'effects. Have fun!', '', 'Hot-keys:',
'"v": Pose estimation result visualization',
'"h": Show help information',
'"m": Show diagnostic information', '"q": Exit'
],
),
# 'MonitorNode':
# This node show diagnostic information in the frame image. It can
# be used for debugging or monitoring system resource status.
dict(
type='MonitorNode',
name='monitor',
enable_key='m',
enable=False,
input_buffer='vis_notice',
output_buffer='display'),
# 'RecorderNode':
# This node save the output video into a file.
dict(
type='RecorderNode',
name='recorder',
out_video_file='webcam_api_demo.mp4',
input_buffer='display',
output_buffer='_display_'
# `_display_` is an executor-reserved buffer
)
])
Loading

0 comments on commit 21181f6

Please sign in to comment.