Skip to content

Commit

Permalink
(Speed up TopDown Inference) modified inference_top_down_model, make …
Browse files Browse the repository at this point in the history
…model able to run on batches of bounding box (#560)

* modified inference_top_down_model to make model-batch runnable

* formattig code by pre-commit

* Fix bug when bbox_thr make empty bbox

* resolve comments

* resolve comments

Co-authored-by: jinsheng <[email protected]>
  • Loading branch information
namirinz and jin-s13 authored Apr 12, 2021
1 parent e5975bf commit cd74bf1
Showing 1 changed file with 81 additions and 65 deletions.
146 changes: 81 additions & 65 deletions mmpose/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _xyxy2xywh(bbox_xyxy):
bbox_xywh = bbox_xyxy.copy()
bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0] + 1
bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1] + 1

return bbox_xywh


Expand All @@ -73,6 +74,7 @@ def _xywh2xyxy(bbox_xywh):
bbox_xyxy = bbox_xywh.copy()
bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0] - 1
bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1] - 1

return bbox_xyxy


Expand Down Expand Up @@ -141,7 +143,7 @@ def __call__(self, results):

def _inference_single_pose_model(model,
img_or_path,
bbox,
bboxes,
dataset,
return_heatmap=False):
"""Inference a single bbox.
Expand All @@ -151,8 +153,9 @@ def _inference_single_pose_model(model,
Args:
model (nn.Module): The loaded pose model.
img_or_path (str | np.ndarray): Image filename or loaded image.
bbox (list | np.ndarray): Bounding boxes (with scores),
shaped (4, ) or (5, ). (left, top, width, height, [score])
bboxes (list | np.ndarray): All bounding boxes (with scores),
shaped (N, 4) or (N, 5). (left, top, width, height, [score])
where N is number of bounding boxes.
dataset (str): Dataset name.
outputs (list[str] | tuple[str]): Names of layers whose output is
to be returned, default: None
Expand All @@ -171,8 +174,7 @@ def _inference_single_pose_model(model,
] + cfg.test_pipeline[1:]
test_pipeline = Compose(test_pipeline)

assert len(bbox) in [4, 5]
center, scale = _box2cs(cfg, bbox)
assert len(bboxes[0]) in [4, 5]

flip_pairs = None
if dataset in ('TopDownCocoDataset', 'TopDownOCHumanDataset',
Expand Down Expand Up @@ -258,48 +260,58 @@ def _inference_single_pose_model(model,
else:
raise NotImplementedError()

# prepare data
data = {
'img_or_path':
img_or_path,
'center':
center,
'scale':
scale,
'bbox_score':
bbox[4] if len(bbox) == 5 else 1,
'dataset':
dataset,
'joints_3d':
np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
'joints_3d_visible':
np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
'rotation':
0,
'ann_info': {
'image_size': cfg.data_cfg['image_size'],
'num_joints': cfg.data_cfg['num_joints'],
'flip_pairs': flip_pairs
batch_data = []
for bbox in bboxes:
center, scale = _box2cs(cfg, bbox)

# prepare data
data = {
'img_or_path':
img_or_path,
'center':
center,
'scale':
scale,
'bbox_score':
bbox[4] if len(bbox) == 5 else 1,
'bbox_id':
0, # need to be assigned if batch_size > 1
'dataset':
dataset,
'joints_3d':
np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
'joints_3d_visible':
np.zeros((cfg.data_cfg.num_joints, 3), dtype=np.float32),
'rotation':
0,
'ann_info': {
'image_size': cfg.data_cfg['image_size'],
'num_joints': cfg.data_cfg['num_joints'],
'flip_pairs': flip_pairs
}
}
}
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
data = test_pipeline(data)
batch_data.append(data)

batch_data = collate(batch_data, samples_per_gpu=1)

if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'].data[0]
# scatter not work so just move image to cuda device
batch_data['img'] = batch_data['img'].to(device)
# get all img_metas of each bounding box
batch_data['img_metas'] = [
img_metas[0] for img_metas in batch_data['img_metas'].data
]

# forward the model
with torch.no_grad():
result = model(
img=data['img'],
img_metas=data['img_metas'],
img=batch_data['img'],
img_metas=batch_data['img_metas'],
return_loss=False,
return_heatmap=return_heatmap)

return result['preds'][0], result['output_heatmap']
return result['preds'], result['output_heatmap']


def inference_top_down_pose_model(model,
Expand Down Expand Up @@ -350,38 +362,42 @@ def inference_top_down_pose_model(model,
pose_results = []
returned_outputs = []

with OutputHook(model, outputs=outputs, as_tensor=False) as h:
for person_result in person_results:
if format == 'xyxy':
bbox_xyxy = np.expand_dims(np.array(person_result['bbox']), 0)
bbox_xywh = _xyxy2xywh(bbox_xyxy)
else:
bbox_xywh = np.expand_dims(np.array(person_result['bbox']), 0)
bbox_xyxy = _xywh2xyxy(bbox_xywh)

if bbox_thr is not None:
assert bbox_xywh.shape[1] == 5
if bbox_xywh[0, 4] < bbox_thr:
continue

pose, heatmap = _inference_single_pose_model(
model,
img_or_path,
bbox_xywh[0],
dataset,
return_heatmap=return_heatmap)
# Change for-loop preprocess each bbox to preprocess all bboxes at once.
bboxes = np.array([box['bbox'] for box in person_results])

# Select bboxes by score threshold
if bbox_thr is not None:
assert bboxes.shape[1] == 5
bboxes = bboxes[bboxes[:, 4] > bbox_thr]

if return_heatmap:
h.layer_outputs['heatmap'] = heatmap
if format == 'xyxy':
bboxes_xyxy = bboxes
bboxes_xywh = _xyxy2xywh(bboxes)
else:
# format is already 'xywh'
bboxes_xywh = bboxes
bboxes_xyxy = _xywh2xyxy(bboxes)

# if bbox_thr remove all bounding box
if len(bboxes_xywh) == 0:
return [], []

returned_outputs.append(h.layer_outputs)
with OutputHook(model, outputs=outputs, as_tensor=False) as h:
# pose is results['pred'] # N x 17x 3
pose, heatmap = _inference_single_pose_model(
model,
img_or_path,
bboxes_xywh,
dataset,
return_heatmap=return_heatmap)

person_result['keypoints'] = pose
if return_heatmap:
h.layer_outputs['heatmap'] = heatmap

if format == 'xywh':
person_result['bbox'] = bbox_xyxy[0]
returned_outputs.append(h.layer_outputs)

pose_results.append(person_result)
for i in range(len(pose)):
pose_results.append({'keypoints': pose[i], 'bbox': bboxes_xyxy[i]})

return pose_results, returned_outputs

Expand Down

0 comments on commit cd74bf1

Please sign in to comment.