Skip to content

Commit

Permalink
Unify SMPL-like models to mesh models (#830)
Browse files Browse the repository at this point in the history
* Add unified SMPL-like model interface and builder
  • Loading branch information
zengwang430521 authored Aug 19, 2021
1 parent ce68b4c commit 0ecf06e
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
),
disc=dict(),
smpl=dict(
type='SMPL',
smpl_path='models/smpl',
joints_regressor='models/smpl/joints_regressor_cmr.npy'),
train_cfg=dict(disc_step=1),
Expand Down
12 changes: 7 additions & 5 deletions mmpose/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .backbones import * # noqa
from .builder import (BACKBONES, HEADS, LOSSES, NECKS, POSENETS,
build_backbone, build_head, build_loss, build_neck,
build_posenet)
from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS,
build_backbone, build_head, build_loss, build_mesh_model,
build_neck, build_posenet)
from .detectors import * # noqa
from .heads import * # noqa
from .losses import * # noqa
from .necks import * # noqa
from .utils import * # noqa

__all__ = [
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'build_backbone',
'build_head', 'build_loss', 'build_posenet', 'build_neck'
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'MESH_MODELS',
'build_backbone', 'build_head', 'build_loss', 'build_posenet',
'build_neck', 'build_mesh_model'
]
6 changes: 6 additions & 0 deletions mmpose/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HEADS = MODELS
LOSSES = MODELS
POSENETS = MODELS
MESH_MODELS = MODELS


def build_backbone(cfg):
Expand All @@ -33,3 +34,8 @@ def build_loss(cfg):
def build_posenet(cfg):
"""Build posenet."""
return POSENETS.build(cfg)


def build_mesh_model(cfg):
"""Build mesh model."""
return MESH_MODELS.build(cfg)
39 changes: 11 additions & 28 deletions mmpose/models/detectors/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
from ..builder import POSENETS
from .base import BasePose

try:
from smplx import SMPL
has_smpl = True
except (ImportError, ModuleNotFoundError):
has_smpl = False


def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad for all the networks.
Expand Down Expand Up @@ -61,22 +55,11 @@ def __init__(self,
pretrained=None):
super().__init__()

assert has_smpl, 'Please install smplx to use SMPL.'

self.backbone = builder.build_backbone(backbone)
self.mesh_head = builder.build_head(mesh_head)
self.generator = torch.nn.Sequential(self.backbone, self.mesh_head)

self.smpl = SMPL(
model_path=smpl['smpl_path'],
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False)

joints_regressor = torch.tensor(
np.load(smpl['joints_regressor']), dtype=torch.float).unsqueeze(0)
self.register_buffer('joints_regressor', joints_regressor)
self.smpl = builder.build_mesh_model(smpl)

self.with_gan = disc is not None and loss_gan is not None
if self.with_gan:
Expand Down Expand Up @@ -161,16 +144,17 @@ def train_step(self, data_batch, optimizer, **kwargs):
pred_out = self.smpl(
betas=pred_beta,
body_pose=pred_pose[:, 1:],
global_orient=pred_pose[:, :1],
pose2rot=False)
pred_vertices = pred_out.vertices
pred_joints_3d = self.get_3d_joints_from_mesh(pred_vertices)
global_orient=pred_pose[:, :1])
pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[
'joints']

gt_beta = data_batch['beta']
gt_pose = data_batch['pose']
gt_vertices = self.smpl(
betas=gt_beta,
body_pose=gt_pose[:, 3:],
global_orient=gt_pose[:, :3]).vertices
global_orient=gt_pose[:, :3])['vertices']

pred = dict(
pose=pred_pose,
beta=pred_beta,
Expand Down Expand Up @@ -257,10 +241,9 @@ def forward_test(self,
pred_out = self.smpl(
betas=pred_beta,
body_pose=pred_pose[:, 1:],
global_orient=pred_pose[:, :1],
pose2rot=False)
pred_vertices = pred_out.vertices
pred_joints_3d = self.get_3d_joints_from_mesh(pred_vertices)
global_orient=pred_pose[:, :1])
pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[
'joints']

all_preds = {}
all_preds['keypoints_3d'] = pred_joints_3d.detach().cpu().numpy()
Expand All @@ -271,7 +254,7 @@ def forward_test(self,
if return_vertices:
all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
if return_faces:
all_preds['faces'] = self.smpl.faces
all_preds['faces'] = self.smpl.get_faces()

all_boxes = []
image_path = []
Expand Down
3 changes: 1 addition & 2 deletions mmpose/models/heads/hmr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class HMRMeshHead(nn.Module):
Args:
in_channels (int): Number of input channels
in_res (int): The resolution of input feature map.
smpl_mean_parameters (str): The file name of the mean SMPL parameters
smpl_mean_params (str): The file name of the mean SMPL parameters
n_iter (int): The iterations of estimating delta parameters
"""

Expand Down
3 changes: 3 additions & 0 deletions mmpose/models/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .smpl import SMPL

__all__ = ['SMPL']
183 changes: 183 additions & 0 deletions mmpose/models/utils/smpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import numpy as np
import torch
import torch.nn as nn

from ..builder import MESH_MODELS

try:
from smplx import SMPL as SMPL_
has_smpl = True
except (ImportError, ModuleNotFoundError):
has_smpl = False


@MESH_MODELS.register_module()
class SMPL(nn.Module):
"""SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned
multi-person linear model''. This module is based on the smplx project
(https://github.com/vchoutas/smplx).
Args:
smpl_path (str): The path to the folder where the model weights are
stored.
joints_regressor (str): The path to the file where the joints
regressor weight are stored.
"""

def __init__(self, smpl_path, joints_regressor):
super().__init__()

assert has_smpl, 'Please install smplx to use SMPL.'

self.smpl_neutral = SMPL_(
model_path=smpl_path,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='neutral')

self.smpl_male = SMPL_(
model_path=smpl_path,
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='male')

self.smpl_female = SMPL_(
model_path=smpl_path,
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='female')

joints_regressor = torch.tensor(
np.load(joints_regressor), dtype=torch.float)[None, ...]
self.register_buffer('joints_regressor', joints_regressor)

self.num_verts = self.smpl_neutral.get_num_verts()
self.num_joints = self.joints_regressor.shape[1]

def smpl_forward(self, model, **kwargs):
"""Apply a specific SMPL model with given model parameters.
Note:
B: batch size
V: number of vertices
K: number of joints
Returns:
outputs (dict): Dict with mesh vertices and joints.
- vertices: Tensor([B, V, 3]), mesh vertices
- joints: Tensor([B, K, 3]), 3d joints regressed
from mesh vertices.
"""

betas = kwargs['betas']
batch_size = betas.shape[0]
device = betas.device
output = {}
if batch_size == 0:
output['vertices'] = betas.new_zeros([0, self.num_verts, 3])
output['joints'] = betas.new_zeros([0, self.num_joints, 3])
else:
smpl_out = model(**kwargs)
output['vertices'] = smpl_out.vertices
output['joints'] = torch.matmul(
self.joints_regressor.to(device), output['vertices'])
return output

def get_faces(self):
"""Return mesh faces.
Note:
F: number of faces
Returns:
faces: np.ndarray([F, 3]), mesh faces
"""
return self.smpl_neutral.faces

def forward(self,
betas,
body_pose,
global_orient,
transl=None,
gender=None):
"""Forward function.
Note:
B: batch size
J: number of controllable joints of model, for smpl model J=23
K: number of joints
Args:
betas: Tensor([B, 10]), human body shape parameters of SMPL model.
body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
parameters of SMPL model. It should be axis-angle vector
([B, J*3]) or rotation matrix ([B, J, 3, 3)].
global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
of human body. It should be axis-angle vector ([B, 3]) or
rotation matrix ([B, 1, 3, 3)].
transl: Tensor([B, 3]), global translation of human body.
gender: Tensor([B]), gender parameters of human body. -1 for
neutral, 0 for male , 1 for female.
Returns:
outputs (dict): Dict with mesh vertices and joints.
- vertices: Tensor([B, V, 3]), mesh vertices
- joints: Tensor([B, K, 3]), 3d joints regressed from
mesh vertices.
"""

batch_size = betas.shape[0]
pose2rot = True if body_pose.dim() == 2 else False
if batch_size > 0 and gender is not None:
output = {
'vertices': betas.new_zeros([batch_size, self.num_verts, 3]),
'joints': betas.new_zeros([batch_size, self.num_joints, 3])
}

mask = gender < 0
_out = self.smpl_forward(
self.smpl_neutral,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']

mask = gender == 0
_out = self.smpl_forward(
self.smpl_male,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']

mask = gender == 1
_out = self.smpl_forward(
self.smpl_male,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']
else:
return self.smpl_forward(
self.smpl_neutral,
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl,
pose2rot=pose2rot)

return output
70 changes: 70 additions & 0 deletions tests/test_external_model/test_smpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
import torch
from tests.test_model.test_mesh_forward import generate_smpl_weight_file

from mmpose.models.utils import SMPL


def test_smpl():
"""Test smpl model."""

# generate weight file for SMPL model.
generate_smpl_weight_file('tests/data/smpl')

# build smpl model
smpl_cfg = dict(
smpl_path='tests/data/smpl',
joints_regressor='tests/data/smpl/test_joint_regressor.npy')
smpl = SMPL(**smpl_cfg)

# test get face function
faces = smpl.get_faces()
assert isinstance(faces, np.ndarray)

betas = torch.zeros(3, 10)
body_pose = torch.zeros(3, 23 * 3)
global_orient = torch.zeros(3, 3)
transl = torch.zeros(3, 3)
gender = torch.LongTensor([-1, 0, 1])

# test forward with body_pose and global_orient in axis-angle format
smpl_out = smpl(
betas=betas, body_pose=body_pose, global_orient=global_orient)
assert isinstance(smpl_out, dict)
assert smpl_out['vertices'].shape == torch.Size([3, 6890, 3])
assert smpl_out['joints'].shape == torch.Size([3, 24, 3])

# test forward with body_pose and global_orient in rotation matrix format
body_pose = torch.eye(3).repeat([3, 23, 1, 1])
global_orient = torch.eye(3).repeat([3, 1, 1, 1])
_ = smpl(betas=betas, body_pose=body_pose, global_orient=global_orient)

# test forward with translation
_ = smpl(
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl)

# test forward with gender
_ = smpl(
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl,
gender=gender)

# test forward when all samples in the same gender
gender = torch.LongTensor([0, 0, 0])
_ = smpl(
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl,
gender=gender)

# test forward when batch size = 0
_ = smpl(
betas=torch.zeros(0, 10),
body_pose=torch.zeros(0, 23 * 3),
global_orient=torch.zeros(0, 3))
Loading

0 comments on commit 0ecf06e

Please sign in to comment.