diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml
index c730fa9c4..d130b9cb9 100644
--- a/sleap/config/pipeline_form.yaml
+++ b/sleap/config/pipeline_form.yaml
@@ -52,7 +52,7 @@ training:
This pipeline uses two models: a "centroid" model to
locate and crop around each animal in the frame, and a
"centered-instance confidence map" model for predicted node locations
- for each individual animal predicted by the centroid model.'
+ for each individual animal predicted by the centroid model.'
- label: Max Instances
name: max_instances
type: optional_int
@@ -217,7 +217,7 @@ training:
- name: controller_port
label: Controller Port
type: int
- default: 9000
+ default: 9000
range: 1024,65535
- name: publish_port
@@ -388,7 +388,7 @@ inference:
tracking-only:
- name: batch_size
- label: Batch Size
+ label: Batch Size
type: int
default: 4
range: 1,512
@@ -439,7 +439,7 @@ inference:
label: Similarity Method
type: list
default: instance
- options: instance,centroid,iou
+ options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
@@ -478,6 +478,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
+ - type: text
+ text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.'
+ - name: tracking.oks_errors
+ label: Keypoints errors in pixels
+ help: 'Standard error in pixels of the distance for each keypoint.
+ If the list is empty, defaults to 1. If singleton list, each keypoint has
+ the same error. Otherwise, the length should be the same as the number of
+ keypoints in the skeleton.'
+ type: string
+ default:
+ - name: tracking.oks_score_weighting
+ label: Use prediction score for weighting
+ help: 'Use prediction scores to weight the similarity of each keypoint'
+ type: bool
+ default: false
- type: text
text: 'Post-tracker data cleaning:'
- name: tracking.post_connect_single_breaks
@@ -521,8 +537,8 @@ inference:
- name: tracking.similarity
label: Similarity Method
type: list
- default: iou
- options: instance,centroid,iou
+ default: instance
+ options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
@@ -557,6 +573,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
+ - type: text
+ text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.'
+ - name: tracking.oks_errors
+ label: Keypoints errors in pixels
+ help: 'Standard error in pixels of the distance for each keypoint.
+ If the list is empty, defaults to 1. If singleton list, each keypoint has
+ the same error. Otherwise, the length should be the same as the number of
+ keypoints in the skeleton.'
+ type: string
+ default:
+ - name: tracking.oks_score_weighting
+ label: Use prediction score for weighting
+ help: 'Use prediction scores to weight the similarity of each keypoint'
+ type: bool
+ default: false
- type: text
text: 'Post-tracker data cleaning:'
- name: tracking.post_connect_single_breaks
diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py
index 7569607a0..d0bb1f3ba 100644
--- a/sleap/gui/learning/runners.py
+++ b/sleap/gui/learning/runners.py
@@ -260,12 +260,20 @@ def make_predict_cli_call(
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
+ "tracking.oks_score_weighting",
)
for key in bool_items_as_ints:
if key in self.inference_params:
self.inference_params[key] = int(self.inference_params[key])
+ remove_spaces_items = ("tracking.similarity",)
+
+ for key in remove_spaces_items:
+ if key in self.inference_params:
+ value = self.inference_params[key]
+ self.inference_params[key] = value.replace(" ", "_")
+
for key, val in self.inference_params.items():
if not key.startswith(("_", "outputs.", "model.", "data.")):
cli_args.extend((f"--{key}", str(val)))
diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py
index 10b2953b7..b2f35b21f 100644
--- a/sleap/nn/tracker/components.py
+++ b/sleap/nn/tracker/components.py
@@ -14,7 +14,8 @@
"""
import operator
from collections import defaultdict
-from typing import List, Tuple, Optional, TypeVar, Callable
+import logging
+from typing import List, Tuple, Union, Optional, TypeVar, Callable
import attr
import numpy as np
@@ -23,6 +24,8 @@
from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils
+logger = logging.getLogger(__name__)
+
InstanceType = TypeVar("InstanceType", Instance, PredictedInstance)
@@ -40,6 +43,95 @@ def instance_similarity(
return similarity
+def factory_object_keypoint_similarity(
+ keypoint_errors: Optional[Union[List, int, float]] = None,
+ score_weighting: bool = False,
+ normalization_keypoints: str = "all",
+) -> Callable:
+ """Factory for similarity function based on object keypoints.
+
+ Args:
+ keypoint_errors: The standard error of the distance between the predicted
+ keypoint and the true value, in pixels.
+ If None or empty list, defaults to 1.
+ If a scalar or singleton list, every keypoint has the same error.
+ If a list, defines the error for each keypoint, the length should be equal
+ to the number of keypoints in the skeleton.
+ score_weighting: If True, use `score` of `PredictedPoint` to weigh
+ `keypoint_errors`. If False, do not add a weight to `keypoint_errors`.
+ normalization_keypoints: Determine how to normalize similarity score. One of
+ ["all", "ref", "union"]. If "all", similarity score is normalized by number
+ of reference points. If "ref", similarity score is normalized by number of
+ visible reference points. If "union", similarity score is normalized by
+ number of points both visible in query and reference instance.
+ Default is "all".
+
+ Returns:
+ Callable that returns object keypoint similarity between two `Instance`s.
+
+ """
+ keypoint_errors = 1 if keypoint_errors is None else keypoint_errors
+ with np.errstate(divide="ignore"):
+ kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2)
+
+ def object_keypoint_similarity(
+ ref_instance: InstanceType, query_instance: InstanceType
+ ) -> float:
+ nonlocal kp_precision
+ # Keypoints
+ ref_points = ref_instance.points_array
+ query_points = query_instance.points_array
+ # Keypoint scores
+ if score_weighting:
+ ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points)))
+ query_scores = getattr(query_instance, "scores", np.ones(len(query_points)))
+ else:
+ ref_scores = 1
+ query_scores = 1
+ # Number of keypoint for normalization
+ if normalization_keypoints in ("ref", "union"):
+ ref_visible = ~(np.isnan(ref_points).any(axis=1))
+ if normalization_keypoints == "ref":
+ max_n_keypoints = np.sum(ref_visible)
+ elif normalization_keypoints == "union":
+ query_visible = ~(np.isnan(query_points).any(axis=1))
+ max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible))
+ else: # if normalization_keypoints == "all":
+ max_n_keypoints = len(ref_points)
+ if max_n_keypoints == 0:
+ return 0
+
+ # Make sure the sizes of kp_precision and n_points match
+ if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size:
+ # Correct kp_precision size to fit number of points
+ n_points = ref_points.size // 2
+ mess = (
+ "keypoint_errors array should have the same size as the number of "
+ f"keypoints in the instance: {kp_precision.size} != {n_points}"
+ )
+
+ if kp_precision.size > n_points:
+ kp_precision = kp_precision[:n_points]
+ mess += "\nTruncating keypoint_errors array."
+
+ else: # elif kp_precision.size < n_points:
+ pad = n_points - kp_precision.size
+ kp_precision = np.pad(kp_precision, (0, pad), "edge")
+ mess += "\nPadding keypoint_errors array by repeating the last value."
+ logger.warning(mess)
+
+ # Compute distances
+ dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision
+
+ similarity = (
+ np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints
+ )
+
+ return similarity
+
+ return object_keypoint_similarity
+
+
def centroid_distance(
ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict()
) -> float:
diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py
index 9865b7db5..db3184844 100644
--- a/sleap/nn/tracking.py
+++ b/sleap/nn/tracking.py
@@ -10,6 +10,7 @@
from sleap import Track, LabeledFrame, Skeleton
from sleap.nn.tracker.components import (
+ factory_object_keypoint_similarity,
instance_similarity,
centroid_distance,
instance_iou,
@@ -492,6 +493,7 @@ def get_candidates(
instance=instance_similarity,
centroid=centroid_distance,
iou=instance_iou,
+ object_keypoint=instance_similarity,
)
match_policies = dict(
@@ -838,6 +840,10 @@ def make_tracker_by_name(
# Max tracking options
max_tracks: Optional[int] = None,
max_tracking: bool = False,
+ # Object keypoint similarity options
+ oks_errors: Optional[list] = None,
+ oks_score_weighting: bool = False,
+ oks_normalization: str = "all",
**kwargs,
) -> BaseTracker:
@@ -858,7 +864,14 @@ def make_tracker_by_name(
raise ValueError(f"{match} is not a valid tracker matching function.")
candidate_maker = tracker_policies[tracker](min_points=min_match_points)
- similarity_function = similarity_policies[similarity]
+ if similarity == "object_keypoint":
+ similarity_function = factory_object_keypoint_similarity(
+ keypoint_errors=oks_errors,
+ score_weighting=oks_score_weighting,
+ normalization_keypoints=oks_normalization,
+ )
+ else:
+ similarity_function = similarity_policies[similarity]
matching_function = match_policies[match]
if tracker == "flow":
@@ -1054,6 +1067,42 @@ def int_list_func(s):
] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used."
options.append(option)
+ def float_list_func(s):
+ return [float(x.strip()) for x in s.split(",")] if s else None
+
+ option = dict(name="oks_errors", default="1")
+ option["type"] = float_list_func
+ option["help"] = (
+ "For Object Keypoint similarity: the standard error of the distance "
+ "between the predicted keypoint and the true value, in pixels.\n"
+ "If None or empty list, defaults to 1. If a scalar or singleton list, "
+ "every keypoint has the same error. If a list, defines the error for each "
+ "keypoint, the length should be equal to the number of keypoints in the "
+ "skeleton."
+ )
+ options.append(option)
+
+ option = dict(name="oks_score_weighting", default="0")
+ option["type"] = int
+ option["help"] = (
+ "For Object Keypoint similarity: if 0 (default), only the distance between the reference "
+ "and query keypoint is used to compute the similarity. If 1, each distance is weighted "
+ "by the prediction scores of the reference and query keypoint."
+ )
+ options.append(option)
+
+ option = dict(name="oks_normalization", default="all")
+ option["type"] = str
+ option["options"] = ["all", "ref", "union"]
+ option["help"] = (
+ "For Object Keypoint similarity: Determine how to normalize similarity score. "
+ "If 'all', similarity score is normalized by number of reference points. "
+ "If 'ref', similarity score is normalized by number of visible reference points. "
+ "If 'union', similarity score is normalized by number of points both visible "
+ "in query and reference instance."
+ )
+ options.append(option)
+
return options
@classmethod
diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py
index 801fcc092..ec5dfbc29 100644
--- a/tests/fixtures/datasets.py
+++ b/tests/fixtures/datasets.py
@@ -41,6 +41,13 @@ def centered_pair_predictions():
return Labels.load_file(TEST_JSON_PREDICTIONS)
+@pytest.fixture
+def centered_pair_predictions_sorted(centered_pair_predictions):
+ labels: Labels = centered_pair_predictions
+ labels.labeled_frames.sort(key=lambda lf: lf.frame_idx)
+ return labels
+
+
@pytest.fixture
def min_labels():
return Labels.load_file(TEST_JSON_MIN_LABELS)
diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py
index f99f136ab..98f5fbcec 100644
--- a/tests/nn/test_inference.py
+++ b/tests/nn/test_inference.py
@@ -1373,7 +1373,7 @@ def test_retracking(
# Create sleap-track command
cmd = (
f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 "
- "--cpu"
+ "--tracking.similarity object_keypoint --cpu"
)
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
@@ -1393,6 +1393,8 @@ def test_retracking(
parser = _make_cli_parser()
args, _ = parser.parse_known_args(args=args)
tracker = _make_tracker_from_cli(args)
+ # Additional check for similarity method
+ assert tracker.similarity_function.__name__ == "object_keypoint_similarity"
output_path = f"{slp_path}.{tracker.get_name()}.slp"
# Assert tracked predictions file exists
@@ -1747,9 +1749,9 @@ def test_sleap_track_invalid_input(
sleap_track(args=args)
-def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
+def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
- labels: Labels = centered_pair_predictions
+ labels: Labels = centered_pair_predictions_sorted
track_window = 5
# Setup tracker
@@ -1759,7 +1761,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker)
# Run tracking
- frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
+ frames = labels.labeled_frames
# Run tracking on subset of frames using psuedo-implementation of
# sleap.nn.tracking.run_tracker
diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py
index f861241ee..94a9747a5 100644
--- a/tests/nn/test_tracker_components.py
+++ b/tests/nn/test_tracker_components.py
@@ -9,11 +9,29 @@
FrameMatches,
greedy_matching,
)
+from sleap.io.dataset import Labels
from sleap.instance import PredictedInstance
from sleap.skeleton import Skeleton
+def tracker_by_name(frames=None, **kwargs):
+ t = Tracker.make_tracker_by_name(**kwargs)
+ if frames is None:
+ t.track([])
+ t.final_pass([])
+ return
+
+ for lf in frames:
+ # Clear the tracks
+ for inst in lf.instances:
+ inst.track = None
+
+ track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
+ t.track(**track_args)
+ t.final_pass(frames)
+
+
@pytest.mark.parametrize(
"tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
)
@@ -21,11 +39,33 @@
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
@pytest.mark.parametrize("count", [0, 2])
def test_tracker_by_name(tracker, similarity, match, count):
- t = Tracker.make_tracker_by_name(
- "flow", "instance", "greedy", clean_instance_count=2
+ tracker_by_name(
+ tracker=tracker, similarity=similarity, match=match, clean_instance_count=count
+ )
+
+
+@pytest.mark.parametrize(
+ "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
+)
+@pytest.mark.parametrize("oks_score_weighting", ["True", "False"])
+@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"])
+def test_oks_tracker_by_name(
+ centered_pair_predictions_sorted,
+ tracker,
+ oks_score_weighting,
+ oks_normalization,
+):
+ # This is slow, so limit to 5 time points
+ frames = centered_pair_predictions_sorted[:5]
+
+ tracker_by_name(
+ frames=frames,
+ tracker=tracker,
+ similarity="object_keypoint",
+ matching="greedy",
+ oks_score_weighting=oks_score_weighting,
+ oks_normalization=oks_normalization,
)
- t.track([])
- t.final_pass([])
def test_cull_instances(centered_pair_predictions):