diff --git a/movement/filtering.py b/movement/filtering.py index c03ae7bd..451830e8 100644 --- a/movement/filtering.py +++ b/movement/filtering.py @@ -3,9 +3,12 @@ import logging from datetime import datetime from functools import wraps -from typing import Union +from typing import Optional, Union import xarray as xr +from scipy import signal + +from movement.logging import log_error def log_to_attrs(func): @@ -46,7 +49,7 @@ def report_nan_values(ds: xr.Dataset, ds_label: str = "dataset"): Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + Dataset containing position, confidence scores, and metadata. ds_label : str Label to identify the dataset in the report. Default is "dataset". @@ -84,7 +87,7 @@ def interpolate_over_time( Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + Dataset containing position, confidence scores, and metadata. method : str String indicating which method to use for interpolation. Default is ``linear``. See documentation for @@ -128,7 +131,7 @@ def filter_by_confidence( Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + Dataset containing position, confidence scores, and metadata. threshold : float The confidence threshold below which datapoints are filtered. A default value of ``0.6`` is used. See notes for more information. @@ -141,7 +144,7 @@ def filter_by_confidence( ds_thresholded : xarray.Dataset The provided dataset (ds), where points with a confidence value below the user-defined threshold have been converted - to NaNs + to NaNs. Notes ----- @@ -164,3 +167,160 @@ def filter_by_confidence( report_nan_values(ds_thresholded, "filtered dataset") return ds_thresholded + + +@log_to_attrs +def median_filter( + ds: xr.Dataset, + window_length: int, + min_periods: Optional[int] = None, + print_report: bool = True, +) -> xr.Dataset: + """Smooth pose tracks by applying a median filter over time. + + Parameters + ---------- + ds : xarray.Dataset + Dataset containing position, confidence scores, and metadata. + window_length : int + The size of the filter window. Window length is interpreted + as being in the input dataset's time unit, which can be inspected + with ``ds.time_unit``. + min_periods : int + Minimum number of observations in window required to have a value + (otherwise result is NaN). The default, None, is equivalent to + setting ``min_periods`` equal to the size of the window. + This argument is directly passed to the ``min_periods`` parameter of + ``xarray.DataArray.rolling``. + print_report : bool + Whether to print a report on the number of NaNs in the dataset + before and after filtering. Default is ``True``. + + Returns + ------- + ds_smoothed : xarray.Dataset + The provided dataset (ds), where pose tracks have been smoothed + using a median filter with the provided parameters. + + Notes + ----- + By default, whenever one or more NaNs are present in the filter window, + a NaN is returned to the output array. As a result, any + stretch of NaNs present in the input dataset will be propagated + proportionally to the size of the window in frames (specifically, by + ``floor(window_length/2)``). To control this behaviour, the + ``min_periods`` option can be used to specify the minimum number of + non-NaN values required in the window to compute a result. For example, + setting ``min_periods=1`` will result in the filter returning NaNs + only when all values in the window are NaN, since 1 non-NaN value + is sufficient to compute the median. + + """ + ds_smoothed = ds.copy() + + # Express window length (and its half) in frames + if ds.time_unit == "seconds": + window_length = int(window_length * ds.fps) + + half_window = window_length // 2 + + ds_smoothed.update( + { + "position": ds.position.pad( # Pad the edges to avoid NaNs + time=half_window, mode="reflect" + ) + .rolling( # Take rolling windows across time + time=window_length, center=True, min_periods=min_periods + ) + .median( # Compute the median of each window + skipna=True + ) + .isel( # Remove the padded edges + time=slice(half_window, -half_window) + ) + } + ) + + if print_report: + report_nan_values(ds, "input dataset") + report_nan_values(ds_smoothed, "filtered dataset") + + return ds_smoothed + + +@log_to_attrs +def savgol_filter( + ds: xr.Dataset, + window_length: int, + polyorder: int = 2, + print_report: bool = True, + **kwargs, +) -> xr.Dataset: + """Smooth pose tracks by applying a Savitzky-Golay filter over time. + + Parameters + ---------- + ds : xarray.Dataset + Dataset containing position, confidence scores, and metadata. + window_length : int + The size of the filter window. Window length is interpreted + as being in the input dataset's time unit, which can be inspected + with ``ds.time_unit``. + polyorder : int + The order of the polynomial used to fit the samples. Must be + less than ``window_length``. By default, a ``polyorder`` of + 2 is used. + print_report : bool + Whether to print a report on the number of NaNs in the dataset + before and after filtering. Default is ``True``. + **kwargs : dict + Additional keyword arguments are passed to scipy.signal.savgol_filter. + Note that the ``axis`` keyword argument may not be overridden. + + + Returns + ------- + ds_smoothed : xarray.Dataset + The provided dataset (ds), where pose tracks have been smoothed + using a Savitzky-Golay filter with the provided parameters. + + Notes + ----- + Uses the ``scipy.signal.savgol_filter`` function to apply a Savitzky-Golay + filter to the input dataset's ``position`` variable. + See the scipy documentation for more information on that function. + Whenever one or more NaNs are present in a filter window of the + input dataset, a NaN is returned to the output array. As a result, any + stretch of NaNs present in the input dataset will be propagated + proportionally to the size of the window in frames (specifically, by + ``floor(window_length/2)``). Note that, unlike + ``movement.filtering.median_filter()``, there is no ``min_periods`` + option to control this behaviour. + + """ + if "axis" in kwargs: + raise log_error( + ValueError, "The 'axis' argument may not be overridden." + ) + + ds_smoothed = ds.copy() + + if ds.time_unit == "seconds": + window_length = int(window_length * ds.fps) + + position_smoothed = signal.savgol_filter( + ds.position, + window_length, + polyorder, + axis=0, + **kwargs, + ) + position_smoothed_da = ds.position.copy(data=position_smoothed) + + ds_smoothed.update({"position": position_smoothed_da}) + + if print_report: + report_nan_values(ds, "input dataset") + report_nan_values(ds_smoothed, "filtered dataset") + + return ds_smoothed diff --git a/tests/conftest.py b/tests/conftest.py index 84d0f545..b51a8ed6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -326,3 +326,48 @@ def invalid_poses_dataset(request): def kinematic_property(request): """Return a kinematic property.""" return request.param + + +class Helpers: + """Generic helper methods for ``movement`` testing modules.""" + + @staticmethod + def count_nans(ds): + """Count NaNs in the x coordinate timeseries of the first keypoint + of the first individual in the dataset. + """ + n_nans = np.count_nonzero( + np.isnan( + ds.position.isel(individuals=0, keypoints=0, space=0).values + ) + ) + return n_nans + + @staticmethod + def count_nan_repeats(ds): + """Count the number of continuous stretches of NaNs in the + x coordinate timeseries of the first keypoint of the first individual + in the dataset. + """ + x = ds.position.isel(individuals=0, keypoints=0, space=0).values + repeats = [] + running_count = 1 + for i in range(len(x)): + if i != len(x) - 1: + if np.isnan(x[i]) and np.isnan(x[i + 1]): + running_count += 1 + elif np.isnan(x[i]): + repeats.append(running_count) + running_count = 1 + else: + running_count = 1 + elif np.isnan(x[i]): + repeats.append(running_count) + running_count = 1 + return len(repeats) + + +@pytest.fixture +def helpers(): + """Return an instance of the ``Helpers`` class.""" + return Helpers diff --git a/tests/test_integration/test_filtering.py b/tests/test_integration/test_filtering.py new file mode 100644 index 00000000..1f55f9ca --- /dev/null +++ b/tests/test_integration/test_filtering.py @@ -0,0 +1,72 @@ +import pytest + +from movement.filtering import ( + filter_by_confidence, + interpolate_over_time, + median_filter, + savgol_filter, +) +from movement.io import load_poses +from movement.sample_data import fetch_dataset_paths + + +@pytest.fixture(scope="module") +def sample_dataset(): + """Return a single-animal sample dataset, with time unit in frames. + This allows us to better control the expected number of NaNs in the tests. + """ + ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[ + "poses" + ] + return load_poses.from_dlc_file(ds_path, fps=None) + + +@pytest.mark.parametrize("window_length", [3, 5, 6, 13]) +def test_nan_propagation_through_filters( + sample_dataset, window_length, helpers +): + """Tests how NaNs are propagated when passing a dataset through multiple + filters sequentially. For the ``median_filter`` and ``savgol_filter``, + we expect the number of NaNs to increase at most by the filter's window + length minus one (``window_length - 1``) multiplied by the number of + continuous stretches of NaNs present in the input dataset. + """ + # Introduce nans via filter_by_confidence + ds_with_nans = filter_by_confidence(sample_dataset, threshold=0.6) + nans_after_confilt = helpers.count_nans(ds_with_nans) + nan_repeats_after_confilt = helpers.count_nan_repeats(ds_with_nans) + assert nans_after_confilt == 2555, ( + f"Unexpected number of NaNs in filtered dataset: " + f"expected: 2555, got: {nans_after_confilt}" + ) + + # Apply median filter and check that + # it doesn't introduce too many or too few NaNs + ds_medfilt = median_filter(ds_with_nans, window_length) + nans_after_medfilt = helpers.count_nans(ds_medfilt) + nan_repeats_after_medfilt = helpers.count_nan_repeats(ds_medfilt) + max_nans_increase = (window_length - 1) * nan_repeats_after_confilt + assert ( + nans_after_medfilt <= nans_after_confilt + max_nans_increase + ), "Median filter introduced more NaNs than expected." + assert ( + nans_after_medfilt >= nans_after_confilt + ), "Median filter mysteriously removed NaNs." + + # Apply savgol filter and check that + # it doesn't introduce too many or too few NaNs + ds_savgol = savgol_filter( + ds_medfilt, window_length, polyorder=2, print_report=True + ) + nans_after_savgol = helpers.count_nans(ds_savgol) + max_nans_increase = (window_length - 1) * nan_repeats_after_medfilt + assert ( + nans_after_savgol <= nans_after_medfilt + max_nans_increase + ), "Savgol filter introduced more NaNs than expected." + assert ( + nans_after_savgol >= nans_after_medfilt + ), "Savgol filter mysteriously removed NaNs." + + # Apply interpolate_over_time (without max_gap) to eliminate all NaNs + ds_interpolated = interpolate_over_time(ds_savgol, print_report=True) + assert helpers.count_nans(ds_interpolated) == 0 diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index b171a330..a3500b27 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -1,4 +1,3 @@ -import numpy as np import pytest import xarray as xr @@ -6,13 +5,15 @@ filter_by_confidence, interpolate_over_time, log_to_attrs, + median_filter, + savgol_filter, ) from movement.sample_data import fetch_dataset @pytest.fixture(scope="module") def sample_dataset(): - """Return a single-individual sample dataset.""" + """Return a single-animal sample dataset, with time unit in seconds.""" return fetch_dataset("DLC_single-mouse_EPM.predictions.h5") @@ -35,7 +36,7 @@ def fake_func(ds, arg, kwarg=None): ) -def test_interpolate_over_time(sample_dataset): +def test_interpolate_over_time(sample_dataset, helpers): """Test the ``interpolate_over_time`` function. Check that the number of nans is decreased after running this function @@ -44,37 +45,121 @@ def test_interpolate_over_time(sample_dataset): ds_filtered = filter_by_confidence(sample_dataset) ds_interpolated = interpolate_over_time(ds_filtered) - def count_nans(ds): - n_nans = np.count_nonzero( - np.isnan( - ds.position.sel( - individuals="individual_0", keypoints="snout" - ).values[:, 0] - ) - ) - return n_nans - - assert count_nans(ds_interpolated) < count_nans(ds_filtered) + assert helpers.count_nans(ds_interpolated) < helpers.count_nans( + ds_filtered + ) -def test_filter_by_confidence(sample_dataset, caplog): - """Tests for the ``filter_by_confidence`` function. +def test_filter_by_confidence(sample_dataset, caplog, helpers): + """Tests for the ``filter_by_confidence()`` function. Checks that the function filters the expected amount of values from a known dataset, and tests that this value is logged correctly. """ - ds_filtered = filter_by_confidence(sample_dataset) + ds_filtered = filter_by_confidence(sample_dataset, threshold=0.6) assert isinstance(ds_filtered, xr.Dataset) - n_nans = np.count_nonzero( - np.isnan( - ds_filtered.position.sel( - individuals="individual_0", keypoints="snout" - ).values[:, 0] - ) - ) + n_nans = helpers.count_nans(ds_filtered) assert n_nans == 2555 # Check that diagnostics are being logged correctly assert f"snout: {n_nans}/{ds_filtered.time.values.shape[0]}" in caplog.text + + +@pytest.mark.parametrize("window_size", [0.2, 1, 4, 12]) +def test_median_filter(sample_dataset, window_size): + """Tests for the ``median_filter()`` function. Checks that + the function successfully receives the input data and + returns a different xr.Dataset with the correct dimensions. + """ + ds_smoothed = median_filter(sample_dataset, window_size) + + # Test whether filter received and returned correct data + assert isinstance(ds_smoothed, xr.Dataset) and ~( + ds_smoothed == sample_dataset + ) + assert ds_smoothed.position.shape == sample_dataset.position.shape + + +def test_median_filter_with_nans(valid_poses_dataset_with_nan, helpers): + """Test nan behavior of the ``median_filter()`` function. The + ``valid_poses_dataset_with_nan`` dataset (fixture defined in conftest.py) + contains NaN values in all keypoints of the first individual at times + 3, 7, and 8 (0-indexed, 10 total timepoints). + The median filter should propagate NaNs within the windows of the filter, + but it should not introduce any NaNs for the second individual. + """ + ds_smoothed = median_filter(valid_poses_dataset_with_nan, 3) + # There should be NaNs at 7 timepoints for the first individual + # all except for timepoints 0, 1 and 5 + assert helpers.count_nans(ds_smoothed) == 7 + assert ( + ~ds_smoothed.position.isel(individuals=0, time=[0, 1, 5]) + .isnull() + .any() + ) + # The second individual should not contain any NaNs + assert ~ds_smoothed.position.sel(individuals="ind2").isnull().any() + + +@pytest.mark.parametrize("window_length", [0.2, 1, 4, 12]) +@pytest.mark.parametrize("polyorder", [1, 2, 3]) +def test_savgol_filter(sample_dataset, window_length, polyorder): + """Tests for the ``savgol_filter()`` function. + Checks that the function successfully receives the input + data and returns a different xr.Dataset with the correct + dimensions. + """ + ds_smoothed = savgol_filter( + sample_dataset, window_length, polyorder=polyorder + ) + + # Test whether filter received and returned correct data + assert isinstance(ds_smoothed, xr.Dataset) and ~( + ds_smoothed == sample_dataset + ) + assert ds_smoothed.position.shape == sample_dataset.position.shape + + +@pytest.mark.parametrize( + "override_kwargs", + [ + {"mode": "nearest"}, + {"axis": 1}, + {"mode": "nearest", "axis": 1}, + ], +) +def test_savgol_filter_kwargs_override(sample_dataset, override_kwargs): + """Further tests for the ``savgol_filter()`` function. + Checks that the function raises a ValueError when the ``axis`` keyword + argument is overridden, as this is not allowed. Overriding other keyword + arguments (e.g. ``mode``) should not raise an error. + """ + if "axis" in override_kwargs: + with pytest.raises(ValueError): + savgol_filter(sample_dataset, 5, **override_kwargs) + else: + ds_smoothed = savgol_filter(sample_dataset, 5, **override_kwargs) + assert isinstance(ds_smoothed, xr.Dataset) + + +def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers): + """Test nan behavior of the ``savgol_filter()`` function. The + ``valid_poses_dataset_with_nan`` dataset (fixture defined in conftest.py) + contains NaN values in all keypoints of the first individual at times + 3, 7, and 8 (0-indexed, 10 total timepoints). + The Savitzky-Golay filter should propagate NaNs within the windows of + the filter, but it should not introduce any NaNs for the second individual. + """ + ds_smoothed = savgol_filter(valid_poses_dataset_with_nan, 3, polyorder=2) + # There should be NaNs at 7 timepoints for the first individual + # all except for timepoints 0, 1 and 5 + assert helpers.count_nans(ds_smoothed) == 7 + assert ( + ~ds_smoothed.position.isel(individuals=0, time=[0, 1, 5]) + .isnull() + .any() + ) + # The second individual should not contain any NaNs + assert ~ds_smoothed.position.sel(individuals="ind2").isnull().any()