Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Median and Savitzky-Golay Filters #163

Merged
merged 15 commits into from
May 14, 2024
Merged
170 changes: 165 additions & 5 deletions movement/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import logging
from datetime import datetime
from functools import wraps
from typing import Union
from typing import Optional, Union

import xarray as xr
from scipy import signal

from movement.logging import log_error


def log_to_attrs(func):
Expand Down Expand Up @@ -46,7 +49,7 @@ def report_nan_values(ds: xr.Dataset, ds_label: str = "dataset"):
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
Dataset containing position, confidence scores, and metadata.
ds_label : str
Label to identify the dataset in the report. Default is "dataset".

Expand Down Expand Up @@ -84,7 +87,7 @@ def interpolate_over_time(
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
Dataset containing position, confidence scores, and metadata.
method : str
String indicating which method to use for interpolation.
Default is ``linear``. See documentation for
Expand Down Expand Up @@ -128,7 +131,7 @@ def filter_by_confidence(
Parameters
----------
ds : xarray.Dataset
Dataset containing pose tracks, confidence scores, and metadata.
Dataset containing position, confidence scores, and metadata.
threshold : float
The confidence threshold below which datapoints are filtered.
A default value of ``0.6`` is used. See notes for more information.
Expand All @@ -141,7 +144,7 @@ def filter_by_confidence(
ds_thresholded : xarray.Dataset
The provided dataset (ds), where points with a confidence
value below the user-defined threshold have been converted
to NaNs
to NaNs.

Notes
-----
Expand All @@ -164,3 +167,160 @@ def filter_by_confidence(
report_nan_values(ds_thresholded, "filtered dataset")

return ds_thresholded


@log_to_attrs
def median_filter(
ds: xr.Dataset,
window_length: int,
min_periods: Optional[int] = None,
print_report: bool = True,
) -> xr.Dataset:
"""Smooth pose tracks by applying a median filter over time.

Parameters
----------
ds : xarray.Dataset
Dataset containing position, confidence scores, and metadata.
window_length : int
The size of the filter window. Window length is interpreted
as being in the input dataset's time unit, which can be inspected
with ``ds.time_unit``.
min_periods : int
Minimum number of observations in window required to have a value
(otherwise result is NaN). The default, None, is equivalent to
setting ``min_periods`` equal to the size of the window.
This argument is directly passed to the ``min_periods`` parameter of
``xarray.DataArray.rolling``.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.

Returns
-------
ds_smoothed : xarray.Dataset
The provided dataset (ds), where pose tracks have been smoothed
using a median filter with the provided parameters.

Notes
-----
By default, whenever one or more NaNs are present in the filter window,
a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input dataset will be propagated
proportionally to the size of the window in frames (specifically, by
``floor(window_length/2)``). To control this behaviour, the
``min_periods`` option can be used to specify the minimum number of
non-NaN values required in the window to compute a result. For example,
setting ``min_periods=1`` will result in the filter returning NaNs
only when all values in the window are NaN, since 1 non-NaN value
is sufficient to compute the median.

"""
ds_smoothed = ds.copy()

# Express window length (and its half) in frames
if ds.time_unit == "seconds":
window_length = int(window_length * ds.fps)
b-peri marked this conversation as resolved.
Show resolved Hide resolved

half_window = window_length // 2

ds_smoothed.update(
{
"position": ds.position.pad( # Pad the edges to avoid NaNs
time=half_window, mode="reflect"
)
.rolling( # Take rolling windows across time
time=window_length, center=True, min_periods=min_periods
)
.median( # Compute the median of each window
skipna=True
)
.isel( # Remove the padded edges
time=slice(half_window, -half_window)
)
}
)

if print_report:
report_nan_values(ds, "input dataset")
report_nan_values(ds_smoothed, "filtered dataset")

return ds_smoothed


@log_to_attrs
def savgol_filter(
b-peri marked this conversation as resolved.
Show resolved Hide resolved
ds: xr.Dataset,
window_length: int,
polyorder: int = 2,
print_report: bool = True,
**kwargs,
) -> xr.Dataset:
"""Smooth pose tracks by applying a Savitzky-Golay filter over time.

Parameters
----------
ds : xarray.Dataset
Dataset containing position, confidence scores, and metadata.
window_length : int
The size of the filter window. Window length is interpreted
as being in the input dataset's time unit, which can be inspected
with ``ds.time_unit``.
polyorder : int
The order of the polynomial used to fit the samples. Must be
less than ``window_length``. By default, a ``polyorder`` of
2 is used.
b-peri marked this conversation as resolved.
Show resolved Hide resolved
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.
**kwargs : dict
Additional keyword arguments are passed to scipy.signal.savgol_filter.
Note that the ``axis`` keyword argument may not be overridden.


Returns
-------
ds_smoothed : xarray.Dataset
The provided dataset (ds), where pose tracks have been smoothed
using a Savitzky-Golay filter with the provided parameters.

Notes
-----
Uses the ``scipy.signal.savgol_filter`` function to apply a Savitzky-Golay
filter to the input dataset's ``position`` variable.
See the scipy documentation for more information on that function.
Whenever one or more NaNs are present in a filter window of the
input dataset, a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input dataset will be propagated
proportionally to the size of the window in frames (specifically, by
``floor(window_length/2)``). Note that, unlike
b-peri marked this conversation as resolved.
Show resolved Hide resolved
``movement.filtering.median_filter()``, there is no ``min_periods``
option to control this behaviour.

"""
if "axis" in kwargs:
raise log_error(
ValueError, "The 'axis' argument may not be overridden."
)

ds_smoothed = ds.copy()

if ds.time_unit == "seconds":
window_length = int(window_length * ds.fps)

position_smoothed = signal.savgol_filter(
ds.position,
window_length,
polyorder,
axis=0,
**kwargs,
)
position_smoothed_da = ds.position.copy(data=position_smoothed)

ds_smoothed.update({"position": position_smoothed_da})

if print_report:
report_nan_values(ds, "input dataset")
report_nan_values(ds_smoothed, "filtered dataset")

return ds_smoothed
45 changes: 45 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,48 @@ def invalid_poses_dataset(request):
def kinematic_property(request):
"""Return a kinematic property."""
return request.param


class Helpers:
"""Generic helper methods for ``movement`` testing modules."""

@staticmethod
def count_nans(ds):
"""Count NaNs in the x coordinate timeseries of the first keypoint
of the first individual in the dataset.
"""
n_nans = np.count_nonzero(
np.isnan(
ds.position.isel(individuals=0, keypoints=0, space=0).values
)
)
return n_nans

@staticmethod
def count_nan_repeats(ds):
"""Count the number of continuous stretches of NaNs in the
x coordinate timeseries of the first keypoint of the first individual
in the dataset.
"""
x = ds.position.isel(individuals=0, keypoints=0, space=0).values
repeats = []
running_count = 1
for i in range(len(x)):
if i != len(x) - 1:
if np.isnan(x[i]) and np.isnan(x[i + 1]):
running_count += 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
else:
running_count = 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
return len(repeats)


@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
return Helpers
72 changes: 72 additions & 0 deletions tests/test_integration/test_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

from movement.filtering import (
filter_by_confidence,
interpolate_over_time,
median_filter,
savgol_filter,
)
from movement.io import load_poses
from movement.sample_data import fetch_dataset_paths


@pytest.fixture(scope="module")
def sample_dataset():
"""Return a single-animal sample dataset, with time unit in frames.
This allows us to better control the expected number of NaNs in the tests.
"""
ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[
"poses"
]
return load_poses.from_dlc_file(ds_path, fps=None)


@pytest.mark.parametrize("window_length", [3, 5, 6, 13])
def test_nan_propagation_through_filters(
sample_dataset, window_length, helpers
):
"""Tests how NaNs are propagated when passing a dataset through multiple
filters sequentially. For the ``median_filter`` and ``savgol_filter``,
we expect the number of NaNs to increase at most by the filter's window
length minus one (``window_length - 1``) multiplied by the number of
continuous stretches of NaNs present in the input dataset.
"""
# Introduce nans via filter_by_confidence
ds_with_nans = filter_by_confidence(sample_dataset, threshold=0.6)
nans_after_confilt = helpers.count_nans(ds_with_nans)
nan_repeats_after_confilt = helpers.count_nan_repeats(ds_with_nans)
assert nans_after_confilt == 2555, (
f"Unexpected number of NaNs in filtered dataset: "
f"expected: 2555, got: {nans_after_confilt}"
)

# Apply median filter and check that
# it doesn't introduce too many or too few NaNs
ds_medfilt = median_filter(ds_with_nans, window_length)
nans_after_medfilt = helpers.count_nans(ds_medfilt)
nan_repeats_after_medfilt = helpers.count_nan_repeats(ds_medfilt)
max_nans_increase = (window_length - 1) * nan_repeats_after_confilt
assert (
nans_after_medfilt <= nans_after_confilt + max_nans_increase
), "Median filter introduced more NaNs than expected."
assert (
nans_after_medfilt >= nans_after_confilt
), "Median filter mysteriously removed NaNs."

# Apply savgol filter and check that
# it doesn't introduce too many or too few NaNs
ds_savgol = savgol_filter(
ds_medfilt, window_length, polyorder=2, print_report=True
)
nans_after_savgol = helpers.count_nans(ds_savgol)
max_nans_increase = (window_length - 1) * nan_repeats_after_medfilt
assert (
nans_after_savgol <= nans_after_medfilt + max_nans_increase
), "Savgol filter introduced more NaNs than expected."
assert (
nans_after_savgol >= nans_after_medfilt
), "Savgol filter mysteriously removed NaNs."

# Apply interpolate_over_time (without max_gap) to eliminate all NaNs
ds_interpolated = interpolate_over_time(ds_savgol, print_report=True)
assert helpers.count_nans(ds_interpolated) == 0
Loading
Loading