From 21181f67d9b6b1a34ecca34ab4f05aad3c7f5035 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Wed, 12 Apr 2023 17:11:14 +0800 Subject: [PATCH] [Enhance] Update video read/write process in demos (#2192) --- demo/bottomup_demo.py | 70 ++++---- demo/inferencer_demo.py | 6 + demo/topdown_demo_with_mmdet.py | 65 ++++---- demo/webcam_api_demo.py | 3 +- ...ose_estimation.py => human_animal_pose.py} | 0 demo/webcam_cfg/human_pose.py | 102 ++++++++++++ docs/en/user_guides/inference.md | 4 + .../inferencers/base_mmpose_inferencer.py | 154 ++++++++---------- mmpose/apis/inferencers/mmpose_inferencer.py | 51 +++--- mmpose/apis/inferencers/pose2d_inferencer.py | 101 +++++++----- .../apis/inferencers/utils/get_model_alias.py | 4 +- projects/yolox-pose/README.md | 4 +- 12 files changed, 346 insertions(+), 218 deletions(-) rename demo/webcam_cfg/{pose_estimation.py => human_animal_pose.py} (100%) create mode 100644 demo/webcam_cfg/human_pose.py diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index 8d27a17178..47d6f7ccf4 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -1,32 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. import mimetypes import os -import tempfile from argparse import ArgumentParser +import cv2 import json_tricks as json import mmcv import mmengine +import numpy as np +from mmengine.utils import track_iter_progress from mmpose.apis import inference_bottomup, init_model from mmpose.registry import VISUALIZERS from mmpose.structures import split_instances -def process_one_image(args, img_path, pose_estimator, visualizer, - show_interval): +def process_one_image(args, img, pose_estimator, visualizer, show_interval): """Visualize predicted keypoints (and heatmaps) of one image.""" # inference a single image - batch_results = inference_bottomup(pose_estimator, img_path) + batch_results = inference_bottomup(pose_estimator, img) results = batch_results[0] # show the results - img = mmcv.imread(img_path, channel_order='rgb') - - out_file = None - if args.output_root: - out_file = f'{args.output_root}/{os.path.basename(img_path)}' + if isinstance(img, str): + img = mmcv.imread(img, channel_order='rgb') + elif isinstance(img, np.ndarray): + img = mmcv.bgr2rgb(img) visualizer.add_datasample( 'result', @@ -38,8 +38,7 @@ def process_one_image(args, img_path, pose_estimator, visualizer, show_kpt_idx=args.show_kpt_idx, show=args.show, wait_time=show_interval, - out_file=out_file, - kpt_score_thr=args.kpt_thr) + kpt_thr=args.kpt_thr) return results.pred_instances @@ -97,8 +96,11 @@ def main(): args = parse_args() assert args.show or (args.output_root != '') assert args.input != '' + output_file = None if args.output_root: mmengine.mkdir_or_exist(args.output_root) + output_file = os.path.join(args.output_root, + os.path.basename(args.input)) if args.save_predictions: assert args.output_root != '' args.pred_save_path = f'{args.output_root}/results_' \ @@ -128,36 +130,40 @@ def main(): args, args.input, model, visualizer, show_interval=0) pred_instances_list = split_instances(pred_instances) + if output_file: + img_vis = visualizer.get_image() + mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) + elif input_type == 'video': - tmp_folder = tempfile.TemporaryDirectory() - video = mmcv.VideoReader(args.input) - progressbar = mmengine.ProgressBar(len(video)) - video.cvt2frames(tmp_folder.name, show_progress=False) - output_root = args.output_root - args.output_root = tmp_folder.name + video_reader = mmcv.VideoReader(args.input) + video_writer = None + pred_instances_list = [] - for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)): + for frame_id, frame in enumerate(track_iter_progress(video_reader)): pred_instances = process_one_image( - args, - f'{tmp_folder.name}/{img_fname}', - model, - visualizer, - show_interval=1) - progressbar.update() + args, frame, model, visualizer, show_interval=0.001) + pred_instances_list.append( dict( frame_id=frame_id, instances=split_instances(pred_instances))) - if output_root: - mmcv.frames2video( - tmp_folder.name, - f'{output_root}/{os.path.basename(args.input)}', - fps=video.fps, - fourcc='mp4v', - show_progress=False) - tmp_folder.cleanup() + if output_file: + frame_vis = visualizer.get_image() + if video_writer is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # the size of the image with visualization may vary + # depending on the presence of heatmaps + video_writer = cv2.VideoWriter(output_file, fourcc, + video_reader.fps, + (frame_vis.shape[1], + frame_vis.shape[0])) + + video_writer.write(mmcv.rgb2bgr(frame_vis)) + + if video_writer: + video_writer.release() else: args.save_predictions = False diff --git a/demo/inferencer_demo.py b/demo/inferencer_demo.py index e0609596ec..73bd9c5bc3 100644 --- a/demo/inferencer_demo.py +++ b/demo/inferencer_demo.py @@ -60,6 +60,11 @@ def parse_args(): '--draw-bbox', action='store_true', help='Whether to draw the bounding boxes.') + parser.add_argument( + '--draw-heatmap', + action='store_true', + default=False, + help='Whether to draw the predicted heatmaps.') parser.add_argument( '--bbox-thr', type=float, @@ -104,6 +109,7 @@ def parse_args(): 'det_weights', 'det_cat_ids' ] init_args = {} + init_args['output_heatmaps'] = call_args.pop('draw_heatmap') for init_kw in init_kws: init_args[init_kw] = call_args.pop(init_kw) diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index f0938000f1..03658c5a7a 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import mimetypes import os -import tempfile from argparse import ArgumentParser +import cv2 import json_tricks as json import mmcv import mmengine import numpy as np +from mmengine.utils import track_iter_progress from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator @@ -23,12 +24,12 @@ has_mmdet = False -def process_one_image(args, img_path, detector, pose_estimator, visualizer, +def process_one_image(args, img, detector, pose_estimator, visualizer, show_interval): """Visualize predicted keypoints (and heatmaps) of one image.""" # predict bbox - det_result = inference_detector(detector, img_path) + det_result = inference_detector(detector, img) pred_instance = det_result.pred_instances.cpu().numpy() bboxes = np.concatenate( (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) @@ -37,15 +38,14 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer, bboxes = bboxes[nms(bboxes, args.nms_thr), :4] # predict keypoints - pose_results = inference_topdown(pose_estimator, img_path, bboxes) + pose_results = inference_topdown(pose_estimator, img, bboxes) data_samples = merge_data_samples(pose_results) # show the results - img = mmcv.imread(img_path, channel_order='rgb') - - out_file = None - if args.output_root: - out_file = f'{args.output_root}/{os.path.basename(img_path)}' + if isinstance(img, str): + img = mmcv.imread(img, channel_order='rgb') + elif isinstance(img, np.ndarray): + img = mmcv.bgr2rgb(img) visualizer.add_datasample( 'result', @@ -58,7 +58,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer, skeleton_style=args.skeleton_style, show=args.show, wait_time=show_interval, - out_file=out_file, kpt_thr=args.kpt_thr) # if there is no instance detected, return None @@ -154,8 +153,11 @@ def main(): assert args.input != '' assert args.det_config is not None assert args.det_checkpoint is not None + output_file = None if args.output_root: mmengine.mkdir_or_exist(args.output_root) + output_file = os.path.join(args.output_root, + os.path.basename(args.input)) if args.save_predictions: assert args.output_root != '' args.pred_save_path = f'{args.output_root}/results_' \ @@ -196,38 +198,45 @@ def main(): show_interval=0) pred_instances_list = split_instances(pred_instances) + if output_file: + img_vis = visualizer.get_image() + mmcv.imwrite(mmcv.rgb2bgr(img_vis), output_file) + elif input_type == 'video': - tmp_folder = tempfile.TemporaryDirectory() - video = mmcv.VideoReader(args.input) - progressbar = mmengine.ProgressBar(len(video)) - video.cvt2frames(tmp_folder.name, show_progress=False) - output_root = args.output_root - args.output_root = tmp_folder.name + video_reader = mmcv.VideoReader(args.input) + video_writer = None + pred_instances_list = [] - for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)): + for frame_id, frame in enumerate(track_iter_progress(video_reader)): pred_instances = process_one_image( args, - f'{tmp_folder.name}/{img_fname}', + frame, detector, pose_estimator, visualizer, - show_interval=1) + show_interval=0.001) - progressbar.update() pred_instances_list.append( dict( frame_id=frame_id, instances=split_instances(pred_instances))) - if output_root: - mmcv.frames2video( - tmp_folder.name, - f'{output_root}/{os.path.basename(args.input)}', - fps=video.fps, - fourcc='mp4v', - show_progress=False) - tmp_folder.cleanup() + if output_file: + frame_vis = visualizer.get_image() + if video_writer is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # the size of the image with visualization may vary + # depending on the presence of heatmaps + video_writer = cv2.VideoWriter(output_file, fourcc, + video_reader.fps, + (frame_vis.shape[1], + frame_vis.shape[0])) + + video_writer.write(mmcv.rgb2bgr(frame_vis)) + + if video_writer: + video_writer.release() else: args.save_predictions = False diff --git a/demo/webcam_api_demo.py b/demo/webcam_api_demo.py index 7e047ea6b5..7d7ad263b1 100644 --- a/demo/webcam_api_demo.py +++ b/demo/webcam_api_demo.py @@ -13,8 +13,7 @@ def parse_args(): parser = ArgumentParser('Webcam executor configs') parser.add_argument( - '--config', type=str, default='demo/webcam_cfg/pose_estimation.py') - + '--config', type=str, default='demo/webcam_cfg/human_pose.py') parser.add_argument( '--cfg-options', nargs='+', diff --git a/demo/webcam_cfg/pose_estimation.py b/demo/webcam_cfg/human_animal_pose.py similarity index 100% rename from demo/webcam_cfg/pose_estimation.py rename to demo/webcam_cfg/human_animal_pose.py diff --git a/demo/webcam_cfg/human_pose.py b/demo/webcam_cfg/human_pose.py new file mode 100644 index 0000000000..d1bac5722a --- /dev/null +++ b/demo/webcam_cfg/human_pose.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +executor_cfg = dict( + # Basic configurations of the executor + name='Pose Estimation', + camera_id=0, + # Define nodes. + # The configuration of a node usually includes: + # 1. 'type': Node class name + # 2. 'name': Node name + # 3. I/O buffers (e.g. 'input_buffer', 'output_buffer'): specify the + # input and output buffer names. This may depend on the node class. + # 4. 'enable_key': assign a hot-key to toggle enable/disable this node. + # This may depend on the node class. + # 5. Other class-specific arguments + nodes=[ + # 'DetectorNode': + # This node performs object detection from the frame image using an + # MMDetection model. + dict( + type='DetectorNode', + name='detector', + model_config='projects/rtmpose/rtmdet/person/' + 'rtmdet_nano_320-8xb32_coco-person.py', + model_checkpoint='https://download.openmmlab.com/mmpose/v1/' + 'projects/rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth', # noqa + input_buffer='_input_', # `_input_` is an executor-reserved buffer + output_buffer='det_result'), + # 'TopdownPoseEstimatorNode': + # This node performs keypoint detection from the frame image using an + # MMPose top-down model. Detection results is needed. + dict( + type='TopdownPoseEstimatorNode', + name='human pose estimator', + model_config='projects/rtmpose/rtmpose/body_2d_keypoint/' + 'rtmpose-t_8xb256-420e_coco-256x192.py', + model_checkpoint='https://download.openmmlab.com/mmpose/v1/' + 'projects/rtmpose/rtmpose-tiny_simcc-aic-coco_pt-aic-coco_420e-256x192-cfc8f33d_20230126.pth', # noqa + labels=['person'], + input_buffer='det_result', + output_buffer='human_pose'), + # 'ObjectAssignerNode': + # This node binds the latest model inference result with the current + # frame. (This means the frame image and inference result may be + # asynchronous). + dict( + type='ObjectAssignerNode', + name='object assigner', + frame_buffer='_frame_', # `_frame_` is an executor-reserved buffer + object_buffer='human_pose', + output_buffer='frame'), + # 'ObjectVisualizerNode': + # This node draw the pose visualization result in the frame image. + # Pose results is needed. + dict( + type='ObjectVisualizerNode', + name='object visualizer', + enable_key='v', + enable=True, + show_bbox=True, + must_have_keypoint=False, + show_keypoint=True, + input_buffer='frame', + output_buffer='vis'), + # 'NoticeBoardNode': + # This node show a notice board with given content, e.g. help + # information. + dict( + type='NoticeBoardNode', + name='instruction', + enable_key='h', + enable=True, + input_buffer='vis', + output_buffer='vis_notice', + content_lines=[ + 'This is a demo for pose visualization and simple image ' + 'effects. Have fun!', '', 'Hot-keys:', + '"v": Pose estimation result visualization', + '"h": Show help information', + '"m": Show diagnostic information', '"q": Exit' + ], + ), + # 'MonitorNode': + # This node show diagnostic information in the frame image. It can + # be used for debugging or monitoring system resource status. + dict( + type='MonitorNode', + name='monitor', + enable_key='m', + enable=False, + input_buffer='vis_notice', + output_buffer='display'), + # 'RecorderNode': + # This node save the output video into a file. + dict( + type='RecorderNode', + name='recorder', + out_video_file='webcam_api_demo.mp4', + input_buffer='display', + output_buffer='_display_' + # `_display_` is an executor-reserved buffer + ) + ]) diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md index de61a7b446..737aa87106 100644 --- a/docs/en/user_guides/inference.md +++ b/docs/en/user_guides/inference.md @@ -75,6 +75,8 @@ inferencer = MMPoseInferencer( The complere list of model alias can be found in the [Model Alias](#model-alias) section. +**Custom Object Detector for Top-down Pose Estimation Models** + In addition, top-down pose estimators also require an object detection model. The inferencer is capable of inferring the instance type for models trained with datasets supported in MMPose, and subsequently constructing the necessary object detection model. Alternatively, users may also manually specify the detection model using the following methods: ```python @@ -107,6 +109,8 @@ inferencer = MMPoseInferencer( ) ``` +To perform top-down pose estimation on cropped images containing a single object, users can set `det_model='whole_image'`. This bypasses the object detector initialization, creating a bounding box that matches the input image size and directly sending the entire image to the top-down pose estimator. + ### Dump Results After performing pose estimation, you might want to save the results for further analysis or processing. This section will guide you through saving the predicted keypoints and visualizations to your local machine. diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py index 29c9ae33fa..167b30276a 100644 --- a/mmpose/apis/inferencers/base_mmpose_inferencer.py +++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py @@ -1,12 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import mimetypes import os -import shutil -import tempfile import warnings from collections import defaultdict -from typing import (Any, Callable, Dict, Generator, List, Optional, Sequence, - Union) +from typing import (Callable, Dict, Generator, Iterable, List, Optional, + Sequence, Union) import cv2 import mmcv @@ -20,6 +18,7 @@ from mmengine.infer.infer import BaseInferencer from mmengine.runner.checkpoint import _load_checkpoint_to_model from mmengine.structures import InstanceData +from mmengine.utils import mkdir_or_exist from mmpose.apis.inference import dataset_meta_from_config from mmpose.structures import PoseDataSample, split_instances @@ -83,7 +82,7 @@ def _load_weights_to_model(self, model: nn.Module, model.dataset_meta = dataset_meta_from_config( cfg, dataset_mode='train') - def _inputs_to_list(self, inputs: InputsType) -> list: + def _inputs_to_list(self, inputs: InputsType) -> Iterable: """Preprocess the inputs to a list. Preprocess inputs to a list according to its type: @@ -126,21 +125,24 @@ def _inputs_to_list(self, inputs: InputsType) -> list: input_type = mimetypes.guess_type(inputs)[0].split('/')[0] if input_type == 'video': self._video_input = True - # split video frames into a temporary folder - frame_folder = tempfile.TemporaryDirectory() video = mmcv.VideoReader(inputs) self.video_info = dict( fps=video.fps, name=os.path.basename(inputs), - frame_folder=frame_folder) - video.cvt2frames(frame_folder.name, show_progress=False) - frames = sorted(list_dir_or_file(frame_folder.name)) - inputs = [join_path(frame_folder.name, f) for f in frames] + writer=None, + predictions=[]) + inputs = video + elif input_type == 'image': + inputs = [inputs] + else: + raise ValueError(f'Expected input to be an image, video, ' + f'or folder, but received {inputs} of ' + f'type {input_type}.') - if not isinstance(inputs, (list, tuple)): + elif isinstance(inputs, np.ndarray): inputs = [inputs] - return list(inputs) + return inputs def _get_webcam_inputs(self, inputs: str) -> Generator: """Sets up and returns a generator function that reads frames from a @@ -182,7 +184,8 @@ def _get_webcam_inputs(self, inputs: str) -> Generator: # Set video input flag and metadata. self._video_input = True - self.video_info = dict(fps=10, name='webcam.mp4', frame_folder=None) + self.video_info = dict( + fps=10, name='webcam.mp4', writer=None, predictions=[]) # Set up webcam reader generator function. self._window_closing = False @@ -288,37 +291,28 @@ def visualize(self, if isinstance(single_input, str): img = mmcv.imread(single_input, channel_order='rgb') elif isinstance(single_input, np.ndarray): - img = mmcv.bgr2rgb(single_input.copy()) + img = mmcv.bgr2rgb(single_input) else: raise ValueError('Unsupported input type: ' f'{type(single_input)}') img_name = os.path.basename(pred.metainfo['img_path']) - - if vis_out_dir: - if self._video_input: - out_file = join_path(vis_out_dir, 'vis_frames', img_name) - else: - out_file = join_path(vis_out_dir, img_name) - else: - out_file = None + window_name = window_name if window_name else img_name # since visualization and inference utilize the same process, # the wait time is reduced when a video input is utilized, # thereby eliminating the issue of inference getting stuck. wait_time = 1e-5 if self._video_input else wait_time - window_name = window_name if window_name else img_name - visualization = self.visualizer.add_datasample( window_name, img, pred, draw_gt=False, draw_bbox=draw_bbox, + draw_heatmap=True, show=show, wait_time=wait_time, - out_file=out_file, kpt_thr=kpt_thr) results.append(visualization) @@ -332,6 +326,26 @@ def visualize(self, window_close_event_handler ) + if vis_out_dir: + out_img = mmcv.rgb2bgr(visualization) + + if self._video_input: + + if self.video_info['writer'] is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + mkdir_or_exist(vis_out_dir) + out_file = join_path( + vis_out_dir, + os.path.basename(self.video_info['name'])) + self.video_info['writer'] = cv2.VideoWriter( + out_file, fourcc, self.video_info['fps'], + (visualization.shape[1], visualization.shape[0])) + self.video_info['writer'].write(out_img) + + else: + out_file = join_path(vis_out_dir, img_name) + mmcv.imwrite(out_img, out_file) + if return_vis: return results else: @@ -384,66 +398,42 @@ def postprocess( result_dict['predictions'].append(pred) if pred_out_dir != '': - if self._video_input: - pred_out_dir = join_path(pred_out_dir, 'pred_frames') - for pred, data_sample in zip(result_dict['predictions'], preds): - fname = os.path.splitext( - os.path.basename( - data_sample.metainfo['img_path']))[0] + '.json' - mmengine.dump( - pred, join_path(pred_out_dir, fname), indent=' ') + if self._video_input: + self.video_info['predictions'].append(pred) + else: + fname = os.path.splitext( + os.path.basename( + data_sample.metainfo['img_path']))[0] + '.json' + mmengine.dump( + pred, join_path(pred_out_dir, fname), indent=' ') return result_dict - def _merge_outputs(self, vis_out_dir: str, pred_out_dir: str, - **kwargs: Dict[str, Any]) -> None: - """Merge the visualized frames and predicted instance outputs and save - them. + def _finalize_video_processing( + self, + pred_out_dir: str = '', + ): + """Finalize video processing by releasing the video writer and saving + predictions to a file. - Args: - vis_out_dir (str): Path to the directory where the visualized - frames are saved. - pred_out_dir (str): Path to the directory where the predicted - instance outputs are saved. - **kwargs: Other arguments that are not used in this method. + This method should be called after completing the video processing. It + releases the video writer, if it exists, and saves the predictions to a + JSON file if a prediction output directory is provided. """ - assert self._video_input - - if vis_out_dir != '': - vis_frame_out_dir = join_path(vis_out_dir, 'vis_frames') - if not isdir(vis_frame_out_dir) or len( - os.listdir(vis_frame_out_dir)) == 0: - warnings.warn( - f'{vis_frame_out_dir} does not exist or is empty.') - else: - mmcv.frames2video( - vis_frame_out_dir, - join_path(vis_out_dir, self.video_info['name']), - fps=self.video_info['fps'], - fourcc='mp4v', - show_progress=False) - shutil.rmtree(vis_frame_out_dir) - if pred_out_dir != '': - pred_frame_out_dir = join_path(pred_out_dir, 'pred_frames') - if not isdir(pred_frame_out_dir) or len( - os.listdir(pred_frame_out_dir)) == 0: - warnings.warn( - f'{pred_frame_out_dir} does not exist or is empty.') - else: - predictions = [] - pred_files = list_dir_or_file(pred_frame_out_dir) - for frame_id, pred_file in enumerate(sorted(pred_files)): - predictions.append({ - 'frame_id': - frame_id, - 'instances': - mmengine.load( - join_path(pred_frame_out_dir, pred_file)) - }) - fname = os.path.splitext( - os.path.basename(self.video_info['name']))[0] + '.json' - mmengine.dump( - predictions, join_path(pred_out_dir, fname), indent=' ') - shutil.rmtree(pred_frame_out_dir) + # Release the video writer if it exists + if self.video_info['writer'] is not None: + self.video_info['writer'].release() + + # Save predictions + if pred_out_dir: + fname = os.path.splitext( + os.path.basename(self.video_info['name']))[0] + '.json' + predictions = [ + dict(frame_id=i, instances=pred) + for i, pred in enumerate(self.video_info['predictions']) + ] + + mmengine.dump( + predictions, join_path(pred_out_dir, fname), indent=' ') diff --git a/mmpose/apis/inferencers/mmpose_inferencer.py b/mmpose/apis/inferencers/mmpose_inferencer.py index ab0aeb9618..845b3d066a 100644 --- a/mmpose/apis/inferencers/mmpose_inferencer.py +++ b/mmpose/apis/inferencers/mmpose_inferencer.py @@ -5,7 +5,6 @@ import numpy as np import torch from mmengine.config import Config, ConfigDict -from mmengine.fileio import join_path from mmengine.infer.infer import ModelType from mmengine.structures import InstanceData from rich.progress import track @@ -51,6 +50,9 @@ class MMPoseInferencer(BaseMMPoseInferencer): model. Defaults to None. det_cat_ids(int or list[int], optional): Category id for detection model. Defaults to None. + output_heatmaps (bool, optional): Flag to visualize predicted + heatmaps. If set to None, the default setting from the model + config will be used. Default is None. """ preprocess_kwargs: set = {'bbox_thr', 'nms_thr'} @@ -74,17 +76,18 @@ def __init__(self, scope: str = 'mmpose', det_model: Optional[Union[ModelType, str]] = None, det_weights: Optional[str] = None, - det_cat_ids: Optional[Union[int, List]] = None) -> None: + det_cat_ids: Optional[Union[int, List]] = None, + output_heatmaps: Optional[bool] = None) -> None: if pose2d is None: raise ValueError('2d pose estimation algorithm should provided.') self.visualizer = None + self.inferencers = dict() if pose2d is not None: - self.pose2d_inferencer = Pose2DInferencer(pose2d, pose2d_weights, - device, scope, det_model, - det_weights, det_cat_ids) - self.mode = 'pose2d' + self.inferencers['pose2d'] = Pose2DInferencer( + pose2d, pose2d_weights, device, scope, det_model, det_weights, + det_cat_ids, output_heatmaps) def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): """Process the inputs into a model-feedable format. @@ -100,10 +103,10 @@ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): for i, input in enumerate(inputs): data_batch = {} - if 'pose2d' in self.mode: - data_infos = self.pose2d_inferencer.preprocess_single( + if 'pose2d' in self.inferencers: + data_infos = self.inferencers['pose2d'].preprocess_single( input, index=i, **kwargs) - data_batch['pose2d'] = self.pose2d_inferencer.collate_fn( + data_batch['pose2d'] = self.inferencers['pose2d'].collate_fn( data_infos) # only supports inference with batch size 1 yield data_batch, [input] @@ -119,10 +122,8 @@ def forward(self, inputs: InputType, **forward_kwargs) -> PredType: Dict: The prediction results. Possibly with keys "pose2d". """ result = {} - if self.mode == 'pose2d': - data_samples = self.pose2d_inferencer.forward( - inputs['pose2d'], **forward_kwargs) - result['pose2d'] = data_samples + for mode, inferencer in self.inferencers.items(): + result[mode] = inferencer.forward(inputs[mode], **forward_kwargs) return result @@ -179,11 +180,16 @@ def __call__( inputs = self.preprocess( inputs, batch_size=batch_size, **preprocess_kwargs) + # forward forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) + for inferencer in self.inferencers.values(): + inferencer._video_input = self._video_input + if self._video_input: + inferencer.video_info = self.video_info preds = [] - if 'pose2d' not in self.mode or not hasattr(self.pose2d_inferencer, - 'detector'): + if 'pose2d' not in self.inferencers or not hasattr( + self.inferencers['pose2d'], 'detector'): inputs = track(inputs, description='Inference') for proc_inputs, ori_inputs in inputs: @@ -195,9 +201,9 @@ def __call__( **postprocess_kwargs) yield results - # merge visualization and prediction results if self._video_input: - self._merge_outputs(**visualize_kwargs, **postprocess_kwargs) + self._finalize_video_processing( + postprocess_kwargs.get('pred_out_dir', '')) def visualize(self, inputs: InputsType, preds: PredType, **kwargs) -> List[np.ndarray]: @@ -222,16 +228,11 @@ def visualize(self, inputs: InputsType, preds: PredType, List[np.ndarray]: Visualization results. """ - if 'pose2d' in self.mode: + if 'pose2d' in self.inferencers: window_name = '' if self._video_input: window_name = self.video_info['name'] - if kwargs.get('vis_out_dir', ''): - kwargs['vis_out_dir'] = join_path(kwargs['vis_out_dir'], - 'vis_frames') - if kwargs.get('show', False): - kwargs['wait_time'] = 1e-5 - return self.pose2d_inferencer.visualize( + return self.inferencers['pose2d'].visualize( inputs, preds['pose2d'], window_name=window_name, @@ -275,6 +276,6 @@ def postprocess( as strings and numbers. """ - if 'pose2d' in self.mode: + if 'pose2d' in self.inferencers: return super().postprocess(preds['pose2d'], visualization, return_datasample, pred_out_dir) diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py index 63244b1f95..adf80543a5 100644 --- a/mmpose/apis/inferencers/pose2d_inferencer.py +++ b/mmpose/apis/inferencers/pose2d_inferencer.py @@ -55,12 +55,15 @@ class Pose2DInferencer(BaseMMPoseInferencer): device (str, optional): Device to run inference. If None, the available device will be automatically used. Defaults to None. scope (str, optional): The scope of the model. Defaults to "mmpose". - det_model(str, optional): Config path or alias of detection model. + det_model (str, optional): Config path or alias of detection model. Defaults to None. - det_weights(str, optional): Path to the checkpoints of detection + det_weights (str, optional): Path to the checkpoints of detection model. Defaults to None. - det_cat_ids(int or list[int], optional): Category id for + det_cat_ids (int or list[int], optional): Category id for detection model. Defaults to None. + output_heatmaps (bool, optional): Flag to visualize predicted + heatmaps. If set to None, the default setting from the model + config will be used. Default is None. """ preprocess_kwargs: set = {'bbox_thr', 'nms_thr'} @@ -84,44 +87,48 @@ def __init__(self, scope: Optional[str] = 'mmpose', det_model: Optional[Union[ModelType, str]] = None, det_weights: Optional[str] = None, - det_cat_ids: Optional[Union[int, Tuple]] = None) -> None: + det_cat_ids: Optional[Union[int, Tuple]] = None, + output_heatmaps: Optional[bool] = None) -> None: init_default_scope(scope) super().__init__( model=model, weights=weights, device=device, scope=scope) self.model = revert_sync_batchnorm(self.model) + if output_heatmaps is not None: + self.model.test_cfg['output_heatmaps'] = output_heatmaps # assign dataset metainfo to self.visualizer self.visualizer.set_dataset_meta(self.model.dataset_meta) # initialize detector for top-down models if self.cfg.data_mode == 'topdown': - det_scope = 'mmdet' - if det_model is None: - det_model = DATASETS.get( - self.cfg.dataset_type).__module__.split( - 'datasets.')[-1].split('.')[0].lower() - det_info = default_det_models[det_model] - det_model, det_weights, det_cat_ids = det_info[ - 'model'], det_info['weights'], det_info['cat_ids'] - elif os.path.exists(det_model): - det_cfg = Config.fromfile(det_model) - det_scope = det_cfg.default_scope - - if has_mmdet: - self.detector = DetInferencer( - det_model, det_weights, device=device, scope=det_scope) - self.detector.model = revert_sync_batchnorm( - self.detector.model) + if det_model != 'whole_image': + det_scope = 'mmdet' + if det_model is None: + det_model = DATASETS.get( + self.cfg.dataset_type).__module__.split( + 'datasets.')[-1].split('.')[0].lower() + det_info = default_det_models[det_model] + det_model, det_weights, det_cat_ids = det_info[ + 'model'], det_info['weights'], det_info['cat_ids'] + elif os.path.exists(det_model): + det_cfg = Config.fromfile(det_model) + det_scope = det_cfg.default_scope + + if has_mmdet: + self.detector = DetInferencer( + det_model, det_weights, device=device, scope=det_scope) + else: + raise RuntimeError( + 'MMDetection (v3.0.0 or above) is required to build ' + 'inferencers for top-down pose estimation models.') + + if isinstance(det_cat_ids, (tuple, list)): + self.det_cat_ids = det_cat_ids + else: + self.det_cat_ids = (det_cat_ids, ) else: - raise RuntimeError( - 'MMDetection (v3.0.0rc6 or above) is required to ' - 'build inferencers for top-down pose estimation models.') - - if isinstance(det_cat_ids, (tuple, list)): - self.det_cat_ids = det_cat_ids - else: - self.det_cat_ids = (det_cat_ids, ) + self.detector = None self._video_input = False @@ -151,20 +158,24 @@ def preprocess_single(self, data_info.update(self.model.dataset_meta) if self.cfg.data_mode == 'topdown': - det_results = self.detector( - input, return_datasample=True)['predictions'] - pred_instance = det_results[0].pred_instances.cpu().numpy() - bboxes = np.concatenate( - (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) - - label_mask = np.zeros(len(bboxes), dtype=np.uint8) - for cat_id in self.det_cat_ids: - label_mask = np.logical_or(label_mask, - pred_instance.labels == cat_id) - - bboxes = bboxes[np.logical_and(label_mask, - pred_instance.scores > bbox_thr)] - bboxes = bboxes[nms(bboxes, nms_thr)] + if self.detector is not None: + det_results = self.detector( + input, return_datasample=True)['predictions'] + pred_instance = det_results[0].pred_instances.cpu().numpy() + bboxes = np.concatenate( + (pred_instance.bboxes, pred_instance.scores[:, None]), + axis=1) + + label_mask = np.zeros(len(bboxes), dtype=np.uint8) + for cat_id in self.det_cat_ids: + label_mask = np.logical_or(label_mask, + pred_instance.labels == cat_id) + + bboxes = bboxes[np.logical_and( + label_mask, pred_instance.scores > bbox_thr)] + bboxes = bboxes[nms(bboxes, nms_thr)] + else: + bboxes = [] data_infos = [] if len(bboxes) > 0: @@ -270,6 +281,6 @@ def __call__( **postprocess_kwargs) yield results - # merge visualization and prediction results if self._video_input: - self._merge_outputs(**visualize_kwargs, **postprocess_kwargs) + self._finalize_video_processing( + postprocess_kwargs.get('pred_out_dir', '')) diff --git a/mmpose/apis/inferencers/utils/get_model_alias.py b/mmpose/apis/inferencers/utils/get_model_alias.py index 8e8f85910c..49de6528d6 100644 --- a/mmpose/apis/inferencers/utils/get_model_alias.py +++ b/mmpose/apis/inferencers/utils/get_model_alias.py @@ -30,8 +30,8 @@ def get_model_aliases(scope: str = 'mmpose') -> Dict[str, str]: model_alias_dict[alias] = model_cfg['Name'] else: raise ValueError( - 'encounter an unexpected alias type. Please ' - 'raise an issue at https://github.com/open-mmlab/mmpose/issues ' # noqa + 'encounter an unexpected alias type. Please raise an ' + 'issue at https://github.com/open-mmlab/mmpose/issues ' 'to announce us') return model_alias_dict diff --git a/projects/yolox-pose/README.md b/projects/yolox-pose/README.md index e880301ae6..4dd3aa1c70 100644 --- a/projects/yolox-pose/README.md +++ b/projects/yolox-pose/README.md @@ -16,7 +16,7 @@ This project implements a YOLOX-based human pose estimator, utilizing the approa - [MMYOLO](https://github.com/open-mmlab/mmyolo) v0.5.0 or higher - [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0rc1 or higher -All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `yolox-pose/` root directory, run the following line to add the current directory to `PYTHONPATH`: +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. **In `yolox-pose/` root directory**, run the following line to add the current directory to `PYTHONPATH`: ```shell export PYTHONPATH=`pwd`:$PYTHONPATH @@ -91,7 +91,7 @@ Results on COCO val2017 | Model | Input Size | AP | AP50 | AP75 | AR | AR50 | Download | | :-------------------------------------------------------------: | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :----------------------------------------------------------------------: | -| [YOLOX-tiny-Pose](./configs/yolox-pose_tiny_4xb64-300e_coco.py) | 640 | 0.518 | 0.799 | 0.545 | 0.566 | 0.841 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco-c47dd83b_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco_20230321.json) | +| [YOLOX-tiny-Pose](./configs/yolox-pose_tiny_4xb64-300e_coco.py) | 416 | 0.518 | 0.799 | 0.545 | 0.566 | 0.841 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco-c47dd83b_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_tiny_4xb64-300e_coco_20230321.json) | | [YOLOX-s-Pose](./configs/yolox-pose_s_8xb32-300e_coco.py) | 640 | 0.632 | 0.875 | 0.692 | 0.676 | 0.907 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco-9f5e3924_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_s_8xb32-300e_coco_20230321.json) | | [YOLOX-m-Pose](./configs/yolox-pose_m_4xb64-300e_coco.py) | 640 | 0.685 | 0.897 | 0.753 | 0.727 | 0.925 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco-cbd11d30_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_m_4xb64-300e_coco_20230321.json) | | [YOLOX-l-Pose](./configs/yolox-pose_l_4xb64-300e_coco.py) | 640 | 0.706 | 0.907 | 0.775 | 0.747 | 0.934 | [model](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco-122e4cf8_20230321.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/yolox-pose/yolox-pose_l_4xb64-300e_coco_20230321.json) |