diff --git a/demo/body3d_pose_lifter_demo.py b/demo/body3d_pose_lifter_demo.py index 8b834d08ca..02e3014f21 100644 --- a/demo/body3d_pose_lifter_demo.py +++ b/demo/body3d_pose_lifter_demo.py @@ -386,7 +386,7 @@ def main(): pose_results_2d = extract_pose_sequence( pose_est_results_list, frame_idx=frame_idx, - causal=pose_lifter.causal, + causal=pose_lift_dataset.get('causal', False), seq_len=pose_lift_dataset.get('seq_len', 1), step=pose_lift_dataset.get('seq_step', 1)) diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index 2ab81b20a4..f89c33c9ea 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -180,7 +180,7 @@ def inference_pose_lifter_model(model, init_default_scope(model.cfg.get('default_scope', 'mmpose')) pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) - causal = model.causal + causal = model.cfg.test_dataloader.dataset.get('causal', False) target_idx = -1 if causal else len(pose_results_2d) // 2 dataset_info = model.dataset_meta @@ -193,18 +193,21 @@ def inference_pose_lifter_model(model, bbox_scale = None for i, pose_res in enumerate(pose_results_2d): - keypoints = [] for j, data_sample in enumerate(pose_res): - keypoint = np.squeeze(data_sample.pred_instances.keypoints, axis=0) - if norm_pose_2d: - bbox = np.squeeze(data_sample.pred_instances.bboxes) - center = np.array([[(bbox[0] + bbox[2]) / 2, - (bbox[1] + bbox[3]) / 2]]) - scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - keypoints.append((keypoint[:, :2] - center) / scale * - bbox_scale + bbox_center) - else: - keypoints.append(keypoint[:, :2]) + kpts = data_sample.pred_instances.keypoints + bboxes = data_sample.pred_instances.bboxes + keypoints = [] + for k in range(len(kpts)): + kpt = kpts[k] + if norm_pose_2d: + bbox = bboxes[k] + center = np.array([[(bbox[0] + bbox[2]) / 2, + (bbox[1] + bbox[3]) / 2]]) + scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + keypoints.append((kpt[:, :2] - center) / scale * + bbox_scale + bbox_center) + else: + keypoints.append(kpt[:, :2]) pose_results_2d[i][j].pred_instances.keypoints = np.array( keypoints) diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 3a0cfc1cb3..8eccd10bf9 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -124,7 +124,7 @@ def _draw_3d_data_samples( plt.ioff() fig = plt.figure( - figsize=(vis_width * 0.01, vis_height * num_instances * 0.01)) + figsize=(vis_width * num_instances * 0.01, vis_height * 0.01)) def _draw_3d_instances_kpts(keypoints, scores, @@ -135,21 +135,13 @@ def _draw_3d_instances_kpts(keypoints, for idx, (kpts, score, visible) in enumerate( zip(keypoints, scores, keypoints_visible)): - valid = score >= kpt_thr + valid = np.logical_and(score >= kpt_thr, + np.any(~np.isnan(kpts), axis=-1)) ax = fig.add_subplot( 1, num_fig, fig_idx * (idx + 1), projection='3d') ax.view_init(elev=axis_elev, azim=axis_azimuth) - - x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 - y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 - - ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) - ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) ax.set_zlim3d([0, axis_limit]) - ax.set_xlabel('x') - ax.set_ylabel('y') - ax.set_zlabel('z') ax.set_aspect('auto') ax.set_xticks([]) ax.set_yticks([]) @@ -157,9 +149,16 @@ def _draw_3d_instances_kpts(keypoints, ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) + ax.scatter([0], [0], [0], marker='o', color='red') + if title: + ax.set_title(f'{title} ({idx})') ax.dist = axis_dist - ax.scatter([0], [0], [0], marker='o', color='red') + x_c = np.mean(kpts[valid, 0]) if valid.any() else 0 + y_c = np.mean(kpts[valid, 1]) if valid.any() else 0 + + ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2]) + ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2]) kpts = np.array(kpts, copy=False) @@ -208,9 +207,6 @@ def _draw_3d_instances_kpts(keypoints, ax.plot( xs_3d, ys_3d, zs_3d, color=_color, zdir='z') - if title: - ax.set_title(f'{title} ({idx})') - if 'keypoints' in pred_instances: keypoints = pred_instances.get('keypoints', pred_instances.keypoints) @@ -254,7 +250,9 @@ def _draw_3d_instances_kpts(keypoints, if not pred_img_data.any(): pred_img_data = np.full((vis_height, vis_width, 3), 255) else: - pred_img_data = pred_img_data.reshape(vis_height, vis_width, -1) + pred_img_data = pred_img_data.reshape(vis_height, + vis_width * num_instances, + -1) plt.close(fig)