From cc31a7c689f91dcc70055a74ad7bd7c20183d3fc Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 3 Aug 2023 16:10:50 +0800 Subject: [PATCH] fix problems --- demo/body3d_pose_lifter_demo.py | 15 +++++++++++++-- mmpose/apis/inference_3d.py | 30 +++++++++++++++++------------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/demo/body3d_pose_lifter_demo.py b/demo/body3d_pose_lifter_demo.py index d04fca9f3b..3c36d3a88b 100644 --- a/demo/body3d_pose_lifter_demo.py +++ b/demo/body3d_pose_lifter_demo.py @@ -90,7 +90,7 @@ def parse_args(): '--save-predictions', action='store_true', default=False, - help='whether to save predicted results') + help='Whether to save predicted results') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -124,7 +124,14 @@ def parse_args(): '--use-multi-frames', action='store_true', default=False, - help='whether to use multi frames for inference in the 2D pose' + help='Whether to use multi frames for inference in the 2D pose' + 'detection stage. Default: False.') + parser.add_argument( + '--online', + action='store_true', + default=False, + help='Inference mode. If set to True, can not use future frame' + 'information when using multi frames for inference in the 2D pose' 'detection stage. Default: False.') args = parser.parse_args() @@ -405,6 +412,10 @@ def main(): 'Only "PoseLifter" model is supported for the 2nd stage ' \ '(2D-to-3D lifting)' + if args.use_multi_frames: + assert 'frame_indices_test' in pose_estimator.cfg.data.test.data_cfg + indices = pose_estimator.cfg.data.test.data_cfg['frame_indices_test'] + pose_lifter.cfg.visualizer.radius = args.radius pose_lifter.cfg.visualizer.line_width = args.thickness pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index d4b9623b86..303cfd0713 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -181,16 +181,11 @@ def collate_pose_sequence(pose_results_2d, pose_sequences = [] for idx in range(N): pose_seq = PoseDataSample() - gt_instances = InstanceData() pred_instances = InstanceData() - for k in pose_results_2d[target_frame][idx].gt_instances.keys(): - gt_instances.set_field( - pose_results_2d[target_frame][idx].gt_instances[k], k) - for k in pose_results_2d[target_frame][idx].pred_instances.keys(): - if k != 'keypoints': - pred_instances.set_field( - pose_results_2d[target_frame][idx].pred_instances[k], k) + gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone() + pred_instances = pose_results_2d[target_frame][ + idx].pred_instances.clone() pose_seq.pred_instances = pred_instances pose_seq.gt_instances = gt_instances @@ -228,7 +223,7 @@ def collate_pose_sequence(pose_results_2d, # replicate the right most frame keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx] break - pose_seq.pred_instances.keypoints = keypoints + pose_seq.pred_instances.set_field(keypoints, 'keypoints') pose_sequences.append(pose_seq) return pose_sequences @@ -276,8 +271,15 @@ def inference_pose_lifter_model(model, bbox_center = None bbox_scale = None + pose_results_2d_copy = [] for i, pose_res in enumerate(pose_results_2d): + pose_res_copy = [] for j, data_sample in enumerate(pose_res): + data_sample_copy = PoseDataSample() + data_sample_copy.gt_instances = data_sample.gt_instances.clone() + data_sample_copy.pred_instances = data_sample.pred_instances.clone( + ) + data_sample_copy.track_id = data_sample.track_id kpts = data_sample.pred_instances.keypoints bboxes = data_sample.pred_instances.bboxes keypoints = [] @@ -292,11 +294,13 @@ def inference_pose_lifter_model(model, bbox_scale + bbox_center) else: keypoints.append(kpt[:, :2]) - pose_results_2d[i][j].pred_instances.keypoints = np.array( - keypoints) + data_sample_copy.pred_instances.set_field( + np.array(keypoints), 'keypoints') + pose_res_copy.append(data_sample_copy) + pose_results_2d_copy.append(pose_res_copy) - pose_sequences_2d = collate_pose_sequence(pose_results_2d, with_track_id, - target_idx) + pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy, + with_track_id, target_idx) if not pose_sequences_2d: return []