Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix SMCReader #247

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions configs/hmr/resnet50_hmr_pw3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,51 @@
ann_file='cmu_mosh.npz')),
test=dict(
type=dataset_type,

body_model=dict(
type='GenderedSMPL',
type='SMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
dataset_name='humman',
convention='coco_wholebody',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
ann_file='humman_test_kinect_ds10_smpl.npz'),

# body_model=dict(
# type='SMPL',
# keypoint_src='h36m',
# keypoint_dst='h36m',
# model_path='data/body_models/smpl',
# joints_regressor='data/body_models/J_regressor_h36m.npy'),
# dataset_name='humman',
# convention='coco_wholebody',
# data_prefix='data',
# pipeline=test_pipeline,
# ann_file='humman_test_iphone_ds10_smpl.npz'),

# body_model=dict(
# type='GenderedSMPL',
# keypoint_src='h36m',
# keypoint_dst='h36m',
# model_path='data/body_models/smpl',
# joints_regressor='data/body_models/J_regressor_h36m.npy'),
# dataset_name='pw3d',
# data_prefix='data',
# pipeline=test_pipeline,
# ann_file='pw3d_test.npz'),

# body_model=dict(
# type='SMPL',
# keypoint_src='h36m',
# keypoint_dst='h36m',
# model_path='data/body_models/smpl',
# joints_regressor='data/body_models/J_regressor_h36m.npy'),
# dataset_name='h36m',
# convention='h36m', # convert keypoints to h36m
# data_prefix='data',
# pipeline=test_pipeline,
# ann_file='h36m_valid_protocol2.npz'),
)
49 changes: 37 additions & 12 deletions mmhuman3d/data/data_converters/humman.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')

# Body model used for keypoint computation
self.smpl = build_body_model(
dict(
type='SMPL',
Expand All @@ -61,6 +62,18 @@ def __init__(self, *args, **kwargs):
extra_joints_regressor='data/body_models/J_regressor_extra.npy'
)).to(self.device)

# Body model used for pelvis computation in SMCReader
self.smpl_smc = build_body_model(
dict(
type='SMPL',
gender='neutral',
num_betas=10,
keypoint_src='smpl_45',
keypoint_dst='smpl_45',
model_path='data/body_models/smpl',
batch_size=1,
)).to(self.device)

def _derive_keypoints(self, global_orient, body_pose, betas, transl,
focal_length, image_size, camera_center):
"""Get SMPL-derived keypoints."""
Expand Down Expand Up @@ -241,6 +254,12 @@ def convert_by_mode(self, dataset_path: str, out_path: str,

ann_paths = sorted(glob.glob(os.path.join(dataset_path, '*.smc')))

# temp action
if mode != 'test': return
view = 10
# body_part = 'lower_limb'
# with open(os.path.join(dataset_path, f'{body_part}.txt'), 'r') as f:
# split = set(f.read().splitlines())
with open(os.path.join(dataset_path, f'{mode}.txt'), 'r') as f:
split = set(f.read().splitlines())

Expand All @@ -250,7 +269,7 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
continue

try:
smc_reader = SMCReader(ann_path)
smc_reader = SMCReader(ann_path, body_model=self.smpl_smc)
except OSError:
print(f'Unable to load {ann_path}.')
continue
Expand All @@ -268,6 +287,9 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
[('iPhone', i) for i in range(num_iphone)]
assert len(device_list) == num_kinect + num_iphone

# temp
device_list = [('Kinect', view)]

for device, device_id in device_list:
assert device in {
'Kinect', 'iPhone'
Expand Down Expand Up @@ -399,17 +421,20 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
kinect_keypoints3d_smpl_, kinect_keypoints2d_humman_,
kinect_keypoints3d_humman_)

file_name = f'humman_{mode}_kinect_ds{self.downsample_ratio}_smpl.npz'
# temp
# file_name = f'humman_{mode}_kinect_ds{self.downsample_ratio}_smpl.npz'
file_name = f'humman_{mode}_kinect_ds10_view{view}.npz'
out_file = os.path.join(out_path, file_name)
kinect_human_data.dump(out_file)

# make iphone human data
iphone_human_data = self._make_human_data(
iphone_smpl, iphone_image_path_, iphone_image_id_,
iphone_bbox_xywh_, iphone_keypoints2d_smpl_,
iphone_keypoints3d_smpl_, iphone_keypoints2d_humman_,
iphone_keypoints3d_humman_)

file_name = f'humman_{mode}_iphone_ds{self.downsample_ratio}_smpl.npz'
out_file = os.path.join(out_path, file_name)
iphone_human_data.dump(out_file)
# temp
# # make iphone human data
# iphone_human_data = self._make_human_data(
# iphone_smpl, iphone_image_path_, iphone_image_id_,
# iphone_bbox_xywh_, iphone_keypoints2d_smpl_,
# iphone_keypoints3d_smpl_, iphone_keypoints2d_humman_,
# iphone_keypoints3d_humman_)
#
# file_name = f'humman_{mode}_iphone_ds{self.downsample_ratio}_smpl.npz'
# out_file = os.path.join(out_path, file_name)
# iphone_human_data.dump(out_file)
110 changes: 76 additions & 34 deletions mmhuman3d/data/data_structures/smc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
import cv2
import h5py
import numpy as np
import torch
import tqdm

from mmhuman3d.utils.transforms import aa_to_rotmat, rotmat_to_aa
from mmhuman3d.models.body_models.utils import batch_transform_to_camera_frame
from mmhuman3d.models.builder import build_body_model


class SMCReader:

def __init__(self, file_path):
def __init__(self, file_path, body_model=None):
"""Read SenseMocapFile endswith ".smc".

Args:
file_path (str):
Path to an SMC file.
body_model (nn.Module or dict):
Only needed for SMPL transformation to device frame
if nn.Module: a body_model instance
if dict: a body_model config
"""
self.smc = h5py.File(file_path, 'r')
self.__calibration_dict__ = None
Expand Down Expand Up @@ -47,6 +53,26 @@ def __init__(self, file_path):
self.smpl_num_frames = self.smc['SMPL'].attrs['num_frame']
self.smpl_created_time = self.smc['SMPL'].attrs['created_time']

# initialize body model
if isinstance(body_model, torch.nn.Module):
self.body_model = body_model
elif isinstance(body_model, dict):
self.body_model = build_body_model(body_model)
else:
# in most cases, SMCReader is instantiated for image reading
# only. Hence, it is wasteful to initialize a body model until
# really needed in get_smpl()
self.body_model = None
self.default_body_model_config = dict(
type='SMPL',
gender='neutral',
num_betas=10,
keypoint_src='smpl_45',
keypoint_dst='smpl_45',
model_path='data/body_models/smpl',
batch_size=1,
)

def get_kinect_color_extrinsics(self, kinect_id, homogeneous=True):
"""Get extrinsics(cam2world) of a kinect RGB camera by kinect id.

Expand Down Expand Up @@ -837,12 +863,6 @@ def get_keypoints3d(self,
if device_id is not None:
assert device_id >= 0

kps3d_dict = self.smc['Keypoints3D']

# keypoints3d are in world coordinate system
keypoints3d_world = kps3d_dict['keypoints3d'][...]
keypoints3d_mask = kps3d_dict['keypoints3d_mask'][...]

if frame_id is None:
frame_list = range(self.get_keypoints_num_frames())
elif isinstance(frame_id, list):
Expand All @@ -854,7 +874,12 @@ def get_keypoints3d(self,
else:
raise TypeError('frame_id should be int, list or None.')

kps3d_dict = self.smc['Keypoints3D']

# keypoints3d are in world coordinate system
keypoints3d_world = kps3d_dict['keypoints3d'][...]
keypoints3d_world = keypoints3d_world[frame_list, ...]
keypoints3d_mask = kps3d_dict['keypoints3d_mask'][...]

# return keypoints3d in world coordinate system
if device is None:
Expand Down Expand Up @@ -923,12 +948,21 @@ def get_smpl(self,
body_pose = smpl_dict['body_pose'][...]
transl = smpl_dict['transl'][...]
betas = smpl_dict['betas'][...]
if frame_id is not None:
if isinstance(frame_id, int):
frame_id = [frame_id]
body_pose = body_pose[frame_id, ...]
global_orient = global_orient[frame_id, ...]
transl = transl[frame_id, ...]

if frame_id is None:
frame_list = range(self.get_smpl_num_frames())
elif isinstance(frame_id, list):
frame_list = frame_id
elif isinstance(frame_id, int):
assert frame_id < self.get_keypoints_num_frames(),\
'Index out of range...'
frame_list = [frame_id]
else:
raise TypeError('frame_id should be int, list or None.')

body_pose = body_pose[frame_list, ...]
global_orient = global_orient[frame_list, ...]
transl = transl[frame_list, ...]

# return SMPL parameters in world coordinate system
if device is None:
Expand All @@ -942,36 +976,44 @@ def get_smpl(self,

# return SMPL parameters in device coordinate system
else:

if self.body_model is None:
self.body_model = \
build_body_model(self.default_body_model_config)
torch_device = self.body_model.global_orient.device

assert device in {
'Kinect', 'iPhone'
}, f'Undefined device: {device}, should be "Kinect" or "iPhone"'
assert device_id >= 0

if device == 'Kinect':
cam2world = self.get_kinect_color_extrinsics(
T_cam2world = self.get_kinect_color_extrinsics(
kinect_id=device_id, homogeneous=True)
else:
cam2world = self.get_iphone_extrinsics(
T_cam2world = self.get_iphone_extrinsics(
iphone_id=device_id, vertical=vertical)

num_frames = global_orient.shape[0]

T_smpl2world = np.repeat(
np.eye(4).reshape(1, 4, 4), num_frames, axis=0)
assert T_smpl2world.shape == (num_frames, 4, 4)

T_smpl2world[:, :3, :3] = aa_to_rotmat(global_orient)
T_smpl2world[:, :3, 3] = transl
T_world2cam = np.linalg.inv(T_cam2world)

T_world2cam = np.linalg.inv(cam2world)
T_world2cam = np.repeat(
T_world2cam.reshape(1, 4, 4), num_frames, axis=0)
assert T_world2cam.shape == (num_frames, 4, 4)
output = self.body_model(
global_orient=torch.tensor(global_orient, device=torch_device),
body_pose=torch.tensor(body_pose, device=torch_device),
transl=torch.tensor(transl, device=torch_device),
betas=torch.tensor(betas, device=torch_device))
joints = output['joints'].detach().cpu().numpy()
pelvis = joints[:, 0, :]

T_smpl2cam = T_world2cam @ T_smpl2world

global_orient = rotmat_to_aa(T_smpl2cam[:, :3, :3])
transl = T_smpl2world[:, :3, 3]
new_global_orient, new_transl = batch_transform_to_camera_frame(
global_orient=global_orient,
transl=transl,
pelvis=pelvis,
extrinsic=T_world2cam)

smpl_dict = dict(
global_orient=global_orient,
global_orient=new_global_orient,
body_pose=body_pose,
transl=transl,
transl=new_transl,
betas=betas)

return smpl_dict
4 changes: 2 additions & 2 deletions mmhuman3d/data/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mmcv
import numpy as np

from mmhuman3d.data.data_structures import SMCReader
import mmhuman3d.data.data_structures as data_structures
from ..builder import PIPELINES


Expand Down Expand Up @@ -50,7 +50,7 @@ def __call__(self, results):
assert 'image_id' in results, 'Load image from .smc, ' \
'but image_id is not provided.'
device, device_id, frame_id = results['image_id']
smc_reader = SMCReader(filename)
smc_reader = data_structures.SMCReader(filename)
img = smc_reader.get_color(
device, device_id, frame_id, disable_tqdm=True)
img = img.squeeze() # (1, H, W, 3) -> (H, W, 3)
Expand Down
6 changes: 5 additions & 1 deletion mmhuman3d/models/body_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@

from .smpl import SMPL, GenderedSMPL, HybrIKSMPL
from .smplx import SMPLX
from .utils import batch_transform_to_camera_frame, transform_to_camera_frame

__all__ = ['SMPL', 'GenderedSMPL', 'HybrIKSMPL', 'SMPLX']
__all__ = [
'SMPL', 'GenderedSMPL', 'HybrIKSMPL', 'SMPLX', 'transform_to_camera_frame',
'batch_transform_to_camera_frame'
]
Loading