Skip to content

Commit

Permalink
modify DynamicAlphaCombine to use the new alpha modules
Browse files Browse the repository at this point in the history
- modified tests accordingly

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Oct 27, 2024
1 parent 71576f1 commit 19b9a49
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 131 deletions.
49 changes: 22 additions & 27 deletions dgs/models/combine/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
Implementation of modules that use dynamic weights to combine multiple similarities.
"""

from typing import Union

import torch as t
from torch import nn

from dgs.models.combine.combine import CombineSimilaritiesModule
from dgs.models.loader import module_loader
from dgs.utils.state import State
from dgs.utils.torchtools import configure_torch_module
from dgs.utils.types import Config, NodePath, Validations

Expand Down Expand Up @@ -60,7 +59,7 @@ def __init__(self, config: Config, path: NodePath) -> None:
def forward(
self,
*tensors: t.Tensor,
alpha_inputs: Union[t.Tensor, list[t.Tensor], tuple[t.Tensor, ...]] = None,
s: State = None,
**_kwargs,
) -> t.Tensor:
r"""The forward call of this module combines an arbitrary number of similarity matrices
Expand All @@ -73,17 +72,14 @@ def forward(
All tensors should be on the same device and all :math:`s_i` should have the same shape.
Args:
tensors: A tuple of (Float-)Tensors.
tensors: A tuple of tensors describing similarities between the detections and tracks.
All ``S`` similarity matrices of this iterable should have values in range ``[0,1]``,
be of the same shape ``[D x T]``, and be on the same device.
If ``tensors`` is a single tensor, it should have the shape ``[S x D x T]``.
``S`` can be any number of similarity matrices greater than 0,
even though only values greater than 1 really make sense.
alpha_inputs: An iterable of tensors or a single tensor that are all on the same device as ``tensors``.
If ``alpha_inputs`` is a single tensor, it should have the shape ``[S x D x sim_size x ...]``.
But because the inputs for different similarity matrices can have different shapes,
the most common use case is to have a list of ``S`` tensors.
Where every tensor has values in range ``[0, 1]`` and is of shape ``[D x sim_size x ...]``.
s: A :class:`State` containing the batched input data for the alpha models.
The state should be on the same device as ``tensors``.
Returns:
torch.Tensor: The weighted similarity matrix as tensor of shape ``[D x T]``.
Expand All @@ -108,32 +104,31 @@ def forward(
raise ValueError(
f"There should be as many alpha models {len(self.alpha_models)} as tensors {len(tensors)}."
)
# force tensors to be shape [S x D x T]
if len(tensors) == 1:
tensors = tensors[0]
else:
tensors = t.stack(tensors, dim=-3)

tensors = t.stack(tensors, dim=-3) # [S x D x T]

if isinstance(tensors, t.Tensor) and tensors.ndim != 3:
raise ValueError(f"Expected a 3D tensor, but got a tensor with shape {tensors.shape}")
if tensors.ndim != 3:
raise ValueError(f"Expected a 3D tensor [S x D x T], but got a tensor with shape {tensors.shape}")

# validate alpha inputs
if alpha_inputs is None:
raise ValueError("Alpha inputs should be given.")
if not isinstance(alpha_inputs, t.Tensor) and not isinstance(alpha_inputs, (tuple, list)):
raise TypeError("alpha_inputs should be a tensor or an iterable of (float) tensors.")
if any(not isinstance(ai, t.Tensor) for ai in alpha_inputs):
raise TypeError("All alpha inputs should be tensors.")
if alpha_inputs[0].device != tensors.device or any(ai.device != alpha_inputs[0].device for ai in alpha_inputs):
raise RuntimeError("All alpha inputs should be on the same device.")
if len(self.alpha_models) != len(alpha_inputs):
raise ValueError(
f"There should be as many alpha models {len(self.alpha_models)} as alpha inputs {len(alpha_inputs)}."
)
if s is None:
raise ValueError("The state should be given.")
if not isinstance(s, State):
raise TypeError(f"s should be a State. Got: {s}")
if s.device != tensors.device:
raise RuntimeError("s should be on the same device as tensors.")
if tensors.size(-2) != s.B:
raise ValueError(f"The states batch size ({s.B}) should equal D of tensors ({tensors.size(-2)}).")

# [D x S] with softmax over S dimension
alpha = nn.functional.softmax(
t.cat([self.alpha_models[i](a_i) for i, a_i in enumerate(alpha_inputs)], dim=1), dim=-1
t.cat([self.alpha_models[i].forward(s) for i in range(len(tensors))], dim=1), dim=-1
)

# [S x D ( x 1)] hadamard [S x D x T] -> [S x D x T] -> sum over all S [D x T]
# [S x D x 1] hadamard [S x D x T] -> [S x D x T] -> sum over all S => [D x T]
s = t.mul(alpha.T.unsqueeze(-1), tensors).sum(dim=0)

return s
Expand Down
2 changes: 2 additions & 0 deletions dgs/utils/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def fc_linear(
Returns:
A sequential model containing ``N-1`` fully-connected layers.
"""
# pylint: disable=too-many-branches

L = len(hidden_layers)
# validate bias
if isinstance(bias, bool):
Expand Down
Loading

0 comments on commit 19b9a49

Please sign in to comment.