Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support hand3d inferencer #2729

Merged
merged 5 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Collections:
Models:
- Config: configs/hand_3d_keypoint/internet/interhand3d/internet_res50_4xb16-20e_interhand3d-256x256.py
In Collection: InterNet
Alias: hand3d
Metadata:
Architecture: &id001
- InterNet
Expand Down
14 changes: 14 additions & 0 deletions demo/docs/en/3d_hand_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,17 @@ python demo/hand3d_internet_demo.py \
--save-predictions \
--output-root vis_results
```

### 3D Hand Pose Estimation with Inferencer

The Inferencer provides a convenient interface for inference, allowing customization using model aliases instead of configuration files and checkpoint paths. It supports various input formats, including image paths, video paths, image folder paths, and webcams. Below is an example command:

```shell
python demo/inferencer_demo.py tests/data/interhand2.6m/image29590.jpg --pose3d hand3d --vis-out-dir vis_results/hand3d
```

This command infers the image and saves the visualization results in the `vis_results/hand3d` directory.

<img src="https://github.com/open-mmlab/mmpose/assets/26127467/29218285-aff6-455f-9763-39e8539eae61" alt="Image 1" height="300"/>

In addition, the Inferencer supports saving predicted poses. For more information, please refer to the [inferencer document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#inferencer-a-unified-inference-interface).
5 changes: 4 additions & 1 deletion demo/hand3d_internet_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def main():
instance_info=pred_instances_list),
f,
indent='\t')
print(f'predictions have been saved at {args.pred_save_path}')
print_log(
f'predictions have been saved at {args.pred_save_path}',
logger='current',
level=logging.INFO)

if output_file is not None:
input_type = input_type.replace('webcam', 'video')
Expand Down
7 changes: 4 additions & 3 deletions docs/en/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,10 @@ The MMPose library has predefined aliases for several frequently used models. Th

The following table lists the available 3D model aliases and their corresponding configuration names:

| Alias | Configuration Name | Task | 3D Pose Estimator | 2D Pose Estimator | Detector |
| ------- | --------------------------------- | ------------------------ | ----------------- | ----------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | Human 3D pose estimation | MotionBert | RTMPose-m | RTMDet-m |
| Alias | Configuration Name | Task | 3D Pose Estimator | 2D Pose Estimator | Detector |
| ------- | -------------------------------------------- | ------------------------ | ----------------- | ----------------- | ----------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | Human 3D pose estimation | MotionBert | RTMPose-m | RTMDet-m |
| hand3d | internet_res50_4xb16-20e_interhand3d-256x256 | Hand 3D pose estimation | InterNet | - | whole image |

In addition, users can utilize the CLI tool to display all available aliases with the following command:

Expand Down
7 changes: 4 additions & 3 deletions docs/zh_cn/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,10 @@ MMPose 为常用模型提供了一组预定义的别名。在初始化 [MMPoseIn

下表列出了可用的 3D 姿态估计模型别名及其对应的配置文件:

| 别名 | 配置文件名称 | 对应任务 | 3D 姿态估计模型 | 2D 姿态估计模型 | 检测模型 |
| ------- | --------------------------------- | --------------- | --------------- | --------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | 3D 人体姿态估计 | MotionBert | RTMPose-m | RTMDet-m |
| 别名 | 配置文件名称 | 对应任务 | 3D 姿态估计模型 | 2D 姿态估计模型 | 检测模型 |
| ------- | -------------------------------------------- | ----------------- | --------------- | --------------- | -------- |
| human3d | vid_pl_motionbert_8xb32-120e_h36m | 3D 人体姿态估计 | MotionBert | RTMPose-m | RTMDet-m |
| hand3d | internet_res50_4xb16-20e_interhand3d-256x256 | 3D 手部关键点检测 | InterNet | - | 全图 |

此外,用户可以使用命令行界面工具显示所有可用的别名,使用以下命令:

Expand Down
3 changes: 2 additions & 1 deletion mmpose/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hand3d_inferencer import Hand3DInferencer
from .mmpose_inferencer import MMPoseInferencer
from .pose2d_inferencer import Pose2DInferencer
from .pose3d_inferencer import Pose3DInferencer
from .utils import get_model_aliases

__all__ = [
'Pose2DInferencer', 'MMPoseInferencer', 'get_model_aliases',
'Pose3DInferencer'
'Pose3DInferencer', 'Hand3DInferencer'
]
217 changes: 183 additions & 34 deletions mmpose/apis/inferencers/base_mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from collections import defaultdict
from typing import (Callable, Dict, Generator, Iterable, List, Optional,
Sequence, Union)
Sequence, Tuple, Union)

import cv2
import mmcv
Expand All @@ -15,15 +15,23 @@
from mmengine.dataset import Compose
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from mmengine.infer.infer import BaseInferencer
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.logging import print_log
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from mmengine.structures import InstanceData
from mmengine.utils import mkdir_or_exist

from mmpose.apis.inference import dataset_meta_from_config
from mmpose.registry import DATASETS
from mmpose.structures import PoseDataSample, split_instances
from .utils import default_det_models

try:
from mmdet.apis.det_inferencer import DetInferencer
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand All @@ -45,6 +53,44 @@ class BaseMMPoseInferencer(BaseInferencer):
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def _init_detector(
self,
det_model: Optional[Union[ModelType, str]] = None,
det_weights: Optional[str] = None,
det_cat_ids: Optional[Union[int, Tuple]] = None,
device: Optional[str] = None,
):
object_type = DATASETS.get(self.cfg.dataset_type).__module__.split(
'datasets.')[-1].split('.')[0].lower()

if det_model in ('whole_image', 'whole-image') or \
(det_model is None and
object_type not in default_det_models):
self.detector = None

else:
det_scope = 'mmdet'
if det_model is None:
det_info = default_det_models[object_type]
det_model, det_weights, det_cat_ids = det_info[
'model'], det_info['weights'], det_info['cat_ids']
elif os.path.exists(det_model):
det_cfg = Config.fromfile(det_model)
det_scope = det_cfg.default_scope

if has_mmdet:
self.detector = DetInferencer(
det_model, det_weights, device=device, scope=det_scope)
else:
raise RuntimeError(
'MMDetection (v3.0.0 or above) is required to build '
'inferencers for top-down pose estimation models.')

if isinstance(det_cat_ids, (tuple, list)):
self.det_cat_ids = det_cat_ids
else:
self.det_cat_ids = (det_cat_ids, )

def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
cfg: Optional[ConfigType]) -> None:
Expand Down Expand Up @@ -266,6 +312,101 @@ def preprocess(self,
# only supports inference with batch size 1
yield self.collate_fn(data_infos), [input]

def __call__(
self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
out_dir: Optional[str] = None,
**kwargs,
) -> dict:
"""Call the inferencer.

Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
out_dir (str, optional): directory to save visualization
results and predictions. Will be overoden if vis_out_dir or
pred_out_dir are given. Defaults to None
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``,
``visualize_kwargs`` and ``postprocess_kwargs``.

Returns:
dict: Inference and visualization results.
"""
if out_dir is not None:
if 'vis_out_dir' not in kwargs:
kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
if 'pred_out_dir' not in kwargs:
kwargs['pred_out_dir'] = f'{out_dir}/predictions'

(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)

self.update_model_visualizer_settings(**kwargs)

# preprocessing
if isinstance(inputs, str) and inputs.startswith('webcam'):
inputs = self._get_webcam_inputs(inputs)
batch_size = 1
if not visualize_kwargs.get('show', False):
print_log(
'The display mode is closed when using webcam '
'input. It will be turned on automatically.',
logger='current',
level=logging.WARNING)
visualize_kwargs['show'] = True
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

if 'bbox_thr' in self.forward_kwargs:
forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)

preds = []

for proc_inputs, ori_inputs in inputs:
preds = self.forward(proc_inputs, **forward_kwargs)

visualization = self.visualize(ori_inputs, preds,
**visualize_kwargs)
results = self.postprocess(
preds,
visualization,
return_datasamples=return_datasamples,
**postprocess_kwargs)
yield results

if self._video_input:
self._finalize_video_processing(
postprocess_kwargs.get('pred_out_dir', ''))

# In 3D Inferencers, some intermediate results (e.g. 2d keypoints)
# will be temporarily stored in `self._buffer`. It's essential to
# clear this information to prevent any interference with subsequent
# inferences.
if hasattr(self, '_buffer'):
self._buffer.clear()

def visualize(self,
inputs: list,
preds: List[PoseDataSample],
Expand Down Expand Up @@ -349,44 +490,52 @@ def visualize(self,
results.append(visualization)

if vis_out_dir:
out_img = mmcv.rgb2bgr(visualization)
_, file_extension = os.path.splitext(vis_out_dir)
if file_extension:
dir_name = os.path.dirname(vis_out_dir)
file_name = os.path.basename(vis_out_dir)
else:
dir_name = vis_out_dir
file_name = None
mkdir_or_exist(dir_name)

if self._video_input:

if self.video_info['writer'] is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
if file_name is None:
file_name = os.path.basename(
self.video_info['name'])
out_file = join_path(dir_name, file_name)
self.video_info['output_file'] = out_file
self.video_info['writer'] = cv2.VideoWriter(
out_file, fourcc, self.video_info['fps'],
(visualization.shape[1], visualization.shape[0]))
self.video_info['writer'].write(out_img)

else:
file_name = file_name if file_name else img_name
out_file = join_path(dir_name, file_name)
mmcv.imwrite(out_img, out_file)
print_log(
f'the output image has been saved at {out_file}',
logger='current',
level=logging.INFO)
self.save_visualization(
visualization,
vis_out_dir,
img_name=img_name,
)

if return_vis:
return results
else:
return []

def save_visualization(self, visualization, vis_out_dir, img_name=None):
out_img = mmcv.rgb2bgr(visualization)
_, file_extension = os.path.splitext(vis_out_dir)
if file_extension:
dir_name = os.path.dirname(vis_out_dir)
file_name = os.path.basename(vis_out_dir)
else:
dir_name = vis_out_dir
file_name = None
mkdir_or_exist(dir_name)

if self._video_input:

if self.video_info['writer'] is None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
if file_name is None:
file_name = os.path.basename(self.video_info['name'])
out_file = join_path(dir_name, file_name)
self.video_info['output_file'] = out_file
self.video_info['writer'] = cv2.VideoWriter(
out_file, fourcc, self.video_info['fps'],
(visualization.shape[1], visualization.shape[0]))
self.video_info['writer'].write(out_img)

else:
if file_name is None:
file_name = img_name if img_name else 'visualization.jpg'

out_file = join_path(dir_name, file_name)
mmcv.imwrite(out_img, out_file)
print_log(
f'the output image has been saved at {out_file}',
logger='current',
level=logging.INFO)

def postprocess(
self,
preds: List[PoseDataSample],
Expand Down
Loading