Skip to content

Commit

Permalink
Add get_trial_lengths and extend get_trial_length to handle dataframes (
Browse files Browse the repository at this point in the history
  • Loading branch information
bagibence authored Sep 23, 2024
1 parent 51e1aa9 commit 41e21b3
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyaldata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@
get_string_fields,
get_time_varying_fields,
get_trial_length,
get_trial_lengths,
remove_suffix,
)
62 changes: 61 additions & 1 deletion pyaldata/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import warnings
from typing import Union

import numpy as np
import pandas as pd
Expand All @@ -11,6 +12,7 @@
"get_string_fields",
"get_time_varying_fields",
"get_trial_length",
"get_trial_lengths",
"remove_suffix",
]

Expand Down Expand Up @@ -217,7 +219,7 @@ def get_string_fields(trial_data: pd.DataFrame) -> list[str]:
]


def get_trial_length(trial: pd.Series, ref_field: str = None) -> int:
def _get_trial_length_trial(trial: pd.Series, ref_field: str = None) -> int:
"""
Get the number of time points in the trial
Expand Down Expand Up @@ -249,3 +251,61 @@ def get_trial_length(trial: pd.Series, ref_field: str = None) -> int:
ref_field = spike_rate_fields[0]

return np.size(trial[ref_field], axis=0)


def get_trial_length(
trial_or_df: Union[pd.Series, pd.DataFrame], ref_field: str = None
) -> int:
"""
Get the number of time points in a trial, or the number of timepoints in a dataframe
where all trials are the same length.
Parameters
----------
trial_or_df :
Trial or dataframe to check
ref_field : str, optional
time-varying field to use for identifying the length
if not given, the first field that ends with "spikes" is used
Returns
-------
length : int
Raises
------
ValueError
If not all the trials in the dataframe have the same length.
TypeError
If `trial_or_df` is not a pandas Series or DataFrame.
"""
if isinstance(trial_or_df, pd.Series):
return _get_trial_length_trial(trial_or_df, ref_field)
elif isinstance(trial_or_df, pd.DataFrame):
unique_trial_lengths = np.unique(get_trial_lengths(trial_or_df, ref_field))

if len(unique_trial_lengths) != 1:
raise ValueError("All trials must have the same length.")

return unique_trial_lengths[0]
else:
raise TypeError("trial_or_df must be a pandas Series or DataFrame.")


def get_trial_lengths(trial_data: pd.DataFrame, ref_field: str = None) -> np.ndarray:
"""
Get the number of time points in all trials.
Parameters
----------
trial_data : pd.DataFrame
DataFrame to check.
ref_field : str, optional
time-varying field to use for identifying the length
if not given, the first field that ends with "spikes" is used
Returns
-------
numpy array with the length of each trial
"""
return trial_data.apply(lambda trial: get_trial_length(trial, ref_field), axis=1).values
46 changes: 44 additions & 2 deletions tests/test_get_trial_length.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import pandas as pd
import pytest

from pyaldata.utils import get_trial_length
from pyaldata.utils import get_trial_length, get_trial_lengths

from .test_determine_ref_field import T, _generate_mock_data
from .test_determine_ref_field import N, T, _generate_mock_data


def test_single_field():
Expand All @@ -22,3 +24,43 @@ def test_inconsistent_lengths():

with pytest.raises(ValueError):
get_trial_length(df.iloc[0])


def test_get_trial_lengths_all_the_same():
df = _generate_mock_data()

expected = T * np.ones(df.shape[0])

assert np.all(get_trial_lengths(df) == expected)


def test_get_trial_lengths_random_lengths():
trial_lengths = np.random.randint(0, T, size=N)

data = {}
data["pmd_spikes"] = [np.random.normal(size=(l, 100)) for l in trial_lengths]
df = pd.DataFrame(data)

assert np.all(get_trial_lengths(df) == trial_lengths)


def test_get_trial_length_df_happy():
df = _generate_mock_data(correct_rates=True, correct_spikes=True)

assert get_trial_length(df) == T


def test_get_trial_length_df_different_lengths():
trial_lengths = np.random.randint(0, T, size=N)

data = {}
data["pmd_spikes"] = [np.random.normal(size=(l, 100)) for l in trial_lengths]
df = pd.DataFrame(data)

with pytest.raises(ValueError, match="All trials must have the same length."):
get_trial_length(df)


def test_get_trial_length_df_wrong_type():
with pytest.raises(TypeError):
get_trial_length([1, 2, 3])

0 comments on commit 41e21b3

Please sign in to comment.