From 7df3e22cb6ecd2051b5f10c1c62a6ec02896f301 Mon Sep 17 00:00:00 2001 From: lupeng Date: Tue, 13 Jun 2023 11:50:43 +0800 Subject: [PATCH] support openpose style visualization with inferencer --- demo/inferencer_demo.py | 7 +++- .../inferencers/base_mmpose_inferencer.py | 17 ++++----- mmpose/apis/inferencers/mmpose_inferencer.py | 23 ++++-------- mmpose/apis/inferencers/pose2d_inferencer.py | 36 +++++++++++++++---- mmpose/apis/inferencers/pose3d_inferencer.py | 2 ++ 5 files changed, 53 insertions(+), 32 deletions(-) diff --git a/demo/inferencer_demo.py b/demo/inferencer_demo.py index d7bbbb5b52..348eea05d5 100644 --- a/demo/inferencer_demo.py +++ b/demo/inferencer_demo.py @@ -120,6 +120,12 @@ def parse_args(): type=int, default=1, help='Link thickness for visualization.') + parser.add_argument( + '--skeleton-style', + default='mmpose', + type=str, + choices=['mmpose', 'openpose'], + help='Skeleton style selection') parser.add_argument( '--vis-out-dir', type=str, @@ -142,7 +148,6 @@ def parse_args(): 'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights' ] init_args = {} - init_args['output_heatmaps'] = call_args.pop('draw_heatmap') for init_kw in init_kws: init_args[init_kw] = call_args.pop(init_kw) diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py index f914793086..bb1590dc27 100644 --- a/mmpose/apis/inferencers/base_mmpose_inferencer.py +++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py @@ -216,9 +216,6 @@ def _webcam_reader() -> Generator: return _webcam_reader() - def _visualization_window_on_close(self, event): - self._window_closing = True - def _init_pipeline(self, cfg: ConfigType) -> Callable: """Initialize the test pipeline. @@ -233,6 +230,12 @@ def _init_pipeline(self, cfg: ConfigType) -> Callable: init_default_scope(cfg.get('default_scope', 'mmpose')) return Compose(cfg.test_dataloader.dataset.pipeline) + def update_model_visualizer_settings(self, **kwargs): + """Update the settings of models and visualizer according to inference + arguments.""" + + pass + def preprocess(self, inputs: InputsType, batch_size: int = 1, @@ -268,8 +271,7 @@ def visualize(self, kpt_thr: float = 0.3, vis_out_dir: str = '', window_name: str = '', - window_close_event_handler: Optional[Callable] = None - ) -> List[np.ndarray]: + **kwargs) -> List[np.ndarray]: """Visualize predictions. Args: @@ -289,7 +291,6 @@ def visualize(self, results w/o predictions. If left as empty, no file will be saved. Defaults to ''. window_name (str, optional): Title of display window. - window_close_event_handler (callable, optional): Returns: List[np.ndarray]: Visualization results. @@ -329,10 +330,10 @@ def visualize(self, pred, draw_gt=False, draw_bbox=draw_bbox, - draw_heatmap=True, show=show, wait_time=wait_time, - kpt_thr=kpt_thr) + kpt_thr=kpt_thr, + **kwargs) results.append(visualization) if vis_out_dir: diff --git a/mmpose/apis/inferencers/mmpose_inferencer.py b/mmpose/apis/inferencers/mmpose_inferencer.py index d7050272f6..916f83889a 100644 --- a/mmpose/apis/inferencers/mmpose_inferencer.py +++ b/mmpose/apis/inferencers/mmpose_inferencer.py @@ -60,14 +60,8 @@ class MMPoseInferencer(BaseMMPoseInferencer): } forward_kwargs: set = {'rebase_keypoint_height'} visualize_kwargs: set = { - 'return_vis', - 'show', - 'wait_time', - 'draw_bbox', - 'radius', - 'thickness', - 'kpt_thr', - 'vis_out_dir', + 'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness', + 'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap' } postprocess_kwargs: set = {'pred_out_dir'} @@ -80,8 +74,7 @@ def __init__(self, scope: str = 'mmpose', det_model: Optional[Union[ModelType, str]] = None, det_weights: Optional[str] = None, - det_cat_ids: Optional[Union[int, List]] = None, - output_heatmaps: Optional[bool] = None) -> None: + det_cat_ids: Optional[Union[int, List]] = None) -> None: self.visualizer = None if pose3d is not None: @@ -92,7 +85,7 @@ def __init__(self, elif pose2d is not None: self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device, scope, det_model, det_weights, - det_cat_ids, output_heatmaps) + det_cat_ids) else: raise ValueError('Either 2d or 3d pose estimation algorithm ' 'should be provided.') @@ -177,6 +170,8 @@ def __call__( postprocess_kwargs, ) = self._dispatch_kwargs(**kwargs) + self.inferencer.update_model_visualizer_settings(**kwargs) + # preprocessing if isinstance(inputs, str) and inputs.startswith('webcam'): inputs = self.inferencer._get_webcam_inputs(inputs) @@ -240,8 +235,4 @@ def visualize(self, inputs: InputsType, preds: PredType, window_name = self.inferencer.video_info['name'] return self.inferencer.visualize( - inputs, - preds, - window_name=window_name, - window_close_event_handler=self._visualization_window_on_close, - **kwargs) + inputs, preds, window_name=window_name, **kwargs) diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py index 1e8e8d7550..3ac923e9f0 100644 --- a/mmpose/apis/inferencers/pose2d_inferencer.py +++ b/mmpose/apis/inferencers/pose2d_inferencer.py @@ -60,9 +60,6 @@ class Pose2DInferencer(BaseMMPoseInferencer): model. Defaults to None. det_cat_ids (int or list[int], optional): Category id for detection model. Defaults to None. - output_heatmaps (bool, optional): Flag to visualize predicted - heatmaps. If set to None, the default setting from the model - config will be used. Default is None. """ preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'} @@ -76,6 +73,8 @@ class Pose2DInferencer(BaseMMPoseInferencer): 'thickness', 'kpt_thr', 'vis_out_dir', + 'skeleton_style', + 'draw_heatmap', } postprocess_kwargs: set = {'pred_out_dir'} @@ -86,15 +85,12 @@ def __init__(self, scope: Optional[str] = 'mmpose', det_model: Optional[Union[ModelType, str]] = None, det_weights: Optional[str] = None, - det_cat_ids: Optional[Union[int, Tuple]] = None, - output_heatmaps: Optional[bool] = None) -> None: + det_cat_ids: Optional[Union[int, Tuple]] = None) -> None: init_default_scope(scope) super().__init__( model=model, weights=weights, device=device, scope=scope) self.model = revert_sync_batchnorm(self.model) - if output_heatmaps is not None: - self.model.test_cfg['output_heatmaps'] = output_heatmaps # assign dataset metainfo to self.visualizer self.visualizer.set_dataset_meta(self.model.dataset_meta) @@ -134,6 +130,30 @@ def __init__(self, self._video_input = False + def update_model_visualizer_settings(self, + draw_heatmap: bool = False, + skeleton_style: str = 'mmpose', + **kwargs) -> None: + """Update the settings of models and visualizer according to inference + arguments. + + Args: + draw_heatmaps (bool, optional): Flag to visualize predicted + heatmaps. If not provided, it defaults to False. + skeleton_style (str, optional): Skeleton style selection. Valid + options are 'mmpose' and 'openpose'. Defaults to 'mmpose'. + """ + self.model.test_cfg['output_heatmaps'] = draw_heatmap + + if skeleton_style not in ['mmpose', 'openpose']: + raise ValueError('`skeleton_style` must be either \'mmpose\' ' + 'or \'openpose\'') + + if skeleton_style == 'openpose': + self.visualizer.set_dataset_meta(self.model.dataset_meta, + skeleton_style) + self.visualizer.backend = 'matplotlib' + def preprocess_single(self, input: InputType, index: int, @@ -274,6 +294,8 @@ def __call__( postprocess_kwargs, ) = self._dispatch_kwargs(**kwargs) + self.update_model_visualizer_settings(**kwargs) + # preprocessing if isinstance(inputs, str) and inputs.startswith('webcam'): inputs = self._get_webcam_inputs(inputs) diff --git a/mmpose/apis/inferencers/pose3d_inferencer.py b/mmpose/apis/inferencers/pose3d_inferencer.py index d30302cfa2..d5b2a2998d 100644 --- a/mmpose/apis/inferencers/pose3d_inferencer.py +++ b/mmpose/apis/inferencers/pose3d_inferencer.py @@ -384,6 +384,8 @@ def __call__( postprocess_kwargs, ) = self._dispatch_kwargs(**kwargs) + self.update_model_visualizer_settings(**kwargs) + # preprocessing if isinstance(inputs, str) and inputs.startswith('webcam'): inputs = self._get_webcam_inputs(inputs)