Skip to content

Commit

Permalink
Work on engine.py and VisualEmbeddingEngine
Browse files Browse the repository at this point in the history
different engines for different purposes
used own cmc function for cmc computation
used multiclass_auprc from torcheval for map
cleaned-up params of the Engine Class
in posetrack21.py renamed "path" to "json_path"
in dataset.py added params and set default values.
saving and printing results

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jan 24, 2024
1 parent eaadd85 commit 58c873f
Show file tree
Hide file tree
Showing 19 changed files with 537 additions and 188 deletions.
27 changes: 22 additions & 5 deletions configs/train_pose.yaml → configs/train_pose_embedding.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: "Train Pose-Embeddings"
description: "Train the embeddings for a pose-based embedding generator using the dgs module."
description: "Train the embeddings for a pose-based embedding generator
using the embedding generators of the dgs module."

device: "cuda"
print_prio: "normal"
Expand All @@ -8,16 +9,32 @@ is_training: on
train:
batch_size: 32
epochs: 1
loss: "dummy"
metric: "dummy"
loss: "TorchreidCrossEntropyLoss"
loss_kwargs:
num_classes: 5474
optimizer: "Adam"

test:
batch_size: 128
metric: "CosineDistance"

# Modules
# #### #
# DATA #
# #### #

dataset:
dataset_train:
module_name: "PoseTrack21"
dataset_path: "./data/PoseTrack21/"
json_path: "./posetrack_data/train/000001_bonn_train.json"

dataset_test:
module_name: "PoseTrack21"
dataset_path: "./data/PoseTrack21/"
json_path: "./posetrack_data/train/000001_bonn_train.json"

# ####### #
# MODULES #
# ####### #

pose_embedding_generator:
module_name: "LinearPBEG" # see `dgs.models.embedding_generator`
Expand Down
6 changes: 2 additions & 4 deletions dgs/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
cfg.train.metric = "dummy"
cfg.train.optimizer = "Adam"
cfg.train.log_dir = "./results/"
cfg.train.evaluations = ["embeddings"] # fixme

# ####### #
# Testing #
Expand All @@ -49,16 +50,13 @@
cfg.dataset.path = "./posetrack_data_fast/val/"
# cfg.dataset.path = "./posetrack_data_fast/val/000342_mpii_test.json"

cfg.dataset.crop_mode = "zero-pad"
cfg.dataset.crop_size = (256, 256) #

# ################ #
# Visual Embedding #
# ################ #
cfg.visual_embedding_generator = EasyDict()
cfg.visual_embedding_generator.module_name = "torchreid" # module name
cfg.visual_embedding_generator.model_name = "osnet_ain_x1_0" # torchreid model name (if applicable)
cfg.visual_embedding_generator.embedding_size = 128
cfg.visual_embedding_generator.embedding_size = 512
cfg.visual_embedding_generator.weights = (
"./weights/osnet_ain_x1_0_msmt17_256x128_amsgrad_ep50_lr0.0015_coslr_b64_fb10_softmax_labsmth_flip_jitter.pth"
)
Expand Down
37 changes: 34 additions & 3 deletions dgs/models/dataset/alphapose.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dgs.models.dataset.dataset import BaseDataset
from dgs.models.states import DataSample
from dgs.utils.files import read_json
from dgs.utils.types import Config, NodePath, Validations
from dgs.utils.types import Config, ImgShape, NodePath, Validations

ap_load_validations: Validations = {"path": ["str", "file exists in project", ("endswith", ".json")]}

Expand All @@ -55,11 +55,17 @@ def __init__(self, config: Config, path: NodePath) -> None:
else:
raise NotImplementedError(f"JSON file {self.params['path']} does not contain known instances.")

canvas_sizes: set[ImgShape] = set()

for detection in json:
path = self.get_path_in_dataset(detection["image_id"])
detection["full_img_path"] = tuple([path])
# imagesize.get() output = (w,h) and our own format = (h, w)
detection["canvas_size"] = imagesize.get(path)[::-1]
canvas_sizes.add(imagesize.get(path)[::-1])

if len(canvas_sizes) > 1:
raise ValueError(f"Expected all images to have the same shape, but found {canvas_sizes}")
self.canvas_size: ImgShape = canvas_sizes.pop()

def arbitrary_to_ds(self, a) -> DataSample:
"""Here `a` is one dict of the AP-JSON containing image_id, category_id, keypoints, score, box, and idx."""
Expand All @@ -71,11 +77,36 @@ def arbitrary_to_ds(self, a) -> DataSample:

return DataSample(
filepath=a["full_img_path"],
bbox=tv_tensors.BoundingBoxes(a["bboxes"], format="XYWH", canvas_size=a["canvas_size"]),
bbox=tv_tensors.BoundingBoxes(a["bboxes"], format="XYWH", canvas_size=self.canvas_size),
keypoints=keypoints,
person_id=a["idx"],
# additional values which are not required
image_id=a["image_id"],
joint_weight=visibility,
person_score=a["score"], # fixme divide by 6 for COCO, by 1 for MPII...?
)

def __getitems__(self, indices: list[int]) -> DataSample:
def stack_key(key: str) -> torch.Tensor:
return torch.stack([torch.tensor(self.data[i][key], device=self.device) for i in indices])

keypoints, visibility = (
torch.tensor(
torch.stack([torch.tensor(self.data[i]["keypoints"]).reshape((-1, 3)) for i in indices]),
)
.to(device=self.device, dtype=torch.float32)
.split([2, 1], dim=-1)
)
ds = DataSample(
validate=False,
filepath=tuple(self.data[i]["full_img_path"] for i in indices),
bbox=tv_tensors.BoundingBoxes(stack_key("bboxes"), format="XYWH", canvas_size=self.canvas_size),
keypoints=keypoints,
person_id=stack_key("idx").int(),
# additional values which are not required
joint_weight=visibility,
image_id=stack_key("image_id").int(),
)
# make sure to get image crop for batch
self.get_image_crop(ds)
return ds
31 changes: 25 additions & 6 deletions dgs/models/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from dgs.utils.types import Config, FilePath, NodePath, Validations # pylint: disable=unused-import

base_dataset_validations: Validations = {
"crop_mode": ["str", ("in", CustomToAspect.modes)],
"crop_size": [("instance", tuple), ("len", 2), lambda x: x[0] > 0 and x[1] > 0],
"dataset_path": ["str", "folder exists in project"],
"crop_mode": ["optional", "str", ("in", CustomToAspect.modes)],
"crop_size": ["optional", ("instance", tuple), ("len", 2), lambda x: x[0] > 0 and x[1] > 0],
"dataset_path": ["str", ("or", (("folder exists in project",), ("folder exists",)))],
}


Expand Down Expand Up @@ -124,6 +124,20 @@ class BaseDataset(BaseModule, TorchDataset):
The other option is to have batches with slightly different sizes.
The DataLoader loads a fixed batch of images, the Dataset computes the resulting detections and returns those.
Params
------
dataset_path (FilePath):
Path to the directory of the dataset.
The value has to either be a local project path, or a valid absolute path.
crop_mode (str, optional):
The mode for image cropping used when calling :func:``self.get_image_crop``.
Value has to be in CustomToAspect.modes.
Default "zero-pad".
crop_size (tuple[int, int], optional):
The size, the resized image should have.
Default (256, 256).
"""

data: list
Expand Down Expand Up @@ -153,10 +167,15 @@ def __getitem__(self, idx: int) -> DataSample:
The pre-computed backbone outputs.
"""
sample: DataSample = self.arbitrary_to_ds(self.data[idx]).to(self.device)
if "image_crop" not in sample or "local_coordinates" not in sample:
if "image_crop" not in sample:
self.get_image_crop(sample)
return sample

@abstractmethod
def __getitems__(self, indices: list[int]) -> DataSample:
"""Given a list of indices, return a single DataSample object containing them all."""
raise NotImplementedError

@abstractmethod
def arbitrary_to_ds(self, a) -> DataSample:
"""Given a single arbitrary data sample, convert it to a DataSample object."""
Expand All @@ -175,8 +194,8 @@ def get_image_crop(self, ds: DataSample) -> None:
"image": ds.image,
"box": ds.bbox,
"keypoints": ds.keypoints,
"output_size": self.params["crop_size"],
"mode": self.params["crop_mode"],
"output_size": self.params.get("crop_size", (256, 256)),
"mode": self.params.get("crop_mode", "zero-pad"),
}
new_sample = self.transform_crop_resize()(structured_input)
ds.image_crop = new_sample["image"]
Expand Down
33 changes: 27 additions & 6 deletions dgs/models/dataset/posetrack21.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Do not allow import of 'PoseTrack21' base dataset
__all__ = ["validate_pt21_json", "get_pose_track_21", "PoseTrack21JSON", "PoseTrack21Torchreid"]

pt21_json_validations: Validations = {"path": [None]}
pt21_json_validations: Validations = {"json_path": []}


def validate_pt21_json(json: dict) -> None:
Expand Down Expand Up @@ -209,7 +209,7 @@ def extract_all_bboxes(
)


def get_pose_track_21(config: Config, path: NodePath) -> TorchDataset:
def get_pose_track_21(config: Config, path: NodePath) -> Union[BaseDataset, TorchDataset]:
"""Load PoseTrack JSON files.
The path parameter can be one of the following:
Expand Down Expand Up @@ -301,12 +301,31 @@ class PoseTrack21(BaseDataset):
def __init__(self, config: Config, path: NodePath) -> None:
super().__init__(config=config, path=path)

def __getitems__(self, indices: list[int]) -> DataSample:
raise NotImplementedError

def arbitrary_to_ds(self, a) -> DataSample:
raise NotImplementedError


class PoseTrack21JSON(BaseDataset):
"""Load a single precomputed json file."""
"""Load a single precomputed json file from the |PT21| dataset.
Params
------
json_path (FilePath):
The path to the json file, either from within the ``dataset_path`` directory, or as absolute path.
Important Inherited Params
--------------------------
dataset_path (FilePath):
Path to the directory of the dataset.
The value has to either be a local project path, or a valid absolute path.
"""

def __init__(self, config: Config, path: NodePath, json_path: FilePath = None) -> None:
super().__init__(config=config, path=path)
Expand All @@ -315,10 +334,13 @@ def __init__(self, config: Config, path: NodePath, json_path: FilePath = None) -

# validate and get the path to the json
if json_path is None:
json_path: FilePath = self.get_path_in_dataset(self.params["path"])
json_path: FilePath = self.get_path_in_dataset(self.params["json_path"])
else:
if self.print("debug"):
print(f"Used given json_path '{json_path}' instead of self.params['path'] '{self.params['path']}'")
print(
f"Used given json_path '{json_path}' "
f"instead of self.params['json_path'] '{self.params['json_path']}'"
)

# validate and get json data
json: dict[str, list[dict[str, any]]] = read_json(json_path)
Expand All @@ -331,7 +353,6 @@ def __init__(self, config: Config, path: NodePath, json_path: FilePath = None) -
}

# imagesize.get() output = (w,h) and our own format = (h, w)

self.img_shape: ImgShape = imagesize.get(list(self.map_img_id_path.values())[0])[::-1]

if any(imagesize.get(path)[::-1] != self.img_shape for img_id, path in self.map_img_id_path.items()):
Expand Down
18 changes: 10 additions & 8 deletions dgs/models/embedding_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from .pose_based import KeyPointConvolutionPBEG, LinearPBEG
from .torchreid import TorchreidModel

EMBEDDING_GENERATORS: dict[str, Type[EmbeddingGeneratorModule]] = {
"torchreid": TorchreidModel,
"LinearPBEG": LinearPBEG,
"KeyPointConvolutionPBEG": KeyPointConvolutionPBEG,
}


def get_embedding_generator(name: str) -> Type[EmbeddingGeneratorModule]:
"""Given the name of one dataset, return an instance."""
if name == "torchreid":
return TorchreidModel
if name == "LinearPBEG":
return LinearPBEG
if name == "KeyPointConvolutionPBEG":
return KeyPointConvolutionPBEG
raise InvalidParameterException(f"Unknown embedding generator with name: {name}.")
"""Given the name of one dataset, return the type."""
if name not in EMBEDDING_GENERATORS:
raise InvalidParameterException(f"Unknown embedding generator with name: {name}.")
return EMBEDDING_GENERATORS[name]
15 changes: 11 additions & 4 deletions dgs/models/embedding_generator/embedding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,33 @@ class EmbeddingGeneratorModule(BaseModule):
Params
------
embedding_size: (int)
embedding_size (int):
The size of the embedding.
It does not necessarily have to match other embedding sizes.
This size does not necessarily have to match other embedding sizes.
nof_classes (int):
The number of classes in the dataset.
Used during training to predict the id.
"""

embedding_size: int
"""The size of the embedding. It Does not necessarily have to match the size of other (different) embeddings."""

nof_classes: int
"""The number of classes in the dataset / embedding."""

def __init__(self, config: Config, path: NodePath):
super().__init__(config, path)
self.validate_params(embedding_validations)

self.embedding_size = self.params["embedding_size"]
self.nof_classes = self.params["nof_classes"]

def __call__(self, *args, **kwargs) -> torch.Tensor: # pragma: no cover
def __call__(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
"""see self.forward()"""
return self.forward(*args, **kwargs)

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
def forward(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict next outputs using this Re-ID model.
Expand Down
Loading

0 comments on commit 58c873f

Please sign in to comment.