From 4164f6a0de8b83458466f1e0e7ce17f93f24ee80 Mon Sep 17 00:00:00 2001 From: Zhongang Cai <62529255+caizhongang@users.noreply.github.com> Date: Wed, 15 Jun 2022 20:22:42 +0800 Subject: [PATCH] Fix missing arguments in SMPLify pipeline (#198) * Fix missing arguments in SMPLify pipeline * Minor fix * Fix cv2 attributes --- .../torch3d_renderer/smpl_renderer.py | 8 +- mmhuman3d/data/data_structures/smc_reader.py | 4 +- mmhuman3d/models/registrants/smplify.py | 5 ++ mmhuman3d/models/registrants/smplifyx.py | 75 ++++++++++++------- 4 files changed, 61 insertions(+), 31 deletions(-) diff --git a/mmhuman3d/core/visualization/renderer/torch3d_renderer/smpl_renderer.py b/mmhuman3d/core/visualization/renderer/torch3d_renderer/smpl_renderer.py index 4e1999d4..bc68d8cb 100644 --- a/mmhuman3d/core/visualization/renderer/torch3d_renderer/smpl_renderer.py +++ b/mmhuman3d/core/visualization/renderer/torch3d_renderer/smpl_renderer.py @@ -236,7 +236,7 @@ def forward( verts_rgba=joints_rgb_padded.to(self.device), cameras=cameras) - pointcloud_rgb, = pointcloud_images[..., :3] + pointcloud_rgb = pointcloud_images[..., :3] pointcloud_bgr = rgb2bgr(pointcloud_rgb) pointcloud_mask = (pointcloud_images[..., 3:] > 0) * 1.0 output_images = output_images * ( @@ -265,7 +265,11 @@ def forward( # return if self.return_tensor: - rendered_map = rendered_tensor + + if images is not None: + rendered_map = torch.tensor(output_images) + else: + rendered_map = rendered_tensor if self.final_resolution != self.resolution: rendered_map = interpolate( diff --git a/mmhuman3d/data/data_structures/smc_reader.py b/mmhuman3d/data/data_structures/smc_reader.py index 121e3810..537d06d5 100644 --- a/mmhuman3d/data/data_structures/smc_reader.py +++ b/mmhuman3d/data/data_structures/smc_reader.py @@ -504,7 +504,7 @@ def get_iphone_color(self, frame = self.__read_color_from_bytes__( self.smc['iPhone'][str(iphone_id)]['Color'][str(i)][()]) if vertical: - frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) frames.append(frame) return np.stack(frames, axis=0) @@ -549,7 +549,7 @@ def get_iphone_depth(self, for i in tqdm.tqdm(frame_list, disable=disable_tqdm): frame = self.smc['iPhone'][str(iphone_id)]['Depth'][str(i)][()] if vertical: - frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) frames.append(frame) return np.stack(frames, axis=0) diff --git a/mmhuman3d/models/registrants/smplify.py b/mmhuman3d/models/registrants/smplify.py index b27faf64..5e173856 100644 --- a/mmhuman3d/models/registrants/smplify.py +++ b/mmhuman3d/models/registrants/smplify.py @@ -326,8 +326,11 @@ def _optimize_stage(self, joint_prior_weight: weight of joint prior loss smooth_loss_weight: weight of smooth loss pose_prior_weight: weight of pose prior loss + pose_reg_weight: weight of pose regularization loss + limb_length_weight: weight of limb length loss joint_weights: per joint weight of shape (K, ) num_iter: number of iterations + ftol: early stop tolerance for relative change in loss Returns: None @@ -428,6 +431,8 @@ def evaluate( joint_prior_weight: weight of joint prior loss smooth_loss_weight: weight of smooth loss pose_prior_weight: weight of pose prior loss + pose_reg_weight: weight of pose regularization loss + limb_length_weight: weight of limb length loss joint_weights: per joint weight of shape (K, ) return_verts: whether to return vertices return_joints: whether to return joints diff --git a/mmhuman3d/models/registrants/smplifyx.py b/mmhuman3d/models/registrants/smplifyx.py index ea276af2..6ff21372 100644 --- a/mmhuman3d/models/registrants/smplifyx.py +++ b/mmhuman3d/models/registrants/smplifyx.py @@ -173,7 +173,10 @@ def _optimize_stage(self, joint_prior_weight: float = None, smooth_loss_weight: float = None, pose_prior_weight: float = None, + pose_reg_weight: float = None, + limb_length_weight: float = None, joint_weights: dict = {}, + ftol: float = 1e-4, num_iter: int = 1) -> None: """Optimize a stage of body model parameters according to configuration. @@ -208,8 +211,11 @@ def _optimize_stage(self, joint_prior_weight: weight of joint prior loss smooth_loss_weight: weight of smooth loss pose_prior_weight: weight of pose prior loss + pose_reg_weight: weight of pose regularization loss + limb_length_weight: weight of limb length loss joint_weights: per joint weight of shape (K, ) num_iter: number of iterations + ftol: early stop tolerance for relative change in loss Returns: None @@ -229,6 +235,7 @@ def _optimize_stage(self, optimizer = build_optimizer(parameters, self.optimizer) + pre_loss = None for iter_idx in range(num_iter): def closure(): @@ -259,41 +266,52 @@ def closure(): shape_prior_weight=shape_prior_weight, smooth_loss_weight=smooth_loss_weight, pose_prior_weight=pose_prior_weight, + pose_reg_weight=pose_reg_weight, + limb_length_weight=limb_length_weight, joint_weights=joint_weights) loss = loss_dict['total_loss'] loss.backward() return loss - optimizer.step(closure) + loss = optimizer.step(closure) + if iter_idx > 0 and pre_loss is not None and ftol > 0: + loss_rel_change = self._compute_relative_change( + pre_loss, loss.item()) + if loss_rel_change < ftol: + print(f'[ftol={ftol}] Early stop at {iter_idx} iter!') + break + pre_loss = loss.item() def evaluate( self, - betas=None, - body_pose=None, - global_orient=None, - transl=None, - left_hand_pose=None, - right_hand_pose=None, - expression=None, - jaw_pose=None, - leye_pose=None, - reye_pose=None, - keypoints2d=None, - keypoints2d_conf=None, - keypoints2d_weight=None, - keypoints3d=None, - keypoints3d_conf=None, - keypoints3d_weight=None, - shape_prior_weight=None, - joint_prior_weight=None, - smooth_loss_weight=None, - pose_prior_weight=None, - joint_weights={}, - return_verts=False, - return_full_pose=False, - return_joints=False, - reduction_override=None, + betas: torch.Tensor = None, + body_pose: torch.Tensor = None, + global_orient: torch.Tensor = None, + transl: torch.Tensor = None, + left_hand_pose: torch.Tensor = None, + right_hand_pose: torch.Tensor = None, + expression: torch.Tensor = None, + jaw_pose: torch.Tensor = None, + leye_pose: torch.Tensor = None, + reye_pose: torch.Tensor = None, + keypoints2d: torch.Tensor = None, + keypoints2d_conf: torch.Tensor = None, + keypoints2d_weight: float = None, + keypoints3d: torch.Tensor = None, + keypoints3d_conf: torch.Tensor = None, + keypoints3d_weight: float = None, + shape_prior_weight: float = None, + joint_prior_weight: float = None, + smooth_loss_weight: float = None, + pose_prior_weight: float = None, + pose_reg_weight: float = None, + limb_length_weight: float = None, + joint_weights: dict = {}, + return_verts: bool = False, + return_full_pose: bool = False, + return_joints: bool = False, + reduction_override: str = None, ): """Evaluate fitted parameters through loss computation. This function serves two purposes: 1) internally, for loss backpropagation 2) @@ -327,6 +345,8 @@ def evaluate( joint_prior_weight: weight of joint prior loss smooth_loss_weight: weight of smooth loss pose_prior_weight: weight of pose prior loss + pose_reg_weight: weight of pose regularization loss + limb_length_weight: weight of limb length loss joint_weights: per joint weight of shape (K, ) return_verts: whether to return vertices return_joints: whether to return joints @@ -370,6 +390,8 @@ def evaluate( shape_prior_weight=shape_prior_weight, smooth_loss_weight=smooth_loss_weight, pose_prior_weight=pose_prior_weight, + pose_reg_weight=pose_reg_weight, + limb_length_weight=limb_length_weight, joint_weights=joint_weights, reduction_override=reduction_override, body_pose=body_pose, @@ -449,7 +471,6 @@ def _get_weight(self, Returns: weight: per keypoint weight tensor of shape (K) """ - num_keypoint = self.body_model.num_joints if use_shoulder_hip_only: