Skip to content

Commit

Permalink
[Enahnce] Support openpose style visualization with inferencer (#2456)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Jun 13, 2023
1 parent 2b80fce commit bf3d9ee
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 32 deletions.
7 changes: 6 additions & 1 deletion demo/inferencer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def parse_args():
type=int,
default=1,
help='Link thickness for visualization.')
parser.add_argument(
'--skeleton-style',
default='mmpose',
type=str,
choices=['mmpose', 'openpose'],
help='Skeleton style selection')
parser.add_argument(
'--vis-out-dir',
type=str,
Expand All @@ -142,7 +148,6 @@ def parse_args():
'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights'
]
init_args = {}
init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)

Expand Down
17 changes: 9 additions & 8 deletions mmpose/apis/inferencers/base_mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def _webcam_reader() -> Generator:

return _webcam_reader()

def _visualization_window_on_close(self, event):
self._window_closing = True

def _init_pipeline(self, cfg: ConfigType) -> Callable:
"""Initialize the test pipeline.
Expand All @@ -233,6 +230,12 @@ def _init_pipeline(self, cfg: ConfigType) -> Callable:
init_default_scope(cfg.get('default_scope', 'mmpose'))
return Compose(cfg.test_dataloader.dataset.pipeline)

def update_model_visualizer_settings(self, **kwargs):
"""Update the settings of models and visualizer according to inference
arguments."""

pass

def preprocess(self,
inputs: InputsType,
batch_size: int = 1,
Expand Down Expand Up @@ -268,8 +271,7 @@ def visualize(self,
kpt_thr: float = 0.3,
vis_out_dir: str = '',
window_name: str = '',
window_close_event_handler: Optional[Callable] = None
) -> List[np.ndarray]:
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.
Args:
Expand All @@ -289,7 +291,6 @@ def visualize(self,
results w/o predictions. If left as empty, no file will
be saved. Defaults to ''.
window_name (str, optional): Title of display window.
window_close_event_handler (callable, optional):
Returns:
List[np.ndarray]: Visualization results.
Expand Down Expand Up @@ -329,10 +330,10 @@ def visualize(self,
pred,
draw_gt=False,
draw_bbox=draw_bbox,
draw_heatmap=True,
show=show,
wait_time=wait_time,
kpt_thr=kpt_thr)
kpt_thr=kpt_thr,
**kwargs)
results.append(visualization)

if vis_out_dir:
Expand Down
23 changes: 7 additions & 16 deletions mmpose/apis/inferencers/mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@ class MMPoseInferencer(BaseMMPoseInferencer):
}
forward_kwargs: set = {'rebase_keypoint_height'}
visualize_kwargs: set = {
'return_vis',
'show',
'wait_time',
'draw_bbox',
'radius',
'thickness',
'kpt_thr',
'vis_out_dir',
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap'
}
postprocess_kwargs: set = {'pred_out_dir'}

Expand All @@ -80,8 +74,7 @@ def __init__(self,
scope: str = 'mmpose',
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, List]] = None,
output_heatmaps: Optional[bool] = None) -> None:
det_cat_ids: Optional[Union[int, List]] = None) -> None:

self.visualizer = None
if pose3d is not None:
Expand All @@ -92,7 +85,7 @@ def __init__(self,
elif pose2d is not None:
self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device,
scope, det_model, det_weights,
det_cat_ids, output_heatmaps)
det_cat_ids)
else:
raise ValueError('Either 2d or 3d pose estimation algorithm '
'should be provided.')
Expand Down Expand Up @@ -177,6 +170,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.inferencer.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self.inferencer._get_webcam_inputs(inputs)
Expand Down Expand Up @@ -240,8 +235,4 @@ def visualize(self, inputs: InputsType, preds: PredType,
window_name = self.inferencer.video_info['name']

return self.inferencer.visualize(
inputs,
preds,
window_name=window_name,
window_close_event_handler=self._visualization_window_on_close,
**kwargs)
inputs, preds, window_name=window_name, **kwargs)
36 changes: 29 additions & 7 deletions mmpose/apis/inferencers/pose2d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class Pose2DInferencer(BaseMMPoseInferencer):
model. Defaults to None.
det_cat_ids (int or list[int], optional): Category id for
detection model. Defaults to None.
output_heatmaps (bool, optional): Flag to visualize predicted
heatmaps. If set to None, the default setting from the model
config will be used. Default is None.
"""

preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
Expand All @@ -76,6 +73,8 @@ class Pose2DInferencer(BaseMMPoseInferencer):
'thickness',
'kpt_thr',
'vis_out_dir',
'skeleton_style',
'draw_heatmap',
}
postprocess_kwargs: set = {'pred_out_dir'}

Expand All @@ -86,15 +85,12 @@ def __init__(self,
scope: Optional[str] = 'mmpose',
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, Tuple]] = None,
output_heatmaps: Optional[bool] = None) -> None:
det_cat_ids: Optional[Union[int, Tuple]] = None) -> None:

init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)
self.model = revert_sync_batchnorm(self.model)
if output_heatmaps is not None:
self.model.test_cfg['output_heatmaps'] = output_heatmaps

# assign dataset metainfo to self.visualizer
self.visualizer.set_dataset_meta(self.model.dataset_meta)
Expand Down Expand Up @@ -134,6 +130,30 @@ def __init__(self,

self._video_input = False

def update_model_visualizer_settings(self,
draw_heatmap: bool = False,
skeleton_style: str = 'mmpose',
**kwargs) -> None:
"""Update the settings of models and visualizer according to inference
arguments.
Args:
draw_heatmaps (bool, optional): Flag to visualize predicted
heatmaps. If not provided, it defaults to False.
skeleton_style (str, optional): Skeleton style selection. Valid
options are 'mmpose' and 'openpose'. Defaults to 'mmpose'.
"""
self.model.test_cfg['output_heatmaps'] = draw_heatmap

if skeleton_style not in ['mmpose', 'openpose']:
raise ValueError('`skeleton_style` must be either \'mmpose\' '
'or \'openpose\'')

if skeleton_style == 'openpose':
self.visualizer.set_dataset_meta(self.model.dataset_meta,
skeleton_style)
self.visualizer.backend = 'matplotlib'

def preprocess_single(self,
input: InputType,
index: int,
Expand Down Expand Up @@ -274,6 +294,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
Expand Down
2 changes: 2 additions & 0 deletions mmpose/apis/inferencers/pose3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
Expand Down

0 comments on commit bf3d9ee

Please sign in to comment.