Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
added requirements.txt and startet using pip only
removed version from dgs.__init__ file and use importlib to obtain it

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Feb 2, 2024
1 parent a60032f commit f9c9575
Show file tree
Hide file tree
Showing 34 changed files with 143 additions and 44 deletions.
8 changes: 0 additions & 8 deletions dgs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
"""
Tracking via Dynamically Gated Similarities
TODO Provide more package information
"""

__version__ = "0.0.2"
__author__ = "Martin Steinborn"
__homepage__ = "https://bmmtstb.github.io/dynamically-gated-similarities/"
__description__ = 'Code for Paper "Tracking with Dynamically Gated Similarities"'
__url__ = "https://github.com/bmmtstb/dynamically-gated-similarities"
3 changes: 2 additions & 1 deletion dgs/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
These values are used, iff the given config does not set own values.
"""

import torch
from easydict import EasyDict

Expand All @@ -12,7 +13,7 @@
# General #
# ####### #

# cfg.name = "DEFAULT" # shouldn't be set, to force user to give it a name
cfg.name = "DEFAULT"
cfg.print_prio = "normal"
cfg.working_memory_size = 30

Expand Down
1 change: 0 additions & 1 deletion dgs/models/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class EngineModule(BaseModule):
lr_sched: list[optim.lr_scheduler.LRScheduler]
"""The learning-rate sheduler(s) can be changed by setting ``engine.lr_scheduler = [..., ...]``."""

@torch.enable_grad
def __init__(
self,
config: Config,
Expand Down
3 changes: 3 additions & 0 deletions dgs/models/engine/visual_sim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class VisualSimilarityEngine(EngineModule):
val_dl: TorchDataLoader
"""The torch DataLoader containing the validation (query) data."""

# The heart of the project might get a little larger...
# pylint: disable=too-many-arguments,too-many-locals

def __init__(
self,
config: Config,
Expand Down
1 change: 1 addition & 0 deletions dgs/models/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Functions to load and manage torch loss functions.
"""

from typing import Type, Union

from torch import nn
Expand Down
6 changes: 4 additions & 2 deletions dgs/models/metric.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""
Methods for handling the computation of distances and other metrics.
"""

import warnings
from typing import Type, Union

import torch
from torch import nn
from torch.linalg import vector_norm

from dgs.utils.types import Metric

Expand Down Expand Up @@ -71,8 +73,8 @@ def compute_cmc(
def custom_cosine_similarity(input1: torch.Tensor, input2: torch.Tensor, dim: int, eps: float) -> torch.Tensor:
"""See https://github.com/pytorch/pytorch/issues/104564#issuecomment-1625348908"""
# get normalization value
t1_div = torch.linalg.vector_norm(input1, dim=dim, keepdims=True)
t2_div = torch.linalg.vector_norm(input2, dim=dim, keepdims=True)
t1_div = vector_norm(input1, dim=dim, keepdim=True) # pylint: disable=not-callable
t2_div = vector_norm(input2, dim=dim, keepdim=True) # pylint: disable=not-callable

t1_div = t1_div.clone()
t2_div = t2_div.clone()
Expand Down
26 changes: 13 additions & 13 deletions dgs/models/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base model class as lowest building block for dynamic modules
"""

import inspect
from abc import ABC, abstractmethod
from functools import wraps
Expand All @@ -17,7 +18,7 @@
module_validations: Validations = {
"name": ["str", ("longer", 2)],
"print_prio": [("in", PRINT_PRIORITY)],
"device": ["str", ("or", (("in", ["cuda", "cpu"]), ("instance", torch.device)))],
"device": [("or", (("in", ["cuda", "cpu"]), ("type", torch.device)))],
"gpus": ["optional", lambda gpus: isinstance(gpus, list) and all(isinstance(gpu, int) for gpu in gpus)],
"num_workers": ["optional", "int", ("gte", 0)],
"sp": [("instance", bool)],
Expand Down Expand Up @@ -97,7 +98,7 @@ class BaseModule(ABC):
"""

@enable_keyboard_interrupt
def __init__(self, config: Config, path: NodePath, validate_base: bool = False):
def __init__(self, config: Config, path: NodePath):
self.config: Config = config
self.params: Config = get_sub_config(config, path)
self._path: NodePath = path
Expand All @@ -106,16 +107,13 @@ def __init__(self, config: Config, path: NodePath, validate_base: bool = False):
if not self.config["gpus"]:
self.config["gpus"] = [-1]
elif isinstance(self.config["gpus"], str):
self.config["gpus"] = (
[int(i) for i in self.config["gpus"].split(",")] if torch.cuda.device_count() >= 1 else [-1]
)
self.config["gpus"] = [int(i) for i in self.config["gpus"].split(",")]

# set default value of num_workers
if not self.config["num_workers"]:
self.config["num_workers"] = 0

# validate config when calling BaseModule class and flag is True
if validate_base:
self.validate_params(module_validations, "config")
self.validate_params(module_validations, "config")

def validate_params(self, validations: Validations, attrib_name: str = "params") -> None:
"""Given per key validations, validate this module's parameters.
Expand Down Expand Up @@ -173,7 +171,9 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
if param_name not in getattr(self, attrib_name):
if "optional" in list_of_validations:
continue # value is optional and does not exist, skip validation
raise InvalidParameterException(f"{param_name} is expected to be in module {self.__class__.__name__}")
raise InvalidParameterException(
f"'{param_name}' is expected to be in module '{self.__class__.__name__}'"
)

# it is now safe to get the value
value = getattr(self, attrib_name)[param_name]
Expand Down Expand Up @@ -202,13 +202,13 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
if validate_value(value=value, data=data, validation=validation_name):
continue
raise InvalidParameterException(
f"In module {self.__class__.__name__}, parameter {param_name} is not valid. "
f"Value is {value} and is expected to have validation(s) {list_of_validations}."
f"In module '{self.__class__.__name__}', parameter '{param_name}' is not valid. "
f"Value is '{value}' and is expected to have validation(s) '{list_of_validations}'."
)
# no other case was true
raise ValidationException(
f"Validation is expected to be callable or tuple, but is {type(validation)}. "
f"Current module: {self.__class__.__name__}, Parameter: {param_name}"
f"Validation is expected to be callable or tuple, but is '{type(validation)}'. "
f"Current module: '{self.__class__.__name__}', Parameter: '{param_name}'"
)

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions dgs/models/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Functions to load and manage torch optimizers.
"""

from typing import Type, Union

from torch import optim
Expand Down
1 change: 1 addition & 0 deletions dgs/models/pose_warping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Module to warp a given pose, pose-state -
or more generally to predict the next pose of a person given previous time steps.
"""

from typing import Type

from dgs.utils.exceptions import InvalidParameterException
Expand Down
1 change: 1 addition & 0 deletions dgs/models/pose_warping/kalman.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Implementation if kalman filter for basic pose warping
"""

import torch

from dgs.models.pose_warping.pose_warping import PoseWarpingModule
Expand Down
1 change: 1 addition & 0 deletions dgs/models/pose_warping/pose_warping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Helpers and models for warping the pose-state of a track into the next time frame.
"""

from abc import abstractmethod

import torch
Expand Down
1 change: 1 addition & 0 deletions dgs/models/similarity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Modules for handling similarity functions or other models that return similarity scores between two (or more) inputs.
"""

from typing import Type

from dgs.utils.exceptions import InvalidParameterException
Expand Down
1 change: 1 addition & 0 deletions dgs/models/similarity/combined.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Models that combine the results of two or more similarity matrices.
"""

from abc import abstractmethod

import torch
Expand Down
1 change: 1 addition & 0 deletions dgs/models/similarity/pose_similarity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Modules for computing the similarity between two poses.
"""

import torch

from dgs.models.similarity.similarity import SimilarityModule
Expand Down
1 change: 1 addition & 0 deletions dgs/models/similarity/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Similarity functions compute a similarity score "likeness" between two equally sized inputs.
"""

from typing import Callable

import torch
Expand Down
1 change: 1 addition & 0 deletions dgs/models/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
definitions and helpers for pose-state(s)
"""

from collections import UserDict
from typing import Union

Expand Down
1 change: 1 addition & 0 deletions dgs/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Contains functions for validating configuration and parameter of modules.
"""

from copy import deepcopy
from typing import Union

Expand Down
1 change: 1 addition & 0 deletions dgs/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Predefined constants that will not change and might be used at different places.
"""

import os

import torch
Expand Down
1 change: 1 addition & 0 deletions dgs/utils/files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains helper functions for loading and interacting with files and paths.
"""

import json
import os

Expand Down
1 change: 1 addition & 0 deletions dgs/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RGB Images in cv2 have a shape of ``[h x w x C]`` and the channels are in order GBR.
Grayscale Images in cv2 have a shape of ``[h x w]``.
"""

from typing import Iterable, Union

import torch
Expand Down
1 change: 1 addition & 0 deletions dgs/utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Models, functions and helpers for timing operations.
"""

import time
from collections import UserList
from datetime import timedelta
Expand Down
1 change: 1 addition & 0 deletions dgs/utils/torchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tools for handling recurring torch tasks. Mostly taken from the `torchreid package
<https://kaiyangzhou.github.io/deep-person-reid/_modules/torchreid/utils/torchtools.html#load_pretrained_weights>`_
"""

import os
import pickle
import shutil
Expand Down
1 change: 1 addition & 0 deletions dgs/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
definition of regularly used types
"""

from typing import Callable, Union

import torch
Expand Down
2 changes: 2 additions & 0 deletions dgs/utils/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utilities for validating recurring data types.
"""

import os
from collections.abc import Iterable, Sized
from typing import Union
Expand Down Expand Up @@ -30,6 +31,7 @@
"callable": (lambda x, _: callable(x)),
"instance": isinstance, # alias
"isinstance": isinstance,
"type": (lambda x, d: isinstance(d, type) and isinstance(x, d)),
"iterable": (lambda x, _: isinstance(x, Iterable)),
"sized": (lambda x, _: isinstance(x, Sized)),
# number
Expand Down
1 change: 1 addition & 0 deletions dgs/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Matplotlib uses a different order for the images: `[B x H x W x C]`.
At least, the channel for matplotlib is RGB too.
"""

from typing import Union

import numpy as np
Expand Down
10 changes: 6 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import os
import sys
from importlib.metadata import PackageNotFoundError, version as ilib_version

sys.path.insert(0, os.path.abspath(".."))

Expand All @@ -23,10 +24,11 @@
copyright = "2023, Martin Steinborn"
author = "Martin Steinborn"

version_file = "../dgs/__init__.py"
with open(version_file, "r") as f:
exec(compile(f.read(), version_file, "exec"))
__version__ = locals()["__version__"]
try:
__version__ = ilib_version("dynamically_gated_similarities")
except PackageNotFoundError:
__version__ = "0.0.0"

# The short X.Y version
version = __version__[: __version__.find(".", __version__.find(".") + 1)]
# The full version, including alpha/beta/rc tags
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dynamically_gated_similarities"
version = "0.0.2"
version = "0.0.3"
authors = [
{ name = "Martin Steinborn", email = "[email protected]" },
]
Expand Down
Loading

0 comments on commit f9c9575

Please sign in to comment.