Skip to content

Commit

Permalink
fixed some more remnant issues of the last commit
Browse files Browse the repository at this point in the history
Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Oct 27, 2024
1 parent 19b9a49 commit ca0057a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
6 changes: 3 additions & 3 deletions dgs/models/combine/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def forward(
r"""The forward call of this module combines an arbitrary number of similarity matrices
using an importance weight :math:`\alpha`.
:math:`\alpha_i` describes how important :math:`s_i` is.
:math:`\alpha_i` describes how important the similarity :math:`s_i` is.
The sum of all :math:`\alpha_i` should be 1 by definition given the last layer is a softmax layer.
:math:`\alpha` is computed using this class' neural network and the given ``alpha_input`` tensor.
:math:`\alpha` is computed using the respective :class:`BaseAlphaModule` and the given :class:`State`.
All tensors should be on the same device and all :math:`s_i` should have the same shape.
All tensors should be on the same device and should have the same shape.
Args:
tensors: A tuple of tensors describing similarities between the detections and tracks.
Expand Down
11 changes: 5 additions & 6 deletions dgs/models/engine/dgs_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,18 @@ def _track_step(self, detections: State, frame_idx: int, name: str, timers: Diff

if len(track_states) == 0 and N > 0:
# No Tracks yet - every detection will be a new track!
# Make sure to compute the embeddings for every detection, to ensure correct behavior of collate later on
# Make sure to compute the embeddings for every detection,
# this is done to ensure correct behavior of the collate function later on.
time_sim_start = time.time()
_ = self.model.forward(ds=detections, target=detections, alpha_inputs=self.get_data(detections))
_ = self.model.forward(ds=detections, target=detections, s=detections)
timers.add(name="similarity", prev_time=time_sim_start)
# There are no tracks yet, therefore every detection is a new state!
# There are no tracks yet, therefore, every detection is a new state!
time_match_start = time.time()
new_states += detections.split()
timers.add(name="match", prev_time=time_match_start)
elif N > 0:
time_sim_start = time.time()
similarity = self.model.forward(
ds=detections, target=collate_states(track_states), alpha_inputs=self.get_data(detections)
)
similarity = self.model.forward(ds=detections, target=collate_states(track_states), s=detections)
timers.add(name="similarity", prev_time=time_sim_start)

# Solve Linear sum Assignment Problem (LAP/LSA).
Expand Down
4 changes: 2 additions & 2 deletions tests/models/combine/test__combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_state_not_set(self):
self.model.forward(*self.dummy_t)
self.assertTrue("The state should be given" in str(e.exception), msg=e.exception)

def test_runtime_error_alpha_input_device(self):
def test_runtime_error_state_device(self):
if t.cuda.is_available():
# tensor based
with self.assertRaises(RuntimeError) as e:
Expand All @@ -305,7 +305,7 @@ def test_runtime_error_alpha_input_device(self):
) # Different devices
self.assertIn("s should be on the same device as tensors", str(e.exception))

def test_value_error_on_nof_alpha_inputs(self):
def test_value_error_on_batch_mismatch(self):
# Mismatch in number of alpha inputs against number of alpha models
with self.assertRaises(ValueError) as e:
self.model.forward(*self.dummy_t, s=self.dummy_state_batched)
Expand Down

0 comments on commit ca0057a

Please sign in to comment.