From 1bcca03b8c519912fd8314c666ec44e671b3414a Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 23 Jul 2024 15:10:12 +0100 Subject: [PATCH] fix max_tracking --- sleap/nn/tracking.py | 49 ++++++++++++++++++++--------- tests/nn/test_tracker_components.py | 20 ++++++++++-- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index db3184844..1ecc90862 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -392,6 +392,7 @@ def get_ref_instances( def get_candidates( self, track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], + max_tracking: bool, t: int, img: np.ndarray, *args, @@ -405,7 +406,7 @@ def get_candidates( tracks = [] for track, matched_items in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) <= self.max_tracks: tracks.append(track) for matched_item in matched_items: ref_t, ref_img = ( @@ -467,6 +468,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): def get_candidates( self, track_matching_queue_dict: Dict, + max_tracking: bool, *args, **kwargs, ) -> List[InstanceType]: @@ -474,7 +476,7 @@ def get_candidates( candidate_instances = [] tracks = [] for track, matched_instances in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) <= self.max_tracks: tracks.append(track) for ref_instance in matched_instances: if ref_instance.instance_t.n_visible_points >= self.min_points: @@ -600,8 +602,15 @@ def _init_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) + @property + def has_max_tracking(self) -> bool: + return isinstance( + self.candidate_maker, + (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), + ) + def reset_candidates(self): - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict: self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: @@ -612,14 +621,15 @@ def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" unique_tracks = set() - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict.keys(): unique_tracks.add(track) + else: + for match_item in self.track_matching_queue: + for instance in match_item.instances_t: + unique_tracks.add(instance.track) + return list(unique_tracks) @property @@ -648,7 +658,7 @@ def track( # Infer timestep if not provided. if t is None: - if self.max_tracking: + if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: # Default to last timestep + 1 if available. @@ -686,10 +696,10 @@ def track( self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.max_tracking: + if self.has_max_tracking: candidate_instances = self.candidate_maker.get_candidates( track_matching_queue_dict=self.track_matching_queue_dict, - max_tracks=self.max_tracks, + max_tracking=self.max_tracking, t=t, img=img, ) @@ -723,13 +733,13 @@ def track( ) # Add the tracked instances to the dictionary of matched instances. - if self.max_tracking: + if self.has_max_tracking: for tracked_instance in tracked_instances: if tracked_instance.track in self.track_matching_queue_dict: self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) - elif len(self.track_matching_queue_dict) < self.max_tracks: + elif not self.max_tracking or len(self.track_matching_queue_dict) < self.max_tracks: self.track_matching_queue_dict[tracked_instance.track] = deque( maxlen=self.track_window ) @@ -775,7 +785,8 @@ def spawn_for_untracked_instances( # Skip if we've reached the maximum number of tracks. if ( - self.max_tracking + self.has_max_tracking + and self.max_tracking and len(self.track_matching_queue_dict) >= self.max_tracks ): break @@ -846,6 +857,11 @@ def make_tracker_by_name( oks_normalization: str = "all", **kwargs, ) -> BaseTracker: + # Parse max_tracking arguments, only True if max_tracks is not None and > 0 + max_tracking = max_tracking if max_tracks else False + if max_tracking and tracker in ("simple", "flow"): + # Force a candidate maker of 'maxtracks' type + tracker += "maxtracks" if tracker.lower() == "none": candidate_maker = None @@ -944,7 +960,10 @@ def get_by_name_factory_options(cls): option = dict(name="max_tracking", default=False) option["type"] = bool - option["help"] = "If true then the tracker will cap the max number of tracks." + option["help"] = ( + "If true then the tracker will cap the max number of tracks. " + "Falls back to false if `max_tracks` is not defined or 0." + ) options.append(option) option = dict(name="max_tracks", default=None) diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 94a9747a5..5786945fb 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -17,6 +17,8 @@ def tracker_by_name(frames=None, **kwargs): t = Tracker.make_tracker_by_name(**kwargs) + print(kwargs) + print(t.candidate_maker) if frames is None: t.track([]) t.final_pass([]) @@ -38,9 +40,22 @@ def tracker_by_name(frames=None, **kwargs): @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) -def test_tracker_by_name(tracker, similarity, match, count): +def test_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + similarity, + match, + count, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + tracker_by_name( - tracker=tracker, similarity=similarity, match=match, clean_instance_count=count + frames=frames, + tracker=tracker, + similarity=similarity, + match=match, + max_tracks=count, ) @@ -65,6 +80,7 @@ def test_oks_tracker_by_name( matching="greedy", oks_score_weighting=oks_score_weighting, oks_normalization=oks_normalization, + max_tracks=2, )