Skip to content

Commit

Permalink
paf detector
Browse files Browse the repository at this point in the history
  • Loading branch information
luminxu committed May 12, 2021
1 parent a63b148 commit eddc2dc
Show file tree
Hide file tree
Showing 4 changed files with 660 additions and 14 deletions.
26 changes: 19 additions & 7 deletions mmpose/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .bottom_up_eval import (aggregate_results, get_group_preds,
get_multi_stage_outputs)
from .bottom_up_eval import (aggregate_results, aggregate_results_paf,
get_group_preds, get_multi_stage_outputs,
get_multi_stage_outputs_paf)
from .eval_hooks import DistEvalHook, EvalHook
from .mesh_eval import compute_similarity_transform
from .pose3d_eval import keypoint_mpjpe
Expand All @@ -8,9 +9,20 @@
pose_pck_accuracy, post_dark_udp)

__all__ = [
'EvalHook', 'DistEvalHook', 'pose_pck_accuracy', 'keypoints_from_heatmaps',
'keypoints_from_regression', 'keypoint_pck_accuracy', 'keypoint_auc',
'keypoint_epe', 'get_group_preds', 'get_multi_stage_outputs',
'aggregate_results', 'compute_similarity_transform', 'post_dark_udp',
'keypoint_mpjpe'
'EvalHook',
'DistEvalHook',
'pose_pck_accuracy',
'keypoints_from_heatmaps',
'keypoints_from_regression',
'keypoint_pck_accuracy',
'keypoint_auc',
'keypoint_epe',
'get_group_preds',
'get_multi_stage_outputs',
'get_multi_stage_outputs_paf',
'aggregate_results',
'aggregate_results_paf',
'compute_similarity_transform',
'post_dark_udp',
'keypoint_mpjpe',
]
210 changes: 210 additions & 0 deletions mmpose/core/evaluation/bottom_up_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,156 @@ def get_multi_stage_outputs(outputs,
return outputs, heatmaps, tags


def get_multi_stage_outputs_paf(outputs,
outputs_flip,
with_heatmaps,
with_pafs,
flip_index=None,
flip_index_paf=None,
project2image=True,
size_projected=None,
align_corners=False):
"""Inference the model to get multi-stage outputs (heatmaps & pafs), and
resize them to base sizes.
Args:
outputs (list(torch.Tensor)): Outputs of network
outputs_flip (list(torch.Tensor)): Flip outputs of network
with_heatmaps (list[bool]): Option to output
heatmaps for different stages.
with_pafs (list[bool]): Option to output
pafs for different stages.
flip_index (list[int]): Keypoint flip index.
flip_index_paf (list[int]): PAF flip index.
project2image (bool): Option to resize to base scale.
size_projected ([w, h]): Base size of heatmaps.
align_corners (bool): Align corners when performing interpolation.
Returns:
tuple: A tuple containing multi-stage outputs.
- outputs (list(torch.Tensor)): List of simple outputs and
flip outputs.
- heatmaps (torch.Tensor): Multi-stage heatmaps that are resized to
the base size.
- pafs (torch.Tensor): Multi-stage pafs that are resized to
the base size.
"""

heatmaps = []
pafs = []

flip_test = outputs_flip is not None

# aggregate heatmaps from different stages
heatmaps_avg = 0
num_heatmaps = 0

for i, output in enumerate(outputs['heatmaps']):
if i != len(outputs['heatmaps']) - 1:
output = torch.nn.functional.interpolate(
output,
size=(outputs['heatmaps'][-1].size(2),
outputs['heatmaps'][-1].size(3)),
mode='bilinear',
align_corners=align_corners)

if with_heatmaps[i]:
heatmaps_avg += output
num_heatmaps += 1

if num_heatmaps > 0:
heatmaps.append(heatmaps_avg / num_heatmaps)

# aggregate pafs from different stages
pafs_avg = 0
num_pafs = 0

for i, output in enumerate(outputs['pafs']):
if i != len(outputs['pafs']) - 1:
output = torch.nn.functional.interpolate(
output,
size=(outputs['pafs'][-1].size(2),
outputs['pafs'][-1].size(3)),
mode='bilinear',
align_corners=align_corners)

if with_pafs[i]:
pafs_avg += output
num_pafs += 1

if num_pafs > 0:
pafs.append(pafs_avg / num_pafs)

if flip_test:
if flip_index:
# perform flip testing for heatmaps
heatmaps_avg = 0
num_heatmaps = 0

for i, output in enumerate(outputs_flip['heatmaps']):
if i != len(outputs_flip['heatmaps']) - 1:
output = torch.nn.functional.interpolate(
output,
size=(outputs_flip['heatmaps'][-1].size(2),
outputs_flip['heatmaps'][-1].size(3)),
mode='bilinear',
align_corners=align_corners)
output = torch.flip(output, [3])
outputs['heatmaps'].append(output)

if with_heatmaps[i]:
heatmaps_avg += output[:, flip_index, :, :]
num_heatmaps += 1

heatmaps.append(heatmaps_avg / num_heatmaps)

if flip_index_paf:
# perform flip testing for pafs
pafs_avg = 0
num_pafs = 0

for i, output in enumerate(outputs_flip['pafs']):
if i != len(outputs_flip['pafs']) - 1:
output = torch.nn.functional.interpolate(
output,
size=(outputs_flip['pafs'][-1].size(2),
outputs_flip['pafs'][-1].size(3)),
mode='bilinear',
align_corners=align_corners)
output = torch.flip(output, [3])
outputs['pafs'].append(output)

if with_pafs[i]:
pafs_avg[:, ::2, :, :] -= output[:,
flip_index_paf[::2], :, :]
pafs_avg[:,
1::2, :, :] += output[:,
flip_index_paf[1::2], :, :]
num_pafs += 1

pafs.append(pafs_avg / num_pafs)

if project2image and size_projected:
heatmaps = [
torch.nn.functional.interpolate(
hms,
size=(size_projected[1], size_projected[0]),
mode='bilinear',
align_corners=align_corners) for hms in heatmaps
]

pafs = [
torch.nn.functional.interpolate(
pms,
size=(size_projected[1], size_projected[0]),
mode='bilinear',
align_corners=align_corners) for pms in pafs
]

return outputs, heatmaps, pafs


def aggregate_results(scale,
aggregated_heatmaps,
tags_list,
Expand Down Expand Up @@ -184,6 +334,66 @@ def aggregate_results(scale,
return aggregated_heatmaps, tags_list


def aggregate_results_paf(aggregated_heatmaps,
aggregated_pafs,
heatmaps,
pafs,
project2image,
flip_test,
align_corners=False):
"""Aggregate multi-scale outputs.
Note:
batch size: N
keypoints num : K
paf maps num: P
heatmap width: W
heatmap height: H
Args:
aggregated_heatmaps (torch.Tensor | None): Aggregated heatmaps.
aggregated_pafs (torch.Tensor | None): Aggregated pafs.
heatmaps (List(torch.Tensor[NxKxWxH])): A batch of heatmaps.
pafs (List(torch.Tensor[NxPxWxH])): A batch of paf maps.
project2image (bool): Option to resize to base scale.
flip_test (bool): Option to use flip test.
align_corners (bool): Align corners when performing interpolation.
Return:
tuple: a tuple containing aggregated results.
- aggregated_heatmaps (torch.Tensor): Heatmaps with multi scale.
- aggregated_pafs (torch.Tensor): PAF maps of multi scale.
"""

heatmaps_avg = (heatmaps[0] +
heatmaps[1]) / 2.0 if flip_test else heatmaps[0]
if aggregated_heatmaps is None:
aggregated_heatmaps = heatmaps_avg
elif project2image:
aggregated_heatmaps += heatmaps_avg
else:
aggregated_heatmaps += torch.nn.functional.interpolate(
heatmaps_avg,
size=(aggregated_heatmaps.size(2), aggregated_heatmaps.size(3)),
mode='bilinear',
align_corners=align_corners)

pafs_avg = (pafs[0] + pafs[1]) / 2.0 if flip_test else pafs[0]
if aggregated_pafs is None:
aggregated_pafs = pafs_avg
elif project2image:
aggregated_pafs += pafs_avg
else:
aggregated_pafs += torch.nn.functional.interpolate(
pafs_avg,
size=(aggregated_pafs.size(2), aggregated_pafs.size(3)),
mode='bilinear',
align_corners=align_corners)

return aggregated_heatmaps, aggregated_pafs


def get_group_preds(grouped_joints,
center,
scale,
Expand Down
20 changes: 13 additions & 7 deletions mmpose/datasets/datasets/bottom_up/bottom_up_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ def __init__(self,
0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15
]

# joint index starts from 1
self.ann_info['skeleton'] = [[16, 14], [14, 12], [17, 15], [15, 13],
[12, 13], [6, 12], [7, 13], [6, 7],
[6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
[1, 2], [1, 3], [2, 4], [3, 5], [4, 6],
[5, 7]]

self.ann_info['flip_index_paf'] = [
4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 12, 13, 10, 11, 14, 15, 18, 19, 16,
17, 22, 23, 20, 21, 24, 25, 28, 29, 26, 27, 32, 33, 30, 31, 36, 37,
34, 35
]

self.ann_info['use_different_joint_weights'] = False
self.ann_info['joint_weights'] = np.array(
[
Expand All @@ -68,13 +81,6 @@ def __init__(self,
],
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))

# joint index starts from 1
self.ann_info['skeleton'] = [[16, 14], [14, 12], [17, 15], [15, 13],
[12, 13], [6, 12], [7, 13], [6, 7],
[6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
[1, 2], [1, 3], [2, 4], [3, 5], [4, 6],
[5, 7]]

# 'https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/'
# 'pycocotools/cocoeval.py#L523'
self.sigmas = np.array([
Expand Down
Loading

0 comments on commit eddc2dc

Please sign in to comment.