diff --git a/tests/test_post_processing/test_nms.py b/tests/test_post_processing/test_nms.py index 86da74f270..13d793d239 100644 --- a/tests/test_post_processing/test_nms.py +++ b/tests/test_post_processing/test_nms.py @@ -29,6 +29,45 @@ def test_soft_oks_nms(): keep = oks_nms([kpts[i] for i in range(len(kpts))], oks_thr) assert (keep == np.array([0, 2])).all() + kpts_with_score_joints = [] + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([10, 10, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.9]), 17) + }) + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([10, 10, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.4]), 17) + }) + kpts_with_score_joints.append({ + 'keypoints': + np.tile(np.array([100, 100, 0.9]), [17, 1]), + 'area': + 100, + 'score': + np.tile(np.array([0.7]), 17) + }) + keep = soft_oks_nms([ + kpts_with_score_joints[i] for i in range(len(kpts_with_score_joints)) + ], + oks_thr, + score_per_joint=True) + assert (keep == np.array([0, 2, 1])).all() + + keep = oks_nms([ + kpts_with_score_joints[i] for i in range(len(kpts_with_score_joints)) + ], + oks_thr, + score_per_joint=True) + assert (keep == np.array([0, 2])).all() + def test_func_nms(): result = nms(np.array([[0, 0, 10, 10, 0.9], [0, 0, 10, 8, 0.8]]), 0.5)