Skip to content

Commit

Permalink
fix codec
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM committed Jun 30, 2023
1 parent 522d4f9 commit 0b538b8
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions mmpose/codecs/image_pose_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ def __init__(self,
self.save_index = save_index
self.reshape_keypoints = reshape_keypoints
self.concat_vis = concat_vis
if keypoints_mean is not None and keypoints_std is not None:
if keypoints_mean is not None:
keypoints_mean = np.array(keypoints_mean).reshape(
1, num_keypoints, -1)
keypoints_std = np.array(keypoints_std).reshape(
1, num_keypoints, -1)
assert keypoints_std is not None
assert keypoints_mean.shape == keypoints_std.shape
if target_mean is not None and target_std is not None:
if target_mean is not None:
target_dim = num_keypoints - 1 if remove_root else num_keypoints
target_mean = np.array(target_mean).reshape(1, target_dim, -1)
target_std = np.array(target_std).reshape(1, target_dim, -1)
assert target_std is not None
assert target_mean.shape == target_std.shape
self.keypoints_mean = keypoints_mean
self.keypoints_std = keypoints_std
Expand Down Expand Up @@ -160,18 +169,17 @@ def encode(self,

# Normalize the 2D keypoint coordinate with mean and std
keypoint_labels = keypoints.copy()
if self.keypoints_mean is not None and self.keypoints_std is not None:
keypoints_shape = keypoints.shape
assert self.keypoints_mean.shape == keypoints_shape[1:]
if self.keypoints_mean is not None:
assert self.keypoints_mean.shape[1:] == keypoints.shape[1:]
encoded['keypoints_mean'] = self.keypoints_mean.copy()
encoded['keypoints_std'] = self.keypoints_std.copy()

keypoint_labels = (keypoint_labels -
self.keypoints_mean) / self.keypoints_std
if self.target_mean is not None and self.target_std is not None:
assert self.target_mean.ndim in {2, 3}
if self.target_mean.ndim == 2:
self.target_mean = self.target_mean[None, :]
target_shape = lifting_target_label.shape
assert self.target_mean.shape == target_shape
if self.target_mean is not None:
assert self.target_mean.shape == lifting_target_label.shape
encoded['target_mean'] = self.target_mean.copy()
encoded['target_std'] = self.target_std.copy()

lifting_target_label = (lifting_target_label -
self.target_mean) / self.target_std
Expand Down

0 comments on commit 0b538b8

Please sign in to comment.