Skip to content

Commit

Permalink
add custom near-k-accuracy metric
Browse files Browse the repository at this point in the history
- renamed torch to t
- renamed tv_tensors to tvte

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jul 18, 2024
1 parent 08cc6b1 commit 9b6ed09
Show file tree
Hide file tree
Showing 45 changed files with 610 additions and 549 deletions.
3 changes: 3 additions & 0 deletions dgs/default_values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ engine:
sim:
metric_kwargs: {}
image_key: "image_crop"
acc_k_train: [ 1, 5, 10, 20, 50 ]
acc_k_test: [ 1, 5, 10, 20, 50 ]


similarity:
torchreid:
Expand Down
28 changes: 14 additions & 14 deletions dgs/models/combine/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from abc import abstractmethod

import torch
import torch as t
from torch import nn

from dgs.models.module import BaseModule
Expand Down Expand Up @@ -40,7 +40,7 @@ def __call__(self, *args, **kwargs) -> any: # pragma: no cover
return self.forward(*args, **kwargs)

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
def forward(self, *args, **kwargs) -> t.Tensor:
raise NotImplementedError


Expand All @@ -59,7 +59,7 @@ class DynamicallyGatedSimilarities(CombineSimilaritiesModule):
It is possible that :math:`S_1` and :math:`S_2` have different shapes in at least one dimension.
"""

def forward(self, *tensors, alpha: torch.Tensor = torch.tensor([0.5, 0.5]), **_kwargs) -> torch.Tensor:
def forward(self, *tensors, alpha: t.Tensor = t.tensor([0.5, 0.5]), **_kwargs) -> t.Tensor:
"""The forward call of this module combines two weight matrices given a third importance weight :math:`\alpha`.
:math:`\alpha` describes how important s1 is, while :math:`(1- \alpha)` does the same for s2.
Expand All @@ -79,11 +79,11 @@ def forward(self, *tensors, alpha: torch.Tensor = torch.tensor([0.5, 0.5]), **_k
"""
if len(tensors) != 2:
raise ValueError(f"There should be exactly two matrices in the tensors argument, got {len(tensors)}")
if any(not isinstance(t, torch.Tensor) for t in tensors):
if any(not isinstance(tensor, t.Tensor) for tensor in tensors):
raise TypeError("All matrices should be torch (float) tensors.")
s1, s2 = tensors
if (a_max := torch.max(alpha)) > 1.0 or torch.min(alpha) < 0.0:
raise ValueError(f"alpha should lie in the range [0,1], but got [{torch.min(alpha)}, {a_max}]")
if (a_max := t.max(alpha)) > 1.0 or t.min(alpha) < 0.0:
raise ValueError(f"alpha should lie in the range [0,1], but got [{t.min(alpha)}, {a_max}]")

if alpha.ndim > 2:
alpha.squeeze_()
Expand All @@ -104,7 +104,7 @@ def forward(self, *tensors, alpha: torch.Tensor = torch.tensor([0.5, 0.5]), **_k
f"the first dimension has to equal the first dimension of s1 and s2 but got {alpha.shape}."
)

return alpha * s1 + (torch.ones_like(alpha) - alpha) * s2
return alpha * s1 + (t.ones_like(alpha) - alpha) * s2


@configure_torch_module
Expand All @@ -126,14 +126,14 @@ def __init__(self, config: Config, path: NodePath):

self.validate_params(static_alpha_validation)

alpha = torch.tensor(self.params["alpha"], dtype=self.precision).reshape(-1)
alpha = t.tensor(self.params["alpha"], dtype=self.precision).reshape(-1)
self.register_buffer("alpha_const", alpha)
self.len_alpha: int = len(alpha)

if not torch.allclose(a_sum := torch.sum(torch.abs(alpha)), torch.tensor(1.0)): # pragma: no cover # redundant
if not t.allclose(a_sum := t.sum(t.abs(alpha)), t.tensor(1.0)): # pragma: no cover # redundant
raise ValueError(f"alpha should sum to 1.0, but got {a_sum:.8f}")

def forward(self, *tensors, **_kwargs) -> torch.Tensor:
def forward(self, *tensors, **_kwargs) -> t.Tensor:
"""Given alpha from the configuration file and args of the same length,
multiply each alpha with each matrix and compute the sum.
Expand All @@ -154,24 +154,24 @@ def forward(self, *tensors, **_kwargs) -> torch.Tensor:
f"Unknown type for tensors, expected tuple of torch.Tensor but got {type(tensors)}"
)

if any(not isinstance(t, torch.Tensor) for t in tensors):
if any(not isinstance(tensor, t.Tensor) for tensor in tensors):
raise TypeError("All the values in args should be tensors.")

if len(tensors) > 1 and any(t.shape != tensors[0].shape for t in tensors):
if len(tensors) > 1 and any(tensor.shape != tensors[0].shape for tensor in tensors):
raise ValueError("The shapes of every tensor should match.")

if len(tensors) == 1 and self.len_alpha != 1:
# given a single already stacked tensor or a single valued alpha
tensors = tensors[0]
else:
tensors = torch.stack(tensors)
tensors = t.stack(tensors)

if self.len_alpha != 1 and len(tensors) != self.len_alpha:
raise ValueError(
f"The length of the tensors {len(tensors)} should equal the length of alpha {self.len_alpha}"
)

return torch.tensordot(self.alpha_const, tensors.float(), dims=1)
return t.tensordot(self.alpha_const, tensors.float(), dims=1)

def terminate(self) -> None: # pragma: no cover
del self.alpha, self.alpha_const, self.len_alpha
6 changes: 3 additions & 3 deletions dgs/models/dataset/MOT.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from glob import glob

import torch
import torch as t
import torchvision.tv_tensors as tvte

from dgs.models.dataset.dataset import ImageDataset
Expand Down Expand Up @@ -178,15 +178,15 @@ def load_MOT_file(
continue

bboxes = tvte.BoundingBoxes(
[anno[2:6] for anno in annos], format="XYWH", canvas_size=img_shape, dtype=torch.float32, device=device
[anno[2:6] for anno in annos], format="XYWH", canvas_size=img_shape, dtype=t.float32, device=device
)
crop_paths = tuple(os.path.join(base_crop_path, f"{frame_id}_{anno[1]}{crop_info['imExt']}") for anno in annos)
states.append(
State(
bbox=bboxes,
filepath=file_paths,
crop_path=crop_paths,
person_id=torch.tensor([anno[1] for anno in annos], device=device, dtype=torch.long),
person_id=t.tensor([anno[1] for anno in annos], device=device, dtype=t.long),
frame_id=[frame_id] * len(annos),
validate=False,
)
Expand Down
22 changes: 10 additions & 12 deletions dgs/models/dataset/alphapose.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
"""

import imagesize
import torch
from torchvision import tv_tensors
import torch as t
from torchvision import tv_tensors as tvte

from dgs.models.dataset.dataset import BBoxDataset
from dgs.utils.files import read_json
Expand Down Expand Up @@ -70,14 +70,12 @@ def __init__(self, config: Config, path: NodePath) -> None:
def arbitrary_to_ds(self, a, idx: int) -> State:
"""Here `a` is one dict of the AP-JSON containing image_id, category_id, keypoints, score, box, and idx."""
keypoints, visibility = (
torch.tensor(a["keypoints"], dtype=torch.float32, device=self.device)
.reshape((1, -1, 3))
.split([2, 1], dim=-1)
t.tensor(a["keypoints"], dtype=t.float32, device=self.device).reshape((1, -1, 3)).split([2, 1], dim=-1)
)

return State(
filepath=a["full_img_path"],
bbox=tv_tensors.BoundingBoxes(a["bboxes"], format="XYWH", canvas_size=self.canvas_size),
bbox=tvte.BoundingBoxes(a["bboxes"], format="XYWH", canvas_size=self.canvas_size),
keypoints=keypoints,
person_id=a["idx"],
# additional values which are not required
Expand All @@ -87,20 +85,20 @@ def arbitrary_to_ds(self, a, idx: int) -> State:
)

def __getitems__(self, indices: list[int]) -> State:
def stack_key(key: str) -> torch.Tensor:
return torch.stack([torch.tensor(self.data[i][key], device=self.device) for i in indices])
def stack_key(key: str) -> t.Tensor:
return t.stack([t.tensor(self.data[i][key], device=self.device) for i in indices])

keypoints, visibility = (
torch.tensor(
torch.stack([torch.tensor(self.data[i]["keypoints"]).reshape((-1, 3)) for i in indices]),
t.tensor(
t.stack([t.tensor(self.data[i]["keypoints"]).reshape((-1, 3)) for i in indices]),
)
.to(device=self.device, dtype=torch.float32)
.to(device=self.device, dtype=t.float32)
.split([2, 1], dim=-1)
)
ds = State(
validate=False,
filepath=tuple(self.data[i]["full_img_path"] for i in indices),
bbox=tv_tensors.BoundingBoxes(stack_key("bboxes"), format="XYWH", canvas_size=self.canvas_size),
bbox=tvte.BoundingBoxes(stack_key("bboxes"), format="XYWH", canvas_size=self.canvas_size),
keypoints=keypoints,
person_id=stack_key("idx").int(),
# additional values which are not required
Expand Down
28 changes: 13 additions & 15 deletions dgs/models/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from abc import ABC, abstractmethod
from typing import Union

import torch
import torch as t
import torchvision
import torchvision.transforms.v2 as tvt
from torch.utils.data import Dataset as TorchDataset
from torchvision import tv_tensors
from torchvision import tv_tensors as tvte
from torchvision.io import VideoReader
from torchvision.transforms.v2.functional import to_dtype

Expand Down Expand Up @@ -220,8 +220,8 @@ def get_image_crops(self, ds: State) -> None:

# State has length zero and image and local key points are just placeholders
if len(ds) == 0:
ds.image_crop = tv_tensors.Image(torch.empty((0, 3, 1, 1)), device=ds.device)
ds.keypoints_local = torch.empty(
ds.image_crop = tvte.Image(t.empty((0, 3, 1, 1)), device=ds.device)
ds.keypoints_local = t.empty(
(0, ds.J if "keypoints" in ds else 1, ds.joint_dim if "keypoints" in ds else 2), device=ds.device
)
return
Expand All @@ -239,7 +239,7 @@ def get_image_crops(self, ds: State) -> None:
for i in range(len(ds))
)
ds.load_image_crop()
ds.keypoints_local = torch.stack([torch.load(fp.replace(".jpg", ".pt")) for fp in ds.crop_path]).to(
ds.keypoints_local = t.stack([t.load(fp.replace(".jpg", ".pt")) for fp in ds.crop_path]).to(
device=self.device
)
return
Expand All @@ -255,30 +255,28 @@ def get_image_crops(self, ds: State) -> None:
force_reshape=True,
mode=self.params.get("image_mode", DEF_VAL["images"]["image_mode"]),
output_size=self.params.get("image_size", DEF_VAL["images"]["image_size"]),
dtype=torch.uint8,
dtype=t.uint8,
device=ds.device,
)
else:
ds.image = load_image(ds.filepath, device=ds.device, dtype=torch.uint8)
ds.image = load_image(ds.filepath, device=ds.device, dtype=t.uint8)

structured_input = {
"images": ds.image,
"box": ds.bbox,
"keypoints": (
ds.keypoints
if "keypoints" in ds
else torch.zeros((ds.bbox.size(0), 1, 2), device=self.device, dtype=torch.float32)
else t.zeros((ds.bbox.size(0), 1, 2), device=self.device, dtype=t.float32)
),
"output_size": self.params.get("crop_size", DEF_VAL["images"]["crop_size"]),
"mode": self.params.get("crop_mode", DEF_VAL["images"]["crop_mode"]),
}
new_state = self.transform_crop_resize()(structured_input)

ds.image_crop = tv_tensors.Image(
to_dtype(new_state["image"].to(device=self.device), dtype=torch.uint8, scale=True)
)
ds.image_crop = tvte.Image(to_dtype(new_state["image"].to(device=self.device), dtype=t.uint8, scale=True))
if "keypoints" in ds:
ds.keypoints_local = new_state["keypoints"].to(dtype=torch.float32, device=self.device)
ds.keypoints_local = new_state["keypoints"].to(dtype=t.float32, device=self.device)
assert "joint_weights" in ds.data, "visibility should be given"

@staticmethod
Expand All @@ -304,7 +302,7 @@ def transform_resize_image() -> tvt.Compose:
CustomResize(),
tvt.ClampBoundingBoxes(), # make sure to keep bboxes in their canvas_size
# tvt.SanitizeBoundingBoxes(), # clean up bboxes if available
tvt.ToDtype({tv_tensors.Image: torch.float32, "others": None}, scale=True),
tvt.ToDtype({tvte.Image: t.float32, "others": None}, scale=True),
]
)

Expand Down Expand Up @@ -339,7 +337,7 @@ def transform_crop_resize() -> tvt.Compose:
"""
return tvt.Compose(
[
tvt.ConvertBoundingBoxFormat(format=tv_tensors.BoundingBoxFormat.XYWH),
tvt.ConvertBoundingBoxFormat(format=tvte.BoundingBoxFormat.XYWH),
tvt.ClampBoundingBoxes(), # make sure the bboxes are clamped to start with
# tvt.SanitizeBoundingBoxes(), # clean up bboxes
CustomCropResize(), # crop the image at the four corners specified in bboxes
Expand Down Expand Up @@ -505,7 +503,7 @@ def __getitem__(self, idx: int) -> Union[State, list[State]]:
"""
# don't call .to(self.device), the DS should be created on the correct device!
self.data.seek(time_s=float(idx) / self.fps)
frame = tv_tensors.Image(next(self.data)["data"], device=self.device)
frame = tvte.Image(next(self.data)["data"], device=self.device)
s: State = self.arbitrary_to_ds(a=frame, idx=idx)
return s

Expand Down
18 changes: 9 additions & 9 deletions dgs/models/dataset/torchreid_pose_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from typing import Callable, Type, Union

import torch
import torch as t
import torchvision.transforms.v2 as tvt
from torch.utils.data import DataLoader as TorchDataLoader, Dataset as TorchDataset

Expand Down Expand Up @@ -43,7 +43,7 @@ class TorchreidPoseDataset(TorchreidDataset):

def __getitem__(self, index: int) -> dict[str, any]:
pose_path, pid, camid, dsetid = self.data[index]
pose = torch.load(pose_path)
pose = t.load(pose_path)
return {"img": pose, "pid": pid, "camid": camid, "dsetid": dsetid}

def show_summary(self) -> None:
Expand Down Expand Up @@ -222,7 +222,7 @@ def load_test(self) -> (dict[str, dict[str, any]], dict[str, dict[str, any]]):
root=self.root, mode="gallery", transform=self.transform_te, **self.params
)
# build gallery loader
test_loader[dataset]["gallery"] = torch.utils.data.DataLoader(
test_loader[dataset]["gallery"] = t.utils.data.DataLoader(
gallery_set,
batch_size=self.params["batch_size_test"],
shuffle=False,
Expand Down Expand Up @@ -292,13 +292,13 @@ def build_transforms(
ValueError: If ``transforms`` is an invalid object or contains invalid transform names.
"""

def random_move(x: torch.Tensor) -> torch.Tensor:
def random_move(x: t.Tensor) -> t.Tensor:
"""Move a torch tensor by a little bit in random directions using a normal distribution ~N(0,1)."""
return x + torch.randn_like(x)
return x + t.randn_like(x)

def random_resize(x: torch.Tensor) -> torch.Tensor:
def random_resize(x: t.Tensor) -> t.Tensor:
"""Resize the torch tensor by a little bit, up and down. Ranges from 0.95 to 1.05."""
return x * torch.tensor([1.0]).uniform_(0.95, 1.05)
return x * t.tensor([1.0]).uniform_(0.95, 1.05)

if transforms is None:
transforms = []
Expand All @@ -308,7 +308,7 @@ def random_resize(x: torch.Tensor) -> torch.Tensor:
if not isinstance(transforms, list):
raise ValueError(f"Transforms must be a list of strings, but found to be {type(transforms)}")

train_transforms = [tvt.ToTensor(), tvt.ToDtype(dtype=torch.float32)]
train_transforms = [tvt.ToTensor(), tvt.ToDtype(dtype=t.float32)]

for transform in transforms:
if transform == "random_flip":
Expand Down Expand Up @@ -340,6 +340,6 @@ def random_resize(x: torch.Tensor) -> torch.Tensor:
else:
raise ValueError(f"Unknown transform: {transform}")

test_transforms = [tvt.ToTensor(), tvt.ToDtype(dtype=torch.float32)]
test_transforms = [tvt.ToTensor(), tvt.ToDtype(dtype=t.float32)]

return tvt.Compose(train_transforms), tvt.Compose(test_transforms)
10 changes: 5 additions & 5 deletions dgs/models/dgs/dgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base class for a torch module that contains the heart of the dynamically gated similarity tracker.
"""

import torch
import torch as t
from torch import nn

from dgs.models.combine import CombineSimilaritiesModule, get_combine_module
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, config: Config, path: NodePath):
def __call__(self, *args, **kwargs) -> any: # pragma: no cover
return self.forward(*args, **kwargs)

def forward(self, ds: State, target: State) -> torch.Tensor:
def forward(self, ds: State, target: State) -> t.Tensor:
"""Given a State containing the current detections and a target, compute the similarity between every pair.
Returns:
Expand All @@ -102,12 +102,12 @@ def forward(self, ds: State, target: State) -> torch.Tensor:
results = [self.similarity_softmax(m(ds, target)) for m in self.sim_mods]

# combine and possibly compute softmax []
combined: torch.Tensor = self.combined_softmax(self.combine(*results))
combined: t.Tensor = self.combined_softmax(self.combine(*results))

# add a number of columns for the empty / new tracks equal to the length of the input
# every input should be allowed to get assigned to a new track
new_track = torch.zeros((nof_det, nof_det), dtype=self.precision, device=self.device)
return torch.cat([combined, new_track], dim=-1)
new_track = t.zeros((nof_det, nof_det), dtype=self.precision, device=self.device)
return t.cat([combined, new_track], dim=-1)

def terminate(self) -> None:
"""Terminate the DGS module and delete the torch modules."""
Expand Down
Loading

0 comments on commit 9b6ed09

Please sign in to comment.