Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Visualization problems in 3d demo #2594

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def parse_args():
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
help='Whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
Expand Down Expand Up @@ -124,7 +124,14 @@ def parse_args():
'--use-multi-frames',
action='store_true',
default=False,
help='whether to use multi frames for inference in the 2D pose'
help='Whether to use multi frames for inference in the 2D pose'
'detection stage. Default: False.')
parser.add_argument(
'--online',
action='store_true',
default=False,
help='Inference mode. If set to True, can not use future frame'
'information when using multi frames for inference in the 2D pose'
'detection stage. Default: False.')

args = parser.parse_args()
Expand Down Expand Up @@ -405,6 +412,10 @@ def main():
'Only "PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'

if args.use_multi_frames:
assert 'frame_indices_test' in pose_estimator.cfg.data.test.data_cfg
indices = pose_estimator.cfg.data.test.data_cfg['frame_indices_test']

pose_lifter.cfg.visualizer.radius = args.radius
pose_lifter.cfg.visualizer.line_width = args.thickness
pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color
Expand Down
30 changes: 17 additions & 13 deletions mmpose/apis/inference_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,11 @@ def collate_pose_sequence(pose_results_2d,
pose_sequences = []
for idx in range(N):
pose_seq = PoseDataSample()
gt_instances = InstanceData()
pred_instances = InstanceData()

for k in pose_results_2d[target_frame][idx].gt_instances.keys():
gt_instances.set_field(
pose_results_2d[target_frame][idx].gt_instances[k], k)
for k in pose_results_2d[target_frame][idx].pred_instances.keys():
if k != 'keypoints':
pred_instances.set_field(
pose_results_2d[target_frame][idx].pred_instances[k], k)
gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone()
pred_instances = pose_results_2d[target_frame][
idx].pred_instances.clone()
pose_seq.pred_instances = pred_instances
pose_seq.gt_instances = gt_instances

Expand Down Expand Up @@ -228,7 +223,7 @@ def collate_pose_sequence(pose_results_2d,
# replicate the right most frame
keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx]
break
pose_seq.pred_instances.keypoints = keypoints
pose_seq.pred_instances.set_field(keypoints, 'keypoints')
pose_sequences.append(pose_seq)

return pose_sequences
Expand Down Expand Up @@ -276,8 +271,15 @@ def inference_pose_lifter_model(model,
bbox_center = None
bbox_scale = None

pose_results_2d_copy = []
for i, pose_res in enumerate(pose_results_2d):
pose_res_copy = []
for j, data_sample in enumerate(pose_res):
data_sample_copy = PoseDataSample()
data_sample_copy.gt_instances = data_sample.gt_instances.clone()
data_sample_copy.pred_instances = data_sample.pred_instances.clone(
)
data_sample_copy.track_id = data_sample.track_id
kpts = data_sample.pred_instances.keypoints
bboxes = data_sample.pred_instances.bboxes
keypoints = []
Expand All @@ -292,11 +294,13 @@ def inference_pose_lifter_model(model,
bbox_scale + bbox_center)
else:
keypoints.append(kpt[:, :2])
pose_results_2d[i][j].pred_instances.keypoints = np.array(
keypoints)
data_sample_copy.pred_instances.set_field(
np.array(keypoints), 'keypoints')
pose_res_copy.append(data_sample_copy)
pose_results_2d_copy.append(pose_res_copy)

pose_sequences_2d = collate_pose_sequence(pose_results_2d, with_track_id,
target_idx)
pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy,
with_track_id, target_idx)

if not pose_sequences_2d:
return []
Expand Down