diff --git a/ultrack/tracks/_test/test_tracks_stats.py b/ultrack/tracks/_test/test_tracks_stats.py index 0c43195..9a4903a 100644 --- a/ultrack/tracks/_test/test_tracks_stats.py +++ b/ultrack/tracks/_test/test_tracks_stats.py @@ -80,6 +80,32 @@ def test_tracks_df_movement(): pd.testing.assert_frame_equal(result, expected) +def test_tracks_df_movement_2d(): + # Sample test data + df = pd.DataFrame( + { + "track_id": [1, 1, 2, 2], + "t": [1, 2, 1, 2], + "y": [1, 2, 1, 2], + "x": [2, 3, 2, 2], + } + ) + + # Call the function + result = tracks_df_movement(df) + + # Expected result + expected = pd.DataFrame( + { + "y": [0.0, 1.0, 0.0, 1.0], + "x": [0.0, 1.0, 0.0, 0.0], + } + ) + + # Assert that the result matches the expected dataframe + pd.testing.assert_frame_equal(result, expected) + + def test_tracks_profile_matrix_one_track_one_attribute(): tracks_df = pd.DataFrame( {"track_id": [1, 1, 1], "t": [0, 1, 2], "attribute_1": [10, 20, 30]} diff --git a/ultrack/tracks/stats.py b/ultrack/tracks/stats.py index d66a415..379edfe 100644 --- a/ultrack/tracks/stats.py +++ b/ultrack/tracks/stats.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Optional import numpy as np import pandas as pd @@ -13,7 +13,7 @@ def tracks_df_movement( tracks_df: pd.DataFrame, lag: int = 1, - cols: tuple[str, ...] = ("z", "y", "x"), + cols: Optional[tuple[str, ...]] = None, ) -> pd.DataFrame: """ Compute the displacement for track data across given time lags. @@ -33,7 +33,8 @@ def tracks_df_movement( Number of periods to compute the difference over. Default is 1. cols : tuple[str, ...], optional - Columns to compute the displacement for. Default is ("z", "y", "x"). + Columns to compute the displacement for. If not provided, it will try to + find any of ["z", "y", "x"] columns in the dataframe and use them. Returns ------- @@ -62,7 +63,13 @@ def tracks_df_movement( tracks_df.sort_values(by=["track_id", "t"], inplace=True) - cols = list(cols) + if cols is None: + cols = [] + for c in ["z", "y", "x"]: + if c in tracks_df.columns: + cols.append(c) + else: + cols = list(cols) out = tracks_df.groupby("track_id", as_index=False)[cols].diff(periods=lag) out.fillna(0, inplace=True)