diff --git a/mmpose/structures/utils.py b/mmpose/structures/utils.py index 882cda8603..616b139c54 100644 --- a/mmpose/structures/utils.py +++ b/mmpose/structures/utils.py @@ -50,8 +50,7 @@ def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: 0].pred_fields: reverted_heatmaps = [ revert_heatmap(data_sample.pred_fields.heatmaps, - data_sample.gt_instances.bbox_centers, - data_sample.gt_instances.bbox_scales, + data_sample.input_center, data_sample.input_scale, data_sample.ori_shape) for data_sample in data_samples ] @@ -65,8 +64,7 @@ def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: 0].gt_fields: reverted_heatmaps = [ revert_heatmap(data_sample.gt_fields.heatmaps, - data_sample.gt_instances.bbox_centers, - data_sample.gt_instances.bbox_scales, + data_sample.input_center, data_sample.input_scale, data_sample.ori_shape) for data_sample in data_samples ] @@ -79,13 +77,13 @@ def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: return merged -def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape): +def revert_heatmap(heatmap, input_center, input_scale, img_shape): """Revert predicted heatmap on the original image. Args: heatmap (np.ndarray or torch.tensor): predicted heatmap. - bbox_center (np.ndarray): bounding box center coordinate. - bbox_scale (np.ndarray): bounding box scale. + input_center (np.ndarray): bounding box center coordinate. + input_scale (np.ndarray): bounding box scale. img_shape (tuple or list): size of original image. """ if torch.is_tensor(heatmap): @@ -99,8 +97,8 @@ def revert_heatmap(heatmap, bbox_center, bbox_scale, img_shape): hm_h, hm_w = heatmap.shape[:2] img_h, img_w = img_shape warp_mat = get_warp_matrix( - bbox_center.reshape((2, )), - bbox_scale.reshape((2, )), + input_center.reshape((2, )), + input_scale.reshape((2, )), rot=0, output_size=(hm_w, hm_h), inv=True)