Skip to content

Commit

Permalink
Move test code to official tests
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Sep 29, 2023
1 parent c072a34 commit f1b0be4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 43 deletions.
43 changes: 0 additions & 43 deletions sleap/info/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,46 +238,3 @@ def point_match_count(dist_array: np.ndarray, thresh: float = 5) -> int:
def point_nonmatch_count(dist_array: np.ndarray, thresh: float = 5) -> int:
"""Given an array of distances, returns number which are not <= threshold."""
return dist_array.shape[0] - point_match_count(dist_array, thresh)


if __name__ == "__main__":

labels_gt = Labels.load_json("tests/data/json_format_v1/centered_pair.json")
labels_pr = Labels.load_json(
"tests/data/json_format_v2/centered_pair_predictions.json"
)

# OPTION 1

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

nodewise_matching_func = match_instance_lists_nodewise

# OPTION 2

# Match each ground truth instance to a distinct predicted instance:
# We want to maximize the number of "matching" points between instances,
# where "match" means the points are within some threshold distance.
# Note that each sorted list will be as long as the shorted input list.

instwise_matching_func = lambda gt_list, pr_list: match_instance_lists(
gt_list, pr_list, point_nonmatch_count
)

# PICK THE FUNCTION

inst_matching_func = nodewise_matching_func
# inst_matching_func = instwise_matching_func

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
print(f"{node_name}\t\t{mean_d}")
58 changes: 58 additions & 0 deletions tests/info/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np

from sleap import Labels
from sleap.info.metrics import (
match_instance_lists_nodewise,
matched_instance_distances,
)


def test_matched_instance_distances(centered_pair_labels, centered_pair_predictions):
labels_gt = centered_pair_labels
labels_pr = centered_pair_predictions

# Match each ground truth instance node to the closest corresponding node
# from any predicted instance in the same frame.

inst_matching_func = match_instance_lists_nodewise

# Calculate distances
frame_idxs, D, points_gt, points_pr = matched_instance_distances(
labels_gt, labels_pr, inst_matching_func
)

# Show mean difference for each node
node_names = labels_gt.skeletons[0].node_names

for node_idx, node_name in enumerate(node_names):
mean_d = np.nanmean(D[..., node_idx])
# print(f"{node_name}\t\t{mean_d}")

"""Expected values (instance-wise matching):
head 0.872426920709296
neck 0.8016280746914615
thorax 0.8602021363390538
abdomen 1.01012200038258
wingL 1.1297727023475939
wingR 1.0869857897008424
forelegL1 0.780584225081443
forelegL2 1.170805798894702
forelegL3 1.1020486509389473
forelegR1 0.9014698776116817
forelegR2 0.9448001033112047
forelegR3 1.308385214215777
midlegL1 0.9095691623265347
midlegL2 1.2203595627907582
midlegL3 0.9813843358470163
midlegR1 0.9871017182813739
midlegR2 1.0209829335569256
midlegR3 1.0990681234096988
hindlegL1 1.0005335192834348
hindlegL2 1.273539518539708
hindlegL3 1.1752245985832817
hindlegR1 1.1402833959265248
hindlegR2 1.3143221301212737
hindlegR3 1.0441458592503365"""
assert mean_d > 0.77
assert mean_d < 1.32

0 comments on commit f1b0be4

Please sign in to comment.