Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support Openpose skeleton style in inferencer #2456

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion demo/inferencer_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def parse_args():
type=int,
default=1,
help='Link thickness for visualization.')
parser.add_argument(
'--skeleton-style',
default='mmpose',
type=str,
choices=['mmpose', 'openpose'],
help='Skeleton style selection')
parser.add_argument(
'--vis-out-dir',
type=str,
Expand All @@ -142,7 +148,6 @@ def parse_args():
'det_weights', 'det_cat_ids', 'pose3d', 'pose3d_weights'
]
init_args = {}
init_args['output_heatmaps'] = call_args.pop('draw_heatmap')
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)

Expand Down
17 changes: 9 additions & 8 deletions mmpose/apis/inferencers/base_mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def _webcam_reader() -> Generator:

return _webcam_reader()

def _visualization_window_on_close(self, event):
self._window_closing = True

def _init_pipeline(self, cfg: ConfigType) -> Callable:
"""Initialize the test pipeline.

Expand All @@ -233,6 +230,12 @@ def _init_pipeline(self, cfg: ConfigType) -> Callable:
init_default_scope(cfg.get('default_scope', 'mmpose'))
return Compose(cfg.test_dataloader.dataset.pipeline)

def update_model_visualizer_settings(self, **kwargs):
"""Update the settings of models and visualizer according to inference
arguments."""

pass

def preprocess(self,
inputs: InputsType,
batch_size: int = 1,
Expand Down Expand Up @@ -268,8 +271,7 @@ def visualize(self,
kpt_thr: float = 0.3,
vis_out_dir: str = '',
window_name: str = '',
window_close_event_handler: Optional[Callable] = None
) -> List[np.ndarray]:
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.

Args:
Expand All @@ -289,7 +291,6 @@ def visualize(self,
results w/o predictions. If left as empty, no file will
be saved. Defaults to ''.
window_name (str, optional): Title of display window.
window_close_event_handler (callable, optional):

Returns:
List[np.ndarray]: Visualization results.
Expand Down Expand Up @@ -329,10 +330,10 @@ def visualize(self,
pred,
draw_gt=False,
draw_bbox=draw_bbox,
draw_heatmap=True,
show=show,
wait_time=wait_time,
kpt_thr=kpt_thr)
kpt_thr=kpt_thr,
**kwargs)
results.append(visualization)

if vis_out_dir:
Expand Down
23 changes: 7 additions & 16 deletions mmpose/apis/inferencers/mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@ class MMPoseInferencer(BaseMMPoseInferencer):
}
forward_kwargs: set = {'rebase_keypoint_height'}
visualize_kwargs: set = {
'return_vis',
'show',
'wait_time',
'draw_bbox',
'radius',
'thickness',
'kpt_thr',
'vis_out_dir',
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap'
}
postprocess_kwargs: set = {'pred_out_dir'}

Expand All @@ -80,8 +74,7 @@ def __init__(self,
scope: str = 'mmpose',
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, List]] = None,
output_heatmaps: Optional[bool] = None) -> None:
det_cat_ids: Optional[Union[int, List]] = None) -> None:

self.visualizer = None
if pose3d is not None:
Expand All @@ -92,7 +85,7 @@ def __init__(self,
elif pose2d is not None:
self.inferencer = Pose2DInferencer(pose2d, pose2d_weights, device,
scope, det_model, det_weights,
det_cat_ids, output_heatmaps)
det_cat_ids)
else:
raise ValueError('Either 2d or 3d pose estimation algorithm '
'should be provided.')
Expand Down Expand Up @@ -177,6 +170,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.inferencer.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self.inferencer._get_webcam_inputs(inputs)
Expand Down Expand Up @@ -240,8 +235,4 @@ def visualize(self, inputs: InputsType, preds: PredType,
window_name = self.inferencer.video_info['name']

return self.inferencer.visualize(
inputs,
preds,
window_name=window_name,
window_close_event_handler=self._visualization_window_on_close,
**kwargs)
inputs, preds, window_name=window_name, **kwargs)
36 changes: 29 additions & 7 deletions mmpose/apis/inferencers/pose2d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class Pose2DInferencer(BaseMMPoseInferencer):
model. Defaults to None.
det_cat_ids (int or list[int], optional): Category id for
detection model. Defaults to None.
output_heatmaps (bool, optional): Flag to visualize predicted
heatmaps. If set to None, the default setting from the model
config will be used. Default is None.
"""

preprocess_kwargs: set = {'bbox_thr', 'nms_thr', 'bboxes'}
Expand All @@ -76,6 +73,8 @@ class Pose2DInferencer(BaseMMPoseInferencer):
'thickness',
'kpt_thr',
'vis_out_dir',
'skeleton_style',
'draw_heatmap',
}
postprocess_kwargs: set = {'pred_out_dir'}

Expand All @@ -86,15 +85,12 @@ def __init__(self,
scope: Optional[str] = 'mmpose',
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, Tuple]] = None,
output_heatmaps: Optional[bool] = None) -> None:
det_cat_ids: Optional[Union[int, Tuple]] = None) -> None:

init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)
self.model = revert_sync_batchnorm(self.model)
if output_heatmaps is not None:
self.model.test_cfg['output_heatmaps'] = output_heatmaps

# assign dataset metainfo to self.visualizer
self.visualizer.set_dataset_meta(self.model.dataset_meta)
Expand Down Expand Up @@ -134,6 +130,30 @@ def __init__(self,

self._video_input = False

def update_model_visualizer_settings(self,
draw_heatmap: bool = False,
skeleton_style: str = 'mmpose',
**kwargs) -> None:
"""Update the settings of models and visualizer according to inference
arguments.

Args:
draw_heatmaps (bool, optional): Flag to visualize predicted
heatmaps. If not provided, it defaults to False.
skeleton_style (str, optional): Skeleton style selection. Valid
options are 'mmpose' and 'openpose'. Defaults to 'mmpose'.
"""
self.model.test_cfg['output_heatmaps'] = draw_heatmap

if skeleton_style not in ['mmpose', 'openpose']:
raise ValueError('`skeleton_style` must be either \'mmpose\' '
'or \'openpose\'')

if skeleton_style == 'openpose':
self.visualizer.set_dataset_meta(self.model.dataset_meta,
skeleton_style)
self.visualizer.backend = 'matplotlib'

def preprocess_single(self,
input: InputType,
index: int,
Expand Down Expand Up @@ -274,6 +294,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
Expand Down
2 changes: 2 additions & 0 deletions mmpose/apis/inferencers/pose3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def __call__(
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
Expand Down