From b4374b31f084d24915ce8413f020a60cabe704f6 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 12 Oct 2022 14:57:28 +0100 Subject: [PATCH] Add object keypoint similarity method --- sleap/config/pipeline_form.yaml | 44 ++++++++++++-- sleap/gui/learning/runners.py | 8 +++ sleap/nn/tracker/components.py | 94 ++++++++++++++++++++++++++++- sleap/nn/tracking.py | 51 +++++++++++++++- tests/fixtures/datasets.py | 7 +++ tests/nn/test_inference.py | 10 +-- tests/nn/test_tracker_components.py | 48 +++++++++++++-- 7 files changed, 246 insertions(+), 16 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c730fa9c4..d130b9cb9 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -52,7 +52,7 @@ training: This pipeline uses two models: a "centroid" model to locate and crop around each animal in the frame, and a "centered-instance confidence map" model for predicted node locations - for each individual animal predicted by the centroid model.' + for each individual animal predicted by the centroid model.' - label: Max Instances name: max_instances type: optional_int @@ -217,7 +217,7 @@ training: - name: controller_port label: Controller Port type: int - default: 9000 + default: 9000 range: 1024,65535 - name: publish_port @@ -388,7 +388,7 @@ inference: tracking-only: - name: batch_size - label: Batch Size + label: Batch Size type: int default: 4 range: 1,512 @@ -439,7 +439,7 @@ inference: label: Similarity Method type: list default: instance - options: instance,centroid,iou + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -478,6 +478,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks @@ -521,8 +537,8 @@ inference: - name: tracking.similarity label: Similarity Method type: list - default: iou - options: instance,centroid,iou + default: instance + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -557,6 +573,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 7569607a0..d0bb1f3ba 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -260,12 +260,20 @@ def make_predict_cli_call( "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", + "tracking.oks_score_weighting", ) for key in bool_items_as_ints: if key in self.inference_params: self.inference_params[key] = int(self.inference_params[key]) + remove_spaces_items = ("tracking.similarity",) + + for key in remove_spaces_items: + if key in self.inference_params: + value = self.inference_params[key] + self.inference_params[key] = value.replace(" ", "_") + for key, val in self.inference_params.items(): if not key.startswith(("_", "outputs.", "model.", "data.")): cli_args.extend((f"--{key}", str(val))) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 10b2953b7..b2f35b21f 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -14,7 +14,8 @@ """ import operator from collections import defaultdict -from typing import List, Tuple, Optional, TypeVar, Callable +import logging +from typing import List, Tuple, Union, Optional, TypeVar, Callable import attr import numpy as np @@ -23,6 +24,8 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +logger = logging.getLogger(__name__) + InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) @@ -40,6 +43,95 @@ def instance_similarity( return similarity +def factory_object_keypoint_similarity( + keypoint_errors: Optional[Union[List, int, float]] = None, + score_weighting: bool = False, + normalization_keypoints: str = "all", +) -> Callable: + """Factory for similarity function based on object keypoints. + + Args: + keypoint_errors: The standard error of the distance between the predicted + keypoint and the true value, in pixels. + If None or empty list, defaults to 1. + If a scalar or singleton list, every keypoint has the same error. + If a list, defines the error for each keypoint, the length should be equal + to the number of keypoints in the skeleton. + score_weighting: If True, use `score` of `PredictedPoint` to weigh + `keypoint_errors`. If False, do not add a weight to `keypoint_errors`. + normalization_keypoints: Determine how to normalize similarity score. One of + ["all", "ref", "union"]. If "all", similarity score is normalized by number + of reference points. If "ref", similarity score is normalized by number of + visible reference points. If "union", similarity score is normalized by + number of points both visible in query and reference instance. + Default is "all". + + Returns: + Callable that returns object keypoint similarity between two `Instance`s. + + """ + keypoint_errors = 1 if keypoint_errors is None else keypoint_errors + with np.errstate(divide="ignore"): + kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2) + + def object_keypoint_similarity( + ref_instance: InstanceType, query_instance: InstanceType + ) -> float: + nonlocal kp_precision + # Keypoints + ref_points = ref_instance.points_array + query_points = query_instance.points_array + # Keypoint scores + if score_weighting: + ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points))) + query_scores = getattr(query_instance, "scores", np.ones(len(query_points))) + else: + ref_scores = 1 + query_scores = 1 + # Number of keypoint for normalization + if normalization_keypoints in ("ref", "union"): + ref_visible = ~(np.isnan(ref_points).any(axis=1)) + if normalization_keypoints == "ref": + max_n_keypoints = np.sum(ref_visible) + elif normalization_keypoints == "union": + query_visible = ~(np.isnan(query_points).any(axis=1)) + max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible)) + else: # if normalization_keypoints == "all": + max_n_keypoints = len(ref_points) + if max_n_keypoints == 0: + return 0 + + # Make sure the sizes of kp_precision and n_points match + if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size: + # Correct kp_precision size to fit number of points + n_points = ref_points.size // 2 + mess = ( + "keypoint_errors array should have the same size as the number of " + f"keypoints in the instance: {kp_precision.size} != {n_points}" + ) + + if kp_precision.size > n_points: + kp_precision = kp_precision[:n_points] + mess += "\nTruncating keypoint_errors array." + + else: # elif kp_precision.size < n_points: + pad = n_points - kp_precision.size + kp_precision = np.pad(kp_precision, (0, pad), "edge") + mess += "\nPadding keypoint_errors array by repeating the last value." + logger.warning(mess) + + # Compute distances + dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision + + similarity = ( + np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints + ) + + return similarity + + return object_keypoint_similarity + + def centroid_distance( ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict() ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9865b7db5..db3184844 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -10,6 +10,7 @@ from sleap import Track, LabeledFrame, Skeleton from sleap.nn.tracker.components import ( + factory_object_keypoint_similarity, instance_similarity, centroid_distance, instance_iou, @@ -492,6 +493,7 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, + object_keypoint=instance_similarity, ) match_policies = dict( @@ -838,6 +840,10 @@ def make_tracker_by_name( # Max tracking options max_tracks: Optional[int] = None, max_tracking: bool = False, + # Object keypoint similarity options + oks_errors: Optional[list] = None, + oks_score_weighting: bool = False, + oks_normalization: str = "all", **kwargs, ) -> BaseTracker: @@ -858,7 +864,14 @@ def make_tracker_by_name( raise ValueError(f"{match} is not a valid tracker matching function.") candidate_maker = tracker_policies[tracker](min_points=min_match_points) - similarity_function = similarity_policies[similarity] + if similarity == "object_keypoint": + similarity_function = factory_object_keypoint_similarity( + keypoint_errors=oks_errors, + score_weighting=oks_score_weighting, + normalization_keypoints=oks_normalization, + ) + else: + similarity_function = similarity_policies[similarity] matching_function = match_policies[match] if tracker == "flow": @@ -1054,6 +1067,42 @@ def int_list_func(s): ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." options.append(option) + def float_list_func(s): + return [float(x.strip()) for x in s.split(",")] if s else None + + option = dict(name="oks_errors", default="1") + option["type"] = float_list_func + option["help"] = ( + "For Object Keypoint similarity: the standard error of the distance " + "between the predicted keypoint and the true value, in pixels.\n" + "If None or empty list, defaults to 1. If a scalar or singleton list, " + "every keypoint has the same error. If a list, defines the error for each " + "keypoint, the length should be equal to the number of keypoints in the " + "skeleton." + ) + options.append(option) + + option = dict(name="oks_score_weighting", default="0") + option["type"] = int + option["help"] = ( + "For Object Keypoint similarity: if 0 (default), only the distance between the reference " + "and query keypoint is used to compute the similarity. If 1, each distance is weighted " + "by the prediction scores of the reference and query keypoint." + ) + options.append(option) + + option = dict(name="oks_normalization", default="all") + option["type"] = str + option["options"] = ["all", "ref", "union"] + option["help"] = ( + "For Object Keypoint similarity: Determine how to normalize similarity score. " + "If 'all', similarity score is normalized by number of reference points. " + "If 'ref', similarity score is normalized by number of visible reference points. " + "If 'union', similarity score is normalized by number of points both visible " + "in query and reference instance." + ) + options.append(option) + return options @classmethod diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 801fcc092..ec5dfbc29 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -41,6 +41,13 @@ def centered_pair_predictions(): return Labels.load_file(TEST_JSON_PREDICTIONS) +@pytest.fixture +def centered_pair_predictions_sorted(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels.labeled_frames.sort(key=lambda lf: lf.frame_idx) + return labels + + @pytest.fixture def min_labels(): return Labels.load_file(TEST_JSON_MIN_LABELS) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index f99f136ab..98f5fbcec 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1373,7 +1373,7 @@ def test_retracking( # Create sleap-track command cmd = ( f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 " - "--cpu" + "--tracking.similarity object_keypoint --cpu" ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" @@ -1393,6 +1393,8 @@ def test_retracking( parser = _make_cli_parser() args, _ = parser.parse_known_args(args=args) tracker = _make_tracker_from_cli(args) + # Additional check for similarity method + assert tracker.similarity_function.__name__ == "object_keypoint_similarity" output_path = f"{slp_path}.{tracker.get_name()}.slp" # Assert tracked predictions file exists @@ -1747,9 +1749,9 @@ def test_sleap_track_invalid_input( sleap_track(args=args) -def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): +def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): """Test flow tracker instances are pruned.""" - labels: Labels = centered_pair_predictions + labels: Labels = centered_pair_predictions_sorted track_window = 5 # Setup tracker @@ -1759,7 +1761,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking - frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) + frames = labels.labeled_frames # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index f861241ee..94a9747a5 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,11 +9,29 @@ FrameMatches, greedy_matching, ) +from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton +def tracker_by_name(frames=None, **kwargs): + t = Tracker.make_tracker_by_name(**kwargs) + if frames is None: + t.track([]) + t.final_pass([]) + return + + for lf in frames: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + t.track(**track_args) + t.final_pass(frames) + + @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) @@ -21,11 +39,33 @@ @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name(tracker, similarity, match, count): - t = Tracker.make_tracker_by_name( - "flow", "instance", "greedy", clean_instance_count=2 + tracker_by_name( + tracker=tracker, similarity=similarity, match=match, clean_instance_count=count + ) + + +@pytest.mark.parametrize( + "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] +) +@pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) +@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) +def test_oks_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + oks_score_weighting, + oks_normalization, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity="object_keypoint", + matching="greedy", + oks_score_weighting=oks_score_weighting, + oks_normalization=oks_normalization, ) - t.track([]) - t.final_pass([]) def test_cull_instances(centered_pair_predictions):