diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index a88ebac701..cdbd8feb0c 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .associative_embedding import AssociativeEmbedding from .decoupled_heatmap import DecoupledHeatmap +from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap from .msra_heatmap import MSRAHeatmap @@ -8,9 +9,10 @@ from .simcc_label import SimCCLabel from .spr import SPR from .udp_heatmap import UDPHeatmap +from .video_pose_lifting import VideoPoseLifting __all__ = [ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', - 'DecoupledHeatmap' + 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting' ] diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py new file mode 100644 index 0000000000..93530cf15d --- /dev/null +++ b/mmpose/codecs/image_pose_lifting.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class ImagePoseLifting(BaseKeypointCodec): + r"""Generate keypoint coordinates for pose lifter. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - target dimension: C + + Args: + num_keypoints (int): The number of keypoints in the dataset. + root_index (int): Root keypoint index in the pose. + remove_root (bool): If true, remove the root keypoint from the pose. + Default: ``False``. + save_index (bool): If true, store the root position separated from the + original pose. Default: ``False``. + keypoints_mean (np.ndarray, optional): Mean values of keypoints + coordinates in shape (K, D). + keypoints_std (np.ndarray, optional): Std values of keypoints + coordinates in shape (K, D). + target_mean (np.ndarray, optional): Mean values of target coordinates + in shape (K, C). + target_std (np.ndarray, optional): Std values of target coordinates + in shape (K, C). + """ + + auxiliary_encode_keys = {'target', 'target_visible'} + + def __init__(self, + num_keypoints: int, + root_index: int, + remove_root: bool = False, + save_index: bool = False, + keypoints_mean: Optional[np.ndarray] = None, + keypoints_std: Optional[np.ndarray] = None, + target_mean: Optional[np.ndarray] = None, + target_std: Optional[np.ndarray] = None): + super().__init__() + + self.num_keypoints = num_keypoints + self.root_index = root_index + self.remove_root = remove_root + self.save_index = save_index + if keypoints_mean is not None and keypoints_std is not None: + assert keypoints_mean.shape == keypoints_std.shape + if target_mean is not None and target_std is not None: + assert target_mean.shape == target_std.shape + self.keypoints_mean = keypoints_mean + self.keypoints_std = keypoints_std + self.target_mean = target_mean + self.target_std = target_std + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + target: Optional[np.ndarray] = None, + target_visible: Optional[np.ndarray] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). + keypoints_visible (np.ndarray, optional): Keypoint visibilities in + shape (N, K). + target (np.ndarray, optional): Target coordinate in shape (K, C). + target_visible (np.ndarray, optional): Target coordinate in shape + (K, ). + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape (K * D, N) where D is 2 for 2d coordinates. + - target_label: The processed target coordinate in shape (K, C) + or (K-1, C). + - target_weights (np.ndarray): The target weights in shape + (K, ) or (K-1, ). + - trajectory_weights (np.ndarray): The trajectory weights in + shape (K, ). + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). + + In addition, there are some optional items it may contain: + + - target_root_removed (bool): Indicate whether the root of + target is removed. Added if ``self.remove_root`` is ``True``. + - target_root_index (int): An integer indicating the index of + root. Added if ``self.remove_root`` and ``self.save_index`` + are ``True``. + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if target is None: + target = keypoints[0] + + # set initial value for `target_weights` and `trajectory_weights` + if target_visible is None: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + target_weights = target_visible + trajectory_weights = (1 / target[:, 2]) + else: + valid = target_visible > 0.5 + target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = target_weights + + encoded = dict() + + # Zero-center the target pose around a given root keypoint + assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ + f'Got invalid joint shape {target.shape}' + + root = target[..., self.root_index, :] + target_label = target - root + + if self.remove_root: + target_label = np.delete(target_label, self.root_index, axis=-2) + assert target_weights.ndim in {1, 2} + axis_to_remove = -2 if target_weights.ndim == 2 else -1 + target_weights = np.delete( + target_weights, self.root_index, axis=axis_to_remove) + # Add a flag to avoid latter transforms that rely on the root + # joint or the original joint index + encoded['target_root_removed'] = True + + # Save the root index which is necessary to restore the global pose + if self.save_index: + encoded['target_root_index'] = self.root_index + + # Normalize the 2D keypoint coordinate with mean and std + keypoint_labels = keypoints.copy() + if self.keypoints_mean is not None and self.keypoints_std is not None: + keypoints_shape = keypoints.shape + assert self.keypoints_mean.shape == keypoints_shape[1:] + + keypoint_labels = (keypoint_labels - + self.keypoints_mean) / self.keypoints_std + if self.target_mean is not None and self.target_std is not None: + target_shape = target_label.shape + assert self.target_mean.shape == target_shape + + target_label = (target_label - self.target_mean) / self.target_std + + # Generate reshaped keypoint coordinates + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + + encoded['keypoint_labels'] = keypoint_labels + encoded['target_label'] = target_label + encoded['target_weights'] = target_weights + encoded['trajectory_weights'] = trajectory_weights + encoded['target_root'] = root + + return encoded + + def decode(self, + encoded: np.ndarray, + restore_global_position: bool = False, + target_root: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (N, K, C). + restore_global_position (bool): Whether to restore global position. + Default: ``False``. + target_root (np.ndarray, optional): The target root coordinate. + Default: ``None``. + + Returns: + keypoints (np.ndarray): Decoded coordinates in shape (N, K, C). + scores (np.ndarray): The keypoint scores in shape (N, K). + """ + keypoints = encoded.copy() + + if self.target_mean is not None and self.target_std is not None: + assert self.target_mean.shape == keypoints.shape[1:] + keypoints = keypoints * self.target_std + self.target_mean + + if restore_global_position: + assert target_root is not None + keypoints = keypoints + np.expand_dims(target_root, axis=0) + if self.remove_root: + keypoints = np.insert( + keypoints, self.root_index, target_root, axis=1) + scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + + return keypoints, scores diff --git a/mmpose/codecs/regression_label.py b/mmpose/codecs/regression_label.py index 9ae385d2d9..f79195beb4 100644 --- a/mmpose/codecs/regression_label.py +++ b/mmpose/codecs/regression_label.py @@ -78,7 +78,7 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: Returns: tuple: - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) - - socres (np.ndarray): The keypoint scores in shape (N, K). + - scores (np.ndarray): The keypoint scores in shape (N, K). It usually represents the confidence of the keypoint prediction """ diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py new file mode 100644 index 0000000000..fbb6ad429c --- /dev/null +++ b/mmpose/codecs/video_pose_lifting.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from copy import deepcopy +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class VideoPoseLifting(BaseKeypointCodec): + r"""Generate keypoint coordinates for pose lifter. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - target dimension: C + + Args: + num_keypoints (int): The number of keypoints in the dataset. + zero_center: Whether to zero-center the target around root. Default: + ``True``. + root_index (int): Root keypoint index in the pose. Default: 0. + remove_root (bool): If true, remove the root keypoint from the pose. + Default: ``False``. + save_index (bool): If true, store the root position separated from the + original pose. Default: ``False``. + normalize_camera (bool): Whether to normalize camera intrinsics. + Default: ``False``. + """ + + auxiliary_encode_keys = {'target', 'target_visible', 'camera_param'} + + def __init__(self, + num_keypoints: int, + zero_center: bool = True, + root_index: int = 0, + remove_root: bool = False, + save_index: bool = False, + normalize_camera: bool = False): + super().__init__() + + self.num_keypoints = num_keypoints + self.zero_center = zero_center + self.root_index = root_index + self.remove_root = remove_root + self.save_index = save_index + self.normalize_camera = normalize_camera + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + target: Optional[np.ndarray] = None, + target_visible: Optional[np.ndarray] = None, + camera_param: Optional[dict] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D). + keypoints_visible (np.ndarray, optional): Keypoint visibilities in + shape (N, K). + target (np.ndarray, optional): Target coordinate in shape (K, C). + target_visible (np.ndarray, optional): Target coordinate in shape + (K, ). + camera_param (dict, optional): The camera parameter dictionary. + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape (K * D, N) where D is 2 for 2d coordinates. + - target_label: The processed target coordinate in shape (K, C) + or (K-1, C). + - target_weights (np.ndarray): The target weights in shape + (K, ) or (K-1, ). + - trajectory_weights (np.ndarray): The trajectory weights in + shape (K, ). + + In addition, there are some optional items it may contain: + + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). Exists if ``self.zero_center`` is ``True``. + - target_root_removed (bool): Indicate whether the root of + target is removed. Exists if ``self.remove_root`` is + ``True``. + - target_root_index (int): An integer indicating the index of + root. Exists if ``self.remove_root`` and ``self.save_index`` + are ``True``. + - camera_param (dict): The updated camera parameter dictionary. + Exists if ``self.normalize_camera`` is ``True``. + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if target is None: + target = keypoints[0] + + # set initial value for `target_weights` and `trajectory_weights` + if target_visible is None: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + target_weights = target_visible + trajectory_weights = (1 / target[:, 2]) + else: + valid = target_visible > 0.5 + target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = target_weights + + if camera_param is None: + camera_param = dict() + + encoded = dict() + + target_label = target.copy() + # Zero-center the target pose around a given root keypoint + if self.zero_center: + assert target.ndim >= 2 and target.shape[-2] > self.root_index, \ + f'Got invalid joint shape {target.shape}' + + root = target[..., self.root_index, :] + target_label = target_label - root + encoded['target_root'] = root + + if self.remove_root: + target_label = np.delete( + target_label, self.root_index, axis=-2) + assert target_weights.ndim in {1, 2} + axis_to_remove = -2 if target_weights.ndim == 2 else -1 + target_weights = np.delete( + target_weights, self.root_index, axis=axis_to_remove) + # Add a flag to avoid latter transforms that rely on the root + # joint or the original joint index + encoded['target_root_removed'] = True + + # Save the root index for restoring the global pose + if self.save_index: + encoded['target_root_index'] = self.root_index + + # Normalize the 2D keypoint coordinate with image width and height + _camera_param = deepcopy(camera_param) + assert 'w' in _camera_param and 'h' in _camera_param + center = np.array([0.5 * _camera_param['w'], 0.5 * _camera_param['h']], + dtype=np.float32) + scale = np.array(0.5 * _camera_param['w'], dtype=np.float32) + + keypoint_labels = (keypoints - center) / scale + + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + + if self.normalize_camera: + assert 'f' in _camera_param and 'c' in _camera_param + _camera_param['f'] = _camera_param['f'] / scale + _camera_param['c'] = (_camera_param['c'] - center[:, None]) / scale + encoded['camera_param'] = _camera_param + + # Generate reshaped keypoint coordinates + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + + encoded['keypoint_labels'] = keypoint_labels + encoded['target_label'] = target_label + encoded['target_weights'] = target_weights + encoded['trajectory_weights'] = trajectory_weights + + return encoded + + def decode(self, + encoded: np.ndarray, + restore_global_position: bool = False, + target_root: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (1, K, C). + restore_global_position (bool): Whether to restore global position. + Default: ``False``. + target_root (np.ndarray, optional): The target root coordinate. + Default: ``None``. + + Returns: + keypoints (np.ndarray): Decoded coordinates in shape (1, K, C). + scores (np.ndarray): The keypoint scores in shape (1, K). + """ + keypoints = encoded.copy() + + if restore_global_position: + assert target_root is not None + keypoints = keypoints + np.expand_dims(target_root, axis=0) + if self.remove_root: + keypoints = np.insert( + keypoints, self.root_index, target_root, axis=1) + scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + + return keypoints, scores diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py index 61dae74b8c..7ccbf7dac2 100644 --- a/mmpose/datasets/transforms/__init__.py +++ b/mmpose/datasets/transforms/__init__.py @@ -8,6 +8,7 @@ from .converting import KeypointConverter from .formatting import PackPoseInputs from .loading import LoadImage +from .pose3d_transforms import RandomFlipAroundRoot from .topdown_transforms import TopdownAffine __all__ = [ @@ -15,5 +16,5 @@ 'RandomHalfBody', 'TopdownAffine', 'Albumentation', 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage', 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize', - 'GenerateTarget', 'KeypointConverter' + 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot' ] diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index dd9ad522f2..14be378e19 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -89,6 +89,8 @@ class PackPoseInputs(BaseTransform): 'bbox_score': 'bbox_scores', 'keypoints': 'keypoints', 'keypoints_visible': 'keypoints_visible', + 'target': 'target', + 'target_visible': 'target_visible', } # items in `label_mapping_table` will be packed into @@ -137,10 +139,12 @@ def transform(self, results: dict) -> dict: - 'data_samples' (obj:`PoseDataSample`): The annotation info of the sample. """ - # Pack image(s) + # Pack image(s) or 2d keypoints if 'img' in results: img = results['img'] - img_tensor = image_to_tensor(img) + inputs_tensor = image_to_tensor(img) + elif 'keypoints_3d' in results and 'keypoints' in results: + inputs_tensor = results['keypoints'] data_sample = PoseDataSample() @@ -202,7 +206,7 @@ def transform(self, results: dict) -> dict: data_sample.set_metainfo(img_meta) packed_results = dict() - packed_results['inputs'] = img_tensor + packed_results['inputs'] = inputs_tensor packed_results['data_samples'] = data_sample return packed_results diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py new file mode 100644 index 0000000000..4f86c247b9 --- /dev/null +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Dict + +import numpy as np +from mmcv.transforms import BaseTransform + +from mmpose.registry import TRANSFORMS +from mmpose.structures.keypoint import flip_keypoints_custom_center + + +@TRANSFORMS.register_module() +class RandomFlipAroundRoot(BaseTransform): + """Data augmentation with random horizontal joint flip around a root joint. + + Args: + keypoints_flip_cfg (dict): Configurations of the + ``flip_keypoints_custom_center`` function for ``keypoints``. Please + refer to the docstring of the ``flip_keypoints_custom_center`` + function for more details. + target_flip_cfg (dict): Configurations of the + ``flip_keypoints_custom_center`` function for ``target``. Please + refer to the docstring of the ``flip_keypoints_custom_center`` + function for more details. + flip_prob (float): Probability of flip. Default: 0.5. + flip_camera (bool): Whether to flip horizontal distortion coefficients. + Default: ``False``. + + Required keys: + keypoints + target + + Modified keys: + (keypoints, keypoints_visible, target, target_visible, camera_param) + """ + + def __init__(self, + keypoints_flip_cfg, + target_flip_cfg, + flip_prob=0.5, + flip_camera=False): + self.keypoints_flip_cfg = keypoints_flip_cfg + self.target_flip_cfg = target_flip_cfg + self.flip_prob = flip_prob + self.flip_camera = flip_camera + + def transform(self, results: Dict) -> dict: + """The transform function of :class:`ZeroCenterPose`. + + See ``transform()`` method of :class:`BaseTransform` for details. + + Args: + results (dict): The result dict + + Returns: + dict: The result dict. + """ + + keypoints = results['keypoints'] + if 'keypoints_visible' in results: + keypoints_visible = results['keypoints_visible'] + else: + keypoints_visible = np.ones(keypoints.shape[:-1], dtype=np.float32) + target = results['target'] + if 'target_visible' in results: + target_visible = results['target_visible'] + else: + target_visible = np.ones(target.shape[:-1], dtype=np.float32) + + if np.random.rand() <= self.flip_prob: + if 'flip_indices' not in results: + flip_indices = list(range(self.num_keypoints)) + else: + flip_indices = results['flip_indices'] + + # flip joint coordinates + keypoints, keypoints_visible = flip_keypoints_custom_center( + keypoints, keypoints_visible, flip_indices, + **self.keypoints_flip_cfg) + target, target_visible = flip_keypoints_custom_center( + target, target_visible, flip_indices, **self.target_flip_cfg) + + results['keypoints'] = keypoints + results['keypoints_visible'] = keypoints_visible + results['target'] = target + results['target_visible'] = target_visible + + # flip horizontal distortion coefficients + if self.flip_camera: + assert 'camera_param' in results, \ + 'Camera parameters are missing.' + _camera_param = deepcopy(results['camera_param']) + + assert 'c' in _camera_param + _camera_param['c'][0] *= -1 + + if 'p' in _camera_param: + _camera_param['p'][0] *= -1 + + results['camera_param'].update(_camera_param) + + return results diff --git a/mmpose/structures/keypoint/__init__.py b/mmpose/structures/keypoint/__init__.py index b8d5a24c7a..12ee96cf7c 100644 --- a/mmpose/structures/keypoint/__init__.py +++ b/mmpose/structures/keypoint/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .transforms import flip_keypoints +from .transforms import flip_keypoints, flip_keypoints_custom_center -__all__ = ['flip_keypoints'] +__all__ = ['flip_keypoints', 'flip_keypoints_custom_center'] diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py index 99adaa1306..b50da4f8fe 100644 --- a/mmpose/structures/keypoint/transforms.py +++ b/mmpose/structures/keypoint/transforms.py @@ -62,3 +62,60 @@ def flip_keypoints(keypoints: np.ndarray, keypoints = [w, h] - keypoints - 1 return keypoints, keypoints_visible + + +def flip_keypoints_custom_center(keypoints: np.ndarray, + keypoints_visible: np.ndarray, + flip_indices: List[int], + center_mode: str = 'static', + center_x: float = 0.5, + center_index: int = 0): + """Flip human joints horizontally. + + Note: + - num_keypoint: K + - dimension: D + + Args: + keypoints (np.ndarray([..., K, D])): Coordinates of keypoints. + keypoints_visible (np.ndarray([..., K])): Visibility item of keypoints. + flip_indices (list[int]): The indices to flip the keypoints. + center_mode (str): The mode to set the center location on the x-axis + to flip around. Options are: + + - static: use a static x value (see center_x also) + - root: use a root joint (see center_index also) + + Defaults: ``'static'``. + center_x (float): Set the x-axis location of the flip center. Only used + when ``center_mode`` is ``'static'``. Defaults: 0.5. + center_index (int): Set the index of the root joint, whose x location + will be used as the flip center. Only used when ``center_mode`` is + ``'root'``. Defaults: 0. + + Returns: + np.ndarray([..., K, C]): Flipped joints. + """ + + assert keypoints.ndim >= 2, f'Invalid pose shape {keypoints.shape}' + + allowed_center_mode = {'static', 'root'} + assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \ + f'{center_mode}, allowed choices are {allowed_center_mode}' + + if center_mode == 'static': + x_c = center_x + elif center_mode == 'root': + assert keypoints.shape[-2] > center_index + x_c = keypoints[..., center_index, 0] + + keypoints_flipped = keypoints.copy() + keypoints_visible_flipped = keypoints_visible.copy() + # Swap left-right parts + for left, right in enumerate(flip_indices): + keypoints_flipped[..., left, :] = keypoints[..., right, :] + keypoints_visible_flipped[..., left] = keypoints_visible[..., right] + + # Flip horizontally + keypoints_flipped[..., 0] = x_c * 2 - keypoints_flipped[..., 0] + return keypoints_flipped, keypoints_visible_flipped diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py new file mode 100644 index 0000000000..c54bf12d1e --- /dev/null +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmpose.codecs import ImagePoseLifting +from mmpose.registry import KEYPOINT_CODECS + + +class TestImagePoseLifting(TestCase): + + def setUp(self) -> None: + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(1, 17)) + target = (0.1 + 0.8 * np.random.rand(17, 3)) + target_visible = np.random.randint(2, size=(17, )) + encoded_wo_sigma = np.random.rand(1, 17, 3) + + self.keypoints_mean = np.random.rand(17, 2).astype(np.float32) + self.keypoints_std = np.random.rand(17, 2).astype(np.float32) + 1e-6 + self.target_mean = np.random.rand(17, 3).astype(np.float32) + self.target_std = np.random.rand(17, 3).astype(np.float32) + 1e-6 + + self.data = dict( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + target=target, + target_visible=target_visible, + encoded_wo_sigma=encoded_wo_sigma) + + def build_pose_lifting_label(self, **kwargs): + cfg = dict(type='ImagePoseLifting', num_keypoints=17, root_index=0) + cfg.update(kwargs) + return KEYPOINT_CODECS.build(cfg) + + def test_build(self): + codec = self.build_pose_lifting_label() + self.assertIsInstance(codec, ImagePoseLifting) + + def test_encode(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test removing root + codec = self.build_pose_lifting_label( + remove_root=True, save_index=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertTrue('target_root_removed' in encoded + and 'target_root_index' in encoded) + self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test normalization + codec = self.build_pose_lifting_label( + keypoints_mean=self.keypoints_mean, + keypoints_std=self.keypoints_std, + target_mean=self.target_mean, + target_std=self.target_std) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + + def test_decode(self): + target = self.data['target'] + encoded_wo_sigma = self.data['encoded_wo_sigma'] + + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + codec = self.build_pose_lifting_label(remove_root=True) + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 18, 3)) + self.assertEqual(scores.shape, (1, 18)) + + def test_cicular_verification(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test removing root + codec = self.build_pose_lifting_label(remove_root=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test normalization + codec = self.build_pose_lifting_label( + keypoints_mean=self.keypoints_mean, + keypoints_std=self.keypoints_std, + target_mean=self.target_mean, + target_std=self.target_std) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py new file mode 100644 index 0000000000..da404360bb --- /dev/null +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +from mmengine.fileio import load + +from mmpose.codecs import VideoPoseLifting +from mmpose.registry import KEYPOINT_CODECS + + +class TestVideoPoseLifting(TestCase): + + def get_camera_param(self, imgname, camera_param) -> dict: + """Get camera parameters of a frame by its image name.""" + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return camera_param[(subj, camera)] + + def build_pose_lifting_label(self, **kwargs): + cfg = dict(type='VideoPoseLifting', num_keypoints=17) + cfg.update(kwargs) + return KEYPOINT_CODECS.build(cfg) + + def setUp(self) -> None: + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(1, 17)) + target = (0.1 + 0.8 * np.random.rand(17, 3)) + target_visible = np.random.randint(2, size=(17, )) + encoded_wo_sigma = np.random.rand(1, 17, 3) + + camera_param = load('tests/data/h36m/cameras.pkl') + camera_param = self.get_camera_param( + 'S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg', + camera_param) + + self.data = dict( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + target=target, + target_visible=target_visible, + camera_param=camera_param, + encoded_wo_sigma=encoded_wo_sigma) + + def test_build(self): + codec = self.build_pose_lifting_label() + self.assertIsInstance(codec, VideoPoseLifting) + + def test_encode(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test not zero-centering + codec = self.build_pose_lifting_label(zero_center=False) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (17, 3)) + self.assertEqual(encoded['target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + + # test removing root + codec = self.build_pose_lifting_label( + remove_root=True, save_index=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertTrue('target_root_removed' in encoded + and 'target_root_index' in encoded) + self.assertEqual(encoded['target_weights'].shape, (16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (17 * 2, 1)) + self.assertEqual(encoded['target_label'].shape, (16, 3)) + self.assertEqual(encoded['target_root'].shape, (3, )) + + # test normalizing camera + codec = self.build_pose_lifting_label(normalize_camera=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + self.assertTrue('camera_param' in encoded) + scale = np.array(0.5 * camera_param['w'], dtype=np.float32) + self.assertTrue( + np.allclose( + camera_param['f'] / scale, + encoded['camera_param']['f'], + atol=4.)) + + def test_decode(self): + target = self.data['target'] + encoded_wo_sigma = self.data['encoded_wo_sigma'] + + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + codec = self.build_pose_lifting_label(remove_root=True) + + decoded, scores = codec.decode( + encoded_wo_sigma, + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 18, 3)) + self.assertEqual(scores.shape, (1, 18)) + + def test_cicular_verification(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + target = self.data['target'] + target_visible = self.data['target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) + + # test removing root + codec = self.build_pose_lifting_label(remove_root=True) + encoded = codec.encode(keypoints, keypoints_visible, target, + target_visible, camera_param) + + _keypoints, _ = codec.decode( + np.expand_dims(encoded['target_label'], axis=0), + restore_global_position=True, + target_root=target[..., 0, :]) + + self.assertTrue( + np.allclose(np.expand_dims(target, axis=0), _keypoints, atol=5.)) diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py new file mode 100644 index 0000000000..16118db272 --- /dev/null +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from copy import deepcopy +from unittest import TestCase + +import numpy as np +from mmengine.fileio import load + +from mmpose.datasets.transforms import RandomFlipAroundRoot + + +def get_h36m_sample(): + + def _parse_h36m_imgname(imgname): + """Parse imgname to get information of subject, action and camera. + + A typical h36m image filename is like: + S1_Directions_1.54138969_000001.jpg + """ + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return subj, action, camera + + ann_flle = 'tests/data/h36m/test_h36m_body3d.npz' + camera_param_file = 'tests/data/h36m/cameras.pkl' + + data = np.load(ann_flle) + cameras = load(camera_param_file) + + imgnames = data['imgname'] + keypoints = data['part'].astype(np.float32) + keypoints_3d = data['S'].astype(np.float32) + centers = data['center'].astype(np.float32) + scales = data['scale'].astype(np.float32) + + idx = 0 + target_idx = 0 + + data_info = { + 'keypoints': keypoints[idx, :, :2].reshape(1, -1, 2), + 'keypoints_visible': keypoints[idx, :, 2].reshape(1, -1), + 'keypoints_3d': keypoints_3d[idx, :, :3].reshape(1, -1, 3), + 'keypoints_3d_visible': keypoints_3d[idx, :, 3].reshape(1, -1), + 'scale': scales[idx], + 'center': centers[idx].astype(np.float32).reshape(1, -1), + 'id': idx, + 'img_ids': [idx], + 'img_paths': [imgnames[idx]], + 'category_id': 1, + 'iscrowd': 0, + 'sample_idx': idx, + 'target': keypoints_3d[target_idx, :, :3], + 'target_visible': keypoints_3d[target_idx, :, 3], + 'target_img_path': osp.join('tests/data/h36m', imgnames[target_idx]), + } + + # add camera parameters + subj, _, camera = _parse_h36m_imgname(imgnames[idx]) + data_info['camera_param'] = cameras[(subj, camera)] + + # add ann_info + ann_info = {} + ann_info['num_keypoints'] = 17 + ann_info['dataset_keypoint_weights'] = np.full(17, 1.0, dtype=np.float32) + ann_info['flip_pairs'] = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], + [13, 16]] + ann_info['skeleton_links'] = [] + ann_info['upper_body_ids'] = (0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + ann_info['lower_body_ids'] = (1, 2, 3, 4, 5, 6) + ann_info['flip_indices'] = [ + 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 + ] + + data_info.update(ann_info) + + return data_info + + +class TestRandomFlipAroundRoot(TestCase): + + def setUp(self): + self.data_info = get_h36m_sample() + self.keypoints_flip_cfg = dict(center_mode='static', center_x=0.) + self.target_flip_cfg = dict(center_mode='root', center_index=0) + + def test_init(self): + _ = RandomFlipAroundRoot( + self.keypoints_flip_cfg, + self.target_flip_cfg, + flip_prob=0.5, + flip_camera=False) + + def test_transform(self): + kpts1 = self.data_info['keypoints'] + kpts_vis1 = self.data_info['keypoints_visible'] + tar1 = self.data_info['target'] + tar_vis1 = self.data_info['target_visible'] + + transform = RandomFlipAroundRoot( + self.keypoints_flip_cfg, self.target_flip_cfg, flip_prob=1) + results = deepcopy(self.data_info) + results = transform(results) + + kpts2 = results['keypoints'] + kpts_vis2 = results['keypoints_visible'] + tar2 = results['target'] + tar_vis2 = results['target_visible'] + + self.assertEqual(kpts_vis2.shape, (1, 17)) + self.assertEqual(tar_vis2.shape, (17, )) + self.assertEqual(kpts2.shape, (1, 17, 2)) + self.assertEqual(tar2.shape, (17, 3)) + + flip_indices = [ + 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 + ] + for left, right in enumerate(flip_indices): + self.assertTrue( + np.allclose(-kpts1[0][left][:1], kpts2[0][right][:1], atol=4.)) + self.assertTrue( + np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) + self.assertTrue( + np.allclose(tar1[left][1:], tar2[right][1:], atol=4.)) + + self.assertTrue( + np.allclose(kpts_vis1[0][left], kpts_vis2[0][right], atol=4.)) + self.assertTrue( + np.allclose(tar_vis1[left], tar_vis2[right], atol=4.)) + + # test camera flipping + transform = RandomFlipAroundRoot( + self.keypoints_flip_cfg, + self.target_flip_cfg, + flip_prob=1, + flip_camera=True) + results = deepcopy(self.data_info) + results = transform(results) + + camera2 = results['camera_param'] + self.assertTrue( + np.allclose( + -self.data_info['camera_param']['c'][0], + camera2['c'][0], + atol=4.)) + self.assertTrue( + np.allclose( + -self.data_info['camera_param']['p'][0], + camera2['p'][0], + atol=4.))