Skip to content

Commit

Permalink
Move Kalman changes to new branch
Browse files Browse the repository at this point in the history
Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jan 24, 2024
1 parent 58c873f commit a81894a
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 17 deletions.
4 changes: 2 additions & 2 deletions dgs/models/pose_warping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import Type

from dgs.utils.exceptions import InvalidParameterException
from .kalman import KalmanFilterWarpingModel
from .kalman import KalmanFilterWarpingModule
from .pose_warping import PoseWarpingModule


def get_pose_warping(name: str) -> Type[PoseWarpingModule]:
"""Given the name of one pose-warping module, return an instance."""
if name == "Kalman":
return KalmanFilterWarpingModel
return KalmanFilterWarpingModule
raise InvalidParameterException(f"Unknown pose warping module with name: {name}.")
86 changes: 82 additions & 4 deletions dgs/models/pose_warping/kalman.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,92 @@
"""
Implementation if kalman filter for basic pose warping
"""

import torch
from filterpy.common import Q_discrete_white_noise
from filterpy.kalman import KalmanFilter

from dgs.models.pose_warping.pose_warping import PoseWarpingModule
from dgs.models.states import PoseState
from dgs.models.states import PoseState, PoseStates
from dgs.utils.types import Config, FilePath, NodePath, Validations

KFWM_validations: Validations = {
"dim_x": ["dict", ("longer", 0)],
"dim_z": ["dict", ("longer", 0)],
"measures": ["optional", ("instance", "list")],
}


class KalmanFilterWarpingModule(PoseWarpingModule):
r"""Kalman Filter for pose and box warping using `torch_kalman <https://github.com/strongio/torch-kalman>`_ package.
Module Name
-----------
KalmanFilterWarping, KalmanFilterWarpingModule, or KFWM
Description
-----------
A basic Kalman filter using the `filterpy <https://filterpy.readthedocs.io/en/latest/index.html>`_ package.
Given the current state, predict the next one.
Will indirectly compute velocities and variances.
Params
------
dim_x: (dict[str, int])
For every measure, the number of state variables for the Kalman filter.
For example, if you are tracking the (x-y)-position of a person with 17 key-points, dim_x would be
:math:`2 \cdot 17 = 34`.
This is used to set the default size of P, Q, and u
dim_z: (dict[str, int])
Number of measurement inputs.
For example, if the measurement provides you with bbox-position as (x,y,w,h), dim_z would be 4.
measures: (list[str], default=["pose", "box"])
A list of measurement names to compute the Kalman Filter prediction from.
The variables will be extracted from a given DataSample object using `__getitem__(name)`.
"""

class KalmanFilterWarpingModel(PoseWarpingModule):
"""Kalman Filter for pose warping"""
model: dict[str, KalmanFilter]
measures: list[str]

def forward(self, pose: torch.Tensor, jcs: torch.Tensor, bbox: torch.Tensor) -> PoseState:
def __init__(self, config: Config, path: NodePath) -> None:
""""""
super().__init__(config, path)
self.validate_params(validations=KFWM_validations)
self.measures = self.params.get("measures", ["pose", "box"])
# create a basic KF for every measurement
for m in self.measures:
self.model[m] = KalmanFilter(
dim_x=self.params["dim_x"], dim_z=self.params["dim_z"], dim_u=self.params.get("dim_u", 0)
)

def forward(self, ps: PoseStates) -> PoseStates:
"""Given the current pose state, use the kalman filter to predict the next state."""
curr_state: PoseState = ps.get_states()
for m in self.measures:
prediction = self.model[m].predict(curr_state[str(m)])
return ...

def forward_pred(self) -> ...:
"""Get `torch_kalman` internal prediction"""

def load_weights(self, path: FilePath) -> None:
"""..."""
self.model.Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13)
raise NotImplementedError

def train(self, inp: torch.Tensor, epochs: int = 8) -> None:
"""Train kalman filter prediction.
References:
See
Args:
inp:
epochs:
Returns:
"""
13 changes: 4 additions & 9 deletions dgs/models/pose_warping/pose_warping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
"""
from abc import abstractmethod

import torch

from dgs.models.module import BaseModule
from dgs.models.states import PoseState
from dgs.models.states import PoseStates


class PoseWarpingModule(BaseModule):
Expand All @@ -16,21 +14,18 @@ class PoseWarpingModule(BaseModule):
The goal of pose warping is to predict the next PoseState given information about the last (few) states.
"""

def __call__(self, *args, **kwargs) -> PoseState: # pragma: no cover
def __call__(self, *args, **kwargs) -> PoseStates: # pragma: no cover
"""see self.forward()"""
return self.forward(*args, **kwargs)

@abstractmethod
def forward(self, pose: torch.Tensor, jcs: torch.Tensor, bbox: torch.Tensor) -> PoseState:
def forward(self, ps: PoseStates) -> PoseStates:
"""
Args:
pose: History of poses per track as `torch.Tensor` of shape ``[EP x J x 2]``.
jcs: History of JCS per track as `torch.Tensor` of shape ``[EP x J x 1]``.
bbox: History of bboxes per track as `torch.Tensor` of shape ``[EP x 4]``.
ps: History of poses as `PoseState` object.
Returns:
The next pose state.
"""

raise NotImplementedError
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
autodoc_mock_imports = [
"alphapose",
"detector",
"filterpy",
"halpecocotools",
"opencv-python",
# "matplotlib",
Expand Down
2 changes: 2 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ dependencies:
- -e ./dependencies/AlphaPose_Fork
# install PoseTrack21 evaluation toolkit
- -e ./dependencies/PoseTrack21/eval/posetrack21/
# numpy only Kalman Filter -> filterpy
- filterpy
6 changes: 4 additions & 2 deletions tests/requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ black
Cython
coverage[toml]
easydict
filterpy
gdown
imagesize
matplotlib
Expand All @@ -22,8 +23,9 @@ torchreid==0.2.5
tqdm
visdom

# -r ../dependencies/torchreid/requirements.txt
# -e ./dependencies/torchreid
# other git dependencies

# git+https://github.com/strongio/torch-kalman.git#egg=torch_kalman

# local dependencies

Expand Down

0 comments on commit a81894a

Please sign in to comment.