Skip to content

Commit

Permalink
fix max_tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Jul 23, 2024
1 parent b4374b3 commit a2954a7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
49 changes: 34 additions & 15 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def get_ref_instances(
def get_candidates(
self,
track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
max_tracking: bool,
t: int,
img: np.ndarray,
*args,
Expand All @@ -405,7 +406,7 @@ def get_candidates(
tracks = []

for track, matched_items in track_matching_queue_dict.items():
if len(tracks) <= self.max_tracks:
if not max_tracking or len(tracks) <= self.max_tracks:
tracks.append(track)
for matched_item in matched_items:
ref_t, ref_img = (
Expand Down Expand Up @@ -467,14 +468,15 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker):
def get_candidates(
self,
track_matching_queue_dict: Dict,
max_tracking: bool,
*args,
**kwargs,
) -> List[InstanceType]:
# Create set of matchable candidate instances from each track.
candidate_instances = []
tracks = []
for track, matched_instances in track_matching_queue_dict.items():
if len(tracks) <= self.max_tracks:
if not max_tracking or len(tracks) <= self.max_tracks:
tracks.append(track)
for ref_instance in matched_instances:
if ref_instance.instance_t.n_visible_points >= self.min_points:
Expand Down Expand Up @@ -600,8 +602,15 @@ def _init_matching_queue(self):
"""Factory for instantiating default matching queue with specified size."""
return deque(maxlen=self.track_window)

@property
def has_max_tracking(self) -> bool:
return isinstance(
self.candidate_maker,
(SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker),
)

def reset_candidates(self):
if self.max_tracking:
if self.has_max_tracking:
for track in self.track_matching_queue_dict:
self.track_matching_queue_dict[track] = deque(maxlen=self.track_window)
else:
Expand All @@ -612,14 +621,15 @@ def unique_tracks_in_queue(self) -> List[Track]:
"""Returns the unique tracks in the matching queue."""

unique_tracks = set()
for match_item in self.track_matching_queue:
for instance in match_item.instances_t:
unique_tracks.add(instance.track)

if self.max_tracking:
if self.has_max_tracking:
for track in self.track_matching_queue_dict.keys():
unique_tracks.add(track)

else:
for match_item in self.track_matching_queue:
for instance in match_item.instances_t:
unique_tracks.add(instance.track)

return list(unique_tracks)

@property
Expand Down Expand Up @@ -648,7 +658,7 @@ def track(

# Infer timestep if not provided.
if t is None:
if self.max_tracking:
if self.has_max_tracking:
if len(self.track_matching_queue_dict) > 0:

# Default to last timestep + 1 if available.
Expand Down Expand Up @@ -686,10 +696,10 @@ def track(
self.pre_cull_function(untracked_instances)

# Build a pool of matchable candidate instances.
if self.max_tracking:
if self.has_max_tracking:
candidate_instances = self.candidate_maker.get_candidates(
track_matching_queue_dict=self.track_matching_queue_dict,
max_tracks=self.max_tracks,
max_tracking=self.max_tracking,
t=t,
img=img,
)
Expand Down Expand Up @@ -723,13 +733,13 @@ def track(
)

# Add the tracked instances to the dictionary of matched instances.
if self.max_tracking:
if self.has_max_tracking:
for tracked_instance in tracked_instances:
if tracked_instance.track in self.track_matching_queue_dict:
self.track_matching_queue_dict[tracked_instance.track].append(
MatchedFrameInstance(t, tracked_instance, img)
)
elif len(self.track_matching_queue_dict) < self.max_tracks:
elif not self.max_tracking or len(self.track_matching_queue_dict) < self.max_tracks:
self.track_matching_queue_dict[tracked_instance.track] = deque(
maxlen=self.track_window
)
Expand Down Expand Up @@ -775,7 +785,8 @@ def spawn_for_untracked_instances(

# Skip if we've reached the maximum number of tracks.
if (
self.max_tracking
self.has_max_tracking
and self.max_tracking
and len(self.track_matching_queue_dict) >= self.max_tracks
):
break
Expand Down Expand Up @@ -846,6 +857,11 @@ def make_tracker_by_name(
oks_normalization: str = "all",
**kwargs,
) -> BaseTracker:
# Parse max_tracking arguments, only True if max_tracks is not None and > 0
max_tracking = max_tracking if max_tracks else False
if max_tracking and tracker in ("simple", "flow"):
# Force a candidate maker of 'maxtracks' type
tracker += "maxtracks"

if tracker.lower() == "none":
candidate_maker = None
Expand Down Expand Up @@ -944,7 +960,10 @@ def get_by_name_factory_options(cls):

option = dict(name="max_tracking", default=False)
option["type"] = bool
option["help"] = "If true then the tracker will cap the max number of tracks."
option["help"] = (
"If true then the tracker will cap the max number of tracks. "
"Falls back to false if `max_tracks` is not defined or 0."
)
options.append(option)

option = dict(name="max_tracks", default=None)
Expand Down
20 changes: 18 additions & 2 deletions tests/nn/test_tracker_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

def tracker_by_name(frames=None, **kwargs):
t = Tracker.make_tracker_by_name(**kwargs)
print(kwargs)
print(t.candidate_maker)
if frames is None:
t.track([])
t.final_pass([])
Expand All @@ -38,9 +40,22 @@ def tracker_by_name(frames=None, **kwargs):
@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"])
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
@pytest.mark.parametrize("count", [0, 2])
def test_tracker_by_name(tracker, similarity, match, count):
def test_tracker_by_name(
centered_pair_predictions_sorted,
tracker,
similarity,
match,
count,
):
# This is slow, so limit to 5 time points
frames = centered_pair_predictions_sorted[:5]

tracker_by_name(
tracker=tracker, similarity=similarity, match=match, clean_instance_count=count
frames=frames,
tracker=tracker,
similarity=similarity,
match=match,
max_tracks=count,
)


Expand All @@ -65,6 +80,7 @@ def test_oks_tracker_by_name(
matching="greedy",
oks_score_weighting=oks_score_weighting,
oks_normalization=oks_normalization,
max_tracks=2,
)


Expand Down

0 comments on commit a2954a7

Please sign in to comment.