diff --git a/configs/body/3d_kpt_sview_rgb_img/end2end/h36m/coarse2fine_h36m.py b/configs/body/3d_kpt_sview_rgb_img/end2end/h36m/coarse2fine_h36m.py new file mode 100644 index 0000000000..35c70aeef2 --- /dev/null +++ b/configs/body/3d_kpt_sview_rgb_img/end2end/h36m/coarse2fine_h36m.py @@ -0,0 +1,194 @@ +log_level = 'INFO' +load_from = None +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] +checkpoint_config = dict(interval=10) +evaluation = dict(interval=10, metric=['mpjpe', 'p-mpjpe'], save_best='MPJPE') + +# optimizer settings +optimizer = dict( + type='Adam', + lr=1e-3, +) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + by_epoch=False, + step=100000, + gamma=0.96, +) + +total_epochs = 200 + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + +# model settings +model = dict( + type='PoseLifter', + pretrained=None, + backbone=dict( + type='HourglassNet', + downsample_times=2, + num_stacks=5, + feat_channel=[1, 2, 4, 8, 64]), + keypoint_head=dict( + type='TemporalRegressionHead', + in_channels=1024, + num_joints=16, # do not predict root joint + loss_keypoint=dict(type='MSELoss')), + train_cfg=dict(), + test_cfg=dict(restore_global_position=True)) + +# data settings +data_root = 'data/h36m' +data_cfg = dict( + image_size=[256, 256], + heatmap_size=[64, 64, [1, 2, 4, 8, 64]], + heatmap3d_depth_bound=0.5, + num_joints=17, + seq_len=1, + seq_frame_interval=1, + causal=True, + joint_2d_src='gt', + need_camera_param=True, + camera_param_file=f'{data_root}/annotation_body3d/cameras.pkl', +) + +# 3D joint normalization parameters +# From file: '{data_root}/annotation_body3d/fps50/joint3d_rel_stats.pkl' +joint_3d_normalize_param = dict( + mean=[[-2.55652589e-04, -7.11960570e-03, -9.81433052e-04], + [-5.65463051e-03, 3.19636009e-01, 7.19329269e-02], + [-1.01705840e-02, 6.91147892e-01, 1.55352986e-01], + [2.55651315e-04, 7.11954606e-03, 9.81423866e-04], + [-5.09729780e-03, 3.27040413e-01, 7.22258095e-02], + [-9.99656606e-03, 7.08277383e-01, 1.58016408e-01], + [2.90583676e-03, -2.11363307e-01, -4.74210915e-02], + [5.67537804e-03, -4.35088906e-01, -9.76974016e-02], + [5.93884964e-03, -4.91891970e-01, -1.10666618e-01], + [7.37352083e-03, -5.83948619e-01, -1.31171400e-01], + [5.41920653e-03, -3.83931702e-01, -8.68145417e-02], + [2.95964662e-03, -1.87567488e-01, -4.34536934e-02], + [1.26585822e-03, -1.20170579e-01, -2.82526049e-02], + [4.67186639e-03, -3.83644089e-01, -8.55125784e-02], + [1.67648571e-03, -1.97007177e-01, -4.31368364e-02], + [8.70569015e-04, -1.68664569e-01, -3.73902498e-02]], + std=[[0.11072244, 0.02238818, 0.07246294], + [0.15856311, 0.18933832, 0.20880479], + [0.19179935, 0.24320062, 0.24756193], + [0.11072181, 0.02238805, 0.07246253], + [0.15880454, 0.19977188, 0.2147063], + [0.18001944, 0.25052739, 0.24853247], + [0.05210694, 0.05211406, 0.06908241], + [0.09515367, 0.10133032, 0.12899733], + [0.11742458, 0.12648469, 0.16465091], + [0.12360297, 0.13085539, 0.16433336], + [0.14602232, 0.09707956, 0.13952731], + [0.24347532, 0.12982249, 0.20230181], + [0.2446877, 0.21501816, 0.23938235], + [0.13876084, 0.1008926, 0.1424411], + [0.23687529, 0.14491219, 0.20980829], + [0.24400695, 0.23975028, 0.25520584]]) + +train_pipeline = [ + # dict( + # type='GetRootCenteredPose', + # item='target', + # visible_item='target_visible', + # root_index=0, + # root_name='root_position', + # remove_root=True), + # dict(type='LoadImageFromFile'), + # dict(type='ToTensor'), + # dict( + # type='NormalizeTensor', + # mean=[0.485, 0.456, 0.406], + # std=[0.229, 0.224, 0.225]), + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='GetRootCenteredPose', + item='target', + visible_item='target_visible', + root_index=0, + root_name='root_position', + remove_root=True), + # dict( + # type='CameraProjection', + # item='target', + # mode='camera_to_pixel', + # output_name='joint_2d_pixel', + # ), + dict( + type='Generate3DHeatmapTarget_h36m', + sigma=2.5, + max_bound=1, + ), + dict( + type='Collect', + keys=[('img', 'input'), 'target'], + meta_name='metas', + meta_keys=[ + 'target_image_path', + 'flip_pairs', + 'root_position', + 'root_position_index', + ]) +] + +val_pipeline = train_pipeline +test_pipeline = val_pipeline + +data = dict( + samples_per_gpu=3, + workers_per_gpu=0, + val_dataloader=dict(samples_per_gpu=64), + test_dataloader=dict(samples_per_gpu=64), + train=dict( + type='Body3DH36MDataset_E2E', + ann_file=f'{data_root}/annotation_body3d/fps50/h36m_train.npz', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=train_pipeline), + val=dict( + type='Body3DH36MDataset', + ann_file=f'{data_root}/annotation_body3d/fps50/h36m_test.npz', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=val_pipeline), + test=dict( + type='Body3DH36MDataset', + ann_file=f'{data_root}/annotation_body3d/fps50/h36m_test.npz', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=test_pipeline), +) diff --git a/mmpose/datasets/datasets/body3d/__init__.py b/mmpose/datasets/datasets/body3d/__init__.py index 3402755f6b..f029076d11 100644 --- a/mmpose/datasets/datasets/body3d/__init__.py +++ b/mmpose/datasets/datasets/body3d/__init__.py @@ -1,8 +1,11 @@ from .body3d_h36m_dataset import Body3DH36MDataset +from .body3d_h36m_end2end_dataset import Body3DH36MDataset_E2E from .body3d_mpi_inf_3dhp_dataset import Body3DMpiInf3dhpDataset from .body3d_semi_supervision_dataset import Body3DSemiSupervisionDataset __all__ = [ - 'Body3DH36MDataset', 'Body3DSemiSupervisionDataset', - 'Body3DMpiInf3dhpDataset' + 'Body3DH36MDataset', + 'Body3DSemiSupervisionDataset', + 'Body3DMpiInf3dhpDataset', + 'Body3DH36MDataset_E2E', ] diff --git a/mmpose/datasets/datasets/body3d/body3d_h36m_end2end_dataset.py b/mmpose/datasets/datasets/body3d/body3d_h36m_end2end_dataset.py new file mode 100644 index 0000000000..2107b23ef7 --- /dev/null +++ b/mmpose/datasets/datasets/body3d/body3d_h36m_end2end_dataset.py @@ -0,0 +1,334 @@ +import os.path as osp +from collections import OrderedDict, defaultdict + +import mmcv +import numpy as np + +from mmpose.core.evaluation import keypoint_mpjpe +from ...builder import DATASETS +from .body3d_base_dataset import Body3DBaseDataset + + +@DATASETS.register_module() +class Body3DH36MDataset_E2E(Body3DBaseDataset): + """Human3.6M dataset for 3D human pose estimation. + + `Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human + Sensing in Natural Environments' TPAMI`2014 + More details can be found in the `paper + `__. + + Human3.6M keypoint indexes:: + 0: 'root (pelvis)', + 1: 'right_hip', + 2: 'right_knee', + 3: 'right_foot', + 4: 'left_hip', + 5: 'left_knee', + 6: 'left_foot', + 7: 'spine', + 8: 'thorax', + 9: 'neck_base', + 10: 'head', + 11: 'left_shoulder', + 12: 'left_elbow', + 13: 'left_wrist', + 14: 'right_shoulder', + 15: 'right_elbow', + 16: 'right_wrist' + + + Args: + ann_file (str): Path to the annotation file. + img_prefix (str): Path to a directory where images are held. + Default: None. + data_cfg (dict): config + pipeline (list[dict | callable]): A sequence of data transforms. + test_mode (bool): Store True when building test or + validation dataset. Default: False. + """ + + JOINT_NAMES = [ + 'Root', 'RHip', 'RKnee', 'RFoot', 'LHip', 'LKnee', 'LFoot', 'Spine', + 'Thorax', 'NeckBase', 'Head', 'LShoulder', 'LElbow', 'LWrist', + 'RShoulder', 'RElbow', 'RWrist' + ] + + # 2D joint source options: + # "gt": from the annotation file + # "detection": from a detection result file of 2D keypoint + # "pipeline": will be generate by the pipeline + SUPPORTED_JOINT_2D_SRC = {'gt', 'detection', 'pipeline'} + + # metric + ALLOWED_METRICS = {'mpjpe', 'p-mpjpe', 'n-mpjpe'} + + def load_config(self, data_cfg): + super().load_config(data_cfg) + # h36m specific attributes + self.joint_2d_src = data_cfg.get('joint_2d_src', 'gt') + if self.joint_2d_src not in self.SUPPORTED_JOINT_2D_SRC: + raise ValueError( + f'Unsupported joint_2d_src "{self.joint_2d_src}". ' + f'Supported options are {self.SUPPORTED_JOINT_2D_SRC}') + + self.joint_2d_det_file = data_cfg.get('joint_2d_det_file', None) + + self.need_camera_param = data_cfg.get('need_camera_param', False) + if self.need_camera_param: + assert 'camera_param_file' in data_cfg + self.camera_param = self._load_camera_param( + data_cfg['camera_param_file']) + + # h36m specific annotation info + ann_info = {} + ann_info['flip_pairs'] = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], + [13, 16]] + ann_info['upper_body_ids'] = (0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + ann_info['lower_body_ids'] = (1, 2, 3, 4, 5, 6) + ann_info['use_different_joint_weights'] = False + ann_info['image_size'] = np.array(data_cfg['image_size']) + ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size']) + ann_info['heatmap3d_depth_bound'] = data_cfg['heatmap3d_depth_bound'] + + # action filter + actions = data_cfg.get('actions', '_all_') + self.actions = set( + actions if isinstance(actions, (list, tuple)) else [actions]) + + # subject filter + subjects = data_cfg.get('subjects', '_all_') + self.subjects = set( + subjects if isinstance(subjects, (list, tuple)) else [subjects]) + + self.ann_info.update(ann_info) + + def load_annotations(self): + data_info = super().load_annotations() + + # get 2D joints + if self.joint_2d_src == 'gt': + data_info['joints_2d'] = data_info['joints_2d'] + elif self.joint_2d_src == 'detection': + data_info['joints_2d'] = self._load_joint_2d_detection( + self.joint_2d_det_file) + assert data_info['joints_2d'].shape[0] == data_info[ + 'joints_3d'].shape[0] + assert data_info['joints_2d'].shape[2] == 3 + elif self.joint_2d_src == 'pipeline': + # joint_2d will be generated in the pipeline + pass + else: + raise NotImplementedError( + f'Unhandled joint_2d_src option {self.joint_2d_src}') + + # get a part set + data_info['imgnames'] = data_info['imgnames'][0:5000] + data_info['joints_3d'] = data_info['joints_3d'][0:5000] + data_info['joints_2d'] = data_info['joints_2d'][0:5000] + data_info['scales'] = data_info['scales'][0:5000] + data_info['centers'] = data_info['centers'][0:5000] + + return data_info + + def prepare_data(self, idx): + results = super().prepare_data(idx) + results['image_file'] = self.img_prefix + results['target_image_path'] + results['scale'] = np.squeeze(results['scales']) + results['center'] = np.squeeze(results['centers']) + results['rotation'] = 0 + results['joints_3d'] = np.pad(results['input_2d'][0], ((0, 0), (0, 1)), + 'constant') + results['joints_3d_visible'] = np.tile(results['input_2d_visible'][0], + 3) + return results + + @staticmethod + def _parse_h36m_imgname(imgname): + """Parse imgname to get information of subject, action and camera. + + A typical h36m image filename is like: + S1_Directions_1.54138969_000001.jpg + """ + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + + return subj, action, camera + + def build_sample_indices(self): + """Split original videos into sequences and build frame indices. + + This method overrides the default one in the base class. + """ + + # Group frames into videos. Assume that self.data_info is + # chronological. + video_frames = defaultdict(list) + for idx, imgname in enumerate(self.data_info['imgnames']): + subj, action, camera = self._parse_h36m_imgname(imgname) + + if '_all_' not in self.actions and action not in self.actions: + continue + + if '_all_' not in self.subjects and subj not in self.subjects: + continue + + video_frames[(subj, action, camera)].append(idx) + + # build sample indices + sample_indices = [] + _len = (self.seq_len - 1) * self.seq_frame_interval + 1 + _step = self.seq_frame_interval + for _, _indices in sorted(video_frames.items()): + n_frame = len(_indices) + + if self.temporal_padding: + # Pad the sequence so that every frame in the sequence will be + # predicted. + if self.causal: + frames_left = self.seq_len - 1 + frames_right = 0 + else: + frames_left = (self.seq_len - 1) // 2 + frames_right = frames_left + for i in range(n_frame): + pad_left = max(0, frames_left - i // _step) + pad_right = max(0, + frames_right - (n_frame - 1 - i) // _step) + start = max(i % _step, i - frames_left * _step) + end = min(n_frame - (n_frame - 1 - i) % _step, + i + frames_right * _step + 1) + sample_indices.append([_indices[0]] * pad_left + + _indices[start:end:_step] + + [_indices[-1]] * pad_right) + else: + seqs_from_video = [ + _indices[i:(i + _len):_step] + for i in range(0, n_frame - _len + 1) + ] + sample_indices.extend(seqs_from_video) + + # reduce dataset size if self.subset < 1 + assert 0 < self.subset <= 1 + subset_size = int(len(sample_indices) * self.subset) + start = np.random.randint(0, len(sample_indices) - subset_size + 1) + end = start + subset_size + + return sample_indices[start:end] + + def _load_joint_2d_detection(self, det_file): + """"Load 2D joint detection results from file.""" + joints_2d = np.load(det_file).astype(np.float32) + + return joints_2d + + def evaluate(self, + outputs, + res_folder, + metric='mpjpe', + logger=None, + **kwargs): + metrics = metric if isinstance(metric, list) else [metric] + for _metric in metrics: + if _metric not in self.ALLOWED_METRICS: + raise ValueError( + f'Unsupported metric "{_metric}" for human3.6 dataset.' + f'Supported metrics are {self.ALLOWED_METRICS}') + + res_file = osp.join(res_folder, 'result_keypoints.json') + kpts = [] + for output in outputs: + preds = output['preds'] + image_paths = output['target_image_paths'] + batch_size = len(image_paths) + for i in range(batch_size): + target_id = self.name2id[image_paths[i]] + kpts.append({ + 'keypoints': preds[i], + 'target_id': target_id, + }) + + mmcv.dump(kpts, res_file) + + name_value_tuples = [] + for _metric in metrics: + if _metric == 'mpjpe': + _nv_tuples = self._report_mpjpe(kpts) + elif _metric == 'p-mpjpe': + _nv_tuples = self._report_mpjpe(kpts, mode='p-mpjpe') + elif _metric == 'n-mpjpe': + _nv_tuples = self._report_mpjpe(kpts, mode='n-mpjpe') + else: + raise NotImplementedError + name_value_tuples.extend(_nv_tuples) + + return OrderedDict(name_value_tuples) + + def _report_mpjpe(self, keypoint_results, mode='mpjpe'): + """Cauculate mean per joint position error (MPJPE) or its variants like + P-MPJPE or N-MPJPE. + + Args: + keypoint_results (list): Keypoint predictions. See + 'Body3DH36MDataset.evaluate' for details. + mode (str): Specify mpjpe variants. Supported options are: + - ``'mpjpe'``: Standard MPJPE. + - ``'p-mpjpe'``: MPJPE after aligning prediction to groundtruth + via a rigid transformation (scale, rotation and + translation). + - ``'n-mpjpe'``: MPJPE after aligning prediction to groundtruth + in scale only. + """ + + preds = [] + gts = [] + masks = [] + action_category_indices = defaultdict(list) + for idx, result in enumerate(keypoint_results): + pred = result['keypoints'] + target_id = result['target_id'] + gt, gt_visible = np.split( + self.data_info['joints_3d'][target_id], [3], axis=-1) + preds.append(pred) + gts.append(gt) + masks.append(gt_visible) + + action = self._parse_h36m_imgname( + self.data_info['imgnames'][target_id])[1] + action_category = action.split('_')[0] + action_category_indices[action_category].append(idx) + + preds = np.stack(preds) + gts = np.stack(gts) + masks = np.stack(masks).squeeze(-1) > 0 + + err_name = mode.upper() + if mode == 'mpjpe': + alignment = 'none' + elif mode == 'p-mpjpe': + alignment = 'procrustes' + elif mode == 'n-mpjpe': + alignment = 'scale' + else: + raise ValueError(f'Invalid mode: {mode}') + + error = keypoint_mpjpe(preds, gts, masks, alignment) + name_value_tuples = [(err_name, error)] + + for action_category, indices in action_category_indices.items(): + _error = keypoint_mpjpe(preds[indices], gts[indices], + masks[indices]) + name_value_tuples.append((f'{err_name}_{action_category}', _error)) + + return name_value_tuples + + def _load_camera_param(self, camera_param_file): + """Load camera parameters from file.""" + return mmcv.load(camera_param_file) + + def get_camera_param(self, imgname): + """Get camera parameters of a frame by its image name.""" + assert hasattr(self, 'camera_param') + subj, _, camera = self._parse_h36m_imgname(imgname) + return self.camera_param[(subj, camera)] diff --git a/mmpose/datasets/pipelines/pose3d_transform.py b/mmpose/datasets/pipelines/pose3d_transform.py index 143b5e90d7..b28c903449 100644 --- a/mmpose/datasets/pipelines/pose3d_transform.py +++ b/mmpose/datasets/pipelines/pose3d_transform.py @@ -539,3 +539,98 @@ def __call__(self, results): results['target'] = target results['target_weight'] = target_weight return results + + +@PIPELINES.register_module() +class Generate3DHeatmapTarget_h36m: + """Generate the target 3d heatmap. + + Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'. + Modified keys: 'target', and 'target_weight'. + + Args: + sigma: Sigma of heatmap gaussian. + joint_indices (list): Indices of joints used for heatmap generation. + If None (default) is given, all joints will be used. + max_bound (float): The maximal value of heatmap. + """ + + def __init__(self, sigma=2, joint_indices=None, max_bound=1.0): + self.sigma = sigma + self.joint_indices = joint_indices + self.max_bound = max_bound + + def __call__(self, results): + """Generate the target heatmap.""" + joints_3d = np.concatenate( + (results['joints_3d'][1:, :2], results['target'][:, 2].reshape( + 16, 1)), + axis=1) + joints_3d_visible = results['target_visible'] + cfg = results['ann_info'] + image_size = cfg['image_size'] + W, H, D = cfg['heatmap_size'] + heatmap3d_depth_bound = cfg['heatmap3d_depth_bound'] + joint_weights = cfg['joint_weights'] + use_different_joint_weights = cfg['use_different_joint_weights'] + + # select the joints used for target generation + if self.joint_indices is not None: + joints_3d = joints_3d[self.joint_indices, ...] + joints_3d_visible = joints_3d_visible[self.joint_indices, ...] + joint_weights = joint_weights[self.joint_indices, ...] + num_joints = joints_3d.shape[0] + + results['target'] = [] + for i in range(len(D)): + # get the joint location in heatmap coordinates + mu_x = joints_3d[:, 0] * W / image_size[0] + mu_y = joints_3d[:, 1] * H / image_size[1] + mu_z = (joints_3d[:, 2] / heatmap3d_depth_bound + 0.5) * D[i] + + target = np.zeros([num_joints, D[i], H, W], dtype=np.float32) + + target_weight = joints_3d_visible[:, 0].astype(np.float32) + target_weight = target_weight * (mu_z >= 0) * (mu_z < D[i]) + if use_different_joint_weights: + target_weight = target_weight * joint_weights + target_weight = target_weight[:, None] + + # only compute the voxel value near the joints location + tmp_size = 3 * self.sigma + + # get neighboring voxels coordinates + x = y = z = np.arange( + 2 * tmp_size + 1, dtype=np.float32) - tmp_size + zz, yy, xx = np.meshgrid(z, y, x) + xx = xx[None, ...].astype(np.float32) + yy = yy[None, ...].astype(np.float32) + zz = zz[None, ...].astype(np.float32) + mu_x = mu_x[..., None, None, None] + mu_y = mu_y[..., None, None, None] + mu_z = mu_z[..., None, None, None] + xx, yy, zz = xx + mu_x, yy + mu_y, zz + mu_z + + # round the coordinates + xx = xx.round().clip(0, W - 1) + yy = yy.round().clip(0, H - 1) + zz = zz.round().clip(0, D[i] - 1) + + # compute the target value near joints + local_target = \ + np.exp(-((xx - mu_x)**2 + (yy - mu_y)**2 + (zz - mu_z)**2) / + (2 * self.sigma**2)) + + # put the local target value to the full target heatmap + local_size = xx.shape[1] + idx_joints = np.tile( + np.arange(num_joints)[:, None, None, None], + [1, local_size, local_size, local_size]) + idx = np.stack([idx_joints, zz, yy, xx], + axis=-1).astype(np.long).reshape(-1, 4) + target[idx[:, 0], idx[:, 1], idx[:, 2], + idx[:, 3]] = local_target.reshape(-1) + target = target * self.max_bound + results['target'].append(target) + results['target_weight'] = target_weight + return results diff --git a/tests/test_datasets/test_body3d_dataset.py b/tests/test_datasets/test_body3d_dataset.py index ffe0cc3dde..fa03bbe86c 100644 --- a/tests/test_datasets/test_body3d_dataset.py +++ b/tests/test_datasets/test_body3d_dataset.py @@ -255,3 +255,55 @@ def test_body3d_mpi_inf_3dhp_dataset(): np.testing.assert_almost_equal(infos['P-3DPCK'], 100.) np.testing.assert_almost_equal(infos['3DAUC'], 30 / 31 * 100) np.testing.assert_almost_equal(infos['P-3DAUC'], 30 / 31 * 100) + + +def test_body3d_h36m_dataset_E2E(): + # Test Human3.6M dataset using end-to-end + dataset = 'Body3DH36MDataset_E2E' + dataset_class = DATASETS.get(dataset) + + # test single-frame input + data_cfg = dict( + image_size=[256, 256], + heatmap_size=[64, 64, [1, 2, 4, 8, 64]], + heatmap3d_depth_bound=0.5, + num_joints=17, + seq_len=1, + seq_frame_interval=1, + joint_2d_src='pipeline', + joint_2d_det_file=None, + causal=False, + need_camera_param=True, + camera_param_file='tests/data/h36m/cameras.pkl') + + _ = dataset_class( + ann_file='tests/data/h36m/test_h36m_body3d.npz', + img_prefix='tests/data/h36m', + data_cfg=data_cfg, + pipeline=[], + test_mode=False) + + custom_dataset = dataset_class( + ann_file='tests/data/h36m/test_h36m_body3d.npz', + img_prefix='tests/data/h36m', + data_cfg=data_cfg, + pipeline=[], + test_mode=True) + + assert custom_dataset.test_mode is True + _ = custom_dataset[0] + + with tempfile.TemporaryDirectory() as tmpdir: + outputs = [] + for result in custom_dataset: + outputs.append({ + 'preds': result['target'][None, ...], + 'target_image_paths': [result['target_image_path']], + }) + + metrics = ['mpjpe', 'p-mpjpe', 'n-mpjpe'] + infos = custom_dataset.evaluate(outputs, tmpdir, metrics) + + np.testing.assert_almost_equal(infos['MPJPE'], 0.0) + np.testing.assert_almost_equal(infos['P-MPJPE'], 0.0) + np.testing.assert_almost_equal(infos['N-MPJPE'], 0.0)