From 0ecf06e3580f141f6ab44645768a0d6d8ba48383 Mon Sep 17 00:00:00 2001 From: zengwang430521 Date: Thu, 19 Aug 2021 18:22:49 +0800 Subject: [PATCH] Unify SMPL-like models to mesh models (#830) * Add unified SMPL-like model interface and builder --- .../hmr/mixed/res50_mixed_224x224.py | 1 + mmpose/models/__init__.py | 12 +- mmpose/models/builder.py | 6 + mmpose/models/detectors/mesh.py | 39 ++-- mmpose/models/heads/hmr_head.py | 3 +- mmpose/models/utils/__init__.py | 3 + mmpose/models/utils/smpl.py | 183 ++++++++++++++++++ tests/test_external_model/test_smpl.py | 70 +++++++ tests/test_model/test_mesh_forward.py | 9 +- 9 files changed, 289 insertions(+), 37 deletions(-) create mode 100644 mmpose/models/utils/smpl.py create mode 100644 tests/test_external_model/test_smpl.py diff --git a/configs/body/3d_mesh_sview_rgb_img/hmr/mixed/res50_mixed_224x224.py b/configs/body/3d_mesh_sview_rgb_img/hmr/mixed/res50_mixed_224x224.py index f3a64f26ba..56f3e95aa5 100644 --- a/configs/body/3d_mesh_sview_rgb_img/hmr/mixed/res50_mixed_224x224.py +++ b/configs/body/3d_mesh_sview_rgb_img/hmr/mixed/res50_mixed_224x224.py @@ -37,6 +37,7 @@ ), disc=dict(), smpl=dict( + type='SMPL', smpl_path='models/smpl', joints_regressor='models/smpl/joints_regressor_cmr.npy'), train_cfg=dict(disc_step=1), diff --git a/mmpose/models/__init__.py b/mmpose/models/__init__.py index 00eed6ae08..63e7f89b32 100644 --- a/mmpose/models/__init__.py +++ b/mmpose/models/__init__.py @@ -1,13 +1,15 @@ from .backbones import * # noqa -from .builder import (BACKBONES, HEADS, LOSSES, NECKS, POSENETS, - build_backbone, build_head, build_loss, build_neck, - build_posenet) +from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS, + build_backbone, build_head, build_loss, build_mesh_model, + build_neck, build_posenet) from .detectors import * # noqa from .heads import * # noqa from .losses import * # noqa from .necks import * # noqa +from .utils import * # noqa __all__ = [ - 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'build_backbone', - 'build_head', 'build_loss', 'build_posenet', 'build_neck' + 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'MESH_MODELS', + 'build_backbone', 'build_head', 'build_loss', 'build_posenet', + 'build_neck', 'build_mesh_model' ] diff --git a/mmpose/models/builder.py b/mmpose/models/builder.py index 3dfec4983f..78db853301 100644 --- a/mmpose/models/builder.py +++ b/mmpose/models/builder.py @@ -8,6 +8,7 @@ HEADS = MODELS LOSSES = MODELS POSENETS = MODELS +MESH_MODELS = MODELS def build_backbone(cfg): @@ -33,3 +34,8 @@ def build_loss(cfg): def build_posenet(cfg): """Build posenet.""" return POSENETS.build(cfg) + + +def build_mesh_model(cfg): + """Build mesh model.""" + return MESH_MODELS.build(cfg) diff --git a/mmpose/models/detectors/mesh.py b/mmpose/models/detectors/mesh.py index ef9dd517be..16a7892447 100644 --- a/mmpose/models/detectors/mesh.py +++ b/mmpose/models/detectors/mesh.py @@ -9,12 +9,6 @@ from ..builder import POSENETS from .base import BasePose -try: - from smplx import SMPL - has_smpl = True -except (ImportError, ModuleNotFoundError): - has_smpl = False - def set_requires_grad(nets, requires_grad=False): """Set requies_grad for all the networks. @@ -61,22 +55,11 @@ def __init__(self, pretrained=None): super().__init__() - assert has_smpl, 'Please install smplx to use SMPL.' - self.backbone = builder.build_backbone(backbone) self.mesh_head = builder.build_head(mesh_head) self.generator = torch.nn.Sequential(self.backbone, self.mesh_head) - self.smpl = SMPL( - model_path=smpl['smpl_path'], - create_betas=False, - create_global_orient=False, - create_body_pose=False, - create_transl=False) - - joints_regressor = torch.tensor( - np.load(smpl['joints_regressor']), dtype=torch.float).unsqueeze(0) - self.register_buffer('joints_regressor', joints_regressor) + self.smpl = builder.build_mesh_model(smpl) self.with_gan = disc is not None and loss_gan is not None if self.with_gan: @@ -161,16 +144,17 @@ def train_step(self, data_batch, optimizer, **kwargs): pred_out = self.smpl( betas=pred_beta, body_pose=pred_pose[:, 1:], - global_orient=pred_pose[:, :1], - pose2rot=False) - pred_vertices = pred_out.vertices - pred_joints_3d = self.get_3d_joints_from_mesh(pred_vertices) + global_orient=pred_pose[:, :1]) + pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[ + 'joints'] + gt_beta = data_batch['beta'] gt_pose = data_batch['pose'] gt_vertices = self.smpl( betas=gt_beta, body_pose=gt_pose[:, 3:], - global_orient=gt_pose[:, :3]).vertices + global_orient=gt_pose[:, :3])['vertices'] + pred = dict( pose=pred_pose, beta=pred_beta, @@ -257,10 +241,9 @@ def forward_test(self, pred_out = self.smpl( betas=pred_beta, body_pose=pred_pose[:, 1:], - global_orient=pred_pose[:, :1], - pose2rot=False) - pred_vertices = pred_out.vertices - pred_joints_3d = self.get_3d_joints_from_mesh(pred_vertices) + global_orient=pred_pose[:, :1]) + pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[ + 'joints'] all_preds = {} all_preds['keypoints_3d'] = pred_joints_3d.detach().cpu().numpy() @@ -271,7 +254,7 @@ def forward_test(self, if return_vertices: all_preds['vertices'] = pred_vertices.detach().cpu().numpy() if return_faces: - all_preds['faces'] = self.smpl.faces + all_preds['faces'] = self.smpl.get_faces() all_boxes = [] image_path = [] diff --git a/mmpose/models/heads/hmr_head.py b/mmpose/models/heads/hmr_head.py index fdf3dbf1af..9395234075 100644 --- a/mmpose/models/heads/hmr_head.py +++ b/mmpose/models/heads/hmr_head.py @@ -14,8 +14,7 @@ class HMRMeshHead(nn.Module): Args: in_channels (int): Number of input channels - in_res (int): The resolution of input feature map. - smpl_mean_parameters (str): The file name of the mean SMPL parameters + smpl_mean_params (str): The file name of the mean SMPL parameters n_iter (int): The iterations of estimating delta parameters """ diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py index e69de29bb2..70e174c50c 100644 --- a/mmpose/models/utils/__init__.py +++ b/mmpose/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .smpl import SMPL + +__all__ = ['SMPL'] diff --git a/mmpose/models/utils/smpl.py b/mmpose/models/utils/smpl.py new file mode 100644 index 0000000000..e03ed96c7f --- /dev/null +++ b/mmpose/models/utils/smpl.py @@ -0,0 +1,183 @@ +import numpy as np +import torch +import torch.nn as nn + +from ..builder import MESH_MODELS + +try: + from smplx import SMPL as SMPL_ + has_smpl = True +except (ImportError, ModuleNotFoundError): + has_smpl = False + + +@MESH_MODELS.register_module() +class SMPL(nn.Module): + """SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned + multi-person linear model''. This module is based on the smplx project + (https://github.com/vchoutas/smplx). + + Args: + smpl_path (str): The path to the folder where the model weights are + stored. + joints_regressor (str): The path to the file where the joints + regressor weight are stored. + """ + + def __init__(self, smpl_path, joints_regressor): + super().__init__() + + assert has_smpl, 'Please install smplx to use SMPL.' + + self.smpl_neutral = SMPL_( + model_path=smpl_path, + create_global_orient=False, + create_body_pose=False, + create_transl=False, + gender='neutral') + + self.smpl_male = SMPL_( + model_path=smpl_path, + create_betas=False, + create_global_orient=False, + create_body_pose=False, + create_transl=False, + gender='male') + + self.smpl_female = SMPL_( + model_path=smpl_path, + create_betas=False, + create_global_orient=False, + create_body_pose=False, + create_transl=False, + gender='female') + + joints_regressor = torch.tensor( + np.load(joints_regressor), dtype=torch.float)[None, ...] + self.register_buffer('joints_regressor', joints_regressor) + + self.num_verts = self.smpl_neutral.get_num_verts() + self.num_joints = self.joints_regressor.shape[1] + + def smpl_forward(self, model, **kwargs): + """Apply a specific SMPL model with given model parameters. + + Note: + B: batch size + V: number of vertices + K: number of joints + + Returns: + outputs (dict): Dict with mesh vertices and joints. + - vertices: Tensor([B, V, 3]), mesh vertices + - joints: Tensor([B, K, 3]), 3d joints regressed + from mesh vertices. + """ + + betas = kwargs['betas'] + batch_size = betas.shape[0] + device = betas.device + output = {} + if batch_size == 0: + output['vertices'] = betas.new_zeros([0, self.num_verts, 3]) + output['joints'] = betas.new_zeros([0, self.num_joints, 3]) + else: + smpl_out = model(**kwargs) + output['vertices'] = smpl_out.vertices + output['joints'] = torch.matmul( + self.joints_regressor.to(device), output['vertices']) + return output + + def get_faces(self): + """Return mesh faces. + + Note: + F: number of faces + + Returns: + faces: np.ndarray([F, 3]), mesh faces + """ + return self.smpl_neutral.faces + + def forward(self, + betas, + body_pose, + global_orient, + transl=None, + gender=None): + """Forward function. + + Note: + B: batch size + J: number of controllable joints of model, for smpl model J=23 + K: number of joints + + Args: + betas: Tensor([B, 10]), human body shape parameters of SMPL model. + body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose + parameters of SMPL model. It should be axis-angle vector + ([B, J*3]) or rotation matrix ([B, J, 3, 3)]. + global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation + of human body. It should be axis-angle vector ([B, 3]) or + rotation matrix ([B, 1, 3, 3)]. + transl: Tensor([B, 3]), global translation of human body. + gender: Tensor([B]), gender parameters of human body. -1 for + neutral, 0 for male , 1 for female. + + Returns: + outputs (dict): Dict with mesh vertices and joints. + - vertices: Tensor([B, V, 3]), mesh vertices + - joints: Tensor([B, K, 3]), 3d joints regressed from + mesh vertices. + """ + + batch_size = betas.shape[0] + pose2rot = True if body_pose.dim() == 2 else False + if batch_size > 0 and gender is not None: + output = { + 'vertices': betas.new_zeros([batch_size, self.num_verts, 3]), + 'joints': betas.new_zeros([batch_size, self.num_joints, 3]) + } + + mask = gender < 0 + _out = self.smpl_forward( + self.smpl_neutral, + betas=betas[mask], + body_pose=body_pose[mask], + global_orient=global_orient[mask], + transl=transl[mask] if transl is not None else None, + pose2rot=pose2rot) + output['vertices'][mask] = _out['vertices'] + output['joints'][mask] = _out['joints'] + + mask = gender == 0 + _out = self.smpl_forward( + self.smpl_male, + betas=betas[mask], + body_pose=body_pose[mask], + global_orient=global_orient[mask], + transl=transl[mask] if transl is not None else None, + pose2rot=pose2rot) + output['vertices'][mask] = _out['vertices'] + output['joints'][mask] = _out['joints'] + + mask = gender == 1 + _out = self.smpl_forward( + self.smpl_male, + betas=betas[mask], + body_pose=body_pose[mask], + global_orient=global_orient[mask], + transl=transl[mask] if transl is not None else None, + pose2rot=pose2rot) + output['vertices'][mask] = _out['vertices'] + output['joints'][mask] = _out['joints'] + else: + return self.smpl_forward( + self.smpl_neutral, + betas=betas, + body_pose=body_pose, + global_orient=global_orient, + transl=transl, + pose2rot=pose2rot) + + return output diff --git a/tests/test_external_model/test_smpl.py b/tests/test_external_model/test_smpl.py new file mode 100644 index 0000000000..2c8f385b77 --- /dev/null +++ b/tests/test_external_model/test_smpl.py @@ -0,0 +1,70 @@ +import numpy as np +import torch +from tests.test_model.test_mesh_forward import generate_smpl_weight_file + +from mmpose.models.utils import SMPL + + +def test_smpl(): + """Test smpl model.""" + + # generate weight file for SMPL model. + generate_smpl_weight_file('tests/data/smpl') + + # build smpl model + smpl_cfg = dict( + smpl_path='tests/data/smpl', + joints_regressor='tests/data/smpl/test_joint_regressor.npy') + smpl = SMPL(**smpl_cfg) + + # test get face function + faces = smpl.get_faces() + assert isinstance(faces, np.ndarray) + + betas = torch.zeros(3, 10) + body_pose = torch.zeros(3, 23 * 3) + global_orient = torch.zeros(3, 3) + transl = torch.zeros(3, 3) + gender = torch.LongTensor([-1, 0, 1]) + + # test forward with body_pose and global_orient in axis-angle format + smpl_out = smpl( + betas=betas, body_pose=body_pose, global_orient=global_orient) + assert isinstance(smpl_out, dict) + assert smpl_out['vertices'].shape == torch.Size([3, 6890, 3]) + assert smpl_out['joints'].shape == torch.Size([3, 24, 3]) + + # test forward with body_pose and global_orient in rotation matrix format + body_pose = torch.eye(3).repeat([3, 23, 1, 1]) + global_orient = torch.eye(3).repeat([3, 1, 1, 1]) + _ = smpl(betas=betas, body_pose=body_pose, global_orient=global_orient) + + # test forward with translation + _ = smpl( + betas=betas, + body_pose=body_pose, + global_orient=global_orient, + transl=transl) + + # test forward with gender + _ = smpl( + betas=betas, + body_pose=body_pose, + global_orient=global_orient, + transl=transl, + gender=gender) + + # test forward when all samples in the same gender + gender = torch.LongTensor([0, 0, 0]) + _ = smpl( + betas=betas, + body_pose=body_pose, + global_orient=global_orient, + transl=transl, + gender=gender) + + # test forward when batch size = 0 + _ = smpl( + betas=torch.zeros(0, 10), + body_pose=torch.zeros(0, 23 * 3), + global_orient=torch.zeros(0, 3)) diff --git a/tests/test_model/test_mesh_forward.py b/tests/test_model/test_mesh_forward.py index 00e0ec3f82..4563bad9dd 100644 --- a/tests/test_model/test_mesh_forward.py +++ b/tests/test_model/test_mesh_forward.py @@ -19,7 +19,6 @@ def generate_smpl_weight_file(output_dir): joint_regressor_file = os.path.join(output_dir, 'test_joint_regressor.npy') np.save(joint_regressor_file, np.zeros([24, 6890])) - test_model_file = os.path.join(output_dir, 'SMPL_NEUTRAL.pkl') test_data = {} test_data['f'] = np.zeros([1, 3], dtype=np.int32) test_data['J_regressor'] = csc_matrix(np.zeros([24, 6890])) @@ -29,7 +28,12 @@ def generate_smpl_weight_file(output_dir): test_data['posedirs'] = np.zeros([6890, 3, 207]) test_data['v_template'] = np.zeros([6890, 3]) test_data['shapedirs'] = np.zeros([6890, 3, 10]) - with open(test_model_file, 'wb') as out_file: + + with open(os.path.join(output_dir, 'SMPL_NEUTRAL.pkl'), 'wb') as out_file: + pickle.dump(test_data, out_file) + with open(os.path.join(output_dir, 'SMPL_MALE.pkl'), 'wb') as out_file: + pickle.dump(test_data, out_file) + with open(os.path.join(output_dir, 'SMPL_FEMALE.pkl'), 'wb') as out_file: pickle.dump(test_data, out_file) return @@ -51,6 +55,7 @@ def test_parametric_mesh_forward(): ), disc=None, smpl=dict( + type='SMPL', smpl_path='tests/data/smpl', joints_regressor='tests/data/smpl/test_joint_regressor.npy'), train_cfg=dict(disc_step=1),