Skip to content

Commit

Permalink
Fix tracks_df_movement on 2D data (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-theodoro authored Jul 22, 2024
1 parent a21ab8b commit abd2a33
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
26 changes: 26 additions & 0 deletions ultrack/tracks/_test/test_tracks_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ def test_tracks_df_movement():
pd.testing.assert_frame_equal(result, expected)


def test_tracks_df_movement_2d():
# Sample test data
df = pd.DataFrame(
{
"track_id": [1, 1, 2, 2],
"t": [1, 2, 1, 2],
"y": [1, 2, 1, 2],
"x": [2, 3, 2, 2],
}
)

# Call the function
result = tracks_df_movement(df)

# Expected result
expected = pd.DataFrame(
{
"y": [0.0, 1.0, 0.0, 1.0],
"x": [0.0, 1.0, 0.0, 0.0],
}
)

# Assert that the result matches the expected dataframe
pd.testing.assert_frame_equal(result, expected)


def test_tracks_profile_matrix_one_track_one_attribute():
tracks_df = pd.DataFrame(
{"track_id": [1, 1, 1], "t": [0, 1, 2], "attribute_1": [10, 20, 30]}
Expand Down
15 changes: 11 additions & 4 deletions ultrack/tracks/stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Optional

import numpy as np
import pandas as pd
Expand All @@ -13,7 +13,7 @@
def tracks_df_movement(
tracks_df: pd.DataFrame,
lag: int = 1,
cols: tuple[str, ...] = ("z", "y", "x"),
cols: Optional[tuple[str, ...]] = None,
) -> pd.DataFrame:
"""
Compute the displacement for track data across given time lags.
Expand All @@ -33,7 +33,8 @@ def tracks_df_movement(
Number of periods to compute the difference over. Default is 1.
cols : tuple[str, ...], optional
Columns to compute the displacement for. Default is ("z", "y", "x").
Columns to compute the displacement for. If not provided, it will try to
find any of ["z", "y", "x"] columns in the dataframe and use them.
Returns
-------
Expand Down Expand Up @@ -62,7 +63,13 @@ def tracks_df_movement(

tracks_df.sort_values(by=["track_id", "t"], inplace=True)

cols = list(cols)
if cols is None:
cols = []
for c in ["z", "y", "x"]:
if c in tracks_df.columns:
cols.append(c)
else:
cols = list(cols)

out = tracks_df.groupby("track_id", as_index=False)[cols].diff(periods=lag)
out.fillna(0, inplace=True)
Expand Down

0 comments on commit abd2a33

Please sign in to comment.