Skip to content

Commit

Permalink
[Fix] Fix bugs in 3d human pose demo (#2413)
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM authored May 31, 2023
1 parent b3dec97 commit 2260fa2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
2 changes: 1 addition & 1 deletion demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def main():
pose_results_2d = extract_pose_sequence(
pose_est_results_list,
frame_idx=frame_idx,
causal=pose_lifter.causal,
causal=pose_lift_dataset.get('causal', False),
seq_len=pose_lift_dataset.get('seq_len', 1),
step=pose_lift_dataset.get('seq_step', 1))

Expand Down
27 changes: 15 additions & 12 deletions mmpose/apis/inference_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def inference_pose_lifter_model(model,
init_default_scope(model.cfg.get('default_scope', 'mmpose'))
pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)

causal = model.causal
causal = model.cfg.test_dataloader.dataset.get('causal', False)
target_idx = -1 if causal else len(pose_results_2d) // 2

dataset_info = model.dataset_meta
Expand All @@ -193,18 +193,21 @@ def inference_pose_lifter_model(model,
bbox_scale = None

for i, pose_res in enumerate(pose_results_2d):
keypoints = []
for j, data_sample in enumerate(pose_res):
keypoint = np.squeeze(data_sample.pred_instances.keypoints, axis=0)
if norm_pose_2d:
bbox = np.squeeze(data_sample.pred_instances.bboxes)
center = np.array([[(bbox[0] + bbox[2]) / 2,
(bbox[1] + bbox[3]) / 2]])
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
keypoints.append((keypoint[:, :2] - center) / scale *
bbox_scale + bbox_center)
else:
keypoints.append(keypoint[:, :2])
kpts = data_sample.pred_instances.keypoints
bboxes = data_sample.pred_instances.bboxes
keypoints = []
for k in range(len(kpts)):
kpt = kpts[k]
if norm_pose_2d:
bbox = bboxes[k]
center = np.array([[(bbox[0] + bbox[2]) / 2,
(bbox[1] + bbox[3]) / 2]])
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
keypoints.append((kpt[:, :2] - center) / scale *
bbox_scale + bbox_center)
else:
keypoints.append(kpt[:, :2])
pose_results_2d[i][j].pred_instances.keypoints = np.array(
keypoints)

Expand Down
30 changes: 14 additions & 16 deletions mmpose/visualization/local_visualizer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _draw_3d_data_samples(

plt.ioff()
fig = plt.figure(
figsize=(vis_width * 0.01, vis_height * num_instances * 0.01))
figsize=(vis_width * num_instances * 0.01, vis_height * 0.01))

def _draw_3d_instances_kpts(keypoints,
scores,
Expand All @@ -135,31 +135,30 @@ def _draw_3d_instances_kpts(keypoints,
for idx, (kpts, score, visible) in enumerate(
zip(keypoints, scores, keypoints_visible)):

valid = score >= kpt_thr
valid = np.logical_and(score >= kpt_thr,
np.any(~np.isnan(kpts), axis=-1))

ax = fig.add_subplot(
1, num_fig, fig_idx * (idx + 1), projection='3d')
ax.view_init(elev=axis_elev, azim=axis_azimuth)

x_c = np.mean(kpts[valid, 0]) if valid.any() else 0
y_c = np.mean(kpts[valid, 1]) if valid.any() else 0

ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2])
ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2])
ax.set_zlim3d([0, axis_limit])
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_aspect('auto')
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
ax.scatter([0], [0], [0], marker='o', color='red')
if title:
ax.set_title(f'{title} ({idx})')
ax.dist = axis_dist

ax.scatter([0], [0], [0], marker='o', color='red')
x_c = np.mean(kpts[valid, 0]) if valid.any() else 0
y_c = np.mean(kpts[valid, 1]) if valid.any() else 0

ax.set_xlim3d([x_c - axis_limit / 2, x_c + axis_limit / 2])
ax.set_ylim3d([y_c - axis_limit / 2, y_c + axis_limit / 2])

kpts = np.array(kpts, copy=False)

Expand Down Expand Up @@ -208,9 +207,6 @@ def _draw_3d_instances_kpts(keypoints,
ax.plot(
xs_3d, ys_3d, zs_3d, color=_color, zdir='z')

if title:
ax.set_title(f'{title} ({idx})')

if 'keypoints' in pred_instances:
keypoints = pred_instances.get('keypoints',
pred_instances.keypoints)
Expand Down Expand Up @@ -254,7 +250,9 @@ def _draw_3d_instances_kpts(keypoints,
if not pred_img_data.any():
pred_img_data = np.full((vis_height, vis_width, 3), 255)
else:
pred_img_data = pred_img_data.reshape(vis_height, vis_width, -1)
pred_img_data = pred_img_data.reshape(vis_height,
vis_width * num_instances,
-1)

plt.close(fig)

Expand Down

0 comments on commit 2260fa2

Please sign in to comment.