From f1b0be4cf449d9150660d0dd3cd6610cc45bcdd2 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 28 Sep 2023 20:54:54 -0700 Subject: [PATCH] Move test code to official tests --- sleap/info/metrics.py | 43 ---------------------------- tests/info/test_metrics.py | 58 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 43 deletions(-) create mode 100644 tests/info/test_metrics.py diff --git a/sleap/info/metrics.py b/sleap/info/metrics.py index d5e1a7e2d..5bec077e4 100644 --- a/sleap/info/metrics.py +++ b/sleap/info/metrics.py @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int: def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int: """Given an array of distances, returns number which are not <= threshold.""" return dist_array.shape[0] - point_match_count(dist_array, thresh) - - -if __name__ == "__main__": - - labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json") - labels_pr = Labels.load_json( - "tests/data/json_format_v2/centered_pair_predictions.json" - ) - - # OPTION 1 - - # Match each ground truth instance node to the closest corresponding node - # from any predicted instance in the same frame. - - nodewise_matching_func = match_instance_lists_nodewise - - # OPTION 2 - - # Match each ground truth instance to a distinct predicted instance: - # We want to maximize the number of "matching" points between instances, - # where "match" means the points are within some threshold distance. - # Note that each sorted list will be as long as the shorted input list. - - instwise_matching_func = lambda gt_list, pr_list: match_instance_lists( - gt_list, pr_list, point_nonmatch_count - ) - - # PICK THE FUNCTION - - inst_matching_func = nodewise_matching_func - # inst_matching_func = instwise_matching_func - - # Calculate distances - frame_idxs, D, points_gt, points_pr = matched_instance_distances( - labels_gt, labels_pr, inst_matching_func - ) - - # Show mean difference for each node - node_names = labels_gt.skeletons[0].node_names - - for node_idx, node_name in enumerate(node_names): - mean_d = np.nanmean(D[..., node_idx]) - print(f"{node_name}\t\t{mean_d}") diff --git a/tests/info/test_metrics.py b/tests/info/test_metrics.py new file mode 100644 index 000000000..0ec8933f6 --- /dev/null +++ b/tests/info/test_metrics.py @@ -0,0 +1,58 @@ +import numpy as np + +from sleap import Labels +from sleap.info.metrics import ( + match_instance_lists_nodewise, + matched_instance_distances, +) + + +def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions): + labels_gt = centered_pair_labels + labels_pr = centered_pair_predictions + + # Match each ground truth instance node to the closest corresponding node + # from any predicted instance in the same frame. + + inst_matching_func = match_instance_lists_nodewise + + # Calculate distances + frame_idxs, D, points_gt, points_pr = matched_instance_distances( + labels_gt, labels_pr, inst_matching_func + ) + + # Show mean difference for each node + node_names = labels_gt.skeletons[0].node_names + + for node_idx, node_name in enumerate(node_names): + mean_d = np.nanmean(D[..., node_idx]) + # print(f"{node_name}\t\t{mean_d}") + + """Expected values (instance-wise matching): + + head 0.872426920709296 + neck 0.8016280746914615 + thorax 0.8602021363390538 + abdomen 1.01012200038258 + wingL 1.1297727023475939 + wingR 1.0869857897008424 + forelegL1 0.780584225081443 + forelegL2 1.170805798894702 + forelegL3 1.1020486509389473 + forelegR1 0.9014698776116817 + forelegR2 0.9448001033112047 + forelegR3 1.308385214215777 + midlegL1 0.9095691623265347 + midlegL2 1.2203595627907582 + midlegL3 0.9813843358470163 + midlegR1 0.9871017182813739 + midlegR2 1.0209829335569256 + midlegR3 1.0990681234096988 + hindlegL1 1.0005335192834348 + hindlegL2 1.273539518539708 + hindlegL3 1.1752245985832817 + hindlegR1 1.1402833959265248 + hindlegR2 1.3143221301212737 + hindlegR3 1.0441458592503365""" + assert mean_d > 0.77 + assert mean_d < 1.32