Skip to content

Commit

Permalink
add paf detector & head & loss
Browse files Browse the repository at this point in the history
  • Loading branch information
luminxu committed May 18, 2021
1 parent eddc2dc commit c1b888c
Show file tree
Hide file tree
Showing 9 changed files with 612 additions and 36 deletions.
5 changes: 3 additions & 2 deletions mmpose/core/evaluation/bottom_up_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ def get_multi_stage_outputs_paf(outputs,
resize them to base sizes.
Args:
outputs (list(torch.Tensor)): Outputs of network
outputs_flip (list(torch.Tensor)): Flip outputs of network
outputs (dict): Outputs of network, including heatmaps and pafs.
outputs_flip (dict): Flip outputs of network, including
heatmaps and pafs.
with_heatmaps (list[bool]): Option to output
heatmaps for different stages.
with_pafs (list[bool]): Option to output
Expand Down
269 changes: 269 additions & 0 deletions mmpose/core/post_processing/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,272 @@ def parse(self, heatmaps, tags, adjust=True, refine=True):
ans = [ans]

return ans, scores


class PAFParser:
"""The paf parser for post processing."""

def __init__(self, cfg):
self.params = _Params(cfg)
self.tag_per_joint = cfg['tag_per_joint']
self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1,
cfg['nms_padding'])
self.use_udp = cfg.get('use_udp', False)

# def nms(self, heatmaps):
# """Non-Maximum Suppression for heatmaps.
#
# Args:
# heatmap(torch.Tensor): Heatmaps before nms.
#
# Returns:
# torch.Tensor: Heatmaps after nms.
# """
#
# maxm = self.pool(heatmaps)
# maxm = torch.eq(maxm, heatmaps).float()
# heatmaps = heatmaps * maxm
#
# return heatmaps
#
# def match(self, tag_k, loc_k, val_k):
# """Group keypoints to human poses in a batch.
#
# Args:
# tag_k (np.ndarray[NxKxMxL]): tag corresponding to the
# top k values of feature map per keypoint.
# loc_k (np.ndarray[NxKxMx2]): top k locations of the
# feature maps for keypoint.
# val_k (np.ndarray[NxKxM]): top k value of the
# feature maps per keypoint.
#
# Returns:
# list
# """
#
# def _match(x):
# return _match_by_tag(x, self.params)
#
# return list(map(_match, zip(tag_k, loc_k, val_k)))
#
# def top_k(self, heatmaps, tags):
# """Find top_k values in an image.
#
# Note:
# batch size: N
# number of keypoints: K
# heatmap height: H
# heatmap width: W
# max number of people: M
# dim of tags: L
# If use flip testing, L=2; else L=1.
#
# Args:
# heatmaps (torch.Tensor[NxKxHxW])
# tags (torch.Tensor[NxKxHxWxL])
#
# Return:
# dict: A dict containing top_k values.
#
# - tag_k (np.ndarray[NxKxMxL]):
# tag corresponding to the top k values of
# feature map per keypoint.
# - loc_k (np.ndarray[NxKxMx2]):
# top k location of feature map per keypoint.
# - val_k (np.ndarray[NxKxM]):
# top k value of feature map per keypoint.
# """
# heatmaps = self.nms(heatmaps)
# N, K, H, W = heatmaps.size()
# heatmaps = heatmaps.view(N, K, -1)
# val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2)
#
# tags = tags.view(tags.size(0), tags.size(1), W * H, -1)
# if not self.tag_per_joint:
# tags = tags.expand(-1, self.params.num_joints, -1, -1)
#
# tag_k = torch.stack(
# [torch.gather(tags[..., i], 2, ind) for i in
# range(tags.size(3))],
# dim=3)
#
# x = ind % W
# y = ind // W
#
# ind_k = torch.stack((x, y), dim=3)
#
# ans = {
# 'tag_k': tag_k.cpu().numpy(),
# 'loc_k': ind_k.cpu().numpy(),
# 'val_k': val_k.cpu().numpy()
# }
#
# return ans
#
# @staticmethod
# def adjust(ans, heatmaps):
# """Adjust the coordinates for better accuracy.
#
# Note:
# batch size: N
# number of keypoints: K
# heatmap height: H
# heatmap width: W
#
# Args:
# ans (list(np.ndarray)): Keypoint predictions.
# heatmaps (torch.Tensor[NxKxHxW]): Heatmaps.
# """
# _, _, H, W = heatmaps.shape
# for batch_id, people in enumerate(ans):
# for people_id, people_i in enumerate(people):
# for joint_id, joint in enumerate(people_i):
# if joint[2] > 0:
# x, y = joint[0:2]
# xx, yy = int(x), int(y)
# tmp = heatmaps[batch_id][joint_id]
# if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1),
# xx]:
# y += 0.25
# else:
# y -= 0.25
#
# if tmp[yy, min(W - 1, xx + 1)] > tmp[yy,
# max(0, xx - 1)]:
# x += 0.25
# else:
# x -= 0.25
# ans[batch_id][people_id, joint_id,
# 0:2] = (x + 0.5, y + 0.5)
# return ans
#
# @staticmethod
# def refine(heatmap, tag, keypoints, use_udp=False):
# """Given initial keypoint predictions, we identify missing joints.
#
# Note:
# number of keypoints: K
# heatmap height: H
# heatmap width: W
# dim of tags: L
# If use flip testing, L=2; else L=1.
#
# Args:
# heatmap: np.ndarray(K, H, W).
# tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L)
# keypoints: np.ndarray of size (K, 3 + L)
# last dim is (x, y, score, tag).
# use_udp: bool-unbiased data processing
#
# Returns:
# np.ndarray: The refined keypoints.
# """
#
# K, H, W = heatmap.shape
# if len(tag.shape) == 3:
# tag = tag[..., None]
#
# tags = []
# for i in range(K):
# if keypoints[i, 2] > 0:
# # save tag value of detected keypoint
# x, y = keypoints[i][:2].astype(int)
# x = np.clip(x, 0, W - 1)
# y = np.clip(y, 0, H - 1)
# tags.append(tag[i, y, x])
#
# # mean tag of current detected people
# prev_tag = np.mean(tags, axis=0)
# ans = []
#
# for _heatmap, _tag in zip(heatmap, tag):
# # distance of all tag values with mean tag of
# # current detected people
# distance_tag = (((_tag -
# prev_tag[None, None, :])**2).sum(axis=2)**0.5)
# norm_heatmap = _heatmap - np.round(distance_tag)
#
# # find maximum position
# y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape)
# xx = x.copy()
# yy = y.copy()
# # detection score at maximum position
# val = _heatmap[y, x]
# if not use_udp:
# # offset by 0.5
# x += 0.5
# y += 0.5
#
# # add a quarter offset
# if _heatmap[yy, min(W - 1, xx + 1)] > _
# heatmap[yy, max(0, xx - 1)]:
# x += 0.25
# else:
# x -= 0.25
#
# if _heatmap[min(H - 1, yy + 1), xx] > _
# heatmap[max(0, yy - 1), xx]:
# y += 0.25
# else:
# y -= 0.25
#
# ans.append((x, y, val))
# ans = np.array(ans)
#
# if ans is not None:
# for i in range(K):
# # add keypoint if it is not detected
# if ans[i, 2] > 0 and keypoints[i, 2] == 0:
# keypoints[i, :3] = ans[i, :3]
#
# return keypoints

def parse(self, heatmaps, pafs, adjust=True, refine=True):
"""Group keypoints into poses given heatmap and paf.
Note:
batch size: N
number of keypoints: K
number of paf maps: P
heatmap height: H
heatmap width: W
Args:
heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps.
pafs (torch.Tensor[NxPxHxW]): model output pafs.
Returns:
tuple: A tuple containing keypoint grouping results.
- ans (list(np.ndarray)): Pose results.
- scores (list): Score of people.
"""

assert 0, 'The post-process of paf have not been completed.'
# ans = self.match(**self.top_k(heatmaps, pafs))
#
# if adjust:
# if self.use_udp:
# for i in range(len(ans)):
# if ans[i].shape[0] > 0:
# ans[i][..., :2] = post_dark_udp(
# ans[i][..., :2].copy(), heatmaps[i:i + 1, :])
# else:
# ans = self.adjust(ans, heatmaps)
#
# scores = [i[:, 2].mean() for i in ans[0]]
#
# if refine:
# ans = ans[0]
# # for every detected person
# for i in range(len(ans)):
# heatmap_numpy = heatmaps[0].cpu().numpy()
# tag_numpy = tags[0].cpu().numpy()
# if not self.tag_per_joint:
# tag_numpy = np.tile(tag_numpy,
# (self.params.num_joints, 1, 1, 1))
# ans[i] = self.refine(
# heatmap_numpy, tag_numpy, ans[i], use_udp=self.use_udp)
# ans = [ans]
#
# return ans, scores
33 changes: 15 additions & 18 deletions mmpose/models/detectors/paf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmpose.core.evaluation import (aggregate_results_paf, get_group_preds,
get_multi_stage_outputs_paf)
from mmpose.core.post_processing.group import HeatmapParser
from mmpose.core.post_processing.group import PAFParser
from .. import builder
from ..registry import POSENETS
from .base import BasePose
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self,
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.use_udp = test_cfg.get('use_udp', False)
self.parser = HeatmapParser(self.test_cfg)
self.parser = PAFParser(self.test_cfg)
self.init_weights(pretrained=pretrained)

@property
Expand All @@ -74,7 +74,6 @@ def forward(self,
img=None,
targets=None,
masks=None,
joints=None,
img_metas=None,
return_loss=True,
return_heatmap=False,
Expand All @@ -92,11 +91,10 @@ def forward(self,
max_num_people: M
Args:
img(torch.Tensor[NxCximgHximgW]): Input image.
targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps.
targets (list(list)): List of heatmaps and pafs, each of which
multi-scale targets.
masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target
heatmaps
joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target
heatmaps for ae loss
heatmaps.
img_metas(dict):Information about val&test
By default this includes:
- "image_file": image path
Expand All @@ -118,12 +116,11 @@ def forward(self,
"""

if return_loss:
return self.forward_train(img, targets, masks, joints, img_metas,
**kwargs)
return self.forward_train(img, targets, masks, img_metas, **kwargs)
return self.forward_test(
img, img_metas, return_heatmap=return_heatmap, **kwargs)

def forward_train(self, img, targets, masks, joints, img_metas, **kwargs):
def forward_train(self, img, targets, masks, img_metas, **kwargs):
"""Forward the bottom-up model and calculate the loss.
Note:
Expand All @@ -138,11 +135,10 @@ def forward_train(self, img, targets, masks, joints, img_metas, **kwargs):
Args:
img(torch.Tensor[NxCximgHximgW]): Input image.
targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps.
targets (list(list)): List of heatmaps and pafs, each of which
multi-scale targets.
masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target
heatmaps
joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target
heatmaps for ae loss
heatmaps.
img_metas(dict):Information about val&test
By default this includes:
- "image_file": image path
Expand All @@ -166,7 +162,7 @@ def forward_train(self, img, targets, masks, joints, img_metas, **kwargs):
losses = dict()
if self.with_keypoint:
keypoint_losses = self.keypoint_head.get_loss(
output, targets, masks, joints)
output, targets, masks)
losses.update(keypoint_losses)

return losses
Expand Down Expand Up @@ -240,7 +236,7 @@ def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
outputs,
outputs_flipped,
self.test_cfg['with_heatmaps'],
self.test_cfg['with_ae'],
self.test_cfg['with_pafs'],
img_metas['flip_index'],
img_metas['flip_index_paf'],
self.test_cfg['project2image'],
Expand All @@ -259,10 +255,11 @@ def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
# average heatmaps of different scales
aggregated_heatmaps = aggregated_heatmaps / float(
len(test_scale_factor))
tags = torch.cat(pafs, dim=4)
aggregated_pafs = aggregated_pafs / float(len(test_scale_factor))

# perform grouping
grouped, scores = self.parser.parse(aggregated_heatmaps, tags,
grouped, scores = self.parser.parse(aggregated_heatmaps,
aggregated_pafs,
self.test_cfg['adjust'],
self.test_cfg['refine'])

Expand Down
Loading

0 comments on commit c1b888c

Please sign in to comment.