diff --git a/pyaldata/__init__.py b/pyaldata/__init__.py index f5fb178..3773fb2 100644 --- a/pyaldata/__init__.py +++ b/pyaldata/__init__.py @@ -71,5 +71,6 @@ get_string_fields, get_time_varying_fields, get_trial_length, + get_trial_lengths, remove_suffix, ) diff --git a/pyaldata/utils.py b/pyaldata/utils.py index 09ea7b9..32499f2 100644 --- a/pyaldata/utils.py +++ b/pyaldata/utils.py @@ -1,5 +1,6 @@ import functools import warnings +from typing import Union import numpy as np import pandas as pd @@ -11,6 +12,7 @@ "get_string_fields", "get_time_varying_fields", "get_trial_length", + "get_trial_lengths", "remove_suffix", ] @@ -217,7 +219,7 @@ def get_string_fields(trial_data: pd.DataFrame) -> list[str]: ] -def get_trial_length(trial: pd.Series, ref_field: str = None) -> int: +def _get_trial_length_trial(trial: pd.Series, ref_field: str = None) -> int: """ Get the number of time points in the trial @@ -249,3 +251,61 @@ def get_trial_length(trial: pd.Series, ref_field: str = None) -> int: ref_field = spike_rate_fields[0] return np.size(trial[ref_field], axis=0) + + +def get_trial_length( + trial_or_df: Union[pd.Series, pd.DataFrame], ref_field: str = None +) -> int: + """ + Get the number of time points in a trial, or the number of timepoints in a dataframe + where all trials are the same length. + + Parameters + ---------- + trial_or_df : + Trial or dataframe to check + ref_field : str, optional + time-varying field to use for identifying the length + if not given, the first field that ends with "spikes" is used + + Returns + ------- + length : int + + Raises + ------ + ValueError + If not all the trials in the dataframe have the same length. + TypeError + If `trial_or_df` is not a pandas Series or DataFrame. + """ + if isinstance(trial_or_df, pd.Series): + return _get_trial_length_trial(trial_or_df, ref_field) + elif isinstance(trial_or_df, pd.DataFrame): + unique_trial_lengths = np.unique(get_trial_lengths(trial_or_df, ref_field)) + + if len(unique_trial_lengths) != 1: + raise ValueError("All trials must have the same length.") + + return unique_trial_lengths[0] + else: + raise TypeError("trial_or_df must be a pandas Series or DataFrame.") + + +def get_trial_lengths(trial_data: pd.DataFrame, ref_field: str = None) -> np.ndarray: + """ + Get the number of time points in all trials. + + Parameters + ---------- + trial_data : pd.DataFrame + DataFrame to check. + ref_field : str, optional + time-varying field to use for identifying the length + if not given, the first field that ends with "spikes" is used + + Returns + ------- + numpy array with the length of each trial + """ + return trial_data.apply(lambda trial: get_trial_length(trial, ref_field), axis=1).values diff --git a/tests/test_get_trial_length.py b/tests/test_get_trial_length.py index 6d8f2f7..784a27e 100644 --- a/tests/test_get_trial_length.py +++ b/tests/test_get_trial_length.py @@ -1,8 +1,10 @@ +import numpy as np +import pandas as pd import pytest -from pyaldata.utils import get_trial_length +from pyaldata.utils import get_trial_length, get_trial_lengths -from .test_determine_ref_field import T, _generate_mock_data +from .test_determine_ref_field import N, T, _generate_mock_data def test_single_field(): @@ -22,3 +24,43 @@ def test_inconsistent_lengths(): with pytest.raises(ValueError): get_trial_length(df.iloc[0]) + + +def test_get_trial_lengths_all_the_same(): + df = _generate_mock_data() + + expected = T * np.ones(df.shape[0]) + + assert np.all(get_trial_lengths(df) == expected) + + +def test_get_trial_lengths_random_lengths(): + trial_lengths = np.random.randint(0, T, size=N) + + data = {} + data["pmd_spikes"] = [np.random.normal(size=(l, 100)) for l in trial_lengths] + df = pd.DataFrame(data) + + assert np.all(get_trial_lengths(df) == trial_lengths) + + +def test_get_trial_length_df_happy(): + df = _generate_mock_data(correct_rates=True, correct_spikes=True) + + assert get_trial_length(df) == T + + +def test_get_trial_length_df_different_lengths(): + trial_lengths = np.random.randint(0, T, size=N) + + data = {} + data["pmd_spikes"] = [np.random.normal(size=(l, 100)) for l in trial_lengths] + df = pd.DataFrame(data) + + with pytest.raises(ValueError, match="All trials must have the same length."): + get_trial_length(df) + + +def test_get_trial_length_df_wrong_type(): + with pytest.raises(TypeError): + get_trial_length([1, 2, 3])