Skip to content

Commit

Permalink
Removed usages of TrackStatistics
Browse files Browse the repository at this point in the history
Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed May 12, 2024
1 parent ff625ff commit c1a2fb6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 44 deletions.
9 changes: 4 additions & 5 deletions dgs/models/engine/dgs_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dgs.utils.config import DEF_VAL
from dgs.utils.state import collate_states, State
from dgs.utils.torchtools import close_all_layers
from dgs.utils.track import Tracks, TrackStatistics
from dgs.utils.track import Tracks
from dgs.utils.types import Config, Validations
from dgs.utils.utils import torch_to_numpy

Expand Down Expand Up @@ -111,7 +111,7 @@ def get_data(self, ds: State) -> any:
def get_target(self, ds: State) -> any:
return ds["class_id"].long()

def _track_step(self, detections: State) -> tuple[TrackStatistics, dict[str, float]]:
def _track_step(self, detections: State) -> dict[str, float]:
"""Run one step of tracking."""
N: int = len(detections)
T: int = len(self.tracks)
Expand Down Expand Up @@ -161,15 +161,14 @@ def _track_step(self, detections: State) -> tuple[TrackStatistics, dict[str, flo

# update tracks
time_track_update_start = time.time()
ts: TrackStatistics
_, ts = self.tracks.add(tracks=updated_tracks, new=new_states)
self.tracks.add(tracks=updated_tracks, new=new_states)

batch_times["track"] = time.time() - time_track_update_start
batch_times["batch"] = time.time() - time_batch_start
if N > 0:
batch_times["indiv"] = batch_times["batch"] / N

return ts, batch_times
return batch_times

def test(self) -> dict[str, any]:
"""Test the DGS Tracker"""
Expand Down
24 changes: 7 additions & 17 deletions dgs/utils/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,14 @@ def is_removed(self, tid: TrackID) -> bool:
"""Return whether the given Track-ID has been removed."""
return tid not in self.data and tid in self.removed

def add(self, tracks: dict[TrackID, State], new: list[State]) -> tuple[list[TrackID], TrackStatistics]:
def add(self, tracks: dict[TrackID, State], new: list[State]) -> list[TrackID]:
"""Given tracks with existing Track-IDs update those and create new Tracks for States without Track-IDs.
Additionally,
mark Track-IDs that are not in either of the inputs as unseen and therefore as inactive for one more step.
Returns:
The Track-IDs of the new_tracks in the same order as provided.
"""
stats = TrackStatistics()

inactive_ids = self.ids - set(int(k) for k in tracks.keys())

Expand All @@ -510,18 +509,17 @@ def add(self, tracks: dict[TrackID, State], new: list[State]) -> tuple[list[Trac

# add the new state to the new tracks
for tid, new_state in zip(new_tids, new):
self._update_track(tid=tid, add_state=new_state, stats=stats)
stats.new.append(tid)
self._update_track(tid=tid, add_state=new_state)

# add state to Track and remove track from inactive if present
for tid, new_state in tracks.items():
self._update_track(tid=tid, add_state=new_state, stats=stats)
self._update_track(tid=tid, add_state=new_state)

self._handle_inactive(tids=inactive_ids, stats=stats)
self._handle_inactive(tids=inactive_ids)

# step to the next frame
self._next_frame()
return new_tids, stats
return new_tids

def _next_frame(self) -> None:
self._curr_frame += 1
Expand Down Expand Up @@ -574,7 +572,7 @@ def reactivate_track(self, tid: TrackID) -> None:

# todo should the states of the track be removed / cleared ?

def _update_track(self, tid: TrackID, add_state: State, stats: TrackStatistics) -> None:
def _update_track(self, tid: TrackID, add_state: State) -> None:
"""Use the track-ID to update a track given an additional :class:`State` for the :class:`Track`.
Will additionally remove the tid from the inactive Tracks.
Expand All @@ -586,20 +584,16 @@ def _update_track(self, tid: TrackID, add_state: State, stats: TrackStatistics)
raise KeyError(f"Track-ID {tid} neither present in the current or previously removed Tracks.")
# reactivate previously removed track
self.reactivate_track(tid)
stats.reactivated.append(tid)
elif tid in self.inactive:
# update inactive
self.inactive.pop(tid)
stats.found.append(tid)
else:
stats.still_active.append(tid)

# append state to track
self.data[tid].append(state=add_state)
# add track id to state
self.data[tid][-1]["pred_tid"] = torch.tensor(tid, dtype=torch.long, device=add_state.device).flatten()

def _handle_inactive(self, tids: set[TrackID], stats: TrackStatistics) -> None:
def _handle_inactive(self, tids: set[TrackID]) -> None:
"""Given the Track-IDs of the Tracks that haven't been seen this step, update the inactivity tracker.
Create the counter for inactive Track-IDs and update existing counters.
Additionally, remove tracks that have been inactive for too long.
Expand All @@ -610,13 +604,9 @@ def _handle_inactive(self, tids: set[TrackID], stats: TrackStatistics) -> None:

if self.inactive[tid] >= self.inactivity_threshold:
self.remove_tid(tid)
stats.removed.append(tid)
else:
stats.still_inactive.append(tid)
else:
self.inactive[tid] = 1
self.data[tid].set_inactive()
stats.lost.append(tid)

def _get_next_id(self) -> TrackID:
"""Get the next free track-ID."""
Expand Down
38 changes: 16 additions & 22 deletions tests/utils/track/test__tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dgs.utils.config import DEF_VAL
from dgs.utils.state import State
from dgs.utils.track import Track, Tracks, TrackStatistics, TrackStatus
from dgs.utils.track import Track, Tracks, TrackStatus
from helper import test_multiple_devices


Expand Down Expand Up @@ -41,14 +41,14 @@ def _track_w_params(
EMPTY_TRACKS: Tracks = Tracks(N=MAX_LENGTH, thresh=THRESH)

ONE_TRACKS: Tracks = Tracks(N=MAX_LENGTH, thresh=THRESH)
OT_O_ID = ONE_TRACKS.add(tracks={}, new=[DUMMY_STATE.copy()])[0][0]
OT_O_ID = ONE_TRACKS.add(tracks={}, new=[DUMMY_STATE.copy()])[0]

MULTI_TRACKS: Tracks = Tracks(N=MAX_LENGTH, thresh=THRESH)
MT_ACT_ID = MULTI_TRACKS.add(tracks={}, new=[DUMMY_STATE.copy()])[0][0]
MT_DEL_ID = MULTI_TRACKS.add(tracks={MT_ACT_ID: DUMMY_STATE.copy()}, new=[DUMMY_STATE])[0][0]
MT_ACT_ID = MULTI_TRACKS.add(tracks={}, new=[DUMMY_STATE.copy()])[0]
MT_DEL_ID = MULTI_TRACKS.add(tracks={MT_ACT_ID: DUMMY_STATE.copy()}, new=[DUMMY_STATE])[0]
for _ in range(MAX_LENGTH - 1):
MULTI_TRACKS[MT_DEL_ID].append(DUMMY_STATE)
MT_INA_IDS, _ = MULTI_TRACKS.add(tracks={MT_ACT_ID: DUMMY_STATE.copy()}, new=DUMMY_STATES.copy())
MT_INA_IDS = MULTI_TRACKS.add(tracks={MT_ACT_ID: DUMMY_STATE.copy()}, new=DUMMY_STATES.copy())
MULTI_TRACKS.add(tracks={MT_ACT_ID: DUMMY_STATE.copy()}, new=[])
# now MT_F_IDS is removed and MT_M_IDS are all inactive !

Expand Down Expand Up @@ -187,7 +187,7 @@ def test_remove_tid(self):
self.assertEqual(t2.ids_removed, set([MT_DEL_ID] + MT_INA_IDS))

ts_own = Tracks(N=MAX_LENGTH, thresh=THRESH)
own_tid = ts_own.add({}, [DUMMY_STATE.copy()])[0][0]
own_tid = ts_own.add({}, [DUMMY_STATE.copy()])[0]
t_own = ts_own[own_tid]
self.assertEqual(t_own.status, TrackStatus.Active)
ts_own.remove_tid(own_tid)
Expand Down Expand Up @@ -226,21 +226,18 @@ def test_to(self, device):

def test_handle_inactive(self):
t = Tracks(N=MAX_LENGTH, thresh=2)
ts = TrackStatistics()
tid = t.add({}, new=[DUMMY_STATE.copy()])[0][0]
tid = t.add({}, new=[DUMMY_STATE.copy()])[0]
self.assertEqual(t.ids_active, {tid})
self.assertEqual(t.ids_inactive, set())

t._handle_inactive({tid}, stats=ts)
t._handle_inactive({tid})
self.assertEqual(t.ids_active, set())
self.assertEqual(t.ids_inactive, {tid})
self.assertEqual(t.inactive[tid], 1)
self.assertEqual(ts.nof_lost, 1)

t._handle_inactive({tid}, stats=ts)
t._handle_inactive({tid})
self.assertEqual(t.ids_active, set())
self.assertEqual(t.ids_inactive, set())
self.assertEqual(ts.nof_removed, 1)

def test_add_empty_tracks(self):
t0 = EMPTY_TRACKS.copy()
Expand All @@ -260,14 +257,13 @@ def test_add_empty_tracks(self):

def test_update_track(self):
tracks = ONE_TRACKS.copy()
ts = TrackStatistics()

self.assertEqual(len(tracks[OT_O_ID]), 1)
tracks._update_track(OT_O_ID, DUMMY_STATE, stats=ts)
tracks._update_track(OT_O_ID, DUMMY_STATE)
self.assertEqual(len(tracks[OT_O_ID]), 2)

with self.assertRaises(KeyError) as e:
tracks._update_track(100, DUMMY_STATE, stats=TrackStatistics())
tracks._update_track(100, DUMMY_STATE)
self.assertTrue(
"Track-ID 100 neither present in the current or previously removed Tracks" in str(e.exception),
msg=e.exception,
Expand All @@ -283,22 +279,20 @@ def test_update_track(self):

# update inactive
for tid in MT_INA_IDS:
multi_tracks._update_track(tid, DUMMY_STATE, stats=ts)
multi_tracks._update_track(tid, DUMMY_STATE)
self.assertEqual(len(multi_tracks[tid]), 2)
self.assertEqual(multi_tracks[tid][-1]["pred_tid"], tid)
self.assertEqual(multi_tracks.ids_inactive, set())
self.assertEqual(multi_tracks.ids_removed, {MT_DEL_ID})
self.assertEqual(multi_tracks.ids_active, set([MT_ACT_ID] + MT_INA_IDS))
self.assertEqual(ts.nof_found, len(MT_INA_IDS))

# update removed
multi_tracks._update_track(MT_DEL_ID, DUMMY_STATE, stats=ts)
multi_tracks._update_track(MT_DEL_ID, DUMMY_STATE)
self.assertEqual(multi_tracks.ids_inactive, set())
self.assertEqual(multi_tracks.ids_removed, set())
self.assertEqual(multi_tracks.ids_active, set([MT_ACT_ID, MT_DEL_ID] + MT_INA_IDS))
self.assertEqual(len(multi_tracks[MT_DEL_ID]), MAX_LENGTH)
self.assertEqual(multi_tracks.nof_removed, 0)
self.assertEqual(ts.nof_reactivated, 1)
self.assertEqual(multi_tracks[MT_DEL_ID][-1]["pred_tid"], MT_DEL_ID)

def test_get_item(self):
Expand All @@ -323,13 +317,13 @@ def test_reset_deleted(self):
def test_add(self):
t = Tracks(N=MAX_LENGTH, thresh=2)

first_tid = t.add(tracks={}, new=[DUMMY_STATE.copy()])[0][0]
first_tid = t.add(tracks={}, new=[DUMMY_STATE.copy()])[0]
self.assertEqual(len(t), 1)
self.assertEqual(t.ids_active, {first_tid})
self.assertEqual(t.ids_inactive, set())
self.assertTrue(t[first_tid].id == first_tid)

second_tid = t.add(tracks={0: DUMMY_STATE.copy()}, new=[DUMMY_STATE.copy()])[0][0]
second_tid = t.add(tracks={0: DUMMY_STATE.copy()}, new=[DUMMY_STATE.copy()])[0]
self.assertEqual(len(t), 2)
self.assertEqual(t.ids_active, {first_tid, second_tid})
self.assertEqual(t.ids_inactive, set())
Expand Down Expand Up @@ -360,7 +354,7 @@ def test_add(self):

def test_add_keep_inactive(self):
t = Tracks(N=MAX_LENGTH, thresh=5)
tid = t.add(tracks={}, new=[DUMMY_STATE.copy()])[0][0]
tid = t.add(tracks={}, new=[DUMMY_STATE.copy()])[0]

t.add(tracks={}, new=[])
self.assertEqual(t.ids_inactive, {tid})
Expand Down

0 comments on commit c1a2fb6

Please sign in to comment.