diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py index 0ec61f713c..3a622b5801 100644 --- a/mmpose/apis/inference.py +++ b/mmpose/apis/inference.py @@ -525,6 +525,7 @@ def inference_bottom_up_pose_model(model, cfg = model.cfg device = next(model.parameters()).device + score_per_joint = cfg.model.test_cfg.get('score_per_joint', False) # build the data pipeline channel_order = cfg.test_pipeline[0].get('channel_order', 'rgb') @@ -576,7 +577,11 @@ def inference_bottom_up_pose_model(model, }) # pose nms - keep = oks_nms(pose_results, pose_nms_thr, sigmas) + keep = oks_nms( + pose_results, + pose_nms_thr, + sigmas, + score_per_joint=score_per_joint) pose_results = [pose_results[_keep] for _keep in keep] return pose_results, returned_outputs diff --git a/mmpose/core/post_processing/group.py b/mmpose/core/post_processing/group.py index df50a3cd5b..6235dbc111 100644 --- a/mmpose/core/post_processing/group.py +++ b/mmpose/core/post_processing/group.py @@ -150,6 +150,7 @@ def __init__(self, cfg): self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1, cfg['nms_padding']) self.use_udp = cfg.get('use_udp', False) + self.score_per_joint = cfg.get('score_per_joint', False) def nms(self, heatmaps): """Non-Maximum Suppression for heatmaps. @@ -375,7 +376,7 @@ def parse(self, heatmaps, tags, adjust=True, refine=True): tuple: A tuple containing keypoint grouping results. - results (list(np.ndarray)): Pose results. - - scores (list): Score of people. + - scores (list/list(np.ndarray)): Score of people. """ results = self.match(**self.top_k(heatmaps, tags)) @@ -388,7 +389,10 @@ def parse(self, heatmaps, tags, adjust=True, refine=True): else: results = self.adjust(results, heatmaps) - scores = [i[:, 2].mean() for i in results[0]] + if self.score_per_joint: + scores = [i[:, 2] for i in results[0]] + else: + scores = [i[:, 2].mean() for i in results[0]] if refine: results = results[0] diff --git a/mmpose/core/post_processing/nms.py b/mmpose/core/post_processing/nms.py index 9e19e4845e..86a0ab35e0 100644 --- a/mmpose/core/post_processing/nms.py +++ b/mmpose/core/post_processing/nms.py @@ -86,7 +86,7 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, vis_thr=None): return ious -def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None): +def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None, score_per_joint=False): """OKS NMS implementations. Args: @@ -94,6 +94,7 @@ def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None): thr: Retain overlap < thr. sigmas: standard deviation of keypoint labelling. vis_thr: threshold of the keypoint visibility. + score_per_joint: the input scores (in kpts_db) are per joint scores Returns: np.ndarray: indexes to keep. @@ -101,7 +102,11 @@ def oks_nms(kpts_db, thr, sigmas=None, vis_thr=None): if len(kpts_db) == 0: return [] - scores = np.array([k['score'] for k in kpts_db]) + if score_per_joint: + scores = np.array([k['score'].mean() for k in kpts_db]) + else: + scores = np.array([k['score'] for k in kpts_db]) + kpts = np.array([k['keypoints'].flatten() for k in kpts_db]) areas = np.array([k['area'] for k in kpts_db]) @@ -147,7 +152,12 @@ def _rescore(overlap, scores, thr, type='gaussian'): return scores -def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None): +def soft_oks_nms(kpts_db, + thr, + max_dets=20, + sigmas=None, + vis_thr=None, + score_per_joint=False): """Soft OKS NMS implementations. Args: @@ -155,6 +165,7 @@ def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None): thr: retain oks overlap < thr. max_dets: max number of detections to keep. sigmas: Keypoint labelling uncertainty. + score_per_joint: the input scores (in kpts_db) are per joint scores Returns: np.ndarray: indexes to keep. @@ -162,7 +173,11 @@ def soft_oks_nms(kpts_db, thr, max_dets=20, sigmas=None, vis_thr=None): if len(kpts_db) == 0: return [] - scores = np.array([k['score'] for k in kpts_db]) + if score_per_joint: + scores = np.array([k['score'].mean() for k in kpts_db]) + else: + scores = np.array([k['score'] for k in kpts_db]) + kpts = np.array([k['keypoints'].flatten() for k in kpts_db]) areas = np.array([k['area'] for k in kpts_db]) diff --git a/tests/test_post_processing/test_group.py b/tests/test_post_processing/test_group.py index aea0be195e..2ec66efc3a 100644 --- a/tests/test_post_processing/test_group.py +++ b/tests/test_post_processing/test_group.py @@ -32,9 +32,41 @@ def test_group(): fake_tag[0, 8, 6, 6] = 0.9 grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True) assert grouped[0][0, 0, 0] == 10.25 + assert abs(scores[0] - 0.2) < 0.001 cfg['tag_per_joint'] = False parser = HeatmapParser(cfg) grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False) assert grouped[0][0, 0, 0] == 10. grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True) assert grouped[0][0, 0, 0] == 10. + + +def test_group_score_per_joint(): + cfg = {} + cfg['num_joints'] = 17 + cfg['detection_threshold'] = 0.1 + cfg['tag_threshold'] = 1 + cfg['use_detection_val'] = True + cfg['ignore_too_much'] = False + cfg['nms_kernel'] = 5 + cfg['nms_padding'] = 2 + cfg['tag_per_joint'] = True + cfg['max_num_people'] = 1 + cfg['score_per_joint'] = True + parser = HeatmapParser(cfg) + fake_heatmap = torch.zeros(1, 1, 5, 5) + fake_heatmap[0, 0, 3, 3] = 1 + fake_heatmap[0, 0, 3, 2] = 0.8 + assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0 + fake_heatmap = torch.zeros(1, 17, 32, 32) + fake_tag = torch.zeros(1, 17, 32, 32, 1) + fake_heatmap[0, 0, 10, 10] = 0.8 + fake_heatmap[0, 1, 12, 12] = 0.9 + fake_heatmap[0, 4, 8, 8] = 0.8 + fake_heatmap[0, 8, 6, 6] = 0.9 + fake_tag[0, 0, 10, 10] = 0.8 + fake_tag[0, 1, 12, 12] = 0.9 + fake_tag[0, 4, 8, 8] = 0.8 + fake_tag[0, 8, 6, 6] = 0.9 + grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True) + assert len(scores[0]) == 17 diff --git a/tests/test_post_processing/test_nms.py b/tests/test_post_processing/test_nms.py index 86da74f270..13d793d239 100644 --- a/tests/test_post_processing/test_nms.py +++ b/tests/test_post_processing/test_nms.py @@ -29,6 +29,45 @@ def test_soft_oks_nms(): keep = oks_nms([kpts[i] for i in range(len(kpts))], oks_thr) assert (keep == np.array([0, 2])).all() + kpts_with_score_joints = [] + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([10, 10, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.9]), 17) + }) + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([10, 10, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.4]), 17) + }) + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([100, 100, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.7]), 17) + }) + keep = soft_oks_nms([ + kpts_with_score_joints[i] for i in range(len(kpts_with_score_joints)) + ], + oks_thr, + score_per_joint=True) + assert (keep == np.array([0, 2, 1])).all() + + keep = oks_nms([ + kpts_with_score_joints[i] for i in range(len(kpts_with_score_joints)) + ], + oks_thr, + score_per_joint=True) + assert (keep == np.array([0, 2])).all() + def test_func_nms(): result = nms(np.array([[0, 0, 10, 10, 0.9], [0, 0, 10, 8, 0.8]]), 0.5)