From 7232a79d11d5b43c82c72bfb4c98731b9e78dd3e Mon Sep 17 00:00:00 2001 From: Yifan Lareina WU Date: Wed, 19 Apr 2023 14:59:21 +0800 Subject: [PATCH 1/8] [Refactor] Add 3d body base and human3.6m dataset (#2246) --- demo/bottomup_demo.py | 1 - demo/topdown_demo_with_mmdet.py | 1 - mmpose/datasets/datasets/base/__init__.py | 3 +- .../datasets/base/base_mocap_dataset.py | 404 ++++++++++++++++++ mmpose/datasets/datasets/body3d/__init__.py | 4 + .../datasets/datasets/body3d/h36m_dataset.py | 259 +++++++++++ mmpose/visualization/__init__.py | 3 +- mmpose/visualization/fast_visualizer.py | 78 ++++ projects/rtmpose/README.md | 35 ++ projects/rtmpose/README_CN.md | 35 ++ .../S1_Directions_1.54138969_000001.jpg | Bin .../S5_SittingDown.54138969_002061.jpg | Bin .../S7_Greeting.55011271_000396.jpg | Bin .../S8_WalkDog_1.55011271_000026.jpg | Bin .../test_body_datasets/test_h36m_dataset.py | 175 ++++++++ .../test_fast_visualizer.py | 71 +++ 16 files changed, 1065 insertions(+), 4 deletions(-) create mode 100644 mmpose/datasets/datasets/base/base_mocap_dataset.py create mode 100644 mmpose/datasets/datasets/body3d/__init__.py create mode 100644 mmpose/datasets/datasets/body3d/h36m_dataset.py create mode 100644 mmpose/visualization/fast_visualizer.py rename tests/data/h36m/{ => S1/S1_Directions_1.54138969}/S1_Directions_1.54138969_000001.jpg (100%) rename tests/data/h36m/{ => S5/S5_SittingDown.54138969}/S5_SittingDown.54138969_002061.jpg (100%) rename tests/data/h36m/{ => S7/S7_Greeting.55011271}/S7_Greeting.55011271_000396.jpg (100%) rename tests/data/h36m/{ => S8/S8_WalkDog_1.55011271}/S8_WalkDog_1.55011271_000026.jpg (100%) create mode 100644 tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py create mode 100644 tests/test_visualization/test_fast_visualizer.py diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index 3d6fee7a03..c6778c637f 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -11,7 +11,6 @@ import numpy as np from mmpose.apis import inference_bottomup, init_model -from mmpose.registry import VISUALIZERS from mmpose.structures import split_instances diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index 38f4e92e4e..442c4e812c 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -13,7 +13,6 @@ from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms -from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples, split_instances from mmpose.utils import adapt_mmdet_pipeline diff --git a/mmpose/datasets/datasets/base/__init__.py b/mmpose/datasets/datasets/base/__init__.py index 23bb4efb48..810440530e 100644 --- a/mmpose/datasets/datasets/base/__init__.py +++ b/mmpose/datasets/datasets/base/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_coco_style_dataset import BaseCocoStyleDataset +from .base_mocap_dataset import BaseMocapDataset -__all__ = ['BaseCocoStyleDataset'] +__all__ = ['BaseCocoStyleDataset', 'BaseMocapDataset'] diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py new file mode 100644 index 0000000000..877fe01909 --- /dev/null +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -0,0 +1,404 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy +from itertools import filterfalse, groupby +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +from mmengine.dataset import BaseDataset, force_full_init +from mmengine.fileio import exists, get_local_path, load +from mmengine.utils import is_abs +from PIL import Image + +from mmpose.registry import DATASETS +from ..utils import parse_pose_metainfo + + +@DATASETS.register_module() +class BaseMocapDataset(BaseDataset): + """Base class for 3d body datasets. + + Args: + ann_file (str): Annotation file path. Default: ''. + seq_len (int): Number of frames in a sequence. Default: 1. + causal (bool): If set to ``True``, the rightmost input frame will be + the target frame. Otherwise, the middle input frame will be the + target frame. Default: ``True``. + subset_frac (float): The fraction to reduce dataset size. If set to 1, + the dataset size is not reduced. Default: 1. + camera_param_file (str): Cameras' parameters file. Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. + Default: ``dict(img='')``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + seq_len: int = 1, + causal: bool = True, + subset_frac: float = 1.0, + camera_param_file: Optional[str] = None, + data_mode: str = 'topdown', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + + if data_mode not in {'topdown', 'bottomup'}: + raise ValueError( + f'{self.__class__.__name__} got invalid data_mode: ' + f'{data_mode}. Should be "topdown" or "bottomup".') + self.data_mode = data_mode + + _ann_file = ann_file + if not is_abs(_ann_file): + _ann_file = osp.join(data_root, _ann_file) + assert exists(_ann_file), 'Annotation file does not exist.' + with get_local_path(_ann_file) as local_path: + self.ann_data = np.load(local_path) + + self.camera_param_file = camera_param_file + if self.camera_param_file: + if not is_abs(self.camera_param_file): + self.camera_param_file = osp.join(data_root, + self.camera_param_file) + assert exists(self.camera_param_file) + self.camera_param = load(self.camera_param_file) + + self.seq_len = seq_len + self.causal = causal + + assert 0 < subset_frac <= 1, ( + f'Unsupported `subset_frac` {subset_frac}. Supported range ' + 'is (0, 1].') + self.subset_frac = subset_frac + + self.sequence_indices = self.get_sequence_indices() + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @classmethod + def _load_metainfo(cls, metainfo: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + metainfo (dict): Raw data of pose meta information. + + Returns: + dict: Parsed meta information. + """ + + if metainfo is None: + metainfo = deepcopy(cls.METAINFO) + + if not isinstance(metainfo, dict): + raise TypeError( + f'metainfo should be a dict, but got {type(metainfo)}') + + # parse pose metainfo if it has been assigned + if metainfo: + metainfo = parse_pose_metainfo(metainfo) + return metainfo + + @force_full_init + def prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + :class:`BaseCocoStyleDataset` overrides this method from + :class:`mmengine.dataset.BaseDataset` to add the metainfo into + the ``data_info`` before it is passed to the pipeline. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + + return self.pipeline(data_info) + + def get_data_info(self, idx: int) -> dict: + """Get data info by index. + + Args: + idx (int): Index of data info. + + Returns: + dict: Data info. + """ + data_info = super().get_data_info(idx) + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + 'upper_body_ids', 'lower_body_ids', 'flip_pairs', + 'dataset_keypoint_weights', 'flip_indices', 'skeleton_links' + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + 'exists in the `data_info`.') + + data_info[key] = deepcopy(self._metainfo[key]) + + return data_info + + def load_data_list(self) -> List[dict]: + """Load data list from COCO annotation file or person detection result + file.""" + + instance_list, image_list = self._load_annotations() + + if self.data_mode == 'topdown': + data_list = self._get_topdown_data_infos(instance_list) + else: + data_list = self._get_bottomup_data_infos(instance_list, + image_list) + + return data_list + + def get_img_info(self, img_idx, img_name): + try: + with get_local_path(osp.join(self.data_prefix['img'], + img_name)) as local_path: + im = Image.open(local_path) + w, h = im.size + im.close() + except: # noqa: E722 + return None + + img = { + 'file_name': img_name, + 'height': h, + 'width': w, + 'id': img_idx, + 'img_id': img_idx, + 'img_path': osp.join(self.data_prefix['img'], img_name), + } + return img + + def get_sequence_indices(self) -> List[List[int]]: + """Build sequence indices. + + The default method creates sample indices that each sample is a single + frame (i.e. seq_len=1). Override this method in the subclass to define + how frames are sampled to form data samples. + + Outputs: + sample_indices: the frame indices of each sample. + For a sample, all frames will be treated as an input sequence, + and the ground-truth pose of the last frame will be the target. + """ + sequence_indices = [] + if self.seq_len == 1: + num_imgs = len(self.ann_data['imgname']) + sequence_indices = [[idx] for idx in range(num_imgs)] + else: + raise NotImplementedError('Multi-frame data sample unsupported!') + return sequence_indices + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + """Load data from annotations in COCO format.""" + num_keypoints = self.metainfo['num_keypoints'] + + img_names = self.ann_data['imgname'] + num_imgs = len(img_names) + + if 'S' in self.ann_data.keys(): + kpts_3d = self.ann_data['S'] + else: + kpts_3d = np.zeros((num_imgs, num_keypoints, 4), dtype=np.float32) + + if 'part' in self.ann_data.keys(): + kpts_2d = self.ann_data['part'] + else: + kpts_2d = np.zeros((num_imgs, num_keypoints, 3), dtype=np.float32) + + if 'center' in self.ann_data.keys(): + centers = self.ann_data['center'] + else: + centers = np.zeros((num_imgs, 2), dtype=np.float32) + + if 'scale' in self.ann_data.keys(): + scales = self.ann_data['scale'].astype(np.float32) + else: + scales = np.zeros(num_imgs, dtype=np.float32) + + instance_list = [] + image_list = [] + + for idx, frame_ids in enumerate(self.sequence_indices): + assert len(frame_ids) == self.seq_len + + _img_names = img_names[frame_ids] + + _keypoints = kpts_2d[frame_ids].astype(np.float32) + keypoints = _keypoints[..., :2] + keypoints_visible = _keypoints[..., 2] + + _keypoints_3d = kpts_3d[frame_ids].astype(np.float32) + keypoints_3d = _keypoints_3d[..., :3] + keypoints_3d_visible = _keypoints_3d[..., 3] + + target_idx = -1 if self.causal else int(self.seq_len) // 2 + + instance_info = { + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'keypoints_3d': keypoints_3d, + 'keypoints_3d_visible': keypoints_3d_visible, + 'scale': scales[idx], + 'center': centers[idx].astype(np.float32).reshape(1, -1), + 'id': idx, + 'category_id': 1, + 'iscrowd': 0, + 'img_paths': list(_img_names), + 'img_ids': frame_ids, + 'target': keypoints_3d[target_idx], + 'target_visible': keypoints_3d_visible[target_idx], + 'target_img_id': frame_ids[target_idx], + } + + if self.camera_param_file: + _cam_param = self.get_camera_param(_img_names[0]) + instance_info['camera_param'] = _cam_param + + instance_list.append(instance_info) + + for idx, imgname in enumerate(img_names): + img_info = self.get_img_info(idx, imgname) + image_list.append(img_info) + + return instance_list, image_list + + def get_camera_param(self, imgname): + """Get camera parameters of a frame by its image name. + + Override this method to specify how to get camera parameters. + """ + raise NotImplementedError + + @staticmethod + def _is_valid_instance(data_info: Dict) -> bool: + """Check a data info is an instance with valid bbox and keypoint + annotations.""" + # crowd annotation + if 'iscrowd' in data_info and data_info['iscrowd']: + return False + # invalid keypoints + if 'num_keypoints' in data_info and data_info['num_keypoints'] == 0: + return False + # invalid keypoints + if 'keypoints' in data_info: + if np.max(data_info['keypoints']) <= 0: + return False + return True + + def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]: + """Organize the data list in top-down mode.""" + # sanitize data samples + data_list_tp = list(filter(self._is_valid_instance, instance_list)) + + return data_list_tp + + def _get_bottomup_data_infos(self, instance_list: List[Dict], + image_list: List[Dict]) -> List[Dict]: + """Organize the data list in bottom-up mode.""" + + # bottom-up data list + data_list_bu = [] + + used_img_ids = set() + + # group instances by img_id + for img_ids, data_infos in groupby(instance_list, + lambda x: x['img_ids']): + for img_id in img_ids: + used_img_ids.add(img_id) + data_infos = list(data_infos) + + # image data + img_paths = data_infos[0]['img_paths'] + data_info_bu = { + 'img_ids': img_ids, + 'img_paths': img_paths, + } + + for key in data_infos[0].keys(): + if key not in data_info_bu: + seq = [d[key] for d in data_infos] + if isinstance(seq[0], np.ndarray): + seq = np.concatenate(seq, axis=0) + data_info_bu[key] = seq + + # The segmentation annotation of invalid objects will be used + # to generate valid region mask in the pipeline. + invalid_segs = [] + for data_info_invalid in filterfalse(self._is_valid_instance, + data_infos): + if 'segmentation' in data_info_invalid: + invalid_segs.append(data_info_invalid['segmentation']) + data_info_bu['invalid_segs'] = invalid_segs + + data_list_bu.append(data_info_bu) + + # add images without instance for evaluation + if self.test_mode: + print(image_list) + for img_info in image_list: + if img_info['img_id'] not in used_img_ids: + data_info_bu = { + 'img_ids': [img_info['img_id']], + 'img_path': [img_info['img_path']], + 'id': list(), + } + data_list_bu.append(data_info_bu) + + return data_list_bu diff --git a/mmpose/datasets/datasets/body3d/__init__.py b/mmpose/datasets/datasets/body3d/__init__.py new file mode 100644 index 0000000000..d5afeca578 --- /dev/null +++ b/mmpose/datasets/datasets/body3d/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .h36m_dataset import Human36mDataset + +__all__ = ['Human36mDataset'] diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py new file mode 100644 index 0000000000..60094aa254 --- /dev/null +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -0,0 +1,259 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import defaultdict +from typing import Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +from mmengine.fileio import exists, get_local_path +from mmengine.utils import is_abs + +from mmpose.datasets.datasets import BaseMocapDataset +from mmpose.registry import DATASETS + + +@DATASETS.register_module() +class Human36mDataset(BaseMocapDataset): + """Human3.6M dataset for 3D human pose estimation. + + "Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human + Sensing in Natural Environments", TPAMI`2014. + More details can be found in the `paper + `__. + + Human3.6M keypoint indexes:: + + 0: 'root (pelvis)', + 1: 'right_hip', + 2: 'right_knee', + 3: 'right_foot', + 4: 'left_hip', + 5: 'left_knee', + 6: 'left_foot', + 7: 'spine', + 8: 'thorax', + 9: 'neck_base', + 10: 'head', + 11: 'left_shoulder', + 12: 'left_elbow', + 13: 'left_wrist', + 14: 'right_shoulder', + 15: 'right_elbow', + 16: 'right_wrist' + + Args: + ann_file (str): Annotation file path. Default: ''. + seq_len (int): Number of frames in a sequence. Default: 1. + seq_step (int): The interval for extracting frames from the video. + Default: 1. + pad_video_seq (bool): Whether to pad the video so that poses will be + predicted for every frame in the video. Default: ``False``. + causal (bool): If set to ``True``, the rightmost input frame will be + the target frame. Otherwise, the middle input frame will be the + target frame. Default: ``True``. + subset_frac (float): The fraction to reduce dataset size. If set to 1, + the dataset size is not reduced. Default: 1. + keypoint_2d_src (str): Specifies 2D keypoint information options, which + should be one of the following options: + + - ``'gt'``: load from the annotation file + - ``'detection'``: load from a detection + result file of 2D keypoint + - 'pipeline': the information will be generated by the pipeline + + Default: ``'gt'``. + keypoint_2d_det_file (str, optional): The 2D keypoint detection file. + If set, 2d keypoint loaded from this file will be used instead of + ground-truth keypoints. This setting is only when + ``keypoint_2d_src`` is ``'detection'``. Default: ``None``. + camera_param_file (str): Cameras' parameters file. Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. + Default: ``dict(img='')``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/h36m.py') + SUPPORTED_keypoint_2d_src = {'gt', 'detection', 'pipeline'} + + def __init__(self, + ann_file: str = '', + seq_len: int = 1, + seq_step: int = 1, + pad_video_seq: bool = False, + causal: bool = True, + subset_frac: float = 1.0, + keypoint_2d_src: str = 'gt', + keypoint_2d_det_file: Optional[str] = None, + camera_param_file: Optional[str] = None, + data_mode: str = 'topdown', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + # check keypoint_2d_src + self.keypoint_2d_src = keypoint_2d_src + if self.keypoint_2d_src not in self.SUPPORTED_keypoint_2d_src: + raise ValueError( + f'Unsupported `keypoint_2d_src` "{self.keypoint_2d_src}". ' + f'Supported options are {self.SUPPORTED_keypoint_2d_src}') + + if keypoint_2d_det_file: + if not is_abs(keypoint_2d_det_file): + self.keypoint_2d_det_file = osp.join(data_root, + keypoint_2d_det_file) + else: + self.keypoint_2d_det_file = keypoint_2d_det_file + + self.seq_step = seq_step + self.pad_video_seq = pad_video_seq + + super().__init__( + ann_file=ann_file, + seq_len=seq_len, + causal=causal, + subset_frac=subset_frac, + camera_param_file=camera_param_file, + data_mode=data_mode, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + def get_sequence_indices(self) -> List[List[int]]: + """Split original videos into sequences and build frame indices. + + This method overrides the default one in the base class. + """ + imgnames = self.ann_data['imgname'] + video_frames = defaultdict(list) + for idx, imgname in enumerate(imgnames): + subj, action, camera = self._parse_h36m_imgname(imgname) + video_frames[(subj, action, camera)].append(idx) + + # build sample indices + sequence_indices = [] + _len = (self.seq_len - 1) * self.seq_step + 1 + _step = self.seq_step + for _, _indices in sorted(video_frames.items()): + n_frame = len(_indices) + + if self.pad_video_seq: + # Pad the sequence so that every frame in the sequence will be + # predicted. + if self.causal: + frames_left = self.seq_len - 1 + frames_right = 0 + else: + frames_left = (self.seq_len - 1) // 2 + frames_right = frames_left + for i in range(n_frame): + pad_left = max(0, frames_left - i // _step) + pad_right = max(0, + frames_right - (n_frame - 1 - i) // _step) + start = max(i % _step, i - frames_left * _step) + end = min(n_frame - (n_frame - 1 - i) % _step, + i + frames_right * _step + 1) + sequence_indices.append([_indices[0]] * pad_left + + _indices[start:end:_step] + + [_indices[-1]] * pad_right) + else: + seqs_from_video = [ + _indices[i:(i + _len):_step] + for i in range(0, n_frame - _len + 1) + ] + sequence_indices.extend(seqs_from_video) + + # reduce dataset size if needed + subset_size = int(len(sequence_indices) * self.subset_frac) + start = np.random.randint(0, len(sequence_indices) - subset_size + 1) + end = start + subset_size + + return sequence_indices[start:end] + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + instance_list, image_list = super()._load_annotations() + + h36m_data = self.ann_data + kpts_3d = h36m_data['S'] + + if self.keypoint_2d_src == 'detection': + assert exists(self.keypoint_2d_det_file) + kpts_2d = self._load_keypoint_2d_detection( + self.keypoint_2d_det_file) + assert kpts_2d.shape[0] == kpts_3d.shape[0] + assert kpts_2d.shape[2] == 3 + + for idx, frame_ids in enumerate(self.sequence_indices): + kpt_2d = kpts_2d[frame_ids].astype(np.float32) + keypoints = kpt_2d[..., :2] + keypoints_visible = kpt_2d[..., 2] + instance_list[idx].update({ + 'keypoints': + keypoints, + 'keypoints_visible': + keypoints_visible + }) + + return instance_list, image_list + + @staticmethod + def _parse_h36m_imgname(imgname) -> Tuple[str, str, str]: + """Parse imgname to get information of subject, action and camera. + + A typical h36m image filename is like: + S1_Directions_1.54138969_000001.jpg + """ + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return subj, action, camera + + def get_camera_param(self, imgname) -> dict: + """Get camera parameters of a frame by its image name.""" + assert hasattr(self, 'camera_param') + subj, _, camera = self._parse_h36m_imgname(imgname) + return self.camera_param[(subj, camera)] + + def _load_keypoint_2d_detection(self, det_file): + """"Load 2D joint detection results from file.""" + with get_local_path(det_file) as local_path: + kpts_2d = np.load(local_path).astype(np.float32) + + return kpts_2d diff --git a/mmpose/visualization/__init__.py b/mmpose/visualization/__init__.py index 357d40a707..73fbd645a9 100644 --- a/mmpose/visualization/__init__.py +++ b/mmpose/visualization/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .fast_visualizer import FastVisualizer from .local_visualizer import PoseLocalVisualizer -__all__ = ['PoseLocalVisualizer'] +__all__ = ['PoseLocalVisualizer', 'FastVisualizer'] diff --git a/mmpose/visualization/fast_visualizer.py b/mmpose/visualization/fast_visualizer.py new file mode 100644 index 0000000000..fa0cb38527 --- /dev/null +++ b/mmpose/visualization/fast_visualizer.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 + + +class FastVisualizer: + """MMPose Fast Visualizer. + + A simple yet fast visualizer for video/webcam inference. + + Args: + metainfo (dict): pose meta information + radius (int, optional)): Keypoint radius for visualization. + Defaults to 6. + line_width (int, optional): Link width for visualization. + Defaults to 3. + kpt_thr (float, optional): Threshold for keypoints' confidence score, + keypoints with score below this value will not be drawn. + Defaults to 0.3. + """ + + def __init__(self, metainfo, radius=6, line_width=3, kpt_thr=0.3): + self.radius = radius + self.line_width = line_width + self.kpt_thr = kpt_thr + + self.keypoint_id2name = metainfo['keypoint_id2name'] + self.keypoint_name2id = metainfo['keypoint_name2id'] + self.keypoint_colors = metainfo['keypoint_colors'] + self.skeleton_links = metainfo['skeleton_links'] + self.skeleton_link_colors = metainfo['skeleton_link_colors'] + + def draw_pose(self, img, instances): + """Draw pose estimations on the given image. + + This method draws keypoints and skeleton links on the input image + using the provided instances. + + Args: + img (numpy.ndarray): The input image on which to + draw the pose estimations. + instances (object): An object containing detected instances' + information, including keypoints and keypoint_scores. + + Returns: + None: The input image will be modified in place. + """ + + if instances is None: + print('no instance detected') + return + + keypoints = instances.keypoints + scores = instances.keypoint_scores + + for kpts, score in zip(keypoints, scores): + for sk_id, sk in enumerate(self.skeleton_links): + if score[sk[0]] < self.kpt_thr or score[sk[1]] < self.kpt_thr: + # skip the link that should not be drawn + continue + + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + + color = self.skeleton_link_colors[sk_id].tolist() + cv2.line(img, pos1, pos2, color, thickness=self.line_width) + + for kid, kpt in enumerate(kpts): + if score[kid] < self.kpt_thr: + # skip the point that should not be drawn + continue + + x_coord, y_coord = int(kpt[0]), int(kpt[1]) + + color = self.keypoint_colors[kid].tolist() + cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, + color, -1) + cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, + (255, 255, 255)) diff --git a/projects/rtmpose/README.md b/projects/rtmpose/README.md index 4d6f4e6d94..b070f24d1e 100644 --- a/projects/rtmpose/README.md +++ b/projects/rtmpose/README.md @@ -593,6 +593,41 @@ set-ExecutionPolicy RemoteSigned example\cpp\build\Release ``` +### MMPose demo scripts + +MMPose provides demo scripts to conduct [inference with existing models](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html). + +**Note:** + +- Inferencing with Pytorch can not reach the maximum speed of RTMPose, just for verification. + +```shell +# go to the mmpose folder +cd ${PATH_TO_MMPOSE} + +# inference with rtmdet +python demo/topdown_demo_with_mmdet.py \ + projects/rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py \ + {PATH_TO_CHECKPOINT}/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth \ + projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py \ + {PATH_TO_CHECKPOINT}/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth \ + --input {YOUR_TEST_IMG_or_VIDEO} \ + --show + +# inference with webcam +python demo/topdown_demo_with_mmdet.py \ + projects/rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py \ + {PATH_TO_CHECKPOINT}/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth \ + projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py \ + {PATH_TO_CHECKPOINT}/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth \ + --input webcam \ + --show +``` + +Result is as follows: + +![topdown_inference_with_rtmdet](https://user-images.githubusercontent.com/13503330/220005020-06bdf37f-6817-4681-a2c8-9dd55e4fbf1e.png) + ## 👨‍🏫 How to Train [🔝](#-table-of-contents) Please refer to [Train and Test](https://mmpose.readthedocs.io/en/latest/user_guides/train_and_test.html). diff --git a/projects/rtmpose/README_CN.md b/projects/rtmpose/README_CN.md index 7abafc25c4..01f5240fed 100644 --- a/projects/rtmpose/README_CN.md +++ b/projects/rtmpose/README_CN.md @@ -584,6 +584,41 @@ set-ExecutionPolicy RemoteSigned example\cpp\build\Release ``` +### MMPose demo 脚本 + +通过 MMPose 提供的 demo 脚本可以基于 Pytorch 快速进行[模型推理](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html)和效果验证。 + +**提示:** + +- 基于 Pytorch 推理并不能达到 RTMPose 模型的真实推理速度,只用于模型效果验证。 + +```shell +# 前往 mmpose 目录 +cd ${PATH_TO_MMPOSE} + +# RTMDet 与 RTMPose 联合推理 +python demo/topdown_demo_with_mmdet.py \ + projects/rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py \ + {PATH_TO_CHECKPOINT}/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth \ + projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py \ + {PATH_TO_CHECKPOINT}/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth \ + --input {YOUR_TEST_IMG_or_VIDEO} \ + --show + +# 摄像头推理 +python demo/topdown_demo_with_mmdet.py \ + projects/rtmpose/rtmdet/person/rtmdet_nano_320-8xb32_coco-person.py \ + {PATH_TO_CHECKPOINT}/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth \ + projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-m_8xb256-420e_coco-256x192.py \ + {PATH_TO_CHECKPOINT}/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth \ + --input webcam \ + --show +``` + +效果展示: + +![topdown_inference_with_rtmdet](https://user-images.githubusercontent.com/13503330/220005020-06bdf37f-6817-4681-a2c8-9dd55e4fbf1e.png) + ## 👨‍🏫 模型训练 [🔝](#-table-of-contents) 请参考 [训练与测试](https://mmpose.readthedocs.io/en/latest/user_guides/train_and_test.html) 进行 RTMPose 的训练。 diff --git a/tests/data/h36m/S1_Directions_1.54138969_000001.jpg b/tests/data/h36m/S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg similarity index 100% rename from tests/data/h36m/S1_Directions_1.54138969_000001.jpg rename to tests/data/h36m/S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg diff --git a/tests/data/h36m/S5_SittingDown.54138969_002061.jpg b/tests/data/h36m/S5/S5_SittingDown.54138969/S5_SittingDown.54138969_002061.jpg similarity index 100% rename from tests/data/h36m/S5_SittingDown.54138969_002061.jpg rename to tests/data/h36m/S5/S5_SittingDown.54138969/S5_SittingDown.54138969_002061.jpg diff --git a/tests/data/h36m/S7_Greeting.55011271_000396.jpg b/tests/data/h36m/S7/S7_Greeting.55011271/S7_Greeting.55011271_000396.jpg similarity index 100% rename from tests/data/h36m/S7_Greeting.55011271_000396.jpg rename to tests/data/h36m/S7/S7_Greeting.55011271/S7_Greeting.55011271_000396.jpg diff --git a/tests/data/h36m/S8_WalkDog_1.55011271_000026.jpg b/tests/data/h36m/S8/S8_WalkDog_1.55011271/S8_WalkDog_1.55011271_000026.jpg similarity index 100% rename from tests/data/h36m/S8_WalkDog_1.55011271_000026.jpg rename to tests/data/h36m/S8/S8_WalkDog_1.55011271/S8_WalkDog_1.55011271_000026.jpg diff --git a/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py new file mode 100644 index 0000000000..88944dc11f --- /dev/null +++ b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py @@ -0,0 +1,175 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmpose.datasets.datasets.body3d import Human36mDataset + + +class TestH36MDataset(TestCase): + + def build_h36m_dataset(self, **kwargs): + + cfg = dict( + ann_file='test_h36m_body3d.npz', + data_mode='topdown', + data_root='tests/data/h36m', + pipeline=[], + test_mode=False) + + cfg.update(kwargs) + return Human36mDataset(**cfg) + + def check_data_info_keys(self, + data_info: dict, + data_mode: str = 'topdown'): + if data_mode == 'topdown': + expected_keys = dict( + img_ids=list, + img_paths=list, + keypoints=np.ndarray, + keypoints_3d=np.ndarray, + scale=np.float32, + center=np.ndarray, + id=int) + elif data_mode == 'bottomup': + expected_keys = dict( + img_ids=list, + img_paths=list, + keypoints=np.ndarray, + keypoints_3d=np.ndarray, + scale=list, + center=np.ndarray, + invalid_segs=list, + id=list) + else: + raise ValueError(f'Invalid data_mode {data_mode}') + + for key, type_ in expected_keys.items(): + self.assertIn(key, data_info) + self.assertIsInstance(data_info[key], type_, key) + + def check_metainfo_keys(self, metainfo: dict): + expected_keys = dict( + dataset_name=str, + num_keypoints=int, + keypoint_id2name=dict, + keypoint_name2id=dict, + upper_body_ids=list, + lower_body_ids=list, + flip_indices=list, + flip_pairs=list, + keypoint_colors=np.ndarray, + num_skeleton_links=int, + skeleton_links=list, + skeleton_link_colors=np.ndarray, + dataset_keypoint_weights=np.ndarray) + + for key, type_ in expected_keys.items(): + self.assertIn(key, metainfo) + self.assertIsInstance(metainfo[key], type_, key) + + def test_metainfo(self): + dataset = self.build_h36m_dataset() + self.check_metainfo_keys(dataset.metainfo) + # test dataset_name + self.assertEqual(dataset.metainfo['dataset_name'], 'h36m') + + # test number of keypoints + num_keypoints = 17 + self.assertEqual(dataset.metainfo['num_keypoints'], num_keypoints) + self.assertEqual( + len(dataset.metainfo['keypoint_colors']), num_keypoints) + self.assertEqual( + len(dataset.metainfo['dataset_keypoint_weights']), num_keypoints) + + # test some extra metainfo + self.assertEqual( + len(dataset.metainfo['skeleton_links']), + len(dataset.metainfo['skeleton_link_colors'])) + + def test_topdown(self): + # test topdown training + dataset = self.build_h36m_dataset(data_mode='topdown') + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0]) + + # test topdown testing + dataset = self.build_h36m_dataset(data_mode='topdown', test_mode=True) + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0]) + + # test topdown training with camera file + dataset = self.build_h36m_dataset( + data_mode='topdown', camera_param_file='cameras.pkl') + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0]) + + # test topdown training with sequence config + dataset = self.build_h36m_dataset( + data_mode='topdown', + seq_len=27, + seq_step=1, + causal=False, + pad_video_seq=True, + camera_param_file='cameras.pkl') + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0]) + + # test topdown testing with 2d keypoint detection file and + # sequence config + dataset = self.build_h36m_dataset( + data_mode='topdown', + seq_len=27, + seq_step=1, + causal=False, + pad_video_seq=True, + test_mode=True, + keypoint_2d_src='detection', + keypoint_2d_det_file='test_h36m_2d_detection.npy') + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0]) + + def test_bottomup(self): + # test bottomup training + dataset = self.build_h36m_dataset(data_mode='bottomup') + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0], data_mode='bottomup') + + # test bottomup training + dataset = self.build_h36m_dataset( + data_mode='bottomup', + seq_len=27, + seq_step=1, + causal=False, + pad_video_seq=True) + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0], data_mode='bottomup') + + # test bottomup testing + dataset = self.build_h36m_dataset(data_mode='bottomup', test_mode=True) + self.assertEqual(len(dataset), 4) + self.check_data_info_keys(dataset[0], data_mode='bottomup') + + def test_exceptions_and_warnings(self): + + with self.assertRaisesRegex(ValueError, 'got invalid data_mode'): + _ = self.build_h36m_dataset(data_mode='invalid') + + SUPPORTED_keypoint_2d_src = {'gt', 'detection', 'pipeline'} + with self.assertRaisesRegex( + ValueError, 'Unsupported `keypoint_2d_src` "invalid". ' + f'Supported options are {SUPPORTED_keypoint_2d_src}'): + _ = self.build_h36m_dataset( + data_mode='topdown', + test_mode=False, + keypoint_2d_src='invalid') + + with self.assertRaisesRegex(AssertionError, + 'Annotation file does not exist'): + _ = self.build_h36m_dataset( + data_mode='topdown', test_mode=False, ann_file='invalid') + + with self.assertRaisesRegex(AssertionError, + 'Unsupported `subset_frac` 2.'): + _ = self.build_h36m_dataset(data_mode='topdown', subset_frac=2) diff --git a/tests/test_visualization/test_fast_visualizer.py b/tests/test_visualization/test_fast_visualizer.py new file mode 100644 index 0000000000..f4a24ca1f9 --- /dev/null +++ b/tests/test_visualization/test_fast_visualizer.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmpose.visualization import FastVisualizer + + +class TestFastVisualizer(TestCase): + + def setUp(self): + self.metainfo = { + 'keypoint_id2name': { + 0: 'nose', + 1: 'left_eye', + 2: 'right_eye' + }, + 'keypoint_name2id': { + 'nose': 0, + 'left_eye': 1, + 'right_eye': 2 + }, + 'keypoint_colors': np.array([[255, 0, 0], [0, 255, 0], [0, 0, + 255]]), + 'skeleton_links': [(0, 1), (1, 2)], + 'skeleton_link_colors': np.array([[255, 255, 0], [255, 0, 255]]) + } + self.visualizer = FastVisualizer(self.metainfo) + + def test_init(self): + self.assertEqual(self.visualizer.radius, 6) + self.assertEqual(self.visualizer.line_width, 3) + self.assertEqual(self.visualizer.kpt_thr, 0.3) + self.assertEqual(self.visualizer.keypoint_id2name, + self.metainfo['keypoint_id2name']) + self.assertEqual(self.visualizer.keypoint_name2id, + self.metainfo['keypoint_name2id']) + np.testing.assert_array_equal(self.visualizer.keypoint_colors, + self.metainfo['keypoint_colors']) + self.assertEqual(self.visualizer.skeleton_links, + self.metainfo['skeleton_links']) + np.testing.assert_array_equal(self.visualizer.skeleton_link_colors, + self.metainfo['skeleton_link_colors']) + + def test_draw_pose(self): + img = np.zeros((480, 640, 3), dtype=np.uint8) + instances = type('Instances', (object, ), {})() + instances.keypoints = np.array([[[100, 100], [200, 200], [300, 300]]], + dtype=np.float32) + instances.keypoint_scores = np.array([[0.5, 0.5, 0.5]], + dtype=np.float32) + + self.visualizer.draw_pose(img, instances) + + # Check if keypoints are drawn + self.assertNotEqual(img[100, 100].tolist(), [0, 0, 0]) + self.assertNotEqual(img[200, 200].tolist(), [0, 0, 0]) + self.assertNotEqual(img[300, 300].tolist(), [0, 0, 0]) + + # Check if skeleton links are drawn + self.assertNotEqual(img[150, 150].tolist(), [0, 0, 0]) + self.assertNotEqual(img[250, 250].tolist(), [0, 0, 0]) + + def test_draw_pose_with_none_instances(self): + img = np.zeros((480, 640, 3), dtype=np.uint8) + instances = None + + self.visualizer.draw_pose(img, instances) + + # Check if the image is still empty (black) + self.assertEqual(np.count_nonzero(img), 0) From 98adecfc7b8246b4f7b288b9cc4ba6c278abb754 Mon Sep 17 00:00:00 2001 From: Yifan Lareina WU Date: Tue, 25 Apr 2023 15:09:55 +0800 Subject: [PATCH 2/8] [Refactor] Add codec for 3d body (#2251) --- mmpose/codecs/__init__.py | 4 +- mmpose/codecs/image_pose_lifting.py | 203 ++++++++++++++++++ mmpose/codecs/regression_label.py | 2 +- mmpose/codecs/video_pose_lifting.py | 201 +++++++++++++++++ mmpose/datasets/transforms/__init__.py | 3 +- mmpose/datasets/transforms/formatting.py | 10 +- .../datasets/transforms/pose3d_transforms.py | 102 +++++++++ mmpose/structures/keypoint/__init__.py | 4 +- mmpose/structures/keypoint/transforms.py | 57 +++++ tests/test_codecs/test_image_pose_lifting.py | 154 +++++++++++++ tests/test_codecs/test_video_pose_lifting.py | 160 ++++++++++++++ .../test_transforms/test_pose3d_transforms.py | 150 +++++++++++++ 12 files changed, 1042 insertions(+), 8 deletions(-) create mode 100644 mmpose/codecs/image_pose_lifting.py create mode 100644 mmpose/codecs/video_pose_lifting.py create mode 100644 mmpose/datasets/transforms/pose3d_transforms.py create mode 100644 tests/test_codecs/test_image_pose_lifting.py create mode 100644 tests/test_codecs/test_video_pose_lifting.py create mode 100644 tests/test_datasets/test_transforms/test_pose3d_transforms.py diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index a88ebac701..cdbd8feb0c 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .associative_embedding import AssociativeEmbedding from .decoupled_heatmap import DecoupledHeatmap +from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap from .msra_heatmap import MSRAHeatmap @@ -8,9 +9,10 @@ from .simcc_label import SimCCLabel from .spr import SPR from .udp_heatmap import UDPHeatmap +from .video_pose_lifting import VideoPoseLifting __all__ = [ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', - 'DecoupledHeatmap' + 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting' ] diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py new file mode 100644 index 0000000000..93530cf15d --- /dev/null +++ b/mmpose/codecs/image_pose_lifting.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class ImagePoseLifting(BaseKeypointCodec): + r"""Generate keypoint coordinates for pose lifter. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - target dimension: C + + Args: + num_keypoints (int): The number of keypoints in the dataset. + root_index (int): Root keypoint index in the pose. + remove_root (bool): If true, remove the root keypoint from the pose. + Default: ``False``. + save_index (bool): If true, store the root position separated from the + original pose. Default: ``False``. + keypoints_mean (np.ndarray, optional): Mean values of keypoints + coordinates in shape (K, D). + keypoints_std (np.ndarray, optional): Std values of keypoints + coordinates in shape (K, D). + target_mean (np.ndarray, optional): Mean values of target coordinates + in shape (K, C). + target_std (np.ndarray, optional): Std values of target coordinates + in shape (K, C). + """ + + auxiliary_encode_keys = {'target', 'target_visible'} + + def __init__(self, + num_keypoints: int, + root_index: int, + remove_root: bool = False, + save_index: bool = False, + keypoints_mean: Optional[np.ndarray] = None, + keypoints_std: Optional[np.ndarray] = None, + target_mean: Optional[np.ndarray] = None, + target_std: Optional[np.ndarray] = None): + super().__init__() + + self.num_keypoints = num_keypoints + self.root_index = root_index + self.remove_root = remove_root + self.save_index = save_index + if keypoints_mean is not None and keypoints_std is not None: + assert keypoints_mean.shape == keypoints_std.shape + if target_mean is not None and target_std is not None: + assert target_mean.shape == target_std.shape + self.keypoints_mean = keypoints_mean + self.keypoints_std = keypoints_std + self.target_mean = target_mean + self.target_std = target_std + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + target: Optional[np.ndarray] = None, + target_visible: Optional[np.ndarray] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). + keypoints_visible (np.ndarray, optional): Keypoint visibilities in + shape (N, K). + target (np.ndarray, optional): Target coordinate in shape (K, C). + target_visible (np.ndarray, optional): Target coordinate in shape + (K, ). + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape (K * D, N) where D is 2 for 2d coordinates. + - target_label: The processed target coordinate in shape (K, C) + or (K-1, C). + - target_weights (np.ndarray): The target weights in shape + (K, ) or (K-1, ). + - trajectory_weights (np.ndarray): The trajectory weights in + shape (K, ). + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). + + In addition, there are some optional items it may contain: + + - target_root_removed (bool): Indicate whether the root of + target is removed. Added if ``self.remove_root`` is ``True``. + - target_root_index (int): An integer indicating the index of + root. Added if ``self.remove_root`` and ``self.save_index`` + are ``True``. + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if target is None: + target = keypoints[0] + + # set initial value for `target_weights` and `trajectory_weights` + if target_visible is None: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + target_weights = target_visible + trajectory_weights = (1 / target[:, 2]) + else: + valid = target_visible > 0.5 + target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = target_weights + + encoded = dict() + + # Zero-center the target pose around a given root keypoint + assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ + f'Got invalid joint shape {target.shape}' + + root = target[..., self.root_index, :] + target_label = target - root + + if self.remove_root: + target_label = np.delete(target_label, self.root_index, axis=-2) + assert target_weights.ndim in {1, 2} + axis_to_remove = -2 if target_weights.ndim == 2 else -1 + target_weights = np.delete( + target_weights, self.root_index, axis=axis_to_remove) + # Add a flag to avoid latter transforms that rely on the root + # joint or the original joint index + encoded['target_root_removed'] = True + + # Save the root index which is necessary to restore the global pose + if self.save_index: + encoded['target_root_index'] = self.root_index + + # Normalize the 2D keypoint coordinate with mean and std + keypoint_labels = keypoints.copy() + if self.keypoints_mean is not None and self.keypoints_std is not None: + keypoints_shape = keypoints.shape + assert self.keypoints_mean.shape == keypoints_shape[1:] + + keypoint_labels = (keypoint_labels - + self.keypoints_mean) / self.keypoints_std + if self.target_mean is not None and self.target_std is not None: + target_shape = target_label.shape + assert self.target_mean.shape == target_shape + + target_label = (target_label - self.target_mean) / self.target_std + + # Generate reshaped keypoint coordinates + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + + encoded['keypoint_labels'] = keypoint_labels + encoded['target_label'] = target_label + encoded['target_weights'] = target_weights + encoded['trajectory_weights'] = trajectory_weights + encoded['target_root'] = root + + return encoded + + def decode(self, + encoded: np.ndarray, + restore_global_position: bool = False, + target_root: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (N, K, C). + restore_global_position (bool): Whether to restore global position. + Default: ``False``. + target_root (np.ndarray, optional): The target root coordinate. + Default: ``None``. + + Returns: + keypoints (np.ndarray): Decoded coordinates in shape (N, K, C). + scores (np.ndarray): The keypoint scores in shape (N, K). + """ + keypoints = encoded.copy() + + if self.target_mean is not None and self.target_std is not None: + assert self.target_mean.shape == keypoints.shape[1:] + keypoints = keypoints * self.target_std + self.target_mean + + if restore_global_position: + assert target_root is not None + keypoints = keypoints + np.expand_dims(target_root, axis=0) + if self.remove_root: + keypoints = np.insert( + keypoints, self.root_index, target_root, axis=1) + scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + + return keypoints, scores diff --git a/mmpose/codecs/regression_label.py b/mmpose/codecs/regression_label.py index 9ae385d2d9..f79195beb4 100644 --- a/mmpose/codecs/regression_label.py +++ b/mmpose/codecs/regression_label.py @@ -78,7 +78,7 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Returns: tuple: - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) - - socres (np.ndarray): The keypoint scores in shape (N, K). + - scores (np.ndarray): The keypoint scores in shape (N, K). It usually represents the confidence of the keypoint prediction """ diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py new file mode 100644 index 0000000000..fbb6ad429c --- /dev/null +++ b/mmpose/codecs/video_pose_lifting.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from copy import deepcopy +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class VideoPoseLifting(BaseKeypointCodec): + r"""Generate keypoint coordinates for pose lifter. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - target dimension: C + + Args: + num_keypoints (int): The number of keypoints in the dataset. + zero_center: Whether to zero-center the target around root. Default: + ``True``. + root_index (int): Root keypoint index in the pose. Default: 0. + remove_root (bool): If true, remove the root keypoint from the pose. + Default: ``False``. + save_index (bool): If true, store the root position separated from the + original pose. Default: ``False``. + normalize_camera (bool): Whether to normalize camera intrinsics. + Default: ``False``. + """ + + auxiliary_encode_keys = {'target', 'target_visible', 'camera_param'} + + def __init__(self, + num_keypoints: int, + zero_center: bool = True, + root_index: int = 0, + remove_root: bool = False, + save_index: bool = False, + normalize_camera: bool = False): + super().__init__() + + self.num_keypoints = num_keypoints + self.zero_center = zero_center + self.root_index = root_index + self.remove_root = remove_root + self.save_index = save_index + self.normalize_camera = normalize_camera + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + target: Optional[np.ndarray] = None, + target_visible: Optional[np.ndarray] = None, + camera_param: Optional[dict] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). + keypoints_visible (np.ndarray, optional): Keypoint visibilities in + shape (N, K). + target (np.ndarray, optional): Target coordinate in shape (K, C). + target_visible (np.ndarray, optional): Target coordinate in shape + (K, ). + camera_param (dict, optional): The camera parameter dictionary. + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape (K * D, N) where D is 2 for 2d coordinates. + - target_label: The processed target coordinate in shape (K, C) + or (K-1, C). + - target_weights (np.ndarray): The target weights in shape + (K, ) or (K-1, ). + - trajectory_weights (np.ndarray): The trajectory weights in + shape (K, ). + + In addition, there are some optional items it may contain: + + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). Exists if ``self.zero_center`` is ``True``. + - target_root_removed (bool): Indicate whether the root of + target is removed. Exists if ``self.remove_root`` is + ``True``. + - target_root_index (int): An integer indicating the index of + root. Exists if ``self.remove_root`` and ``self.save_index`` + are ``True``. + - camera_param (dict): The updated camera parameter dictionary. + Exists if ``self.normalize_camera`` is ``True``. + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if target is None: + target = keypoints[0] + + # set initial value for `target_weights` and `trajectory_weights` + if target_visible is None: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + target_weights = target_visible + trajectory_weights = (1 / target[:, 2]) + else: + valid = target_visible > 0.5 + target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = target_weights + + if camera_param is None: + camera_param = dict() + + encoded = dict() + + target_label = target.copy() + # Zero-center the target pose around a given root keypoint + if self.zero_center: + assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ + f'Got invalid joint shape {target.shape}' + + root = target[..., self.root_index, :] + target_label = target_label - root + encoded['target_root'] = root + + if self.remove_root: + target_label = np.delete( + target_label, self.root_index, axis=-2) + assert target_weights.ndim in {1, 2} + axis_to_remove = -2 if target_weights.ndim == 2 else -1 + target_weights = np.delete( + target_weights, self.root_index, axis=axis_to_remove) + # Add a flag to avoid latter transforms that rely on the root + # joint or the original joint index + encoded['target_root_removed'] = True + + # Save the root index for restoring the global pose + if self.save_index: + encoded['target_root_index'] = self.root_index + + # Normalize the 2D keypoint coordinate with image width and height + _camera_param = deepcopy(camera_param) + assert 'w' in _camera_param and 'h' in _camera_param + center = np.array([0.5 * _camera_param['w'], 0.5 * _camera_param['h']], + dtype=np.float32) + scale = np.array(0.5 * _camera_param['w'], dtype=np.float32) + + keypoint_labels = (keypoints - center) / scale + + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + + if self.normalize_camera: + assert 'f' in _camera_param and 'c' in _camera_param + _camera_param['f'] = _camera_param['f'] / scale + _camera_param['c'] = (_camera_param['c'] - center[:, None]) / scale + encoded['camera_param'] = _camera_param + + # Generate reshaped keypoint coordinates + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + + encoded['keypoint_labels'] = keypoint_labels + encoded['target_label'] = target_label + encoded['target_weights'] = target_weights + encoded['trajectory_weights'] = trajectory_weights + + return encoded + + def decode(self, + encoded: np.ndarray, + restore_global_position: bool = False, + target_root: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (1, K, C). + restore_global_position (bool): Whether to restore global position. + Default: ``False``. + target_root (np.ndarray, optional): The target root coordinate. + Default: ``None``. + + Returns: + keypoints (np.ndarray): Decoded coordinates in shape (1, K, C). + scores (np.ndarray): The keypoint scores in shape (1, K). + """ + keypoints = encoded.copy() + + if restore_global_position: + assert target_root is not None + keypoints = keypoints + np.expand_dims(target_root, axis=0) + if self.remove_root: + keypoints = np.insert( + keypoints, self.root_index, target_root, axis=1) + scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + + return keypoints, scores diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py index 61dae74b8c..7ccbf7dac2 100644 --- a/mmpose/datasets/transforms/__init__.py +++ b/mmpose/datasets/transforms/__init__.py @@ -8,6 +8,7 @@ from .converting import KeypointConverter from .formatting import PackPoseInputs from .loading import LoadImage +from .pose3d_transforms import RandomFlipAroundRoot from .topdown_transforms import TopdownAffine __all__ = [ @@ -15,5 +16,5 @@ 'RandomHalfBody', 'TopdownAffine', 'Albumentation', 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage', 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize', - 'GenerateTarget', 'KeypointConverter' + 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot' ] diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index dd9ad522f2..14be378e19 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -89,6 +89,8 @@ class PackPoseInputs(BaseTransform): 'bbox_score': 'bbox_scores', 'keypoints': 'keypoints', 'keypoints_visible': 'keypoints_visible', + 'target': 'target', + 'target_visible': 'target_visible', } # items in `label_mapping_table` will be packed into @@ -137,10 +139,12 @@ def transform(self, results: dict) -> dict: - 'data_samples' (obj:`PoseDataSample`): The annotation info of the sample. """ - # Pack image(s) + # Pack image(s) or 2d keypoints if 'img' in results: img = results['img'] - img_tensor = image_to_tensor(img) + inputs_tensor = image_to_tensor(img) + elif 'keypoints_3d' in results and 'keypoints' in results: + inputs_tensor = results['keypoints'] data_sample = PoseDataSample() @@ -202,7 +206,7 @@ def transform(self, results: dict) -> dict: data_sample.set_metainfo(img_meta) packed_results = dict() - packed_results['inputs'] = img_tensor + packed_results['inputs'] = inputs_tensor packed_results['data_samples'] = data_sample return packed_results diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py new file mode 100644 index 0000000000..4f86c247b9 --- /dev/null +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Dict + +import numpy as np +from mmcv.transforms import BaseTransform + +from mmpose.registry import TRANSFORMS +from mmpose.structures.keypoint import flip_keypoints_custom_center + + +@TRANSFORMS.register_module() +class RandomFlipAroundRoot(BaseTransform): + """Data augmentation with random horizontal joint flip around a root joint. + + Args: + keypoints_flip_cfg (dict): Configurations of the + ``flip_keypoints_custom_center`` function for ``keypoints``. Please + refer to the docstring of the ``flip_keypoints_custom_center`` + function for more details. + target_flip_cfg (dict): Configurations of the + ``flip_keypoints_custom_center`` function for ``target``. Please + refer to the docstring of the ``flip_keypoints_custom_center`` + function for more details. + flip_prob (float): Probability of flip. Default: 0.5. + flip_camera (bool): Whether to flip horizontal distortion coefficients. + Default: ``False``. + + Required keys: + keypoints + target + + Modified keys: + (keypoints, keypoints_visible, target, target_visible, camera_param) + """ + + def __init__(self, + keypoints_flip_cfg, + target_flip_cfg, + flip_prob=0.5, + flip_camera=False): + self.keypoints_flip_cfg = keypoints_flip_cfg + self.target_flip_cfg = target_flip_cfg + self.flip_prob = flip_prob + self.flip_camera = flip_camera + + def transform(self, results: Dict) -> dict: + """The transform function of :class:`ZeroCenterPose`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + keypoints = results['keypoints'] + if 'keypoints_visible' in results: + keypoints_visible = results['keypoints_visible'] + else: + keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32) + target = results['target'] + if 'target_visible' in results: + target_visible = results['target_visible'] + else: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + + if np.random.rand() <= self.flip_prob: + if 'flip_indices' not in results: + flip_indices = list(range(self.num_keypoints)) + else: + flip_indices = results['flip_indices'] + + # flip joint coordinates + keypoints, keypoints_visible = flip_keypoints_custom_center( + keypoints, keypoints_visible, flip_indices, + **self.keypoints_flip_cfg) + target, target_visible = flip_keypoints_custom_center( + target, target_visible, flip_indices, **self.target_flip_cfg) + + results['keypoints'] = keypoints + results['keypoints_visible'] = keypoints_visible + results['target'] = target + results['target_visible'] = target_visible + + # flip horizontal distortion coefficients + if self.flip_camera: + assert 'camera_param' in results, \ + 'Camera parameters are missing.' + _camera_param = deepcopy(results['camera_param']) + + assert 'c' in _camera_param + _camera_param['c'][0] *= -1 + + if 'p' in _camera_param: + _camera_param['p'][0] *= -1 + + results['camera_param'].update(_camera_param) + + return results diff --git a/mmpose/structures/keypoint/__init__.py b/mmpose/structures/keypoint/__init__.py index b8d5a24c7a..12ee96cf7c 100644 --- a/mmpose/structures/keypoint/__init__.py +++ b/mmpose/structures/keypoint/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .transforms import flip_keypoints +from .transforms import flip_keypoints, flip_keypoints_custom_center -__all__ = ['flip_keypoints'] +__all__ = ['flip_keypoints', 'flip_keypoints_custom_center'] diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py index 99adaa1306..b50da4f8fe 100644 --- a/mmpose/structures/keypoint/transforms.py +++ b/mmpose/structures/keypoint/transforms.py @@ -62,3 +62,60 @@ def flip_keypoints(keypoints: np.ndarray, keypoints = [w, h] - keypoints - 1 return keypoints, keypoints_visible + + +def flip_keypoints_custom_center(keypoints: np.ndarray, + keypoints_visible: np.ndarray, + flip_indices: List[int], + center_mode: str = 'static', + center_x: float = 0.5, + center_index: int = 0): + """Flip human joints horizontally. + + Note: + - num_keypoint: K + - dimension: D + + Args: + keypoints (np.ndarray([..., K, D])): Coordinates of keypoints. + keypoints_visible (np.ndarray([..., K])): Visibility item of keypoints. + flip_indices (list[int]): The indices to flip the keypoints. + center_mode (str): The mode to set the center location on the x-axis + to flip around. Options are: + + - static: use a static x value (see center_x also) + - root: use a root joint (see center_index also) + + Defaults: ``'static'``. + center_x (float): Set the x-axis location of the flip center. Only used + when ``center_mode`` is ``'static'``. Defaults: 0.5. + center_index (int): Set the index of the root joint, whose x location + will be used as the flip center. Only used when ``center_mode`` is + ``'root'``. Defaults: 0. + + Returns: + np.ndarray([..., K, C]): Flipped joints. + """ + + assert keypoints.ndim >= 2, f'Invalid pose shape {keypoints.shape}' + + allowed_center_mode = {'static', 'root'} + assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \ + f'{center_mode}, allowed choices are {allowed_center_mode}' + + if center_mode == 'static': + x_c = center_x + elif center_mode == 'root': + assert keypoints.shape[-2] > center_index + x_c = keypoints[..., center_index, 0] + + keypoints_flipped = keypoints.copy() + keypoints_visible_flipped = keypoints_visible.copy() + # Swap left-right parts + for left, right in enumerate(flip_indices): + keypoints_flipped[..., left, :] = keypoints[..., right, :] + keypoints_visible_flipped[..., left] = keypoints_visible[..., right] + + # Flip horizontally + keypoints_flipped[..., 0] = x_c * 2 - keypoints_flipped[..., 0] + return keypoints_flipped, keypoints_visible_flipped diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py new file mode 100644 index 0000000000..c54bf12d1e --- /dev/null +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmpose.codecs import ImagePoseLifting +from mmpose.registry import KEYPOINT_CODECS + + +class TestImagePoseLifting(TestCase): + + def setUp(self) -> None: + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(1, 17)) + target = (0.1 + 0.8 * np.random.rand(17, 3)) + target_visible = np.random.randint(2, size=(17, )) + encoded_wo_sigma = np.random.rand(1, 17, 3) + + self.keypoints_mean = np.random.rand(17, 2).astype(np.float32) + self.keypoints_std = np.random.rand(17, 2).astype(np.float32) + 1e-6 + self.target_mean = np.random.rand(17, 3).astype(np.float32) + self.target_std = np.random.rand(17, 3).astype(np.float32) + 1e-6 + + self.data = dict( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + target=target, + target_visible=target_visible, + encoded_wo_sigma=encoded_wo_sigma) + + def build_pose_lifting_label(self, **kwargs): + cfg = dict(type='ImagePoseLifting', num_keypoints=17, root_index=0) + cfg.update(kwargs) + return KEYPOINT_CODECS.build(cfg) + + def test_build(self): + codec = self.build_pose_lifting_label() + self.assertIsInstance(codec, ImagePoseLifting) + + def test_encode(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test removing root + codec = self.build_pose_lifting_label( + remove_root=True, save_index=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertTrue('target_root_removed' in encoded + and 'target_root_index' in encoded) + self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test normalization + codec = self.build_pose_lifting_label( + keypoints_mean=self.keypoints_mean, + keypoints_std=self.keypoints_std, + target_mean=self.target_mean, + target_std=self.target_std) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + + def test_decode(self): + target = self.data['target'] + encoded_wo_sigma = self.data['encoded_wo_sigma'] + + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + codec = self.build_pose_lifting_label(remove_root=True) + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 18, 3)) + self.assertEqual(scores.shape, (1, 18)) + + def test_cicular_verification(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test removing root + codec = self.build_pose_lifting_label(remove_root=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test normalization + codec = self.build_pose_lifting_label( + keypoints_mean=self.keypoints_mean, + keypoints_std=self.keypoints_std, + target_mean=self.target_mean, + target_std=self.target_std) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py new file mode 100644 index 0000000000..da404360bb --- /dev/null +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +from mmengine.fileio import load + +from mmpose.codecs import VideoPoseLifting +from mmpose.registry import KEYPOINT_CODECS + + +class TestVideoPoseLifting(TestCase): + + def get_camera_param(self, imgname, camera_param) -> dict: + """Get camera parameters of a frame by its image name.""" + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return camera_param[(subj, camera)] + + def build_pose_lifting_label(self, **kwargs): + cfg = dict(type='VideoPoseLifting', num_keypoints=17) + cfg.update(kwargs) + return KEYPOINT_CODECS.build(cfg) + + def setUp(self) -> None: + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(1, 17)) + target = (0.1 + 0.8 * np.random.rand(17, 3)) + target_visible = np.random.randint(2, size=(17, )) + encoded_wo_sigma = np.random.rand(1, 17, 3) + + camera_param = load('tests/data/h36m/cameras.pkl') + camera_param = self.get_camera_param( + 'S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg', + camera_param) + + self.data = dict( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + target=target, + target_visible=target_visible, + camera_param=camera_param, + encoded_wo_sigma=encoded_wo_sigma) + + def test_build(self): + codec = self.build_pose_lifting_label() + self.assertIsInstance(codec, VideoPoseLifting) + + def test_encode(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test not zero-centering + codec = self.build_pose_lifting_label(zero_center=False) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + + # test removing root + codec = self.build_pose_lifting_label( + remove_root=True, save_index=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertTrue('target_root_removed' in encoded + and 'target_root_index' in encoded) + self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test normalizing camera + codec = self.build_pose_lifting_label(normalize_camera=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertTrue('camera_param' in encoded) + scale = np.array(0.5 * camera_param['w'], dtype=np.float32) + self.assertTrue( + np.allclose( + camera_param['f'] / scale, + encoded['camera_param']['f'], + atol=4.)) + + def test_decode(self): + target = self.data['target'] + encoded_wo_sigma = self.data['encoded_wo_sigma'] + + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + codec = self.build_pose_lifting_label(remove_root=True) + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 18, 3)) + self.assertEqual(scores.shape, (1, 18)) + + def test_cicular_verification(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test removing root + codec = self.build_pose_lifting_label(remove_root=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py new file mode 100644 index 0000000000..16118db272 --- /dev/null +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy +from unittest import TestCase + +import numpy as np +from mmengine.fileio import load + +from mmpose.datasets.transforms import RandomFlipAroundRoot + + +def get_h36m_sample(): + + def _parse_h36m_imgname(imgname): + """Parse imgname to get information of subject, action and camera. + + A typical h36m image filename is like: + S1_Directions_1.54138969_000001.jpg + """ + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return subj, action, camera + + ann_flle = 'tests/data/h36m/test_h36m_body3d.npz' + camera_param_file = 'tests/data/h36m/cameras.pkl' + + data = np.load(ann_flle) + cameras = load(camera_param_file) + + imgnames = data['imgname'] + keypoints = data['part'].astype(np.float32) + keypoints_3d = data['S'].astype(np.float32) + centers = data['center'].astype(np.float32) + scales = data['scale'].astype(np.float32) + + idx = 0 + target_idx = 0 + + data_info = { + 'keypoints': keypoints[idx, :, :2].reshape(1, -1, 2), + 'keypoints_visible': keypoints[idx, :, 2].reshape(1, -1), + 'keypoints_3d': keypoints_3d[idx, :, :3].reshape(1, -1, 3), + 'keypoints_3d_visible': keypoints_3d[idx, :, 3].reshape(1, -1), + 'scale': scales[idx], + 'center': centers[idx].astype(np.float32).reshape(1, -1), + 'id': idx, + 'img_ids': [idx], + 'img_paths': [imgnames[idx]], + 'category_id': 1, + 'iscrowd': 0, + 'sample_idx': idx, + 'target': keypoints_3d[target_idx, :, :3], + 'target_visible': keypoints_3d[target_idx, :, 3], + 'target_img_path': osp.join('tests/data/h36m', imgnames[target_idx]), + } + + # add camera parameters + subj, _, camera = _parse_h36m_imgname(imgnames[idx]) + data_info['camera_param'] = cameras[(subj, camera)] + + # add ann_info + ann_info = {} + ann_info['num_keypoints'] = 17 + ann_info['dataset_keypoint_weights'] = np.full(17, 1.0, dtype=np.float32) + ann_info['flip_pairs'] = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], + [13, 16]] + ann_info['skeleton_links'] = [] + ann_info['upper_body_ids'] = (0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + ann_info['lower_body_ids'] = (1, 2, 3, 4, 5, 6) + ann_info['flip_indices'] = [ + 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 + ] + + data_info.update(ann_info) + + return data_info + + +class TestRandomFlipAroundRoot(TestCase): + + def setUp(self): + self.data_info = get_h36m_sample() + self.keypoints_flip_cfg = dict(center_mode='static', center_x=0.) + self.target_flip_cfg = dict(center_mode='root', center_index=0) + + def test_init(self): + _ = RandomFlipAroundRoot( + self.keypoints_flip_cfg, + self.target_flip_cfg, + flip_prob=0.5, + flip_camera=False) + + def test_transform(self): + kpts1 = self.data_info['keypoints'] + kpts_vis1 = self.data_info['keypoints_visible'] + tar1 = self.data_info['target'] + tar_vis1 = self.data_info['target_visible'] + + transform = RandomFlipAroundRoot( + self.keypoints_flip_cfg, self.target_flip_cfg, flip_prob=1) + results = deepcopy(self.data_info) + results = transform(results) + + kpts2 = results['keypoints'] + kpts_vis2 = results['keypoints_visible'] + tar2 = results['target'] + tar_vis2 = results['target_visible'] + + self.assertEqual(kpts_vis2.shape, (1, 17)) + self.assertEqual(tar_vis2.shape, (17, )) + self.assertEqual(kpts2.shape, (1, 17, 2)) + self.assertEqual(tar2.shape, (17, 3)) + + flip_indices = [ + 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 + ] + for left, right in enumerate(flip_indices): + self.assertTrue( + np.allclose(-kpts1[0][left][:1], kpts2[0][right][:1], atol=4.)) + self.assertTrue( + np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) + self.assertTrue( + np.allclose(tar1[left][1:], tar2[right][1:], atol=4.)) + + self.assertTrue( + np.allclose(kpts_vis1[0][left], kpts_vis2[0][right], atol=4.)) + self.assertTrue( + np.allclose(tar_vis1[left], tar_vis2[right], atol=4.)) + + # test camera flipping + transform = RandomFlipAroundRoot( + self.keypoints_flip_cfg, + self.target_flip_cfg, + flip_prob=1, + flip_camera=True) + results = deepcopy(self.data_info) + results = transform(results) + + camera2 = results['camera_param'] + self.assertTrue( + np.allclose( + -self.data_info['camera_param']['c'][0], + camera2['c'][0], + atol=4.)) + self.assertTrue( + np.allclose( + -self.data_info['camera_param']['p'][0], + camera2['p'][0], + atol=4.)) From 20ef71465a4f9fd225a034cc9ddd218a63e732f5 Mon Sep 17 00:00:00 2001 From: Tau Date: Tue, 25 Apr 2023 19:49:25 +0800 Subject: [PATCH 3/8] [Refactor] Add PoseLifter, TemporalRegressionHead, TrajectoryRegressionHead (#2311) --- mmpose/codecs/image_pose_lifting.py | 4 +- mmpose/codecs/video_pose_lifting.py | 4 +- mmpose/models/heads/__init__.py | 6 +- .../models/heads/regression_heads/__init__.py | 4 + .../temporal_regression_head.py | 134 +++++++ .../trajectory_regression_head.py | 134 +++++++ mmpose/models/pose_estimators/__init__.py | 3 +- mmpose/models/pose_estimators/pose_lifter.py | 332 ++++++++++++++++++ tests/test_codecs/test_image_pose_lifting.py | 11 +- tests/test_codecs/test_video_pose_lifting.py | 10 +- 10 files changed, 616 insertions(+), 26 deletions(-) create mode 100644 mmpose/models/heads/regression_heads/temporal_regression_head.py create mode 100644 mmpose/models/heads/regression_heads/trajectory_regression_head.py create mode 100644 mmpose/models/pose_estimators/pose_lifter.py diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 93530cf15d..bf5b009c5d 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -169,7 +169,6 @@ def encode(self, def decode(self, encoded: np.ndarray, - restore_global_position: bool = False, target_root: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray]: """Decode keypoint coordinates from normalized space to input image @@ -192,8 +191,7 @@ def decode(self, assert self.target_mean.shape == keypoints.shape[1:] keypoints = keypoints * self.target_std + self.target_mean - if restore_global_position: - assert target_root is not None + if target_root is not None: keypoints = keypoints + np.expand_dims(target_root, axis=0) if self.remove_root: keypoints = np.insert( diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index fbb6ad429c..ceda999f85 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -171,7 +171,6 @@ def encode(self, def decode(self, encoded: np.ndarray, - restore_global_position: bool = False, target_root: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray]: """Decode keypoint coordinates from normalized space to input image @@ -190,8 +189,7 @@ def decode(self, """ keypoints = encoded.copy() - if restore_global_position: - assert target_root is not None + if target_root is not None: keypoints = keypoints + np.expand_dims(target_root, axis=0) if self.remove_root: keypoints = np.insert( diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index 8b4d988a5f..75a626569b 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -5,10 +5,12 @@ HeatmapHead, MSPNHead, ViPNASHead) from .hybrid_heads import DEKRHead from .regression_heads import (DSNTHead, IntegralRegressionHead, - RegressionHead, RLEHead) + RegressionHead, RLEHead, TemporalRegressionHead, + TrajectoryRegressionHead) __all__ = [ 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', - 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead' + 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'CIDHead', 'RTMCCHead', + 'TemporalRegressionHead', 'TrajectoryRegressionHead' ] diff --git a/mmpose/models/heads/regression_heads/__init__.py b/mmpose/models/heads/regression_heads/__init__.py index f2a5027b1b..ce9cd5e1b0 100644 --- a/mmpose/models/heads/regression_heads/__init__.py +++ b/mmpose/models/heads/regression_heads/__init__.py @@ -3,10 +3,14 @@ from .integral_regression_head import IntegralRegressionHead from .regression_head import RegressionHead from .rle_head import RLEHead +from .temporal_regression_head import TemporalRegressionHead +from .trajectory_regression_head import TrajectoryRegressionHead __all__ = [ 'RegressionHead', 'IntegralRegressionHead', 'DSNTHead', 'RLEHead', + 'TemporalRegressionHead', + 'TrajectoryRegressionHead', ] diff --git a/mmpose/models/heads/regression_heads/temporal_regression_head.py b/mmpose/models/heads/regression_heads/temporal_regression_head.py new file mode 100644 index 0000000000..5bc019f948 --- /dev/null +++ b/mmpose/models/heads/regression_heads/temporal_regression_head.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class TemporalRegressionHead(BaseHead): + """Temporal Regression head of `VideoPose3D`_ by Dario et al (CVPR'2019). + + Args: + in_channels (int | sequence[int]): Number of input channels + num_joints (int): Number of joints + loss (Config): Config for keypoint loss. Defaults to use + :class:`SmoothL1Loss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`VideoPose3D`: https://arxiv.org/abs/1811.11742 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_joints: int, + loss: ConfigType = dict( + type='MSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_joints = num_joints + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Define fully-connected layers + self.conv = nn.Conv1d(in_channels, self.num_joints * 3, 1) + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates(and sigmas[optional]). + """ + x = feats[-1] + + x = self.conv(x) + + return x.reshape(-1, self.num_joints, 3) + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from outputs.""" + + batch_coords = self.forward(feats) # (B, K, D) + + batch_coords.unsqueeze_(dim=1) # (B, N, K, D) + + # Restore global position with target_root + target_root = batch_data_samples[0].metainfo.get('target_root', None) + if target_root is not None: + target_root = torch.stack( + [m['target_root'] for m in batch_data_samples[0].metainfo]) + + preds = self.decode(batch_coords, target_root) + + return preds + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_outputs = self.forward(inputs) + + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + keypoint_weights = torch.cat([ + d.gt_instance_labels.keypoint_weights for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_outputs, keypoint_labels, + keypoint_weights.unsqueeze(-1)) + + losses.update(loss_pose3d=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_outputs), + gt=to_numpy(keypoint_labels), + mask=to_numpy(keypoint_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32)) + + mpjpe_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(mpjpe=mpjpe_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/mmpose/models/heads/regression_heads/trajectory_regression_head.py b/mmpose/models/heads/regression_heads/trajectory_regression_head.py new file mode 100644 index 0000000000..76626fb3dc --- /dev/null +++ b/mmpose/models/heads/regression_heads/trajectory_regression_head.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from ..base_head import BaseHead + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class TrajectoryRegressionHead(BaseHead): + """Trajectory Regression head of `VideoPose3D`_ by Dario et al (CVPR'2019). + + Args: + in_channels (int | sequence[int]): Number of input channels + num_joints (int): Number of joints + loss (Config): Config for trajectory loss. Defaults to use + :class:`MPJPELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`VideoPose3D`: https://arxiv.org/abs/1811.11742 + """ + + _version = 2 + + def __init__(self, + in_channels: Union[int, Sequence[int]], + num_joints: int, + loss: ConfigType = dict( + type='MPJPELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.num_joints = num_joints + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Define fully-connected layers + self.conv = nn.Conv1d(in_channels, self.num_joints * 3, 1) + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: output coordinates(and sigmas[optional]). + """ + x = feats[-1] + + x = self.conv(x) + + return x.reshape(-1, self.num_joints, 3) + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from outputs.""" + + batch_coords = self.forward(feats) # (B, K, D) + + batch_coords.unsqueeze_(dim=1) # (B, N, K, D) + + # Restore global position with target_root + target_root = batch_data_samples[0].metainfo.get('target_root', None) + if target_root is not None: + target_root = torch.stack( + [m['target_root'] for m in batch_data_samples[0].metainfo]) + + preds = self.decode(batch_coords, target_root) + + return preds + + def loss(self, + inputs: Union[Tensor, Tuple[Tensor]], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_outputs = self.forward(inputs) + + keypoint_labels = torch.cat( + [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) + trjectory_weights = torch.cat([ + d.gt_instance_labels.trjectory_weights for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_outputs, keypoint_labels, + trjectory_weights.unsqueeze(-1)) + + losses.update(loss_traj=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_outputs), + gt=to_numpy(keypoint_labels), + mask=to_numpy(trjectory_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32)) + + mpjpe_traj = torch.tensor(avg_acc, device=keypoint_labels.device) + losses.update(mpjpe_traj=mpjpe_traj) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/mmpose/models/pose_estimators/__init__.py b/mmpose/models/pose_estimators/__init__.py index 6ead1a979e..c5287e0c2c 100644 --- a/mmpose/models/pose_estimators/__init__.py +++ b/mmpose/models/pose_estimators/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bottomup import BottomupPoseEstimator +from .pose_lifter import PoseLifter from .topdown import TopdownPoseEstimator -__all__ = ['TopdownPoseEstimator', 'BottomupPoseEstimator'] +__all__ = ['TopdownPoseEstimator', 'BottomupPoseEstimator', 'PoseLifter'] diff --git a/mmpose/models/pose_estimators/pose_lifter.py b/mmpose/models/pose_estimators/pose_lifter.py new file mode 100644 index 0000000000..5b0abf3690 --- /dev/null +++ b/mmpose/models/pose_estimators/pose_lifter.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import zip_longest +from typing import Tuple, Union + +from torch import Tensor + +from mmpose.models.utils import check_and_update_config +from mmpose.registry import MODELS +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + Optional, OptMultiConfig, OptSampleList, + PixelDataList, SampleList) +from .base import BasePoseEstimator + + +@MODELS.register_module() +class PoseLifter(BasePoseEstimator): + """Base class for pose lifter. + + Args: + backbone (dict): The backbone config + neck (dict, optional): The neck config. Defaults to ``None`` + head (dict, optional): The head config. Defaults to ``None`` + traj_backbone (dict, optional): The backbone config for trajectory + model. Defaults to ``None`` + traj_neck (dict, optional): The neck config for trajectory model. + Defaults to ``None`` + traj_head (dict, optional): The head config for trajectory model. + Defaults to ``None`` + semi_loss (dict, optional): The semi-supervised loss config. + Defaults to ``None`` + train_cfg (dict, optional): The runtime config for training process. + Defaults to ``None`` + test_cfg (dict, optional): The runtime config for testing process. + Defaults to ``None`` + data_preprocessor (dict, optional): The data preprocessing config to + build the instance of :class:`BaseDataPreprocessor`. Defaults to + ``None`` + init_cfg (dict, optional): The config to control the initialization. + Defaults to ``None`` + metainfo (dict): Meta information for dataset, such as keypoints + definition and properties. If set, the metainfo of the input data + batch will be overridden. For more details, please refer to + https://mmpose.readthedocs.io/en/latest/user_guides/ + prepare_datasets.html#create-a-custom-dataset-info- + config-file-for-the-dataset. Defaults to ``None`` + """ + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + head: OptConfigType = None, + traj_backbone: OptConfigType = None, + traj_neck: OptConfigType = None, + traj_head: OptConfigType = None, + semi_loss: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None, + metainfo: Optional[dict] = None): + super().__init__( + backbone=backbone, + neck=neck, + head=head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg, + metainfo=metainfo) + + # trajectory model + self.share_backbone = False + if traj_head is not None: + if traj_backbone is not None: + self.traj_backbone = MODELS.build(traj_backbone) + else: + self.share_backbone = True + + # the PR #2108 and #2126 modified the interface of neck and head. + # The following function automatically detects outdated + # configurations and updates them accordingly, while also providing + # clear and concise information on the changes made. + traj_neck, traj_head = check_and_update_config( + traj_neck, traj_head) + + if traj_neck is not None: + self.traj_neck = MODELS.build(traj_neck) + + self.traj_head = MODELS.build(traj_head) + + # semi-supervised loss + self.semi_supervised = semi_loss is not None + if self.semi_supervised: + assert any([head, traj_head]) + self.semi_loss = MODELS.build(semi_loss) + + @property + def with_traj_backbone(self): + """bool: Whether the pose lifter has trajectory backbone.""" + return hasattr(self, 'traj_backbone') and \ + self.traj_backbone is not None + + @property + def with_traj_neck(self): + """bool: Whether the pose lifter has trajectory neck.""" + return hasattr(self, 'traj_neck') and self.traj_neck is not None + + @property + def with_traj(self): + """bool: Whether the pose lifter has trajectory head.""" + return hasattr(self, 'traj_head') + + @property + def causal(self): + """bool: Whether the pose lifter is causal.""" + if hasattr(self.backbone, 'causal'): + return self.backbone.causal + else: + raise AttributeError('A PoseLifter\'s backbone should have ' + 'the bool attribute "causal" to indicate if' + 'it performs causal inference.') + + def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + inputs (Tensor): Image tensor with shape (N, K, C, T). + + Returns: + tuple[Tensor]: Multi-level features that may have various + resolutions. + """ + # supervised learning + # pose model + feats = self.backbone(inputs) + if self.with_neck: + x = self.neck(feats) + + # trajectory model + if self.with_traj: + if self.share_backbone: + traj_x = feats + else: + traj_x = self.traj_backbone(inputs) + + if self.with_traj_neck: + traj_x = self.traj_neck(traj_x) + return x, traj_x + else: + return x + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None + ) -> Union[Tensor, Tuple[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + inputs (Tensor): Inputs with shape (N, K, C, T). + + Returns: + Union[Tensor | Tuple[Tensor]]: forward output of the network. + """ + feats = self.extract_feat(inputs) + + if self.with_traj: + # forward with trajectory model + x, traj_x = feats + if self.with_head: + x = self.head.forward(x) + + traj_x = self.traj_head.forward(traj_x) + return x, traj_x + else: + # forward without trajectory model + x = feats + if self.with_head: + x = self.head.forward(x) + return x + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Inputs with shape (N, K, C, T). + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples. + + Returns: + dict: A dictionary of losses. + """ + feats = self.extract_feat(inputs) + + losses = {} + + if self.with_traj: + x, traj_x = feats + # loss of trajectory model + losses.update( + self.traj_head.loss( + traj_x, data_samples, train_cfg=self.train_cfg)) + else: + x = feats + + if self.with_head: + # loss of pose model + losses.update( + self.head.loss(x, data_samples, train_cfg=self.train_cfg)) + + # TODO: support semi-supervised learning + if self.semi_supervised: + losses.update(semi_loss=self.semi_loss(inputs, data_samples)) + + return losses + + def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, K, C, T). + data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + + Returns: + list[:obj:`PoseDataSample`]: The pose estimation results of the + input images. The return value is `PoseDataSample` instances with + ``pred_instances`` and ``pred_fields``(optional) field , and + ``pred_instances`` usually contains the following keys: + + - keypoints (Tensor): predicted keypoint coordinates in shape + (num_instances, K, D) where K is the keypoint number and D + is the keypoint dimension + - keypoint_scores (Tensor): predicted keypoint scores in shape + (num_instances, K) + """ + assert self.with_head, ( + 'The model must have head to perform prediction.') + + feats = self.extract_feat(inputs) + + pose_preds, batch_pred_instances, batch_pred_fields = None, None, None + traj_preds, batch_traj_instances, batch_traj_fields = None, None, None + if self.with_traj: + x, traj_x = feats + traj_preds = self.traj_head.predict( + traj_x, data_samples, test_cfg=self.test_cfg) + else: + x = feats + + if self.with_head: + pose_preds = self.head.predict( + x, data_samples, test_cfg=self.test_cfg) + + if isinstance(pose_preds, tuple): + batch_pred_instances, batch_pred_fields = pose_preds + else: + batch_pred_instances = pose_preds + + if isinstance(traj_preds, tuple): + batch_traj_instances, batch_traj_fields = traj_preds + else: + batch_traj_instances = traj_preds + + results = self.add_pred_to_datasample(batch_pred_instances, + batch_pred_fields, + batch_traj_instances, + batch_traj_fields, data_samples) + + return results + + def add_pred_to_datasample( + self, + batch_pred_instances: InstanceList, + batch_pred_fields: Optional[PixelDataList], + batch_traj_instances: InstanceList, + batch_traj_fields: Optional[PixelDataList], + batch_data_samples: SampleList, + ) -> SampleList: + """Add predictions into data samples. + + Args: + batch_pred_instances (List[InstanceData]): The predicted instances + of the input data batch + batch_pred_fields (List[PixelData], optional): The predicted + fields (e.g. heatmaps) of the input batch + batch_traj_instances (List[InstanceData]): The predicted instances + of the input data batch + batch_traj_fields (List[PixelData], optional): The predicted + fields (e.g. heatmaps) of the input batch + batch_data_samples (List[PoseDataSample]): The input data batch + + Returns: + List[PoseDataSample]: A list of data samples where the predictions + are stored in the ``pred_instances`` field of each data sample. + """ + assert len(batch_pred_instances) == len(batch_data_samples) + if batch_pred_fields is None: + batch_pred_fields, batch_traj_fields = [], [] + output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', + None) + + for (pred_instances, pred_fields, traj_instances, traj_fields, + data_sample) in zip_longest(batch_pred_instances, + batch_pred_fields, + batch_traj_instances, + batch_traj_fields, + batch_data_samples): + + if output_keypoint_indices is not None: + # select output keypoints with given indices + num_keypoints = pred_instances.keypoints.shape[1] + for key, value in pred_instances.all_items(): + if key.startswith('keypoint'): + pred_instances.set_field( + value[:, output_keypoint_indices], key) + + data_sample.pred_instances = pred_instances + + if pred_fields is not None: + if output_keypoint_indices is not None: + # select output heatmap channels with keypoint indices + # when the number of heatmap channel matches num_keypoints + for key, value in pred_fields.all_items(): + if value.shape[0] != num_keypoints: + continue + pred_fields.set_field(value[output_keypoint_indices], + key) + data_sample.pred_fields = pred_fields + + return batch_data_samples diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index c54bf12d1e..1c3eaf0b9f 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -87,9 +87,7 @@ def test_decode(self): codec = self.build_pose_lifting_label() decoded, scores = codec.decode( - encoded_wo_sigma, - restore_global_position=True, - target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=target[..., 0, :]) self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) @@ -97,9 +95,7 @@ def test_decode(self): codec = self.build_pose_lifting_label(remove_root=True) decoded, scores = codec.decode( - encoded_wo_sigma, - restore_global_position=True, - target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=target[..., 0, :]) self.assertEqual(decoded.shape, (1, 18, 3)) self.assertEqual(scores.shape, (1, 18)) @@ -117,7 +113,6 @@ def test_cicular_verification(self): _keypoints, _ = codec.decode( np.expand_dims(encoded['target_label'], axis=0), - restore_global_position=True, target_root=target[..., 0, :]) self.assertTrue( @@ -130,7 +125,6 @@ def test_cicular_verification(self): _keypoints, _ = codec.decode( np.expand_dims(encoded['target_label'], axis=0), - restore_global_position=True, target_root=target[..., 0, :]) self.assertTrue( @@ -147,7 +141,6 @@ def test_cicular_verification(self): _keypoints, _ = codec.decode( np.expand_dims(encoded['target_label'], axis=0), - restore_global_position=True, target_root=target[..., 0, :]) self.assertTrue( diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index da404360bb..58fc2cd29b 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -109,9 +109,7 @@ def test_decode(self): codec = self.build_pose_lifting_label() decoded, scores = codec.decode( - encoded_wo_sigma, - restore_global_position=True, - target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=target[..., 0, :]) self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) @@ -119,9 +117,7 @@ def test_decode(self): codec = self.build_pose_lifting_label(remove_root=True) decoded, scores = codec.decode( - encoded_wo_sigma, - restore_global_position=True, - target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=target[..., 0, :]) self.assertEqual(decoded.shape, (1, 18, 3)) self.assertEqual(scores.shape, (1, 18)) @@ -140,7 +136,6 @@ def test_cicular_verification(self): _keypoints, _ = codec.decode( np.expand_dims(encoded['target_label'], axis=0), - restore_global_position=True, target_root=target[..., 0, :]) self.assertTrue( @@ -153,7 +148,6 @@ def test_cicular_verification(self): _keypoints, _ = codec.decode( np.expand_dims(encoded['target_label'], axis=0), - restore_global_position=True, target_root=target[..., 0, :]) self.assertTrue( From 0cff9c3b07e13ac09143dad51a874e99ef5e1578 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Wed, 26 Apr 2023 14:00:09 +0800 Subject: [PATCH 4/8] [Refactor] add mpjpe metric (#2247) --- .../datasets/base/base_mocap_dataset.py | 1 + mmpose/evaluation/functional/__init__.py | 6 +- mmpose/evaluation/functional/keypoint_eval.py | 54 +++ mmpose/evaluation/functional/mesh_eval.py | 66 ++++ mmpose/evaluation/metrics/__init__.py | 3 +- .../evaluation/metrics/keypoint_3d_metrics.py | 129 +++++++ .../test_functional/test_keypoint_eval.py | 357 ++++++++++-------- .../test_metrics/test_keypoint_3d_metrics.py | 70 ++++ 8 files changed, 528 insertions(+), 158 deletions(-) create mode 100644 mmpose/evaluation/functional/mesh_eval.py create mode 100644 mmpose/evaluation/metrics/keypoint_3d_metrics.py create mode 100644 tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index 877fe01909..412cfcf985 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -304,6 +304,7 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: 'target': keypoints_3d[target_idx], 'target_visible': keypoints_3d_visible[target_idx], 'target_img_id': frame_ids[target_idx], + 'target_img_path': _img_names[target_idx], } if self.camera_param_file: diff --git a/mmpose/evaluation/functional/__init__.py b/mmpose/evaluation/functional/__init__.py index 2c4a8b5d1e..49f243163c 100644 --- a/mmpose/evaluation/functional/__init__.py +++ b/mmpose/evaluation/functional/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .keypoint_eval import (keypoint_auc, keypoint_epe, keypoint_nme, - keypoint_pck_accuracy, +from .keypoint_eval import (keypoint_auc, keypoint_epe, keypoint_mpjpe, + keypoint_nme, keypoint_pck_accuracy, multilabel_classification_accuracy, pose_pck_accuracy, simcc_pck_accuracy) from .nms import nms, oks_nms, soft_oks_nms @@ -8,5 +8,5 @@ __all__ = [ 'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_nme', 'keypoint_epe', 'pose_pck_accuracy', 'multilabel_classification_accuracy', - 'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms' + 'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms', 'keypoint_mpjpe' ] diff --git a/mmpose/evaluation/functional/keypoint_eval.py b/mmpose/evaluation/functional/keypoint_eval.py index 060243357b..9c0d6998ee 100644 --- a/mmpose/evaluation/functional/keypoint_eval.py +++ b/mmpose/evaluation/functional/keypoint_eval.py @@ -4,6 +4,7 @@ import numpy as np from mmpose.codecs.utils import get_heatmap_maximum, get_simcc_maximum +from .mesh_eval import compute_similarity_transform def _calc_distances(preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, @@ -318,3 +319,56 @@ def multilabel_classification_accuracy(pred: np.ndarray, # only if it's correct for all labels. acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean() return acc + + +def keypoint_mpjpe(pred: np.ndarray, + gt: np.ndarray, + mask: np.ndarray, + alignment: str = 'none'): + """Calculate the mean per-joint position error (MPJPE) and the error after + rigid alignment with the ground truth (P-MPJPE). + + Note: + - batch_size: N + - num_keypoints: K + - keypoint_dims: C + + Args: + pred (np.ndarray): Predicted keypoint location with shape [N, K, C]. + gt (np.ndarray): Groundtruth keypoint location with shape [N, K, C]. + mask (np.ndarray): Visibility of the target with shape [N, K]. + False for invisible joints, and True for visible. + Invisible joints will be ignored for accuracy calculation. + alignment (str, optional): method to align the prediction with the + groundtruth. Supported options are: + + - ``'none'``: no alignment will be applied + - ``'scale'``: align in the least-square sense in scale + - ``'procrustes'``: align in the least-square sense in + scale, rotation and translation. + Returns: + tuple: A tuple containing joint position errors + + - (float | np.ndarray): mean per-joint position error (mpjpe). + - (float | np.ndarray): mpjpe after rigid alignment with the + ground truth (p-mpjpe). + """ + assert mask.any() + + if alignment == 'none': + pass + elif alignment == 'procrustes': + pred = np.stack([ + compute_similarity_transform(pred_i, gt_i) + for pred_i, gt_i in zip(pred, gt) + ]) + elif alignment == 'scale': + pred_dot_pred = np.einsum('nkc,nkc->n', pred, pred) + pred_dot_gt = np.einsum('nkc,nkc->n', pred, gt) + scale_factor = pred_dot_gt / pred_dot_pred + pred = pred * scale_factor[:, None, None] + else: + raise ValueError(f'Invalid value for alignment: {alignment}') + error = np.linalg.norm(pred - gt, ord=2, axis=-1)[mask].mean() + + return error diff --git a/mmpose/evaluation/functional/mesh_eval.py b/mmpose/evaluation/functional/mesh_eval.py new file mode 100644 index 0000000000..683b4539b2 --- /dev/null +++ b/mmpose/evaluation/functional/mesh_eval.py @@ -0,0 +1,66 @@ +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/akanazawa/hmr +# Original licence: Copyright (c) 2018 akanazawa, under the MIT License. +# ------------------------------------------------------------------------------ + +import numpy as np + + +def compute_similarity_transform(source_points, target_points): + """Computes a similarity transform (sR, t) that takes a set of 3D points + source_points (N x 3) closest to a set of 3D points target_points, where R + is an 3x3 rotation matrix, t 3x1 translation, s scale. And return the + transformed 3D points source_points_hat (N x 3). i.e. solves the orthogonal + Procrutes problem. + + Note: + Points number: N + + Args: + source_points (np.ndarray): Source point set with shape [N, 3]. + target_points (np.ndarray): Target point set with shape [N, 3]. + + Returns: + np.ndarray: Transformed source point set with shape [N, 3]. + """ + + assert target_points.shape[0] == source_points.shape[0] + assert target_points.shape[1] == 3 and source_points.shape[1] == 3 + + source_points = source_points.T + target_points = target_points.T + + # 1. Remove mean. + mu1 = source_points.mean(axis=1, keepdims=True) + mu2 = target_points.mean(axis=1, keepdims=True) + X1 = source_points - mu1 + X2 = target_points - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, _, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale * (R.dot(mu1)) + + # 7. Transform the source points: + source_points_hat = scale * R.dot(source_points) + t + + source_points_hat = source_points_hat.T + + return source_points_hat diff --git a/mmpose/evaluation/metrics/__init__.py b/mmpose/evaluation/metrics/__init__.py index f02c353ef7..ac7e21b5cc 100644 --- a/mmpose/evaluation/metrics/__init__.py +++ b/mmpose/evaluation/metrics/__init__.py @@ -3,11 +3,12 @@ from .coco_wholebody_metric import CocoWholeBodyMetric from .keypoint_2d_metrics import (AUC, EPE, NME, JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy) +from .keypoint_3d_metrics import MPJPE from .keypoint_partition_metric import KeypointPartitionMetric from .posetrack18_metric import PoseTrack18Metric __all__ = [ 'CocoMetric', 'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'AUC', 'EPE', 'NME', 'PoseTrack18Metric', 'CocoWholeBodyMetric', - 'KeypointPartitionMetric' + 'KeypointPartitionMetric', 'MPJPE' ] diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py new file mode 100644 index 0000000000..c65dbb1998 --- /dev/null +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from os import path as osp +from typing import Dict, Optional, Sequence + +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpose.registry import METRICS +from ..functional import keypoint_mpjpe + + +@METRICS.register_module() +class MPJPE(BaseMetric): + """MPJPE evaluation metric. + + Calculate the mean per-joint position error (MPJPE) of keypoints. + + Note: + - length of dataset: N + - num_keypoints: K + - number of keypoint dimensions: D (typically D = 2) + + Args: + mode (str): Method to align the prediction with the + ground truth. Supported options are: + + - ``'mpjpe'``: no alignment will be applied + - ``'p-mpjpe'``: align in the least-square sense in scale + - ``'n-mpjpe'``: align in the least-square sense in + scale, rotation, and translation. + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be ``'cpu'`` or + ``'gpu'``. Default: ``'cpu'``. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, ``self.default_prefix`` + will be used instead. Default: ``None``. + """ + + ALIGNMENT = {'mpjpe': 'none', 'p-mpjpe': 'procrustes', 'n-mpjpe': 'scale'} + + def __init__(self, + mode: str = 'mpjpe', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + allowed_modes = self.ALIGNMENT.keys() + if mode not in allowed_modes: + raise KeyError("`mode` should be 'mpjpe', 'p-mpjpe', or " + f"'n-mpjpe', but got '{mode}'.") + + self.mode = mode + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + # predicted keypoints coordinates, [1, K, D] + pred_coords = data_sample['pred_instances']['keypoints'] + # ground truth data_info + gt = data_sample['gt_instances'] + # ground truth keypoints coordinates, [1, K, D] + gt_coords = gt['target'] + # ground truth keypoints_visible, [1, K, 1] + mask = gt['target_visible'].astype(bool).reshape(1, -1) + # instance action + img_path = data_sample['target_img_path'] + _, rest = osp.basename(img_path).split('_', 1) + action, _ = rest.split('.', 1) + + result = { + 'pred_coords': pred_coords, + 'gt_coords': gt_coords, + 'mask': mask, + 'action': action + } + + self.results.append(result) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are the corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + # pred_coords: [N, K, D] + pred_coords = np.concatenate( + [result['pred_coords'] for result in results]) + # gt_coords: [N, K, D] + gt_coords = np.concatenate([result['gt_coords'] for result in results]) + # mask: [N, K] + mask = np.concatenate([result['mask'] for result in results]) + # action_category_indices: Dict[List[int]] + action_category_indices = defaultdict(list) + for idx, result in enumerate(results): + action_category = result['action'].split('_')[0] + action_category_indices[action_category].append(idx) + + error_name = self.mode.upper() + + logger.info(f'Evaluating {self.mode.upper()}...') + metrics = dict() + + metrics[error_name] = keypoint_mpjpe(pred_coords, gt_coords, mask, + self.ALIGNMENT[self.mode]) + + for action_category, indices in action_category_indices.items(): + metrics[f'{error_name}_{action_category}'] = keypoint_mpjpe( + pred_coords[indices], gt_coords[indices], mask[indices]) + + return metrics diff --git a/tests/test_evaluation/test_functional/test_keypoint_eval.py b/tests/test_evaluation/test_functional/test_keypoint_eval.py index 2234c8e547..47ede83921 100644 --- a/tests/test_evaluation/test_functional/test_keypoint_eval.py +++ b/tests/test_evaluation/test_functional/test_keypoint_eval.py @@ -1,163 +1,212 @@ # Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + import numpy as np from numpy.testing import assert_array_almost_equal from mmpose.evaluation.functional import (keypoint_auc, keypoint_epe, - keypoint_nme, keypoint_pck_accuracy, + keypoint_mpjpe, keypoint_nme, + keypoint_pck_accuracy, multilabel_classification_accuracy, pose_pck_accuracy) -def test_keypoint_pck_accuracy(): - output = np.zeros((2, 5, 2)) - target = np.zeros((2, 5, 2)) - mask = np.array([[True, True, False, True, True], - [True, True, False, True, True]]) - thr = np.full((2, 2), 10, dtype=np.float32) - # first channel - output[0, 0] = [10, 0] - target[0, 0] = [10, 0] - # second channel - output[0, 1] = [20, 20] - target[0, 1] = [10, 10] - # third channel - output[0, 2] = [0, 0] - target[0, 2] = [-1, 0] - # fourth channel - output[0, 3] = [30, 30] - target[0, 3] = [30, 30] - # fifth channel - output[0, 4] = [0, 10] - target[0, 4] = [0, 10] - - acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, thr) - - assert_array_almost_equal(acc, np.array([1, 0.5, -1, 1, 1]), decimal=4) - assert abs(avg_acc - 0.875) < 1e-4 - assert abs(cnt - 4) < 1e-4 - - acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, - np.zeros((2, 2))) - assert_array_almost_equal(acc, np.array([-1, -1, -1, -1, -1]), decimal=4) - assert abs(avg_acc) < 1e-4 - assert abs(cnt) < 1e-4 - - acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, - np.array([[0, 0], [10, 10]])) - assert_array_almost_equal(acc, np.array([1, 1, -1, 1, 1]), decimal=4) - assert abs(avg_acc - 1) < 1e-4 - assert abs(cnt - 4) < 1e-4 - - -def test_keypoint_auc(): - output = np.zeros((1, 5, 2)) - target = np.zeros((1, 5, 2)) - mask = np.array([[True, True, False, True, True]]) - # first channel - output[0, 0] = [10, 4] - target[0, 0] = [10, 0] - # second channel - output[0, 1] = [10, 18] - target[0, 1] = [10, 10] - # third channel - output[0, 2] = [0, 0] - target[0, 2] = [0, -1] - # fourth channel - output[0, 3] = [40, 40] - target[0, 3] = [30, 30] - # fifth channel - output[0, 4] = [20, 10] - target[0, 4] = [0, 10] - - auc = keypoint_auc(output, target, mask, 20, 4) - assert abs(auc - 0.375) < 1e-4 - - -def test_keypoint_epe(): - output = np.zeros((1, 5, 2)) - target = np.zeros((1, 5, 2)) - mask = np.array([[True, True, False, True, True]]) - # first channel - output[0, 0] = [10, 4] - target[0, 0] = [10, 0] - # second channel - output[0, 1] = [10, 18] - target[0, 1] = [10, 10] - # third channel - output[0, 2] = [0, 0] - target[0, 2] = [-1, -1] - # fourth channel - output[0, 3] = [40, 40] - target[0, 3] = [30, 30] - # fifth channel - output[0, 4] = [20, 10] - target[0, 4] = [0, 10] - - epe = keypoint_epe(output, target, mask) - assert abs(epe - 11.5355339) < 1e-4 - - -def test_keypoint_nme(): - output = np.zeros((1, 5, 2)) - target = np.zeros((1, 5, 2)) - mask = np.array([[True, True, False, True, True]]) - # first channel - output[0, 0] = [10, 4] - target[0, 0] = [10, 0] - # second channel - output[0, 1] = [10, 18] - target[0, 1] = [10, 10] - # third channel - output[0, 2] = [0, 0] - target[0, 2] = [-1, -1] - # fourth channel - output[0, 3] = [40, 40] - target[0, 3] = [30, 30] - # fifth channel - output[0, 4] = [20, 10] - target[0, 4] = [0, 10] - - normalize_factor = np.ones((output.shape[0], output.shape[2])) - - nme = keypoint_nme(output, target, mask, normalize_factor) - assert abs(nme - 11.5355339) < 1e-4 - - -def test_pose_pck_accuracy(): - output = np.zeros((1, 5, 64, 64), dtype=np.float32) - target = np.zeros((1, 5, 64, 64), dtype=np.float32) - mask = np.array([[True, True, False, False, False]]) - # first channel - output[0, 0, 20, 20] = 1 - target[0, 0, 10, 10] = 1 - # second channel - output[0, 1, 30, 30] = 1 - target[0, 1, 30, 30] = 1 - - acc, avg_acc, cnt = pose_pck_accuracy(output, target, mask) - - assert_array_almost_equal(acc, np.array([0, 1, -1, -1, -1]), decimal=4) - assert abs(avg_acc - 0.5) < 1e-4 - assert abs(cnt - 2) < 1e-4 - - -def test_multilabel_classification_accuracy(): - output = np.array([[0.7, 0.8, 0.4], [0.8, 0.1, 0.1]]) - target = np.array([[1, 0, 0], [1, 0, 1]]) - mask = np.array([[True, True, True], [True, True, True]]) - thr = 0.5 - acc = multilabel_classification_accuracy(output, target, mask, thr) - assert acc == 0 - - output = np.array([[0.7, 0.2, 0.4], [0.8, 0.1, 0.9]]) - thr = 0.5 - acc = multilabel_classification_accuracy(output, target, mask, thr) - assert acc == 1 - - thr = 0.3 - acc = multilabel_classification_accuracy(output, target, mask, thr) - assert acc == 0.5 - - mask = np.array([[True, True, False], [True, True, True]]) - acc = multilabel_classification_accuracy(output, target, mask, thr) - assert acc == 1 +class TestKeypointEval(TestCase): + + def test_keypoint_pck_accuracy(self): + + output = np.zeros((2, 5, 2)) + target = np.zeros((2, 5, 2)) + mask = np.array([[True, True, False, True, True], + [True, True, False, True, True]]) + + # first channel + output[0, 0] = [10, 0] + target[0, 0] = [10, 0] + # second channel + output[0, 1] = [20, 20] + target[0, 1] = [10, 10] + # third channel + output[0, 2] = [0, 0] + target[0, 2] = [-1, 0] + # fourth channel + output[0, 3] = [30, 30] + target[0, 3] = [30, 30] + # fifth channel + output[0, 4] = [0, 10] + target[0, 4] = [0, 10] + + thr = np.full((2, 2), 10, dtype=np.float32) + + acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, + thr) + + assert_array_almost_equal(acc, np.array([1, 0.5, -1, 1, 1]), decimal=4) + self.assertAlmostEqual(avg_acc, 0.875, delta=1e-4) + self.assertAlmostEqual(cnt, 4, delta=1e-4) + + acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, + np.zeros((2, 2))) + assert_array_almost_equal( + acc, np.array([-1, -1, -1, -1, -1]), decimal=4) + self.assertAlmostEqual(avg_acc, 0, delta=1e-4) + self.assertAlmostEqual(cnt, 0, delta=1e-4) + + acc, avg_acc, cnt = keypoint_pck_accuracy(output, target, mask, 0.5, + np.array([[0, 0], [10, 10]])) + assert_array_almost_equal(acc, np.array([1, 1, -1, 1, 1]), decimal=4) + self.assertAlmostEqual(avg_acc, 1, delta=1e-4) + self.assertAlmostEqual(cnt, 4, delta=1e-4) + + def test_keypoint_auc(self): + output = np.zeros((1, 5, 2)) + target = np.zeros((1, 5, 2)) + mask = np.array([[True, True, False, True, True]]) + # first channel + output[0, 0] = [10, 4] + target[0, 0] = [10, 0] + # second channel + output[0, 1] = [10, 18] + target[0, 1] = [10, 10] + # third channel + output[0, 2] = [0, 0] + target[0, 2] = [0, -1] + # fourth channel + output[0, 3] = [40, 40] + target[0, 3] = [30, 30] + # fifth channel + output[0, 4] = [20, 10] + target[0, 4] = [0, 10] + + auc = keypoint_auc(output, target, mask, 20, 4) + self.assertAlmostEqual(auc, 0.375, delta=1e-4) + + def test_keypoint_epe(self): + output = np.zeros((1, 5, 2)) + target = np.zeros((1, 5, 2)) + mask = np.array([[True, True, False, True, True]]) + # first channel + output[0, 0] = [10, 4] + target[0, 0] = [10, 0] + # second channel + output[0, 1] = [10, 18] + target[0, 1] = [10, 10] + # third channel + output[0, 2] = [0, 0] + target[0, 2] = [-1, -1] + # fourth channel + output[0, 3] = [40, 40] + target[0, 3] = [30, 30] + # fifth channel + output[0, 4] = [20, 10] + target[0, 4] = [0, 10] + + epe = keypoint_epe(output, target, mask) + self.assertAlmostEqual(epe, 11.5355339, delta=1e-4) + + def test_keypoint_nme(self): + output = np.zeros((1, 5, 2)) + target = np.zeros((1, 5, 2)) + mask = np.array([[True, True, False, True, True]]) + # first channel + output[0, 0] = [10, 4] + target[0, 0] = [10, 0] + # second channel + output[0, 1] = [10, 18] + target[0, 1] = [10, 10] + # third channel + output[0, 2] = [0, 0] + target[0, 2] = [-1, -1] + # fourth channel + output[0, 3] = [40, 40] + target[0, 3] = [30, 30] + # fifth channel + output[0, 4] = [20, 10] + target[0, 4] = [0, 10] + + normalize_factor = np.ones((output.shape[0], output.shape[2])) + + nme = keypoint_nme(output, target, mask, normalize_factor) + self.assertAlmostEqual(nme, 11.5355339, delta=1e-4) + + def test_pose_pck_accuracy(self): + output = np.zeros((1, 5, 64, 64), dtype=np.float32) + target = np.zeros((1, 5, 64, 64), dtype=np.float32) + mask = np.array([[True, True, False, False, False]]) + # first channel + output[0, 0, 20, 20] = 1 + target[0, 0, 10, 10] = 1 + # second channel + output[0, 1, 30, 30] = 1 + target[0, 1, 30, 30] = 1 + + acc, avg_acc, cnt = pose_pck_accuracy(output, target, mask) + + assert_array_almost_equal(acc, np.array([0, 1, -1, -1, -1]), decimal=4) + self.assertAlmostEqual(avg_acc, 0.5, delta=1e-4) + self.assertAlmostEqual(cnt, 2, delta=1e-4) + + def test_multilabel_classification_accuracy(self): + output = np.array([[0.7, 0.8, 0.4], [0.8, 0.1, 0.1]]) + target = np.array([[1, 0, 0], [1, 0, 1]]) + mask = np.array([[True, True, True], [True, True, True]]) + thr = 0.5 + acc = multilabel_classification_accuracy(output, target, mask, thr) + self.assertEqual(acc, 0) + + output = np.array([[0.7, 0.2, 0.4], [0.8, 0.1, 0.9]]) + thr = 0.5 + acc = multilabel_classification_accuracy(output, target, mask, thr) + self.assertEqual(acc, 1) + + thr = 0.3 + acc = multilabel_classification_accuracy(output, target, mask, thr) + self.assertEqual(acc, 0.5) + + mask = np.array([[True, True, False], [True, True, True]]) + acc = multilabel_classification_accuracy(output, target, mask, thr) + self.assertEqual(acc, 1) + + def test_keypoint_mpjpe(self): + output = np.zeros((2, 5, 3)) + target = np.zeros((2, 5, 3)) + mask = np.array([[True, True, False, True, True], + [True, True, False, True, True]]) + + # first channel + output[0, 0] = [1, 0, 0] + target[0, 0] = [1, 0, 0] + output[1, 0] = [1, 0, 0] + target[1, 0] = [1, 1, 0] + # second channel + output[0, 1] = [2, 2, 0] + target[0, 1] = [1, 1, 1] + output[1, 1] = [2, 2, 1] + target[1, 1] = [1, 0, 1] + # third channel + output[0, 2] = [0, 0, -1] + target[0, 2] = [-1, 0, 0] + output[1, 2] = [-1, 0, 0] + target[1, 2] = [-1, 0, 0] + # fourth channel + output[0, 3] = [3, 3, 1] + target[0, 3] = [3, 3, 1] + output[1, 3] = [0, 0, 3] + target[1, 3] = [0, 0, 3] + # fifth channel + output[0, 4] = [0, 1, 1] + target[0, 4] = [0, 1, 0] + output[1, 4] = [0, 0, 1] + target[1, 4] = [1, 1, 0] + + mpjpe = keypoint_mpjpe(output, target, mask) + self.assertAlmostEqual(mpjpe, 0.9625211990796929, delta=1e-4) + + p_mpjpe = keypoint_mpjpe(output, target, mask, 'procrustes') + self.assertAlmostEqual(p_mpjpe, 1.0047897634604497, delta=1e-4) + + s_mpjpe = keypoint_mpjpe(output, target, mask, 'scale') + self.assertAlmostEqual(s_mpjpe, 1.0277129678465953, delta=1e-4) + + with self.assertRaises(ValueError): + _ = keypoint_mpjpe(output, target, mask, 'alignment') diff --git a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py new file mode 100644 index 0000000000..40da092cad --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +from mmengine.structures import InstanceData + +from mmpose.evaluation import MPJPE +from mmpose.structures import PoseDataSample + + +class TestMPJPE(TestCase): + + def setUp(self): + """Setup variables used in every test method.""" + self.batch_size = 8 + num_keypoints = 15 + self.data_batch = [] + self.data_samples = [] + + for i in range(self.batch_size): + gt_instances = InstanceData() + keypoints = np.random.random((1, num_keypoints, 3)) + gt_instances.target = keypoints + gt_instances.target_visible = np.ones( + (1, num_keypoints, 1)).astype(bool) + + pred_instances = InstanceData() + pred_instances.keypoints = keypoints + np.random.normal( + 0, 0.01, keypoints.shape) + + data = {'inputs': None} + data_sample = PoseDataSample( + gt_instances=gt_instances, pred_instances=pred_instances) + data_sample.set_metainfo( + dict(target_img_path='tests/data/h36m/S7/' + 'S7_Greeting.55011271/S7_Greeting.55011271_000396.jpg')) + + self.data_batch.append(data) + self.data_samples.append(data_sample.to_dict()) + + def test_init(self): + """Test metric init method.""" + # Test invalid mode + with self.assertRaisesRegex( + KeyError, "`mode` should be 'mpjpe', 'p-mpjpe', or 'n-mpjpe', " + "but got 'invalid'."): + MPJPE(mode='invalid') + + def test_evaluate(self): + """Test MPJPE evaluation metric.""" + mpjpe_metric = MPJPE(mode='mpjpe') + mpjpe_metric.process(self.data_batch, self.data_samples) + mpjpe = mpjpe_metric.evaluate(self.batch_size) + self.assertIsInstance(mpjpe, dict) + self.assertIn('MPJPE', mpjpe) + self.assertTrue(mpjpe['MPJPE'] >= 0) + + p_mpjpe_metric = MPJPE(mode='p-mpjpe') + p_mpjpe_metric.process(self.data_batch, self.data_samples) + p_mpjpe = p_mpjpe_metric.evaluate(self.batch_size) + self.assertIsInstance(p_mpjpe, dict) + self.assertIn('P-MPJPE', p_mpjpe) + self.assertTrue(p_mpjpe['P-MPJPE'] >= 0) + + n_mpjpe_metric = MPJPE(mode='n-mpjpe') + n_mpjpe_metric.process(self.data_batch, self.data_samples) + n_mpjpe = n_mpjpe_metric.evaluate(self.batch_size) + self.assertIsInstance(n_mpjpe, dict) + self.assertIn('N-MPJPE', n_mpjpe) + self.assertTrue(n_mpjpe['N-MPJPE'] >= 0) From 729d37db295ab5a1086b265dda0a4b114da8909b Mon Sep 17 00:00:00 2001 From: Yifan Lareina WU Date: Wed, 26 Apr 2023 17:08:15 +0800 Subject: [PATCH 5/8] multiple changes (#2314) --- mmpose/codecs/image_pose_lifting.py | 89 +++++++++--------- mmpose/codecs/video_pose_lifting.py | 94 ++++++++++--------- .../datasets/base/base_mocap_dataset.py | 6 +- mmpose/datasets/transforms/formatting.py | 17 +++- .../datasets/transforms/pose3d_transforms.py | 29 +++--- mmpose/evaluation/functional/keypoint_eval.py | 1 + .../evaluation/metrics/keypoint_3d_metrics.py | 4 +- .../temporal_regression_head.py | 43 ++++++--- .../trajectory_regression_head.py | 40 +++++--- tests/test_codecs/test_image_pose_lifting.py | 77 +++++++-------- tests/test_codecs/test_video_pose_lifting.py | 72 +++++++------- .../test_transforms/test_pose3d_transforms.py | 12 +-- .../test_metrics/test_keypoint_3d_metrics.py | 4 +- 13 files changed, 274 insertions(+), 214 deletions(-) diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index bf5b009c5d..1a02cda17e 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -16,7 +16,7 @@ class ImagePoseLifting(BaseKeypointCodec): - instance number: N - keypoint number: K - keypoint dimension: D - - target dimension: C + - pose-lifitng target dimension: C Args: num_keypoints (int): The number of keypoints in the dataset. @@ -29,13 +29,13 @@ class ImagePoseLifting(BaseKeypointCodec): coordinates in shape (K, D). keypoints_std (np.ndarray, optional): Std values of keypoints coordinates in shape (K, D). - target_mean (np.ndarray, optional): Mean values of target coordinates - in shape (K, C). - target_std (np.ndarray, optional): Std values of target coordinates - in shape (K, C). + target_mean (np.ndarray, optional): Mean values of pose-lifitng target + coordinates in shape (K, C). + target_std (np.ndarray, optional): Std values of pose-lifitng target + coordinates in shape (K, C). """ - auxiliary_encode_keys = {'target', 'target_visible'} + auxiliary_encode_keys = {'lifting_target', 'lifting_target_visible'} def __init__(self, num_keypoints: int, @@ -64,27 +64,28 @@ def __init__(self, def encode(self, keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None, - target: Optional[np.ndarray] = None, - target_visible: Optional[np.ndarray] = None) -> dict: + lifting_target: Optional[np.ndarray] = None, + lifting_target_visible: Optional[np.ndarray] = None) -> dict: """Encoding keypoints from input image space to normalized space. Args: keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). keypoints_visible (np.ndarray, optional): Keypoint visibilities in shape (N, K). - target (np.ndarray, optional): Target coordinate in shape (K, C). - target_visible (np.ndarray, optional): Target coordinate in shape - (K, ). + lifting_target (np.ndarray, optional): 3d target coordinate in + shape (K, C). + lifting_target_visible (np.ndarray, optional): Target coordinate in + shape (K, ). Returns: encoded (dict): Contains the following items: - keypoint_labels (np.ndarray): The processed keypoints in shape (K * D, N) where D is 2 for 2d coordinates. - - target_label: The processed target coordinate in shape (K, C) - or (K-1, C). - - target_weights (np.ndarray): The target weights in shape - (K, ) or (K-1, ). + - lifting_target_label: The processed target coordinate in + shape (K, C) or (K-1, C). + - lifting_target_weights (np.ndarray): The target weights in + shape (K, ) or (K-1, ). - trajectory_weights (np.ndarray): The trajectory weights in shape (K, ). - target_root (np.ndarray): The root coordinate of target in @@ -93,7 +94,8 @@ def encode(self, In addition, there are some optional items it may contain: - target_root_removed (bool): Indicate whether the root of - target is removed. Added if ``self.remove_root`` is ``True``. + pose lifting target is removed. Added if ``self.remove_root`` + is ``True``. - target_root_index (int): An integer indicating the index of root. Added if ``self.remove_root`` and ``self.save_index`` are ``True``. @@ -101,34 +103,38 @@ def encode(self, if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) - if target is None: - target = keypoints[0] + if lifting_target is None: + lifting_target = keypoints[0] - # set initial value for `target_weights` and `trajectory_weights` - if target_visible is None: - target_visible = np.ones(target.shape[:-1], dtype=np.float32) - target_weights = target_visible - trajectory_weights = (1 / target[:, 2]) + # set initial value for `lifting_target_weights` + # and `trajectory_weights` + if lifting_target_visible is None: + lifting_target_visible = np.ones( + lifting_target.shape[:-1], dtype=np.float32) + lifting_target_weights = lifting_target_visible + trajectory_weights = (1 / lifting_target[:, 2]) else: - valid = target_visible > 0.5 - target_weights = np.where(valid, 1., 0.).astype(np.float32) - trajectory_weights = target_weights + valid = lifting_target_visible > 0.5 + lifting_target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = lifting_target_weights encoded = dict() # Zero-center the target pose around a given root keypoint - assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ - f'Got invalid joint shape {target.shape}' + assert (lifting_target.ndim >= 2 and + lifting_target.shape[-2] > self.root_index), \ + f'Got invalid joint shape {lifting_target.shape}' - root = target[..., self.root_index, :] - target_label = target - root + root = lifting_target[..., self.root_index, :] + lifting_target_label = lifting_target - root if self.remove_root: - target_label = np.delete(target_label, self.root_index, axis=-2) - assert target_weights.ndim in {1, 2} - axis_to_remove = -2 if target_weights.ndim == 2 else -1 - target_weights = np.delete( - target_weights, self.root_index, axis=axis_to_remove) + lifting_target_label = np.delete( + lifting_target_label, self.root_index, axis=-2) + assert lifting_target_weights.ndim in {1, 2} + axis_to_remove = -2 if lifting_target_weights.ndim == 2 else -1 + lifting_target_weights = np.delete( + lifting_target_weights, self.root_index, axis=axis_to_remove) # Add a flag to avoid latter transforms that rely on the root # joint or the original joint index encoded['target_root_removed'] = True @@ -146,10 +152,11 @@ def encode(self, keypoint_labels = (keypoint_labels - self.keypoints_mean) / self.keypoints_std if self.target_mean is not None and self.target_std is not None: - target_shape = target_label.shape + target_shape = lifting_target_label.shape assert self.target_mean.shape == target_shape - target_label = (target_label - self.target_mean) / self.target_std + lifting_target_label = (lifting_target_label - + self.target_mean) / self.target_std # Generate reshaped keypoint coordinates assert keypoint_labels.ndim in {2, 3} @@ -160,8 +167,8 @@ def encode(self, keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) encoded['keypoint_labels'] = keypoint_labels - encoded['target_label'] = target_label - encoded['target_weights'] = target_weights + encoded['lifting_target_label'] = lifting_target_label + encoded['lifting_target_weights'] = lifting_target_weights encoded['trajectory_weights'] = trajectory_weights encoded['target_root'] = root @@ -176,8 +183,6 @@ def decode(self, Args: encoded (np.ndarray): Coordinates in shape (N, K, C). - restore_global_position (bool): Whether to restore global position. - Default: ``False``. target_root (np.ndarray, optional): The target root coordinate. Default: ``None``. @@ -191,7 +196,7 @@ def decode(self, assert self.target_mean.shape == keypoints.shape[1:] keypoints = keypoints * self.target_std + self.target_mean - if target_root is not None: + if target_root.size > 0: keypoints = keypoints + np.expand_dims(target_root, axis=0) if self.remove_root: keypoints = np.insert( diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index ceda999f85..0331aad544 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -18,7 +18,7 @@ class VideoPoseLifting(BaseKeypointCodec): - instance number: N - keypoint number: K - keypoint dimension: D - - target dimension: C + - pose-lifitng target dimension: C Args: num_keypoints (int): The number of keypoints in the dataset. @@ -33,7 +33,9 @@ class VideoPoseLifting(BaseKeypointCodec): Default: ``False``. """ - auxiliary_encode_keys = {'target', 'target_visible', 'camera_param'} + auxiliary_encode_keys = { + 'lifting_target', 'lifting_target_visible', 'camera_param' + } def __init__(self, num_keypoints: int, @@ -54,8 +56,8 @@ def __init__(self, def encode(self, keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None, - target: Optional[np.ndarray] = None, - target_visible: Optional[np.ndarray] = None, + lifting_target: Optional[np.ndarray] = None, + lifting_target_visible: Optional[np.ndarray] = None, camera_param: Optional[dict] = None) -> dict: """Encoding keypoints from input image space to normalized space. @@ -63,9 +65,10 @@ def encode(self, keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). keypoints_visible (np.ndarray, optional): Keypoint visibilities in shape (N, K). - target (np.ndarray, optional): Target coordinate in shape (K, C). - target_visible (np.ndarray, optional): Target coordinate in shape - (K, ). + lifting_target (np.ndarray, optional): 3d target coordinate in + shape (K, C). + lifting_target_visible (np.ndarray, optional): Target coordinate in + shape (K, ). camera_param (dict, optional): The camera parameter dictionary. Returns: @@ -73,10 +76,10 @@ def encode(self, - keypoint_labels (np.ndarray): The processed keypoints in shape (K * D, N) where D is 2 for 2d coordinates. - - target_label: The processed target coordinate in shape (K, C) - or (K-1, C). - - target_weights (np.ndarray): The target weights in shape - (K, ) or (K-1, ). + - lifting_target_label: The processed target coordinate in + shape (K, C) or (K-1, C). + - lifting_target_weights (np.ndarray): The target weights in + shape (K, ) or (K-1, ). - trajectory_weights (np.ndarray): The trajectory weights in shape (K, ). @@ -85,8 +88,8 @@ def encode(self, - target_root (np.ndarray): The root coordinate of target in shape (C, ). Exists if ``self.zero_center`` is ``True``. - target_root_removed (bool): Indicate whether the root of - target is removed. Exists if ``self.remove_root`` is - ``True``. + pose-lifitng target is removed. Exists if + ``self.remove_root`` is ``True``. - target_root_index (int): An integer indicating the index of root. Exists if ``self.remove_root`` and ``self.save_index`` are ``True``. @@ -96,41 +99,46 @@ def encode(self, if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) - if target is None: - target = keypoints[0] + if lifting_target is None: + lifting_target = keypoints[0] - # set initial value for `target_weights` and `trajectory_weights` - if target_visible is None: - target_visible = np.ones(target.shape[:-1], dtype=np.float32) - target_weights = target_visible - trajectory_weights = (1 / target[:, 2]) + # set initial value for `lifting_target_weights` + # and `trajectory_weights` + if lifting_target_visible is None: + lifting_target_visible = np.ones( + lifting_target.shape[:-1], dtype=np.float32) + lifting_target_weights = lifting_target_visible + trajectory_weights = (1 / lifting_target[:, 2]) else: - valid = target_visible > 0.5 - target_weights = np.where(valid, 1., 0.).astype(np.float32) - trajectory_weights = target_weights + valid = lifting_target_visible > 0.5 + lifting_target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = lifting_target_weights if camera_param is None: camera_param = dict() encoded = dict() - target_label = target.copy() + lifting_target_label = lifting_target.copy() # Zero-center the target pose around a given root keypoint if self.zero_center: - assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ - f'Got invalid joint shape {target.shape}' + assert (lifting_target.ndim >= 2 and + lifting_target.shape[-2] > self.root_index), \ + f'Got invalid joint shape {lifting_target.shape}' - root = target[..., self.root_index, :] - target_label = target_label - root + root = lifting_target[..., self.root_index, :] + lifting_target_label = lifting_target_label - root encoded['target_root'] = root if self.remove_root: - target_label = np.delete( - target_label, self.root_index, axis=-2) - assert target_weights.ndim in {1, 2} - axis_to_remove = -2 if target_weights.ndim == 2 else -1 - target_weights = np.delete( - target_weights, self.root_index, axis=axis_to_remove) + lifting_target_label = np.delete( + lifting_target_label, self.root_index, axis=-2) + assert lifting_target_weights.ndim in {1, 2} + axis_to_remove = -2 if lifting_target_weights.ndim == 2 else -1 + lifting_target_weights = np.delete( + lifting_target_weights, + self.root_index, + axis=axis_to_remove) # Add a flag to avoid latter transforms that rely on the root # joint or the original joint index encoded['target_root_removed'] = True @@ -163,8 +171,8 @@ def encode(self, keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) encoded['keypoint_labels'] = keypoint_labels - encoded['target_label'] = target_label - encoded['target_weights'] = target_weights + encoded['lifting_target_label'] = lifting_target_label + encoded['lifting_target_weights'] = lifting_target_weights encoded['trajectory_weights'] = trajectory_weights return encoded @@ -177,19 +185,17 @@ def decode(self, space. Args: - encoded (np.ndarray): Coordinates in shape (1, K, C). - restore_global_position (bool): Whether to restore global position. - Default: ``False``. - target_root (np.ndarray, optional): The target root coordinate. - Default: ``None``. + encoded (np.ndarray): Coordinates in shape (N, K, C). + target_root (np.ndarray, optional): The pose-lifitng target root + coordinate. Default: ``None``. Returns: - keypoints (np.ndarray): Decoded coordinates in shape (1, K, C). - scores (np.ndarray): The keypoint scores in shape (1, K). + keypoints (np.ndarray): Decoded coordinates in shape (N, K, C). + scores (np.ndarray): The keypoint scores in shape (N, K). """ keypoints = encoded.copy() - if target_root is not None: + if target_root.size > 0: keypoints = keypoints + np.expand_dims(target_root, axis=0) if self.remove_root: keypoints = np.insert( diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index 412cfcf985..d671a6ae94 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -301,9 +301,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: 'iscrowd': 0, 'img_paths': list(_img_names), 'img_ids': frame_ids, - 'target': keypoints_3d[target_idx], - 'target_visible': keypoints_3d_visible[target_idx], - 'target_img_id': frame_ids[target_idx], + 'lifting_target': keypoints_3d[target_idx], + 'lifting_target_visible': keypoints_3d_visible[target_idx], 'target_img_path': _img_names[target_idx], } @@ -392,7 +391,6 @@ def _get_bottomup_data_infos(self, instance_list: List[Dict], # add images without instance for evaluation if self.test_mode: - print(image_list) for img_info in image_list: if img_info['img_id'] not in used_img_ids: data_info_bu = { diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 14be378e19..6b09b2c770 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -89,8 +89,8 @@ class PackPoseInputs(BaseTransform): 'bbox_score': 'bbox_scores', 'keypoints': 'keypoints', 'keypoints_visible': 'keypoints_visible', - 'target': 'target', - 'target_visible': 'target_visible', + 'lifting_target': 'lifting_target', + 'lifting_target_visible': 'lifting_target_visible', } # items in `label_mapping_table` will be packed into @@ -98,6 +98,9 @@ class PackPoseInputs(BaseTransform): # will be used for computing losses label_mapping_table = { 'keypoint_labels': 'keypoint_labels', + 'lifting_target_label': 'lifting_target_label', + 'lifting_target_weights': 'lifting_target_weights', + 'trajectory_weights': 'trajectory_weights', 'keypoint_x_labels': 'keypoint_x_labels', 'keypoint_y_labels': 'keypoint_y_labels', 'keypoint_weights': 'keypoint_weights', @@ -139,11 +142,12 @@ def transform(self, results: dict) -> dict: - 'data_samples' (obj:`PoseDataSample`): The annotation info of the sample. """ - # Pack image(s) or 2d keypoints + # Pack image(s) for 2d pose estimation if 'img' in results: img = results['img'] inputs_tensor = image_to_tensor(img) - elif 'keypoints_3d' in results and 'keypoints' in results: + # Pack keypoints for 3d pose-lifting + elif 'lifting_target' in results and 'keypoints' in results: inputs_tensor = results['keypoints'] data_sample = PoseDataSample() @@ -166,6 +170,11 @@ def transform(self, results: dict) -> dict: gt_instance_labels = InstanceData() for key, packed_key in self.label_mapping_table.items(): if key in results: + # For pose-lifting, store only target-related fields + if 'lifting_target' in results and key in { + 'keypoint_labels', 'keypoint_weights' + }: + continue if isinstance(results[key], list): # A list of labels is usually generated by combined # multiple encoders (See ``GenerateTarget`` in diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py index 4f86c247b9..e6559fa398 100644 --- a/mmpose/datasets/transforms/pose3d_transforms.py +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -19,19 +19,20 @@ class RandomFlipAroundRoot(BaseTransform): refer to the docstring of the ``flip_keypoints_custom_center`` function for more details. target_flip_cfg (dict): Configurations of the - ``flip_keypoints_custom_center`` function for ``target``. Please - refer to the docstring of the ``flip_keypoints_custom_center`` - function for more details. + ``flip_keypoints_custom_center`` function for ``lifting_target``. + Please refer to the docstring of the + ``flip_keypoints_custom_center`` function for more details. flip_prob (float): Probability of flip. Default: 0.5. flip_camera (bool): Whether to flip horizontal distortion coefficients. Default: ``False``. Required keys: keypoints - target + lifting_target Modified keys: - (keypoints, keypoints_visible, target, target_visible, camera_param) + (keypoints, keypoints_visible, lifting_target, lifting_target_visible, + camera_param) """ def __init__(self, @@ -61,11 +62,12 @@ def transform(self, results: Dict) -> dict: keypoints_visible = results['keypoints_visible'] else: keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32) - target = results['target'] - if 'target_visible' in results: - target_visible = results['target_visible'] + lifting_target = results['lifting_target'] + if 'lifting_target_visible' in results: + lifting_target_visible = results['lifting_target_visible'] else: - target_visible = np.ones(target.shape[:-1], dtype=np.float32) + lifting_target_visible = np.ones( + lifting_target.shape[:-1], dtype=np.float32) if np.random.rand() <= self.flip_prob: if 'flip_indices' not in results: @@ -77,13 +79,14 @@ def transform(self, results: Dict) -> dict: keypoints, keypoints_visible = flip_keypoints_custom_center( keypoints, keypoints_visible, flip_indices, **self.keypoints_flip_cfg) - target, target_visible = flip_keypoints_custom_center( - target, target_visible, flip_indices, **self.target_flip_cfg) + lifting_target, lifting_target_visible = flip_keypoints_custom_center( # noqa + lifting_target, lifting_target_visible, flip_indices, + **self.target_flip_cfg) results['keypoints'] = keypoints results['keypoints_visible'] = keypoints_visible - results['target'] = target - results['target_visible'] = target_visible + results['lifting_target'] = lifting_target + results['lifting_target_visible'] = lifting_target_visible # flip horizontal distortion coefficients if self.flip_camera: diff --git a/mmpose/evaluation/functional/keypoint_eval.py b/mmpose/evaluation/functional/keypoint_eval.py index 9c0d6998ee..3c689f3b00 100644 --- a/mmpose/evaluation/functional/keypoint_eval.py +++ b/mmpose/evaluation/functional/keypoint_eval.py @@ -346,6 +346,7 @@ def keypoint_mpjpe(pred: np.ndarray, - ``'scale'``: align in the least-square sense in scale - ``'procrustes'``: align in the least-square sense in scale, rotation and translation. + Returns: tuple: A tuple containing joint position errors diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index c65dbb1998..0b313d4d3f 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -72,9 +72,9 @@ def process(self, data_batch: Sequence[dict], # ground truth data_info gt = data_sample['gt_instances'] # ground truth keypoints coordinates, [1, K, D] - gt_coords = gt['target'] + gt_coords = gt['lifting_target'] # ground truth keypoints_visible, [1, K, 1] - mask = gt['target_visible'].astype(bool).reshape(1, -1) + mask = gt['lifting_target_visible'].astype(bool).reshape(1, -1) # instance action img_path = data_sample['target_img_path'] _, rest = osp.basename(img_path).split('_', 1) diff --git a/mmpose/models/heads/regression_heads/temporal_regression_head.py b/mmpose/models/heads/regression_heads/temporal_regression_head.py index 5bc019f948..a33de19594 100644 --- a/mmpose/models/heads/regression_heads/temporal_regression_head.py +++ b/mmpose/models/heads/regression_heads/temporal_regression_head.py @@ -66,7 +66,7 @@ def forward(self, feats: Tuple[Tensor]) -> Tensor: feats (Tuple[Tensor]): Multi scale feature maps. Returns: - Tensor: output coordinates(and sigmas[optional]). + Tensor: Output coordinates (and sigmas[optional]). """ x = feats[-1] @@ -78,7 +78,16 @@ def predict(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, test_cfg: ConfigType = {}) -> Predictions: - """Predict results from outputs.""" + """Predict results from outputs. + + Returns: + preds (sequence[InstanceData]): Prediction results. + Each contains the following fields: + + - keypoints: Predicted keypoints of shape (B, N, K, D). + - keypoint_scores: Scores of predicted keypoints of shape + (B, N, K). + """ batch_coords = self.forward(feats) # (B, K, D) @@ -89,8 +98,13 @@ def predict(self, if target_root is not None: target_root = torch.stack( [m['target_root'] for m in batch_data_samples[0].metainfo]) + else: + target_root = torch.stack([ + torch.empty((0), dtype=torch.float32) + for _ in batch_data_samples[0].metainfo + ]) - preds = self.decode(batch_coords, target_root) + preds = self.decode((batch_coords, target_root)) return preds @@ -102,28 +116,31 @@ def loss(self, pred_outputs = self.forward(inputs) - keypoint_labels = torch.cat( - [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) - keypoint_weights = torch.cat([ - d.gt_instance_labels.keypoint_weights for d in batch_data_samples + lifting_target_label = torch.cat([ + d.gt_instance_labels.lifting_target_label + for d in batch_data_samples + ]) + lifting_target_weights = torch.cat([ + d.gt_instance_labels.lifting_target_weights + for d in batch_data_samples ]) # calculate losses losses = dict() - loss = self.loss_module(pred_outputs, keypoint_labels, - keypoint_weights.unsqueeze(-1)) + loss = self.loss_module(pred_outputs, lifting_target_label, + lifting_target_weights.unsqueeze(-1)) losses.update(loss_pose3d=loss) # calculate accuracy _, avg_acc, _ = keypoint_pck_accuracy( pred=to_numpy(pred_outputs), - gt=to_numpy(keypoint_labels), - mask=to_numpy(keypoint_weights) > 0, + gt=to_numpy(lifting_target_label), + mask=to_numpy(lifting_target_weights) > 0, thr=0.05, - norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32)) + norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32)) - mpjpe_pose = torch.tensor(avg_acc, device=keypoint_labels.device) + mpjpe_pose = torch.tensor(avg_acc, device=lifting_target_label.device) losses.update(mpjpe=mpjpe_pose) return losses diff --git a/mmpose/models/heads/regression_heads/trajectory_regression_head.py b/mmpose/models/heads/regression_heads/trajectory_regression_head.py index 76626fb3dc..0b72ae3155 100644 --- a/mmpose/models/heads/regression_heads/trajectory_regression_head.py +++ b/mmpose/models/heads/regression_heads/trajectory_regression_head.py @@ -78,7 +78,16 @@ def predict(self, feats: Tuple[Tensor], batch_data_samples: OptSampleList, test_cfg: ConfigType = {}) -> Predictions: - """Predict results from outputs.""" + """Predict results from outputs. + + Returns: + preds (sequence[InstanceData]): Prediction results. + Each contains the following fields: + + - keypoints: Predicted keypoints of shape (B, N, K, D). + - keypoint_scores: Scores of predicted keypoints of shape + (B, N, K). + """ batch_coords = self.forward(feats) # (B, K, D) @@ -89,8 +98,13 @@ def predict(self, if target_root is not None: target_root = torch.stack( [m['target_root'] for m in batch_data_samples[0].metainfo]) + else: + target_root = torch.stack([ + torch.empty((0), dtype=torch.float32) + for _ in batch_data_samples[0].metainfo + ]) - preds = self.decode(batch_coords, target_root) + preds = self.decode((batch_coords, target_root)) return preds @@ -102,28 +116,30 @@ def loss(self, pred_outputs = self.forward(inputs) - keypoint_labels = torch.cat( - [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) - trjectory_weights = torch.cat([ - d.gt_instance_labels.trjectory_weights for d in batch_data_samples + lifting_target_label = torch.cat([ + d.gt_instance_labels.lifting_target_label + for d in batch_data_samples + ]) + trajectory_weights = torch.cat([ + d.gt_instance_labels.trajectory_weights for d in batch_data_samples ]) # calculate losses losses = dict() - loss = self.loss_module(pred_outputs, keypoint_labels, - trjectory_weights.unsqueeze(-1)) + loss = self.loss_module(pred_outputs, lifting_target_label, + trajectory_weights.unsqueeze(-1)) losses.update(loss_traj=loss) # calculate accuracy _, avg_acc, _ = keypoint_pck_accuracy( pred=to_numpy(pred_outputs), - gt=to_numpy(keypoint_labels), - mask=to_numpy(trjectory_weights) > 0, + gt=to_numpy(lifting_target_label), + mask=to_numpy(trajectory_weights) > 0, thr=0.05, - norm_factor=np.ones((pred_outputs.size(0), 2), dtype=np.float32)) + norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32)) - mpjpe_traj = torch.tensor(avg_acc, device=keypoint_labels.device) + mpjpe_traj = torch.tensor(avg_acc, device=lifting_target_label.device) losses.update(mpjpe_traj=mpjpe_traj) return losses diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index 1c3eaf0b9f..78a4262834 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -13,8 +13,8 @@ def setUp(self) -> None: keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] keypoints = np.round(keypoints).astype(np.float32) keypoints_visible = np.random.randint(2, size=(1, 17)) - target = (0.1 + 0.8 * np.random.rand(17, 3)) - target_visible = np.random.randint(2, size=(17, )) + lifting_target = (0.1 + 0.8 * np.random.rand(17, 3)) + lifting_target_visible = np.random.randint(2, size=(17, )) encoded_wo_sigma = np.random.rand(1, 17, 3) self.keypoints_mean = np.random.rand(17, 2).astype(np.float32) @@ -25,8 +25,8 @@ def setUp(self) -> None: self.data = dict( keypoints=keypoints, keypoints_visible=keypoints_visible, - target=target, - target_visible=target_visible, + lifting_target=lifting_target, + lifting_target_visible=lifting_target_visible, encoded_wo_sigma=encoded_wo_sigma) def build_pose_lifting_label(self, **kwargs): @@ -41,31 +41,31 @@ def test_build(self): def test_encode(self): keypoints = self.data['keypoints'] keypoints_visible = self.data['keypoints_visible'] - target = self.data['target'] - target_visible = self.data['target_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] # test default settings codec = self.build_pose_lifting_label() - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (17, 3)) - self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) self.assertEqual(encoded['target_root'].shape, (3, )) # test removing root codec = self.build_pose_lifting_label( remove_root=True, save_index=True) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) self.assertEqual(encoded['target_root'].shape, (3, )) # test normalization @@ -74,20 +74,20 @@ def test_encode(self): keypoints_std=self.keypoints_std, target_mean=self.target_mean, target_std=self.target_std) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) def test_decode(self): - target = self.data['target'] + lifting_target = self.data['lifting_target'] encoded_wo_sigma = self.data['encoded_wo_sigma'] codec = self.build_pose_lifting_label() decoded, scores = codec.decode( - encoded_wo_sigma, target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) @@ -95,7 +95,7 @@ def test_decode(self): codec = self.build_pose_lifting_label(remove_root=True) decoded, scores = codec.decode( - encoded_wo_sigma, target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) self.assertEqual(decoded.shape, (1, 18, 3)) self.assertEqual(scores.shape, (1, 18)) @@ -103,32 +103,34 @@ def test_decode(self): def test_cicular_verification(self): keypoints = self.data['keypoints'] keypoints_visible = self.data['keypoints_visible'] - target = self.data['target'] - target_visible = self.data['target_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] # test default settings codec = self.build_pose_lifting_label() - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['target_label'], axis=0), - target_root=target[..., 0, :]) + np.expand_dims(encoded['lifting_target_label'], axis=0), + target_root=lifting_target[..., 0, :]) self.assertTrue( - np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + np.allclose( + np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) # test removing root codec = self.build_pose_lifting_label(remove_root=True) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['target_label'], axis=0), - target_root=target[..., 0, :]) + np.expand_dims(encoded['lifting_target_label'], axis=0), + target_root=lifting_target[..., 0, :]) self.assertTrue( - np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + np.allclose( + np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) # test normalization codec = self.build_pose_lifting_label( @@ -136,12 +138,13 @@ def test_cicular_verification(self): keypoints_std=self.keypoints_std, target_mean=self.target_mean, target_std=self.target_std) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['target_label'], axis=0), - target_root=target[..., 0, :]) + np.expand_dims(encoded['lifting_target_label'], axis=0), + target_root=lifting_target[..., 0, :]) self.assertTrue( - np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + np.allclose( + np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index 58fc2cd29b..05fc10ee95 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -27,8 +27,8 @@ def setUp(self) -> None: keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] keypoints = np.round(keypoints).astype(np.float32) keypoints_visible = np.random.randint(2, size=(1, 17)) - target = (0.1 + 0.8 * np.random.rand(17, 3)) - target_visible = np.random.randint(2, size=(17, )) + lifting_target = (0.1 + 0.8 * np.random.rand(17, 3)) + lifting_target_visible = np.random.randint(2, size=(17, )) encoded_wo_sigma = np.random.rand(1, 17, 3) camera_param = load('tests/data/h36m/cameras.pkl') @@ -39,8 +39,8 @@ def setUp(self) -> None: self.data = dict( keypoints=keypoints, keypoints_visible=keypoints_visible, - target=target, - target_visible=target_visible, + lifting_target=lifting_target, + lifting_target_visible=lifting_target_visible, camera_param=camera_param, encoded_wo_sigma=encoded_wo_sigma) @@ -51,48 +51,48 @@ def test_build(self): def test_encode(self): keypoints = self.data['keypoints'] keypoints_visible = self.data['keypoints_visible'] - target = self.data['target'] - target_visible = self.data['target_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] camera_param = self.data['camera_param'] # test default settings codec = self.build_pose_lifting_label() - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (17, 3)) - self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) self.assertEqual(encoded['target_root'].shape, (3, )) # test not zero-centering codec = self.build_pose_lifting_label(zero_center=False) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (17, 3)) - self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) # test removing root codec = self.build_pose_lifting_label( remove_root=True, save_index=True) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) - self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) self.assertEqual(encoded['target_root'].shape, (3, )) # test normalizing camera codec = self.build_pose_lifting_label(normalize_camera=True) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) self.assertTrue('camera_param' in encoded) scale = np.array(0.5 * camera_param['w'], dtype=np.float32) @@ -103,13 +103,13 @@ def test_encode(self): atol=4.)) def test_decode(self): - target = self.data['target'] + lifting_target = self.data['lifting_target'] encoded_wo_sigma = self.data['encoded_wo_sigma'] codec = self.build_pose_lifting_label() decoded, scores = codec.decode( - encoded_wo_sigma, target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) @@ -117,7 +117,7 @@ def test_decode(self): codec = self.build_pose_lifting_label(remove_root=True) decoded, scores = codec.decode( - encoded_wo_sigma, target_root=target[..., 0, :]) + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) self.assertEqual(decoded.shape, (1, 18, 3)) self.assertEqual(scores.shape, (1, 18)) @@ -125,30 +125,32 @@ def test_decode(self): def test_cicular_verification(self): keypoints = self.data['keypoints'] keypoints_visible = self.data['keypoints_visible'] - target = self.data['target'] - target_visible = self.data['target_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] camera_param = self.data['camera_param'] # test default settings codec = self.build_pose_lifting_label() - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) _keypoints, _ = codec.decode( - np.expand_dims(encoded['target_label'], axis=0), - target_root=target[..., 0, :]) + np.expand_dims(encoded['lifting_target_label'], axis=0), + target_root=lifting_target[..., 0, :]) self.assertTrue( - np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + np.allclose( + np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) # test removing root codec = self.build_pose_lifting_label(remove_root=True) - encoded = codec.encode(keypoints, keypoints_visible, target, - target_visible, camera_param) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) _keypoints, _ = codec.decode( - np.expand_dims(encoded['target_label'], axis=0), - target_root=target[..., 0, :]) + np.expand_dims(encoded['lifting_target_label'], axis=0), + target_root=lifting_target[..., 0, :]) self.assertTrue( - np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + np.allclose( + np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py index 16118db272..5f5d5aa096 100644 --- a/tests/test_datasets/test_transforms/test_pose3d_transforms.py +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -50,8 +50,8 @@ def _parse_h36m_imgname(imgname): 'category_id': 1, 'iscrowd': 0, 'sample_idx': idx, - 'target': keypoints_3d[target_idx, :, :3], - 'target_visible': keypoints_3d[target_idx, :, 3], + 'lifting_target': keypoints_3d[target_idx, :, :3], + 'lifting_target_visible': keypoints_3d[target_idx, :, 3], 'target_img_path': osp.join('tests/data/h36m', imgnames[target_idx]), } @@ -94,8 +94,8 @@ def test_init(self): def test_transform(self): kpts1 = self.data_info['keypoints'] kpts_vis1 = self.data_info['keypoints_visible'] - tar1 = self.data_info['target'] - tar_vis1 = self.data_info['target_visible'] + tar1 = self.data_info['lifting_target'] + tar_vis1 = self.data_info['lifting_target_visible'] transform = RandomFlipAroundRoot( self.keypoints_flip_cfg, self.target_flip_cfg, flip_prob=1) @@ -104,8 +104,8 @@ def test_transform(self): kpts2 = results['keypoints'] kpts_vis2 = results['keypoints_visible'] - tar2 = results['target'] - tar_vis2 = results['target_visible'] + tar2 = results['lifting_target'] + tar_vis2 = results['lifting_target_visible'] self.assertEqual(kpts_vis2.shape, (1, 17)) self.assertEqual(tar_vis2.shape, (17, )) diff --git a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py index 40da092cad..d51d493cbc 100644 --- a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py +++ b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py @@ -20,8 +20,8 @@ def setUp(self): for i in range(self.batch_size): gt_instances = InstanceData() keypoints = np.random.random((1, num_keypoints, 3)) - gt_instances.target = keypoints - gt_instances.target_visible = np.ones( + gt_instances.lifting_target = keypoints + gt_instances.lifting_target_visible = np.ones( (1, num_keypoints, 1)).astype(bool) pred_instances = InstanceData() From b709a783bb8dc5251e59225eec2e10aacc708824 Mon Sep 17 00:00:00 2001 From: Yifan Lareina WU Date: Sat, 27 May 2023 12:23:03 +0800 Subject: [PATCH 6/8] [Refactor] Support VideoPose (#2328) --- .../video_pose_lift/README.md | 17 + ...pose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py | 132 ++++ ...e3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py | 132 ++++ ...videopose3d-243frm-supv_8xb128-80e_h36m.py | 128 ++++ ...-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py | 119 ++++ ...opose3d-27frm-semi-supv_8xb64-200e_h36m.py | 117 ++++ ..._videopose3d-27frm-supv_8xb128-80e_h36m.py | 128 ++++ ..._videopose3d-81frm-supv_8xb128-80e_h36m.py | 128 ++++ .../video_pose_lift/h36m/videopose3d_h36m.yml | 102 ++++ .../video_pose_lift/h36m/videpose3d_h36m.md | 67 +++ demo/body3d_pose_lifter_demo.py | 466 +++++++++++++++ demo/docs/3d_human_pose_demo.md | 74 +++ mmpose/apis/__init__.py | 9 +- mmpose/apis/inference.py | 33 + mmpose/apis/inference_3d.py | 252 ++++++++ mmpose/apis/inference_tracking.py | 103 ++++ mmpose/codecs/image_pose_lifting.py | 3 - mmpose/codecs/video_pose_lifting.py | 7 +- mmpose/datasets/datasets/__init__.py | 1 + mmpose/datasets/datasets/utils.py | 5 + mmpose/datasets/transforms/formatting.py | 37 +- .../evaluation/metrics/keypoint_3d_metrics.py | 4 +- .../temporal_regression_head.py | 8 +- .../trajectory_regression_head.py | 8 +- mmpose/models/pose_estimators/base.py | 2 + mmpose/models/pose_estimators/pose_lifter.py | 16 +- mmpose/visualization/__init__.py | 3 +- mmpose/visualization/local_visualizer_3d.py | 563 ++++++++++++++++++ tests/test_codecs/test_image_pose_lifting.py | 6 +- tests/test_codecs/test_video_pose_lifting.py | 6 +- .../test_metrics/test_keypoint_3d_metrics.py | 4 +- 31 files changed, 2646 insertions(+), 34 deletions(-) create mode 100644 configs/body_3d_keypoint/video_pose_lift/README.md create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/videopose3d_h36m.yml create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/videpose3d_h36m.md create mode 100644 demo/body3d_pose_lifter_demo.py create mode 100644 demo/docs/3d_human_pose_demo.md create mode 100644 mmpose/apis/inference_3d.py create mode 100644 mmpose/apis/inference_tracking.py create mode 100644 mmpose/visualization/local_visualizer_3d.py diff --git a/configs/body_3d_keypoint/video_pose_lift/README.md b/configs/body_3d_keypoint/video_pose_lift/README.md new file mode 100644 index 0000000000..c23b69ea7f --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/README.md @@ -0,0 +1,17 @@ +# 3D human pose estimation in video with temporal convolutions and semi-supervised training + +Based on the success of 2d human pose estimation, it directly "lifts" a sequence of 2d keypoints to 3d keypoints. + +## Results and Models + +### Human3.6m Dataset + +| Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :------------------------------------------------------ | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------: | :-----------------------------------------------------: | +| [VideoPose3D-supervised](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | +| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | +| [VideoPose3D-semi-supervised](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | +| [VideoPose3D-semi-supervised-CPN](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py new file mode 100644 index 0000000000..0cbf89142d --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py @@ -0,0 +1,132 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=80, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-4)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.98, end=80, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=4, + kernel_sizes=(1, 1, 1, 1, 1), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + )) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(), + target_flip_cfg=dict(), + ), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +train_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=1, + causal=False, + pad_video_seq=False, + keypoint_2d_src='detection', + keypoint_2d_det_file='joint_2d_det_files/cpn_ft_h36m_dbb_train.npy', + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + ), +) +val_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=1, + causal=False, + pad_video_seq=False, + keypoint_2d_src='detection', + keypoint_2d_det_file='joint_2d_det_files/cpn_ft_h36m_dbb_test.npy', + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py new file mode 100644 index 0000000000..3ef3df570b --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py @@ -0,0 +1,132 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=200, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-4)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.98, end=200, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=4, + kernel_sizes=(3, 3, 3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + )) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(), + target_flip_cfg=dict(), + ), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +train_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=243, + causal=False, + pad_video_seq=True, + keypoint_2d_src='detection', + keypoint_2d_det_file='joint_2d_det_files/cpn_ft_h36m_dbb_train.npy', + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + ), +) +val_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=243, + causal=False, + pad_video_seq=True, + keypoint_2d_src='detection', + keypoint_2d_det_file='joint_2d_det_files/cpn_ft_h36m_dbb_test.npy', + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py new file mode 100644 index 0000000000..0f311ac5cf --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py @@ -0,0 +1,128 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=80, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.975, end=80, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=4, + kernel_sizes=(3, 3, 3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + )) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(), + target_flip_cfg=dict(), + ), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +train_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=243, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + ), +) +val_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=243, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py new file mode 100644 index 0000000000..08bcda8ed7 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py @@ -0,0 +1,119 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = None + +# optimizer + +# learning policy + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + ), + traj_backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + traj_head=dict( + type='TrajectoryRegressionHead', + in_channels=1024, + num_joints=1, + loss=dict(type='MPJPELoss', use_target_weight=True), + decoder=codec, + ), + semi_loss=dict( + type='SemiSupervisionLoss', + joint_parents=[0, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15], + warmup_iterations=1311376 // 64 // 8 * 5), +) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +val_dataloader = dict( + batch_size=64, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=27, + causal=False, + pad_video_seq=True, + keypoint_2d_src='detection', + keypoint_2d_det_file='joint_2d_det_files/cpn_ft_h36m_dbb_test.npy', + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe'), + dict(type='MPJPE', mode='n-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py new file mode 100644 index 0000000000..d145f05b17 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py @@ -0,0 +1,117 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = None + +# optimizer + +# learning policy + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + ), + traj_backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + traj_head=dict( + type='TrajectoryRegressionHead', + in_channels=1024, + num_joints=1, + loss=dict(type='MPJPELoss', use_target_weight=True), + decoder=codec, + ), + semi_loss=dict( + type='SemiSupervisionLoss', + joint_parents=[0, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15], + warmup_iterations=1311376 // 64 // 8 * 5), +) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +val_dataloader = dict( + batch_size=64, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=27, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe'), + dict(type='MPJPE', mode='n-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py new file mode 100644 index 0000000000..2589b493a6 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py @@ -0,0 +1,128 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=80, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.975, end=80, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=2, + kernel_sizes=(3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + )) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(), + target_flip_cfg=dict(), + ), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +train_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=27, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + ), +) +val_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=27, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py new file mode 100644 index 0000000000..f2c27e423d --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py @@ -0,0 +1,128 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = dict(max_epochs=80, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) + +# learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.975, end=80, by_epoch=True) +] + +auto_scale_lr = dict(base_batch_size=1024) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='TCN', + in_channels=2 * 17, + stem_channels=1024, + num_blocks=3, + kernel_sizes=(3, 3, 3, 3), + dropout=0.25, + use_stride_conv=True, + ), + head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=17, + loss=dict(type='MPJPELoss'), + decoder=codec, + )) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg=dict(), + target_flip_cfg=dict(), + ), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +train_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=81, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + ), +) +val_dataloader = dict( + batch_size=128, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=81, + causal=False, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/videopose3d_h36m.yml b/configs/body_3d_keypoint/video_pose_lift/h36m/videopose3d_h36m.yml new file mode 100644 index 0000000000..0703111b1b --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/videopose3d_h36m.yml @@ -0,0 +1,102 @@ +Collections: +- Name: VideoPose3D + Paper: + Title: 3d human pose estimation in video with temporal convolutions and semi-supervised + training + URL: http://openaccess.thecvf.com/content_CVPR_2019/html/Pavllo_3D_Human_Pose_Estimation_in_Video_With_Temporal_Convolutions_and_CVPR_2019_paper.html + README: https://github.com/open-mmlab/mmpose/blob/main/docs/en/papers/algorithms/videopose3d.md +Models: +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: &id001 + - VideoPose3D + Training Data: Human3.6M + Name: vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 40.0 + P-MPJPE: 30.1 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 38.9 + P-MPJPE: 29.2 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 37.6 + P-MPJPE: 28.3 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 52.9 + P-MPJPE: 41.3 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 47.9 + P-MPJPE: 38.0 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 58.1 + N-MPJPE: 54.7 + P-MPJPE: 42.8 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py + In Collection: VideoPose3D + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 67.4 + N-MPJPE: 63.2 + P-MPJPE: 50.1 + Task: Body 3D Keypoint + Weights: https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/videpose3d_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/videpose3d_h36m.md new file mode 100644 index 0000000000..c36ef29df9 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/videpose3d_h36m.md @@ -0,0 +1,67 @@ + + +
+ +VideoPose3D (CVPR'2019) + +```bibtex +@inproceedings{pavllo20193d, +title={3d human pose estimation in video with temporal convolutions and semi-supervised training}, +author={Pavllo, Dario and Feichtenhofer, Christoph and Grangier, David and Auli, Michael}, +booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, +pages={7753--7762}, +year={2019} +} +``` + +
+ + + +
+Human3.6M (TPAMI'2014) + +```bibtex +@article{h36m_pami, +author = {Ionescu, Catalin and Papava, Dragos and Olaru, Vlad and Sminchisescu, Cristian}, +title = {Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments}, +journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, +publisher = {IEEE Computer Society}, +volume = {36}, +number = {7}, +pages = {1325-1339}, +month = {jul}, +year = {2014} +} +``` + +
+ +Testing results on Human3.6M dataset with ground truth 2D detections, supervised training + +| Arch | Receptive Field | MPJPE | P-MPJPE | ckpt | log | +| :--------------------------------------------------------- | :-------------: | :---: | :-----: | :--------------------------------------------------------: | :-------------------------------------------------------: | +| [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | | | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | + +Testing results on Human3.6M dataset with CPN 2D detections1, supervised training + +| Arch | Receptive Field | MPJPE | P-MPJPE | ckpt | log | +| :--------------------------------------------------------- | :-------------: | :---: | :-----: | :--------------------------------------------------------: | :-------------------------------------------------------: | +| [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | +| [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | | | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | + +Testing results on Human3.6M dataset with ground truth 2D detections, semi-supervised training + +| Training Data | Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :------------ | :-------------------------------------------------: | :-------------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-------------------------------------------------: | +| 10% S1 | [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | + +Testing results on Human3.6M dataset with CPN 2D detections1, semi-supervised training + +| Training Data | Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :------------ | :-------------------------------------------------: | :-------------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-------------------------------------------------: | +| 10% S1 | [VideoPose3D](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | + +1 CPN 2D detections are provided by [official repo](https://github.com/facebookresearch/VideoPose3D/blob/master/DATASETS.md). The reformatted version used in this repository can be downloaded from [train_detection](https://download.openmmlab.com/mmpose/body3d/videopose/cpn_ft_h36m_dbb_train.npy) and [test_detection](https://download.openmmlab.com/mmpose/body3d/videopose/cpn_ft_h36m_dbb_test.npy). diff --git a/demo/body3d_pose_lifter_demo.py b/demo/body3d_pose_lifter_demo.py new file mode 100644 index 0000000000..8b834d08ca --- /dev/null +++ b/demo/body3d_pose_lifter_demo.py @@ -0,0 +1,466 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import cv2 +import mmcv +import numpy as np +from mmengine.registry import init_default_scope +from mmengine.structures import InstanceData + +from mmpose.apis import (_track_by_iou, _track_by_oks, collect_multi_frames, + extract_pose_sequence, inference_pose_lifter_model, + inference_topdown, init_model) +from mmpose.models.pose_estimators import PoseLifter +from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator +from mmpose.registry import VISUALIZERS +from mmpose.structures import PoseDataSample, merge_data_samples +from mmpose.utils import adapt_mmdet_pipeline + +try: + from mmdet.apis import inference_detector, init_detector + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + + +def convert_keypoint_definition(keypoints, pose_det_dataset, + pose_lift_dataset): + """Convert pose det dataset keypoints definition to pose lifter dataset + keypoints definition, so that they are compatible with the definitions + required for 3D pose lifting. + + Args: + keypoints (ndarray[N, K, 2 or 3]): 2D keypoints to be transformed. + pose_det_dataset, (str): Name of the dataset for 2D pose detector. + pose_lift_dataset (str): Name of the dataset for pose lifter model. + + Returns: + ndarray[K, 2 or 3]: the transformed 2D keypoints. + """ + assert pose_lift_dataset in [ + 'Human36mDataset'], '`pose_lift_dataset` should be ' \ + f'`Human36mDataset`, but got {pose_lift_dataset}.' + + coco_style_datasets = [ + 'CocoDataset', 'PoseTrack18VideoDataset', 'PoseTrack18Dataset' + ] + keypoints_new = np.zeros((keypoints.shape[0], 17, keypoints.shape[2]), + dtype=keypoints.dtype) + if pose_lift_dataset == 'Human36mDataset': + if pose_det_dataset in ['Human36mDataset']: + keypoints_new = keypoints + elif pose_det_dataset in coco_style_datasets: + # pelvis (root) is in the middle of l_hip and r_hip + keypoints_new[:, 0] = (keypoints[:, 11] + keypoints[:, 12]) / 2 + # thorax is in the middle of l_shoulder and r_shoulder + keypoints_new[:, 8] = (keypoints[:, 5] + keypoints[:, 6]) / 2 + # spine is in the middle of thorax and pelvis + keypoints_new[:, + 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2 + # in COCO, head is in the middle of l_eye and r_eye + # in PoseTrack18, head is in the middle of head_bottom and head_top + keypoints_new[:, 10] = (keypoints[:, 1] + keypoints[:, 2]) / 2 + # rearrange other keypoints + keypoints_new[:, [1, 2, 3, 4, 5, 6, 9, 11, 12, 13, 14, 15, 16]] = \ + keypoints[:, [12, 14, 16, 11, 13, 15, 0, 5, 7, 9, 6, 8, 10]] + elif pose_det_dataset in ['AicDataset']: + # pelvis (root) is in the middle of l_hip and r_hip + keypoints_new[:, 0] = (keypoints[:, 9] + keypoints[:, 6]) / 2 + # thorax is in the middle of l_shoulder and r_shoulder + keypoints_new[:, 8] = (keypoints[:, 3] + keypoints[:, 0]) / 2 + # spine is in the middle of thorax and pelvis + keypoints_new[:, + 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2 + # neck base (top end of neck) is 1/4 the way from + # neck (bottom end of neck) to head top + keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4 + # head (spherical centre of head) is 7/12 the way from + # neck (bottom end of neck) to head top + keypoints_new[:, 10] = (5 * keypoints[:, 13] + + 7 * keypoints[:, 12]) / 12 + + keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \ + keypoints[:, [6, 7, 8, 9, 10, 11, 3, 4, 5, 0, 1, 2]] + elif pose_det_dataset in ['CrowdPoseDataset']: + # pelvis (root) is in the middle of l_hip and r_hip + keypoints_new[:, 0] = (keypoints[:, 6] + keypoints[:, 7]) / 2 + # thorax is in the middle of l_shoulder and r_shoulder + keypoints_new[:, 8] = (keypoints[:, 0] + keypoints[:, 1]) / 2 + # spine is in the middle of thorax and pelvis + keypoints_new[:, + 7] = (keypoints_new[:, 0] + keypoints_new[:, 8]) / 2 + # neck base (top end of neck) is 1/4 the way from + # neck (bottom end of neck) to head top + keypoints_new[:, 9] = (3 * keypoints[:, 13] + keypoints[:, 12]) / 4 + # head (spherical centre of head) is 7/12 the way from + # neck (bottom end of neck) to head top + keypoints_new[:, 10] = (5 * keypoints[:, 13] + + 7 * keypoints[:, 12]) / 12 + + keypoints_new[:, [1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 16]] = \ + keypoints[:, [7, 9, 11, 6, 8, 10, 0, 2, 4, 1, 3, 5]] + else: + raise NotImplementedError( + f'unsupported conversion between {pose_lift_dataset} and ' + f'{pose_det_dataset}') + + return keypoints_new + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('det_config', help='Config file for detection') + parser.add_argument('det_checkpoint', help='Checkpoint file for detection') + parser.add_argument( + 'pose_estimator_config', + type=str, + default=None, + help='Config file for the 1st stage 2D pose estimator') + parser.add_argument( + 'pose_estimator_checkpoint', + type=str, + default=None, + help='Checkpoint file for the 1st stage 2D pose estimator') + parser.add_argument( + 'pose_lifter_config', + help='Config file for the 2nd stage pose lifter model') + parser.add_argument( + 'pose_lifter_checkpoint', + help='Checkpoint file for the 2nd stage pose lifter model') + parser.add_argument('--input', type=str, default='', help='Video path') + parser.add_argument( + '--show', + action='store_true', + default=False, + help='Whether to show visualizations') + parser.add_argument( + '--rebase-keypoint-height', + action='store_true', + help='Rebase the predicted 3D pose so its lowest keypoint has a ' + 'height of 0 (landing on the ground). This is useful for ' + 'visualization when the model do not predict the global position ' + 'of the 3D pose.') + parser.add_argument( + '--norm-pose-2d', + action='store_true', + help='Scale the bbox (along with the 2D pose) to the average bbox ' + 'scale of the dataset, and move the bbox (along with the 2D pose) to ' + 'the average bbox center of the dataset. This is useful when bbox ' + 'is small, especially in multi-person scenarios.') + parser.add_argument( + '--output-root', + type=str, + default='', + help='Root of the output video file. ' + 'Default not saving the visualization video.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--det-cat-id', + type=int, + default=0, + help='Category id for bounding box detection model') + parser.add_argument( + '--bbox-thr', + type=float, + default=0.9, + help='Bounding box score threshold') + parser.add_argument('--kpt-thr', type=float, default=0.3) + parser.add_argument( + '--use-oks-tracking', action='store_true', help='Using OKS tracking') + parser.add_argument( + '--tracking-thr', type=float, default=0.3, help='Tracking threshold') + parser.add_argument( + '--thickness', + type=int, + default=1, + help='Link thickness for visualization') + parser.add_argument( + '--radius', + type=int, + default=3, + help='Keypoint radius for visualization') + parser.add_argument( + '--use-multi-frames', + action='store_true', + default=False, + help='whether to use multi frames for inference in the 2D pose' + 'detection stage. Default: False.') + + args = parser.parse_args() + return args + + +def get_area(results): + for i, data_sample in enumerate(results): + pred_instance = data_sample.pred_instances.cpu().numpy() + if 'bboxes' in pred_instance: + bboxes = pred_instance.bboxes + results[i].pred_instances.set_field( + np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + for bbox in bboxes]), 'areas') + else: + keypoints = pred_instance.keypoints + areas, bboxes = [], [] + for keypoint in keypoints: + xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10) + xmax = np.max(keypoint[:, 0]) + ymin = np.min(keypoint[:, 1][keypoint[:, 1] > 0], initial=1e10) + ymax = np.max(keypoint[:, 1]) + areas.append((xmax - xmin) * (ymax - ymin)) + bboxes.append([xmin, ymin, xmax, ymax]) + results[i].pred_instances.areas = np.array(areas) + results[i].pred_instances.bboxes = np.array(bboxes) + return results + + +def main(): + assert has_mmdet, 'Please install mmdet to run the demo.' + + args = parse_args() + + assert args.show or (args.output_root != '') + assert args.input != '' + assert args.det_config is not None + assert args.det_checkpoint is not None + + detector = init_detector( + args.det_config, args.det_checkpoint, device=args.device.lower()) + detector.cfg = adapt_mmdet_pipeline(detector.cfg) + + pose_estimator = init_model( + args.pose_estimator_config, + args.pose_estimator_checkpoint, + device=args.device.lower()) + + assert isinstance(pose_estimator, TopdownPoseEstimator), 'Only "TopDown"' \ + 'model is supported for the 1st stage (2D pose detection)' + + det_kpt_color = pose_estimator.dataset_meta.get('keypoint_colors', None) + det_dataset_skeleton = pose_estimator.dataset_meta.get( + 'skeleton_links', None) + det_dataset_link_color = pose_estimator.dataset_meta.get( + 'skeleton_link_colors', None) + + # frame index offsets for inference, used in multi-frame inference setting + if args.use_multi_frames: + assert 'frame_indices' in pose_estimator.cfg.test_dataloader.dataset + indices = pose_estimator.cfg.test_dataloader.dataset[ + 'frame_indices_test'] + + pose_det_dataset = pose_estimator.cfg.test_dataloader.dataset + + pose_lifter = init_model( + args.pose_lifter_config, + args.pose_lifter_checkpoint, + device=args.device.lower()) + + assert isinstance(pose_lifter, PoseLifter), \ + 'Only "PoseLifter" model is supported for the 2nd stage ' \ + '(2D-to-3D lifting)' + pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset + + pose_lifter.cfg.visualizer.radius = args.radius + pose_lifter.cfg.visualizer.line_width = args.thickness + local_visualizer = VISUALIZERS.build(pose_lifter.cfg.visualizer) + + # the dataset_meta is loaded from the checkpoint + local_visualizer.set_dataset_meta(pose_lifter.dataset_meta) + + init_default_scope(pose_lifter.cfg.get('default_scope', 'mmpose')) + + if args.output_root == '': + save_out_video = False + else: + os.makedirs(args.output_root, exist_ok=True) + save_out_video = True + + if save_out_video: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = None + + pose_est_results_list = [] + next_id = 0 + pose_est_results = [] + + video = cv2.VideoCapture(args.input) + assert video.isOpened(), f'Failed to load video file {args.input}' + + (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') + if int(major_ver) < 3: + fps = video.get(cv2.cv.CV_CAP_PROP_FPS) + width = video.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) + height = video.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) + else: + fps = video.get(cv2.CAP_PROP_FPS) + width = video.get(cv2.CAP_PROP_FRAME_WIDTH) + height = video.get(cv2.CAP_PROP_FRAME_HEIGHT) + + frame_idx = -1 + + while video.isOpened(): + success, frame = video.read() + frame_idx += 1 + + if not success: + break + + pose_est_results_last = pose_est_results + + # First stage: 2D pose detection + # test a single image, the resulting box is (x1, y1, x2, y2) + det_result = inference_detector(detector, frame) + pred_instance = det_result.pred_instances.cpu().numpy() + + bboxes = pred_instance.bboxes + bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, + pred_instance.scores > args.bbox_thr)] + + if args.use_multi_frames: + frames = collect_multi_frames(video, frame_idx, indices, + args.online) + + # make person results for current image + pose_est_results = inference_topdown( + pose_estimator, frames if args.use_multi_frames else frame, bboxes) + + pose_est_results = get_area(pose_est_results) + if args.use_oks_tracking: + _track = partial(_track_by_oks) + else: + _track = _track_by_iou + + for i, result in enumerate(pose_est_results): + track_id, pose_est_results_last, match_result = _track( + result, pose_est_results_last, args.tracking_thr) + if track_id == -1: + pred_instances = result.pred_instances.cpu().numpy() + keypoints = pred_instances.keypoints + if np.count_nonzero(keypoints[:, :, 1]) >= 3: + pose_est_results[i].set_field(next_id, 'track_id') + next_id += 1 + else: + # If the number of keypoints detected is small, + # delete that person instance. + keypoints[:, :, 1] = -10 + pose_est_results[i].pred_instances.set_field( + keypoints, 'keypoints') + bboxes = pred_instances.bboxes * 0 + pose_est_results[i].pred_instances.set_field( + bboxes, 'bboxes') + pose_est_results[i].set_field(-1, 'track_id') + pose_est_results[i].set_field(pred_instances, + 'pred_instances') + else: + pose_est_results[i].set_field(track_id, 'track_id') + + del match_result + + pose_est_results_converted = [] + for pose_est_result in pose_est_results: + pose_est_result_converted = PoseDataSample() + gt_instances = InstanceData() + pred_instances = InstanceData() + for k in pose_est_result.gt_instances.keys(): + gt_instances.set_field(pose_est_result.gt_instances[k], k) + for k in pose_est_result.pred_instances.keys(): + pred_instances.set_field(pose_est_result.pred_instances[k], k) + pose_est_result_converted.gt_instances = gt_instances + pose_est_result_converted.pred_instances = pred_instances + pose_est_result_converted.track_id = pose_est_result.track_id + pose_est_results_converted.append(pose_est_result_converted) + + for i, result in enumerate(pose_est_results_converted): + keypoints = result.pred_instances.keypoints + keypoints = convert_keypoint_definition(keypoints, + pose_det_dataset['type'], + pose_lift_dataset['type']) + pose_est_results_converted[i].pred_instances.keypoints = keypoints + + pose_est_results_list.append(pose_est_results_converted.copy()) + + # extract and pad input pose2d sequence + pose_results_2d = extract_pose_sequence( + pose_est_results_list, + frame_idx=frame_idx, + causal=pose_lifter.causal, + seq_len=pose_lift_dataset.get('seq_len', 1), + step=pose_lift_dataset.get('seq_step', 1)) + + # Second stage: Pose lifting + # 2D-to-3D pose lifting + pose_lift_results = inference_pose_lifter_model( + pose_lifter, + pose_results_2d, + image_size=(width, height), + norm_pose_2d=args.norm_pose_2d) + + # Pose processing + for idx, pose_lift_res in enumerate(pose_lift_results): + gt_instances = pose_lift_res.gt_instances + + pose_lift_res.track_id = pose_est_results_converted[i].get( + 'track_id', 1e4) + + pred_instances = pose_lift_res.pred_instances + keypoints = pred_instances.keypoints + + keypoints = keypoints[..., [0, 2, 1]] + keypoints[..., 0] = -keypoints[..., 0] + keypoints[..., 2] = -keypoints[..., 2] + + # rebase height (z-axis) + if args.rebase_keypoint_height: + keypoints[..., 2] -= np.min( + keypoints[..., 2], axis=-1, keepdims=True) + + pose_lift_results[i].pred_instances.keypoints = keypoints + + pose_lift_results = sorted( + pose_lift_results, key=lambda x: x.get('track_id', 1e4)) + + pred_3d_data_samples = merge_data_samples(pose_lift_results) + + # Visualization + frame = mmcv.bgr2rgb(frame) + + det_data_sample = merge_data_samples(pose_est_results) + + local_visualizer.add_datasample( + 'result', + frame, + data_sample=pred_3d_data_samples, + det_data_sample=det_data_sample, + draw_gt=False, + det_kpt_color=det_kpt_color, + det_dataset_skeleton=det_dataset_skeleton, + det_dataset_link_color=det_dataset_link_color, + show=args.show, + draw_bbox=True, + kpt_thr=args.kpt_thr, + wait_time=0.001) + + frame_vis = local_visualizer.get_image() + + if save_out_video: + if video_writer is None: + # the size of the image with visualization may vary + # depending on the presence of heatmaps + video_writer = cv2.VideoWriter( + osp.join(args.output_root, + f'vis_{osp.basename(args.input)}'), fourcc, fps, + (frame_vis.shape[1], frame_vis.shape[0])) + + video_writer.write(mmcv.rgb2bgr(frame_vis)) + + video.release() + + if video_writer: + video_writer.release() + + +if __name__ == '__main__': + main() diff --git a/demo/docs/3d_human_pose_demo.md b/demo/docs/3d_human_pose_demo.md new file mode 100644 index 0000000000..eb2eab92ae --- /dev/null +++ b/demo/docs/3d_human_pose_demo.md @@ -0,0 +1,74 @@ +## 3D Human Pose Demo + +
+ +### 3D Human Pose Two-stage Estimation Video Demo + +#### Using mmdet for human bounding box detection and top-down model for the 1st stage (2D pose detection), and inference the 2nd stage (2D-to-3D lifting) + +Assume that you have already installed [mmdet](https://github.com/open-mmlab/mmdetection). + +```shell +python demo/body3d_pose_lifter_demo.py \ +${MMDET_CONFIG_FILE} \ +${MMDET_CHECKPOINT_FILE} \ +${MMPOSE_CONFIG_FILE_2D} \ +${MMPOSE_CHECKPOINT_FILE_2D} \ +${MMPOSE_CONFIG_FILE_3D} \ +${MMPOSE_CHECKPOINT_FILE_3D} \ +--input ${VIDEO_PATH} \ +[--show] \ +[--rebase-keypoint-height] \ +[--norm-pose-2d] \ +[--output-root ${OUT_VIDEO_ROOT}] \ +[--device ${GPU_ID or CPU}] \ +[--det-cat-id DET_CAT_ID] \ +[--bbox-thr BBOX_THR] \ +[--kpt-thr KPT_THR] \ +[--use-oks-tracking] \ +[--tracking-thr TRACKING_THR] \ +[--thickness THICKNESS] \ +[--radius RADIUS] \ +[--use-multi-frames] [--online] +``` + +Note that + +1. `${VIDEO_PATH}` can be the local path or **URL** link to video file. + +2. You can turn on the `[--use-multi-frames]` option to use multi frames for inference in the 2D pose detection stage. + +3. If the `[--online]` option is set to **True**, future frame information can **not** be used when using multi frames for inference in the 2D pose detection stage. + +Examples: + +During 2D pose detection, for single-frame inference that do not rely on extra frames to get the final results of the current frame, try this: + +```shell +python demo/body3d_pose_lifter_demo.py \ +demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py \ +https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ +configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w48_8xb32-210e_coco-256x192.py \ +https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth \ +configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py \ +https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth \ +--input https://user-images.githubusercontent.com/87690686/164970135-b14e424c-765a-4180-9bc8-fa8d6abc5510.mp4 \ +--output-root vis_results \ +--rebase-keypoint-height +``` + +During 2D pose detection, for multi-frame inference that rely on extra frames to get the final results of the current frame, try this: + +```shell +python demo/body3d_pose_lifter_demo.py \ +demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py \ +https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \ +configs/body_2d_keypoint/topdown_heatmap/posetrack18/td-hm_hrnet-w48_8xb64-20e_posetrack18-384x288.py \ +https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_posetrack18_384x288-5fd6d3ff_20211130.pth \ +configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py \ +https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth \ +--input https://user-images.githubusercontent.com/87690686/164970135-b14e424c-765a-4180-9bc8-fa8d6abc5510.mp4 \ +--output-root vis_results \ +--rebase-keypoint-height \ +--use-multi-frames --online +``` diff --git a/mmpose/apis/__init__.py b/mmpose/apis/__init__.py index ff7149e453..dcce33742c 100644 --- a/mmpose/apis/__init__.py +++ b/mmpose/apis/__init__.py @@ -1,8 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .inference import inference_bottomup, inference_topdown, init_model +from .inference import (collect_multi_frames, inference_bottomup, + inference_topdown, init_model) +from .inference_3d import extract_pose_sequence, inference_pose_lifter_model +from .inference_tracking import _compute_iou, _track_by_iou, _track_by_oks from .inferencers import MMPoseInferencer, Pose2DInferencer __all__ = [ 'init_model', 'inference_topdown', 'inference_bottomup', - 'Pose2DInferencer', 'MMPoseInferencer' + 'collect_multi_frames', 'Pose2DInferencer', 'MMPoseInferencer', + '_track_by_iou', '_track_by_oks', '_compute_iou', + 'inference_pose_lifter_model', 'extract_pose_sequence' ] diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py index 6763d318d5..7f733fff45 100644 --- a/mmpose/apis/inference.py +++ b/mmpose/apis/inference.py @@ -223,3 +223,36 @@ def inference_bottomup(model: nn.Module, img: Union[np.ndarray, str]): results = model.test_step(batch) return results + + +def collect_multi_frames(video, frame_id, indices, online=False): + """Collect multi frames from the video. + + Args: + video (mmcv.VideoReader): A VideoReader of the input video file. + frame_id (int): index of the current frame + indices (list(int)): index offsets of the frames to collect + online (bool): inference mode, if set to True, can not use future + frame information. + + Returns: + list(ndarray): multi frames collected from the input video file. + """ + num_frames = len(video) + frames = [] + # put the current frame at first + frames.append(video[frame_id]) + # use multi frames for inference + for idx in indices: + # skip current frame + if idx == 0: + continue + support_idx = frame_id + idx + # online mode, can not use future frame information + if online: + support_idx = np.clip(support_idx, 0, frame_id) + else: + support_idx = np.clip(support_idx, 0, num_frames - 1) + frames.append(video[support_idx]) + + return frames diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py new file mode 100644 index 0000000000..2ab81b20a4 --- /dev/null +++ b/mmpose/apis/inference_3d.py @@ -0,0 +1,252 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmengine.dataset import Compose, pseudo_collate +from mmengine.registry import init_default_scope +from mmengine.structures import InstanceData + +from mmpose.structures import PoseDataSample + + +def extract_pose_sequence(pose_results, frame_idx, causal, seq_len, step=1): + """Extract the target frame from 2D pose results, and pad the sequence to a + fixed length. + + Args: + pose_results (List[List[:obj:`PoseDataSample`]]): Multi-frame pose + detection results stored in a list. + frame_idx (int): The index of the frame in the original video. + causal (bool): If True, the target frame is the last frame in + a sequence. Otherwise, the target frame is in the middle of + a sequence. + seq_len (int): The number of frames in the input sequence. + step (int): Step size to extract frames from the video. + + Returns: + List[List[:obj:`PoseDataSample`]]: Multi-frame pose detection results + stored in a nested list with a length of seq_len. + """ + if causal: + frames_left = seq_len - 1 + frames_right = 0 + else: + frames_left = (seq_len - 1) // 2 + frames_right = frames_left + num_frames = len(pose_results) + + # get the padded sequence + pad_left = max(0, frames_left - frame_idx // step) + pad_right = max(0, frames_right - (num_frames - 1 - frame_idx) // step) + start = max(frame_idx % step, frame_idx - frames_left * step) + end = min(num_frames - (num_frames - 1 - frame_idx) % step, + frame_idx + frames_right * step + 1) + pose_results_seq = [pose_results[0]] * pad_left + \ + pose_results[start:end:step] + [pose_results[-1]] * pad_right + return pose_results_seq + + +def _collate_pose_sequence(pose_results_2d, + with_track_id=True, + target_frame=-1): + """Reorganize multi-frame pose detection results into individual pose + sequences. + + Note: + - The temporal length of the pose detection results: T + - The number of the person instances: N + - The number of the keypoints: K + - The channel number of each keypoint: C + + Args: + pose_results_2d (List[List[:obj:`PoseDataSample`]]): Multi-frame pose + detection results stored in a nested list. Each element of the + outer list is the pose detection results of a single frame, and + each element of the inner list is the pose information of one + person, which contains: + + - keypoints (ndarray[K, 2 or 3]): x, y, [score] + - track_id (int): unique id of each person, required when + ``with_track_id==True``` + + with_track_id (bool): If True, the element in pose_results is expected + to contain "track_id", which will be used to gather the pose + sequence of a person from multiple frames. Otherwise, the pose + results in each frame are expected to have a consistent number and + order of identities. Default is True. + target_frame (int): The index of the target frame. Default: -1. + + Returns: + List[:obj:`PoseDataSample`]: Indivisual pose sequence in with length N. + """ + T = len(pose_results_2d) + assert T > 0 + + target_frame = (T + target_frame) % T # convert negative index to positive + + N = len( + pose_results_2d[target_frame]) # use identities in the target frame + if N == 0: + return [] + + B, K, C = pose_results_2d[target_frame][0].pred_instances.keypoints.shape + + track_ids = None + if with_track_id: + track_ids = [res.track_id for res in pose_results_2d[target_frame]] + + pose_sequences = [] + for idx in range(N): + pose_seq = PoseDataSample() + gt_instances = InstanceData() + pred_instances = InstanceData() + + for k in pose_results_2d[target_frame][idx].gt_instances.keys(): + gt_instances.set_field( + pose_results_2d[target_frame][idx].gt_instances[k], k) + for k in pose_results_2d[target_frame][idx].pred_instances.keys(): + if k != 'keypoints': + pred_instances.set_field( + pose_results_2d[target_frame][idx].pred_instances[k], k) + pose_seq.pred_instances = pred_instances + pose_seq.gt_instances = gt_instances + + if not with_track_id: + pose_seq.pred_instances.keypoints = np.stack([ + frame[idx].pred_instances.keypoints + for frame in pose_results_2d + ], + axis=1) + else: + keypoints = np.zeros((B, T, K, C), dtype=np.float32) + keypoints[:, target_frame] = pose_results_2d[target_frame][ + idx].pred_instances.keypoints + # find the left most frame containing track_ids[idx] + for frame_idx in range(target_frame - 1, -1, -1): + contains_idx = False + for res in pose_results_2d[frame_idx]: + if res.track_id == track_ids[idx]: + keypoints[:, frame_idx] = res.pred_instances.keypoints + contains_idx = True + break + if not contains_idx: + # replicate the left most frame + keypoints[:, :frame_idx + 1] = keypoints[:, frame_idx + 1] + break + # find the right most frame containing track_idx[idx] + for frame_idx in range(target_frame + 1, T): + contains_idx = False + for res in pose_results_2d[frame_idx]: + if res.track_id == track_ids[idx]: + keypoints[:, frame_idx] = res.pred_instances.keypoints + contains_idx = True + break + if not contains_idx: + # replicate the right most frame + keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx] + break + pose_seq.pred_instances.keypoints = keypoints + pose_sequences.append(pose_seq) + + return pose_sequences + + +def inference_pose_lifter_model(model, + pose_results_2d, + with_track_id=True, + image_size=None, + norm_pose_2d=False): + """Inference 3D pose from 2D pose sequences using a pose lifter model. + + Args: + model (nn.Module): The loaded pose lifter model + pose_results_2d (List[List[:obj:`PoseDataSample`]]): The 2D pose + sequences stored in a nested list. + with_track_id: If True, the element in pose_results_2d is expected to + contain "track_id", which will be used to gather the pose sequence + of a person from multiple frames. Otherwise, the pose results in + each frame are expected to have a consistent number and order of + identities. Default is True. + image_size (tuple|list): image width, image height. If None, image size + will not be contained in dict ``data``. + norm_pose_2d (bool): If True, scale the bbox (along with the 2D + pose) to the average bbox scale of the dataset, and move the bbox + (along with the 2D pose) to the average bbox center of the dataset. + + Returns: + List[:obj:`PoseDataSample`]: 3D pose inference results. Specifically, + the predicted keypoints and scores are saved at + ``data_sample.pred_instances.keypoints_3d``. + """ + init_default_scope(model.cfg.get('default_scope', 'mmpose')) + pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) + + causal = model.causal + target_idx = -1 if causal else len(pose_results_2d) // 2 + + dataset_info = model.dataset_meta + if dataset_info is not None: + if 'stats_info' in dataset_info: + bbox_center = dataset_info['stats_info']['bbox_center'] + bbox_scale = dataset_info['stats_info']['bbox_scale'] + else: + bbox_center = None + bbox_scale = None + + for i, pose_res in enumerate(pose_results_2d): + keypoints = [] + for j, data_sample in enumerate(pose_res): + keypoint = np.squeeze(data_sample.pred_instances.keypoints, axis=0) + if norm_pose_2d: + bbox = np.squeeze(data_sample.pred_instances.bboxes) + center = np.array([[(bbox[0] + bbox[2]) / 2, + (bbox[1] + bbox[3]) / 2]]) + scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + keypoints.append((keypoint[:, :2] - center) / scale * + bbox_scale + bbox_center) + else: + keypoints.append(keypoint[:, :2]) + pose_results_2d[i][j].pred_instances.keypoints = np.array( + keypoints) + + pose_sequences_2d = _collate_pose_sequence(pose_results_2d, with_track_id, + target_idx) + + if not pose_sequences_2d: + return [] + + data_list = [] + for i, pose_seq in enumerate(pose_sequences_2d): + data_info = dict() + + keypoints_2d = pose_seq.pred_instances.keypoints + keypoints_2d = np.squeeze( + keypoints_2d) if keypoints_2d.ndim == 4 else keypoints_2d + + T, K, C = keypoints_2d.shape + + data_info['keypoints'] = keypoints_2d + data_info['keypoints_visible'] = np.ones(( + T, + K, + ), dtype=np.float32) + data_info['lifting_target'] = np.zeros((K, 3), dtype=np.float32) + data_info['lifting_target_visible'] = np.ones((K, 1), dtype=np.float32) + + if image_size is not None: + assert len(image_size) == 2 + data_info['camera_param'] = dict(w=image_size[0], h=image_size[1]) + + data_info.update(model.dataset_meta) + data_list.append(pipeline(data_info)) + + if data_list: + # collate data list into a batch, which is a dict with following keys: + # batch['inputs']: a list of input images + # batch['data_samples']: a list of :obj:`PoseDataSample` + batch = pseudo_collate(data_list) + with torch.no_grad(): + results = model.test_step(batch) + else: + results = [] + + return results diff --git a/mmpose/apis/inference_tracking.py b/mmpose/apis/inference_tracking.py new file mode 100644 index 0000000000..c823adcfc7 --- /dev/null +++ b/mmpose/apis/inference_tracking.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np + +from mmpose.evaluation.functional.nms import oks_iou + + +def _compute_iou(bboxA, bboxB): + """Compute the Intersection over Union (IoU) between two boxes . + + Args: + bboxA (list): The first bbox info (left, top, right, bottom, score). + bboxB (list): The second bbox info (left, top, right, bottom, score). + + Returns: + float: The IoU value. + """ + + x1 = max(bboxA[0], bboxB[0]) + y1 = max(bboxA[1], bboxB[1]) + x2 = min(bboxA[2], bboxB[2]) + y2 = min(bboxA[3], bboxB[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + + bboxA_area = (bboxA[2] - bboxA[0]) * (bboxA[3] - bboxA[1]) + bboxB_area = (bboxB[2] - bboxB[0]) * (bboxB[3] - bboxB[1]) + union_area = float(bboxA_area + bboxB_area - inter_area) + if union_area == 0: + union_area = 1e-5 + warnings.warn('union_area=0 is unexpected') + + iou = inter_area / union_area + + return iou + + +def _track_by_iou(res, results_last, thr): + """Get track id using IoU tracking greedily.""" + + bbox = list(np.squeeze(res.pred_instances.bboxes, axis=0)) + + max_iou_score = -1 + max_index = -1 + match_result = {} + for index, res_last in enumerate(results_last): + bbox_last = list(np.squeeze(res_last.pred_instances.bboxes, axis=0)) + + iou_score = _compute_iou(bbox, bbox_last) + if iou_score > max_iou_score: + max_iou_score = iou_score + max_index = index + + if max_iou_score > thr: + track_id = results_last[max_index].track_id + match_result = results_last[max_index] + del results_last[max_index] + else: + track_id = -1 + + return track_id, results_last, match_result + + +def _track_by_oks(res, results_last, thr, sigmas=None): + """Get track id using OKS tracking greedily.""" + keypoint = np.concatenate((res.pred_instances.keypoints, + res.pred_instances.keypoint_scores[:, :, None]), + axis=2) + keypoint = np.squeeze(keypoint, axis=0).reshape((-1)) + area = np.squeeze(res.pred_instances.areas, axis=0) + max_index = -1 + match_result = {} + + if len(results_last) == 0: + return -1, results_last, match_result + + keypoints_last = np.array([ + np.squeeze( + np.concatenate( + (res_last.pred_instances.keypoints, + res_last.pred_instances.keypoint_scores[:, :, None]), + axis=2), + axis=0).reshape((-1)) for res_last in results_last + ]) + area_last = np.array([ + np.squeeze(res_last.pred_instances.areas, axis=0) + for res_last in results_last + ]) + + oks_score = oks_iou( + keypoint, keypoints_last, area, area_last, sigmas=sigmas) + + max_index = np.argmax(oks_score) + + if oks_score[max_index] > thr: + track_id = results_last[max_index].track_id + match_result = results_last[max_index] + del results_last[max_index] + else: + track_id = -1 + + return track_id, results_last, match_result diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 1a02cda17e..64bf925997 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -163,9 +163,6 @@ def encode(self, if keypoint_labels.ndim == 2: keypoint_labels = keypoint_labels[None, ...] - N = keypoint_labels.shape[0] - keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) - encoded['keypoint_labels'] = keypoint_labels encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index 0331aad544..56cf35fa2d 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -28,7 +28,8 @@ class VideoPoseLifting(BaseKeypointCodec): remove_root (bool): If true, remove the root keypoint from the pose. Default: ``False``. save_index (bool): If true, store the root position separated from the - original pose. Default: ``False``. + original pose, only takes effect if ``remove_root`` is ``True``. + Default: ``False``. normalize_camera (bool): Whether to normalize camera intrinsics. Default: ``False``. """ @@ -166,10 +167,6 @@ def encode(self, _camera_param['c'] = (_camera_param['c'] - center[:, None]) / scale encoded['camera_param'] = _camera_param - # Generate reshaped keypoint coordinates - N = keypoint_labels.shape[0] - keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) - encoded['keypoint_labels'] = keypoint_labels encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights diff --git a/mmpose/datasets/datasets/__init__.py b/mmpose/datasets/datasets/__init__.py index 03a0f493ca..9f5801753f 100644 --- a/mmpose/datasets/datasets/__init__.py +++ b/mmpose/datasets/datasets/__init__.py @@ -2,6 +2,7 @@ from .animal import * # noqa: F401, F403 from .base import * # noqa: F401, F403 from .body import * # noqa: F401, F403 +from .body3d import * # noqa: F401, F403 from .face import * # noqa: F401, F403 from .fashion import * # noqa: F401, F403 from .hand import * # noqa: F401, F403 diff --git a/mmpose/datasets/datasets/utils.py b/mmpose/datasets/datasets/utils.py index 5140126163..7433a168b9 100644 --- a/mmpose/datasets/datasets/utils.py +++ b/mmpose/datasets/datasets/utils.py @@ -174,6 +174,11 @@ def parse_pose_metainfo(metainfo: dict): metainfo['joint_weights'], dtype=np.float32) parsed['sigmas'] = np.array(metainfo['sigmas'], dtype=np.float32) + if 'stats_info' in metainfo: + parsed['stats_info'] = {} + for name, val in metainfo['stats_info'].items(): + parsed['stats_info'][name] = np.array(val, dtype=np.float32) + # formatting def _map(src, mapping: dict): if isinstance(src, (list, tuple)): diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 6b09b2c770..403147120d 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -37,6 +37,31 @@ def image_to_tensor(img: Union[np.ndarray, return tensor +def keypoints_to_tensor(keypoints: Union[np.ndarray, Sequence[np.ndarray]] + ) -> torch.torch.Tensor: + """Translate keypoints or sequence of keypoints to tensor. Multiple + keypoints tensors will be stacked. + + Args: + keypoints (np.ndarray | Sequence[np.ndarray]): The keypoints or + keypoints sequence. + + Returns: + torch.Tensor: The output tensor. + """ + if isinstance(keypoints, np.ndarray): + keypoints = np.ascontiguousarray(keypoints) + N = keypoints.shape[0] + keypoints = keypoints.transpose(1, 2, 0).reshape(-1, N) + tensor = torch.from_numpy(keypoints).contiguous() + else: + assert is_seq_of(keypoints, np.ndarray) + tensor = torch.stack( + [keypoints_to_tensor(_keypoints) for _keypoints in keypoints]) + + return tensor + + @TRANSFORMS.register_module() class PackPoseInputs(BaseTransform): """Pack the inputs data for pose estimation. @@ -148,7 +173,11 @@ def transform(self, results: dict) -> dict: inputs_tensor = image_to_tensor(img) # Pack keypoints for 3d pose-lifting elif 'lifting_target' in results and 'keypoints' in results: - inputs_tensor = results['keypoints'] + if 'keypoint_labels' in results: + keypoints = results['keypoint_labels'] + else: + keypoints = results['keypoints'] + inputs_tensor = keypoints_to_tensor(keypoints) data_sample = PoseDataSample() @@ -156,6 +185,10 @@ def transform(self, results: dict) -> dict: gt_instances = InstanceData() for key, packed_key in self.instance_mapping_table.items(): if key in results: + if 'lifting_target' in results and key in { + 'keypoints', 'keypoints_visible' + }: + continue gt_instances.set_field(results[key], packed_key) # pack `transformed_keypoints` for visualizing data transform @@ -171,7 +204,7 @@ def transform(self, results: dict) -> dict: for key, packed_key in self.label_mapping_table.items(): if key in results: # For pose-lifting, store only target-related fields - if 'lifting_target' in results and key in { + if 'lifting_target_label' in results and key in { 'keypoint_labels', 'keypoint_weights' }: continue diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index 0b313d4d3f..e945650c30 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -104,8 +104,10 @@ def compute_metrics(self, results: list) -> Dict[str, float]: # pred_coords: [N, K, D] pred_coords = np.concatenate( [result['pred_coords'] for result in results]) + if pred_coords.ndim == 4 and pred_coords.shape[1] == 1: + pred_coords = np.squeeze(pred_coords, axis=1) # gt_coords: [N, K, D] - gt_coords = np.concatenate([result['gt_coords'] for result in results]) + gt_coords = np.stack([result['gt_coords'] for result in results]) # mask: [N, K] mask = np.concatenate([result['mask'] for result in results]) # action_category_indices: Dict[List[int]] diff --git a/mmpose/models/heads/regression_heads/temporal_regression_head.py b/mmpose/models/heads/regression_heads/temporal_regression_head.py index a33de19594..ac76316842 100644 --- a/mmpose/models/heads/regression_heads/temporal_regression_head.py +++ b/mmpose/models/heads/regression_heads/temporal_regression_head.py @@ -91,13 +91,13 @@ def predict(self, batch_coords = self.forward(feats) # (B, K, D) - batch_coords.unsqueeze_(dim=1) # (B, N, K, D) - # Restore global position with target_root target_root = batch_data_samples[0].metainfo.get('target_root', None) if target_root is not None: - target_root = torch.stack( - [m['target_root'] for m in batch_data_samples[0].metainfo]) + target_root = torch.stack([ + torch.from_numpy(b.metainfo['target_root']) + for b in batch_data_samples + ]) else: target_root = torch.stack([ torch.empty((0), dtype=torch.float32) diff --git a/mmpose/models/heads/regression_heads/trajectory_regression_head.py b/mmpose/models/heads/regression_heads/trajectory_regression_head.py index 0b72ae3155..adfd7353d3 100644 --- a/mmpose/models/heads/regression_heads/trajectory_regression_head.py +++ b/mmpose/models/heads/regression_heads/trajectory_regression_head.py @@ -91,13 +91,13 @@ def predict(self, batch_coords = self.forward(feats) # (B, K, D) - batch_coords.unsqueeze_(dim=1) # (B, N, K, D) - # Restore global position with target_root target_root = batch_data_samples[0].metainfo.get('target_root', None) if target_root is not None: - target_root = torch.stack( - [m['target_root'] for m in batch_data_samples[0].metainfo]) + target_root = torch.stack([ + torch.from_numpy(b.metainfo['target_root']) + for b in batch_data_samples + ]) else: target_root = torch.stack([ torch.empty((0), dtype=torch.float32) diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index 73d60de93a..0ae921d0ec 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -130,6 +130,8 @@ def forward(self, - If ``mode='loss'``, return a dict of tensor(s) which is the loss function value """ + if isinstance(inputs, list): + inputs = torch.stack(inputs) if mode == 'loss': return self.loss(inputs, data_samples) elif mode == 'predict': diff --git a/mmpose/models/pose_estimators/pose_lifter.py b/mmpose/models/pose_estimators/pose_lifter.py index 5b0abf3690..5bad3dde3c 100644 --- a/mmpose/models/pose_estimators/pose_lifter.py +++ b/mmpose/models/pose_estimators/pose_lifter.py @@ -134,7 +134,7 @@ def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: # pose model feats = self.backbone(inputs) if self.with_neck: - x = self.neck(feats) + feats = self.neck(feats) # trajectory model if self.with_traj: @@ -145,9 +145,9 @@ def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: if self.with_traj_neck: traj_x = self.traj_neck(traj_x) - return x, traj_x + return feats, traj_x else: - return x + return feats def _forward(self, inputs: Tensor, @@ -218,8 +218,14 @@ def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. + Note: + - batch_size: B + - num_input_keypoints: K + - input_keypoint_dim: C + - input_sequence_len: T + Args: - inputs (Tensor): Inputs with shape (N, K, C, T). + inputs (Tensor): Inputs with shape like (B, K, C, T). data_samples (List[:obj:`PoseDataSample`]): The batch data samples @@ -298,6 +304,8 @@ def add_pred_to_datasample( assert len(batch_pred_instances) == len(batch_data_samples) if batch_pred_fields is None: batch_pred_fields, batch_traj_fields = [], [] + if batch_traj_instances is None: + batch_traj_instances = [] output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', None) diff --git a/mmpose/visualization/__init__.py b/mmpose/visualization/__init__.py index 73fbd645a9..4a18e8bc5b 100644 --- a/mmpose/visualization/__init__.py +++ b/mmpose/visualization/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .fast_visualizer import FastVisualizer from .local_visualizer import PoseLocalVisualizer +from .local_visualizer_3d import Pose3dLocalVisualizer -__all__ = ['PoseLocalVisualizer', 'FastVisualizer'] +__all__ = ['PoseLocalVisualizer', 'FastVisualizer', 'Pose3dLocalVisualizer'] diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py new file mode 100644 index 0000000000..3a0cfc1cb3 --- /dev/null +++ b/mmpose/visualization/local_visualizer_3d.py @@ -0,0 +1,563 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import mmcv +import numpy as np +from matplotlib import pyplot as plt +from mmengine.dist import master_only +from mmengine.structures import InstanceData + +from mmpose.registry import VISUALIZERS +from mmpose.structures import PoseDataSample +from . import PoseLocalVisualizer + + +@VISUALIZERS.register_module() +class Pose3dLocalVisualizer(PoseLocalVisualizer): + """MMPose 3d Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to ``None`` + vis_backends (list, optional): Visual backend config list. Defaults to + ``None`` + save_dir (str, optional): Save file dir for all storage backends. + If it is ``None``, the backend storage will not save any data. + Defaults to ``None`` + bbox_color (str, tuple(int), optional): Color of bbox lines. + The tuple of color should be in BGR order. Defaults to ``'green'`` + kpt_color (str, tuple(tuple(int)), optional): Color of keypoints. + The tuple of color should be in BGR order. Defaults to ``'red'`` + link_color (str, tuple(tuple(int)), optional): Color of skeleton. + The tuple of color should be in BGR order. Defaults to ``None`` + line_width (int, float): The width of lines. Defaults to 1 + radius (int, float): The radius of keypoints. Defaults to 4 + show_keypoint_weight (bool): Whether to adjust the transparency + of keypoints according to their score. Defaults to ``False`` + alpha (int, float): The transparency of bboxes. Defaults to ``0.8`` + """ + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + bbox_color: Optional[Union[str, Tuple[int]]] = 'green', + kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = 'red', + link_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, + text_color: Optional[Union[str, + Tuple[int]]] = (255, 255, 255), + skeleton: Optional[Union[List, Tuple]] = None, + line_width: Union[int, float] = 1, + radius: Union[int, float] = 3, + show_keypoint_weight: bool = False, + alpha: float = 0.8): + super().__init__(name, image, vis_backends, save_dir, bbox_color, + kpt_color, link_color, text_color, skeleton, + line_width, radius, show_keypoint_weight, alpha) + + def _draw_3d_data_samples( + self, + image: np.ndarray, + pose_samples: PoseDataSample, + draw_gt: bool = True, + kpt_thr: float = 0.3, + num_instances=-1, + axis_azimuth: float = 70, + axis_limit: float = 1.7, + axis_dist: float = 10.0, + axis_elev: float = 15.0, + ): + """Draw keypoints and skeletons (optional) of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + draw_gt (bool): Whether to draw GT PoseDataSample. Default to + ``True`` + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. + num_instances (int): Number of instances to be shown in 3D. If + smaller than 0, all the instances in the pose_result will be + shown. Otherwise, pad or truncate the pose_result to a length + of num_instances. + axis_azimuth (float): axis azimuth angle for 3D visualizations. + axis_dist (float): axis distance for 3D visualizations. + axis_elev (float): axis elevation view angle for 3D visualizations. + axis_limit (float): The axis limit to visualize 3d pose. The xyz + range will be set as: + - x: [x_c - axis_limit/2, x_c + axis_limit/2] + - y: [y_c - axis_limit/2, y_c + axis_limit/2] + - z: [0, axis_limit] + Where x_c, y_c is the mean value of x and y coordinates + + Returns: + Tuple(np.ndarray): the drawn image which channel is RGB. + """ + vis_height, vis_width, _ = image.shape + + if 'pred_instances' in pose_samples: + pred_instances = pose_samples.pred_instances + else: + pred_instances = InstanceData() + if num_instances < 0: + if 'keypoints' in pred_instances: + num_instances = len(pred_instances) + else: + num_instances = 0 + else: + if len(pred_instances) > num_instances: + for k in pred_instances.keys(): + new_val = pred_instances.k[:num_instances] + pose_samples.pred_instances.k = new_val + elif num_instances < len(pred_instances): + num_instances = len(pred_instances) + + num_fig = num_instances + if draw_gt: + vis_width *= 2 + num_fig *= 2 + + plt.ioff() + fig = plt.figure( + figsize=(vis_width * 0.01, vis_height * num_instances * 0.01)) + + def _draw_3d_instances_kpts(keypoints, + scores, + keypoints_visible, + fig_idx, + title=None): + + for idx, (kpts, score, visible) in enumerate( + zip(keypoints, scores, keypoints_visible)): + + valid = score >= kpt_thr + + ax = fig.add_subplot( + 1, num_fig, fig_idx * (idx + 1), projection='3d') + ax.view_init(elev=axis_elev, azim=axis_azimuth) + + x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 + y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 + + ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) + ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) + ax.set_zlim3d([0, axis_limit]) + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_zlabel('z') + ax.set_aspect('auto') + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + ax.dist = axis_dist + + ax.scatter([0], [0], [0], marker='o', color='red') + + kpts = np.array(kpts, copy=False) + + if self.kpt_color is None or isinstance(self.kpt_color, str): + kpt_color = [self.kpt_color] * len(kpts) + elif len(self.kpt_color) == len(kpts): + kpt_color = self.kpt_color + else: + raise ValueError( + f'the length of kpt_color ' + f'({len(self.kpt_color)}) does not matches ' + f'that of keypoints ({len(kpts)})') + + kpts = kpts[valid] + x_3d, y_3d, z_3d = np.split(kpts[:, :3], [1, 2], axis=1) + + kpt_color = kpt_color[valid][..., ::-1] / 255. + + ax.scatter(x_3d, y_3d, z_3d, marker='o', color=kpt_color) + + for kpt_idx in range(len(x_3d)): + ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0], + z_3d[kpt_idx][0], str(kpt_idx)) + + if self.skeleton is not None and self.link_color is not None: + if self.link_color is None or isinstance( + self.link_color, str): + link_color = [self.link_color] * len(self.skeleton) + elif len(self.link_color) == len(self.skeleton): + link_color = self.link_color + else: + raise ValueError( + f'the length of link_color ' + f'({len(self.link_color)}) does not matches ' + f'that of skeleton ({len(self.skeleton)})') + + for sk_id, sk in enumerate(self.skeleton): + sk_indices = [_i for _i in sk] + xs_3d = kpts[sk_indices, 0] + ys_3d = kpts[sk_indices, 1] + zs_3d = kpts[sk_indices, 2] + kpt_score = score[sk_indices] + if kpt_score.min() > kpt_thr: + # matplotlib uses RGB color in [0, 1] value range + _color = link_color[sk_id][::-1] / 255. + ax.plot( + xs_3d, ys_3d, zs_3d, color=_color, zdir='z') + + if title: + ax.set_title(f'{title} ({idx})') + + if 'keypoints' in pred_instances: + keypoints = pred_instances.get('keypoints', + pred_instances.keypoints) + + if 'keypoint_scores' in pred_instances: + scores = pred_instances.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + + if 'keypoints_visible' in pred_instances: + keypoints_visible = pred_instances.keypoints_visible + else: + keypoints_visible = np.ones(keypoints.shape[:-1]) + + _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, 1, + 'Prediction') + + if draw_gt and 'gt_instances' in pose_samples: + gt_instances = pose_samples.gt_instances + if 'lifting_target' in gt_instances: + keypoints = gt_instances.get('lifting_target', + gt_instances.lifting_target) + scores = np.ones(keypoints.shape[:-1]) + + if 'lifting_target_visible' in gt_instances: + keypoints_visible = gt_instances.lifting_target_visible + else: + keypoints_visible = np.ones(keypoints.shape[:-1]) + + _draw_3d_instances_kpts(keypoints, scores, keypoints_visible, + 2, 'Ground Truth') + + # convert figure to numpy array + fig.tight_layout() + fig.canvas.draw() + + pred_img_data = fig.canvas.tostring_rgb() + pred_img_data = np.frombuffer( + fig.canvas.tostring_rgb(), dtype=np.uint8) + + if not pred_img_data.any(): + pred_img_data = np.full((vis_height, vis_width, 3), 255) + else: + pred_img_data = pred_img_data.reshape(vis_height, vis_width, -1) + + plt.close(fig) + + return pred_img_data + + def _draw_instances_kpts( + self, + image: np.ndarray, + instances: InstanceData, + kpt_thr: float = 0.3, + show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose', + det_kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, + det_dataset_skeleton: Optional[List] = None, + det_dataset_link_color: Optional[np.ndarray] = None): + """Draw keypoints and skeletons (optional) of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + instances (:obj:`InstanceData`): Data structure for + instance-level annotations or predictions. + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` + det_kpt_color (str, tuple(tuple(int)), optional): Keypoints + color info for detection. Defaults to ``None`` + det_dataset_skeleton (list): Skeleton info for detection. Defaults + to ``None`` + det_dataset_link_color (list): Link color for detection. Defaults + to ``None`` + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + + self.set_image(image) + img_h, img_w, _ = image.shape + + if 'keypoints' in instances: + keypoints = instances.get('transformed_keypoints', + instances.keypoints) + + if 'keypoint_scores' in instances: + scores = instances.keypoint_scores + else: + scores = np.ones(keypoints.shape[:-1]) + + if 'keypoints_visible' in instances: + keypoints_visible = instances.keypoints_visible + else: + keypoints_visible = np.ones(keypoints.shape[:-1]) + + if skeleton_style == 'openpose': + keypoints_info = np.concatenate( + (keypoints, scores[..., None], keypoints_visible[..., + None]), + axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > kpt_thr, + keypoints_info[:, 6, 2:4] > kpt_thr).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores, keypoints_visible = keypoints_info[ + ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + + kpt_color = self.kpt_color + if det_kpt_color is not None: + kpt_color = det_kpt_color + + for kpts, score, visible in zip(keypoints, scores, + keypoints_visible): + kpts = np.array(kpts, copy=False) + + if kpt_color is None or isinstance(kpt_color, str): + kpt_color = [kpt_color] * len(kpts) + elif len(kpt_color) == len(kpts): + kpt_color = kpt_color + else: + raise ValueError(f'the length of kpt_color ' + f'({len(kpt_color)}) does not matches ' + f'that of keypoints ({len(kpts)})') + + # draw each point on image + for kid, kpt in enumerate(kpts): + if score[kid] < kpt_thr or not visible[ + kid] or kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = kpt_color[kid] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max(0, min(1, score[kid])) + self.draw_circles( + kpt, + radius=np.array([self.radius]), + face_colors=color, + edge_colors=color, + alpha=transparency, + line_widths=self.radius) + if show_kpt_idx: + self.draw_texts( + str(kid), + kpt, + colors=color, + font_sizes=self.radius * 3, + vertical_alignments='bottom', + horizontal_alignments='center') + + # draw links + skeleton = self.skeleton + if det_dataset_skeleton is not None: + skeleton = det_dataset_skeleton + link_color = self.link_color + if det_dataset_link_color is not None: + link_color = det_dataset_link_color + if skeleton is not None and link_color is not None: + if link_color is None or isinstance(link_color, str): + link_color = [link_color] * len(skeleton) + elif len(link_color) == len(skeleton): + link_color = link_color + else: + raise ValueError( + f'the length of link_color ' + f'({len(link_color)}) does not matches ' + f'that of skeleton ({len(skeleton)})') + + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + if not (visible[sk[0]] and visible[sk[1]]): + continue + + if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 + or pos1[1] >= img_h or pos2[0] <= 0 + or pos2[0] >= img_w or pos2[1] <= 0 + or pos2[1] >= img_h or score[sk[0]] < kpt_thr + or score[sk[1]] < kpt_thr + or link_color[sk_id] is None): + # skip the link that should not be drawn + continue + X = np.array((pos1[0], pos2[0])) + Y = np.array((pos1[1], pos2[1])) + color = link_color[sk_id] + if not isinstance(color, str): + color = tuple(int(c) for c in color) + transparency = self.alpha + if self.show_keypoint_weight: + transparency *= max( + 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) + + if skeleton_style == 'openpose': + mX = np.mean(X) + mY = np.mean(Y) + length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5 + angle = math.degrees( + math.atan2(Y[0] - Y[1], X[0] - X[1])) + stickwidth = 2 + polygons = cv2.ellipse2Poly( + (int(mX), int(mY)), + (int(length / 2), int(stickwidth)), int(angle), + 0, 360, 1) + + self.draw_polygons( + polygons, + edge_colors=color, + face_colors=color, + alpha=transparency) + + else: + self.draw_lines( + X, Y, color, line_widths=self.line_width) + + return self.get_image() + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: PoseDataSample, + det_data_sample: Optional[PoseDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + draw_2d: bool = True, + det_kpt_color: Optional[Union[str, Tuple[Tuple[int]]]] = None, + det_dataset_skeleton: Optional[Union[str, + Tuple[Tuple[int]]]] = None, + det_dataset_link_color: Optional[np.ndarray] = None, + draw_bbox: bool = False, + show_kpt_idx: bool = False, + skeleton_style: str = 'mmpose', + show: bool = False, + wait_time: float = 0, + out_file: Optional[str] = None, + kpt_thr: float = 0.3, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. t is usually used when the display + is not available. + + Args: + name (str): The image identifier + image (np.ndarray): The image to draw + data_sample (:obj:`PoseDataSample`): The 3d data sample + to visualize + det_data_sample (:obj:`PoseDataSample`, optional): The 2d detection + data sample to visualize + draw_gt (bool): Whether to draw GT PoseDataSample. Default to + ``True`` + draw_pred (bool): Whether to draw Prediction PoseDataSample. + Defaults to ``True`` + draw_2d (bool): Whether to draw 2d detection results. Defaults to + ``True`` + det_kpt_color (str, tuple(tuple(int)), optional): Keypoints color + info for detection. Defaults to ``None`` + det_dataset_skeleton (np.ndarray, optional): The skeleton link info + for detection data. Default to ``None`` + det_dataset_link_color (str, tuple(tuple(int)), optional): Link + color for detection. Defaults to ``None`` + draw_bbox (bool): Whether to draw bounding boxes. Default to + ``False`` + show_kpt_idx (bool): Whether to show the index of keypoints. + Defaults to ``False`` + skeleton_style (str): Skeleton style selection. Defaults to + ``'mmpose'`` + show (bool): Whether to display the drawn image. Default to + ``False`` + wait_time (float): The interval of show (s). Defaults to 0 + out_file (str): Path to output file. Defaults to ``None`` + kpt_thr (float, optional): Minimum threshold of keypoints + to be shown. Default: 0.3. + step (int): Global step value to record. Defaults to 0 + """ + + det_img_data = None + gt_img_data = None + + if draw_2d: + det_img_data = image.copy() + + # draw bboxes & keypoints + if 'pred_instances' in det_data_sample: + det_img_data = self._draw_instances_kpts( + det_img_data, det_data_sample.pred_instances, kpt_thr, + show_kpt_idx, skeleton_style, det_kpt_color, + det_dataset_skeleton, det_dataset_link_color) + if draw_bbox: + det_img_data = self._draw_instances_bbox( + det_img_data, det_data_sample.pred_instances) + + pred_img_data = self._draw_3d_data_samples( + image.copy(), data_sample, draw_gt=draw_gt) + + # merge visualization results + if det_img_data is not None and gt_img_data is not None: + drawn_img = np.concatenate( + (det_img_data, pred_img_data, gt_img_data), axis=1) + elif det_img_data is not None: + drawn_img = np.concatenate((det_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = np.concatenate((det_img_data, gt_img_data), axis=1) + else: + drawn_img = pred_img_data + + # It is convenient for users to obtain the drawn image. + # For example, the user wants to obtain the drawn image and + # save it as a video during video inference. + self.set_image(drawn_img) + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(drawn_img[..., ::-1], out_file) + else: + # save drawn_img to backends + self.add_image(name, drawn_img, step) + + return self.get_image() diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index 78a4262834..bb94786c32 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -49,7 +49,7 @@ def test_encode(self): encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) @@ -64,7 +64,7 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) self.assertEqual(encoded['target_root'].shape, (3, )) @@ -77,7 +77,7 @@ def test_encode(self): encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) def test_decode(self): diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index 05fc10ee95..cc58292d0c 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -60,7 +60,7 @@ def test_encode(self): encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible, camera_param) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) @@ -71,7 +71,7 @@ def test_encode(self): encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible, camera_param) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) @@ -85,7 +85,7 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) - self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) self.assertEqual(encoded['target_root'].shape, (3, )) diff --git a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py index d51d493cbc..8289b09d0f 100644 --- a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py +++ b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py @@ -20,9 +20,9 @@ def setUp(self): for i in range(self.batch_size): gt_instances = InstanceData() keypoints = np.random.random((1, num_keypoints, 3)) - gt_instances.lifting_target = keypoints + gt_instances.lifting_target = np.random.random((num_keypoints, 3)) gt_instances.lifting_target_visible = np.ones( - (1, num_keypoints, 1)).astype(bool) + (num_keypoints, 1)).astype(bool) pred_instances = InstanceData() pred_instances.keypoints = keypoints + np.random.normal( From d51ea65e55e382bd72ffcb06208b114e7d6e28b5 Mon Sep 17 00:00:00 2001 From: Yifan Lareina WU Date: Wed, 31 May 2023 12:47:46 +0800 Subject: [PATCH 7/8] [Fix] Fix bugs in 3d human pose demo (#2413) --- demo/body3d_pose_lifter_demo.py | 2 +- mmpose/apis/inference_3d.py | 27 ++++++++++--------- mmpose/visualization/local_visualizer_3d.py | 30 ++++++++++----------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/demo/body3d_pose_lifter_demo.py b/demo/body3d_pose_lifter_demo.py index 8b834d08ca..02e3014f21 100644 --- a/demo/body3d_pose_lifter_demo.py +++ b/demo/body3d_pose_lifter_demo.py @@ -386,7 +386,7 @@ def main(): pose_results_2d = extract_pose_sequence( pose_est_results_list, frame_idx=frame_idx, - causal=pose_lifter.causal, + causal=pose_lift_dataset.get('causal', False), seq_len=pose_lift_dataset.get('seq_len', 1), step=pose_lift_dataset.get('seq_step', 1)) diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index 2ab81b20a4..f89c33c9ea 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -180,7 +180,7 @@ def inference_pose_lifter_model(model, init_default_scope(model.cfg.get('default_scope', 'mmpose')) pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) - causal = model.causal + causal = model.cfg.test_dataloader.dataset.get('causal', False) target_idx = -1 if causal else len(pose_results_2d) // 2 dataset_info = model.dataset_meta @@ -193,18 +193,21 @@ def inference_pose_lifter_model(model, bbox_scale = None for i, pose_res in enumerate(pose_results_2d): - keypoints = [] for j, data_sample in enumerate(pose_res): - keypoint = np.squeeze(data_sample.pred_instances.keypoints, axis=0) - if norm_pose_2d: - bbox = np.squeeze(data_sample.pred_instances.bboxes) - center = np.array([[(bbox[0] + bbox[2]) / 2, - (bbox[1] + bbox[3]) / 2]]) - scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - keypoints.append((keypoint[:, :2] - center) / scale * - bbox_scale + bbox_center) - else: - keypoints.append(keypoint[:, :2]) + kpts = data_sample.pred_instances.keypoints + bboxes = data_sample.pred_instances.bboxes + keypoints = [] + for k in range(len(kpts)): + kpt = kpts[k] + if norm_pose_2d: + bbox = bboxes[k] + center = np.array([[(bbox[0] + bbox[2]) / 2, + (bbox[1] + bbox[3]) / 2]]) + scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + keypoints.append((kpt[:, :2] - center) / scale * + bbox_scale + bbox_center) + else: + keypoints.append(kpt[:, :2]) pose_results_2d[i][j].pred_instances.keypoints = np.array( keypoints) diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 3a0cfc1cb3..8eccd10bf9 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -124,7 +124,7 @@ def _draw_3d_data_samples( plt.ioff() fig = plt.figure( - figsize=(vis_width * 0.01, vis_height * num_instances * 0.01)) + figsize=(vis_width * num_instances * 0.01, vis_height * 0.01)) def _draw_3d_instances_kpts(keypoints, scores, @@ -135,21 +135,13 @@ def _draw_3d_instances_kpts(keypoints, for idx, (kpts, score, visible) in enumerate( zip(keypoints, scores, keypoints_visible)): - valid = score >= kpt_thr + valid = np.logical_and(score >= kpt_thr, + np.any(~np.isnan(kpts), axis=-1)) ax = fig.add_subplot( 1, num_fig, fig_idx * (idx + 1), projection='3d') ax.view_init(elev=axis_elev, azim=axis_azimuth) - - x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 - y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 - - ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) - ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) ax.set_zlim3d([0, axis_limit]) - ax.set_xlabel('x') - ax.set_ylabel('y') - ax.set_zlabel('z') ax.set_aspect('auto') ax.set_xticks([]) ax.set_yticks([]) @@ -157,9 +149,16 @@ def _draw_3d_instances_kpts(keypoints, ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) + ax.scatter([0], [0], [0], marker='o', color='red') + if title: + ax.set_title(f'{title} ({idx})') ax.dist = axis_dist - ax.scatter([0], [0], [0], marker='o', color='red') + x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 + y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 + + ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) + ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) kpts = np.array(kpts, copy=False) @@ -208,9 +207,6 @@ def _draw_3d_instances_kpts(keypoints, ax.plot( xs_3d, ys_3d, zs_3d, color=_color, zdir='z') - if title: - ax.set_title(f'{title} ({idx})') - if 'keypoints' in pred_instances: keypoints = pred_instances.get('keypoints', pred_instances.keypoints) @@ -254,7 +250,9 @@ def _draw_3d_instances_kpts(keypoints, if not pred_img_data.any(): pred_img_data = np.full((vis_height, vis_width, 3), 255) else: - pred_img_data = pred_img_data.reshape(vis_height, vis_width, -1) + pred_img_data = pred_img_data.reshape(vis_height, + vis_width * num_instances, + -1) plt.close(fig) From 56b0c1a5f351da2210fa2f3bb144f01084bb5df4 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 1 Jun 2023 10:47:07 +0800 Subject: [PATCH 8/8] fix problems during rebasing --- demo/bottomup_demo.py | 1 + demo/topdown_demo_with_mmdet.py | 1 + mmpose/apis/inference_3d.py | 2 +- mmpose/visualization/local_visualizer_3d.py | 4 +++- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index c6778c637f..3d6fee7a03 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -11,6 +11,7 @@ import numpy as np from mmpose.apis import inference_bottomup, init_model +from mmpose.registry import VISUALIZERS from mmpose.structures import split_instances diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index 442c4e812c..38f4e92e4e 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -13,6 +13,7 @@ from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms +from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples, split_instances from mmpose.utils import adapt_mmdet_pipeline diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index f89c33c9ea..5fbc934adc 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -223,7 +223,7 @@ def inference_pose_lifter_model(model, keypoints_2d = pose_seq.pred_instances.keypoints keypoints_2d = np.squeeze( - keypoints_2d) if keypoints_2d.ndim == 4 else keypoints_2d + keypoints_2d, axis=0) if keypoints_2d.ndim == 4 else keypoints_2d T, K, C = keypoints_2d.shape diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 8eccd10bf9..764a85dee2 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -54,10 +54,12 @@ def __init__(self, line_width: Union[int, float] = 1, radius: Union[int, float] = 3, show_keypoint_weight: bool = False, + backend: str = 'opencv', alpha: float = 0.8): super().__init__(name, image, vis_backends, save_dir, bbox_color, kpt_color, link_color, text_color, skeleton, - line_width, radius, show_keypoint_weight, alpha) + line_width, radius, show_keypoint_weight, backend, + alpha) def _draw_3d_data_samples( self,