-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Median and Savitzky-Golay Filters #163
Conversation
This PR was a bit tricky and some of the defaults currently implemented are a bit arbitrary. Here's some of the reasoning behind a few of the decisions I've made thus far, as well as some things I'd like a bit of feedback on: Median FilterIn issue #55, @niksirbi suggested three potential implementations for the median filter:
While The flexibility of this function, however, also means that we must now decide on how many of these parameters to expose to the users, and what defaults to use for these, irrespective of whether they are user-facing or not. Here, I'm specifically worried about the Savitzky-Golay FilterWhile our choice of implementations was more constrained for the Savitzky-Golay filter, again I am still trying to figure out what some sensible default values might be - specifically for the DiagnosticsI assume it would be useful to report a number of diagnostics after each of these functions has been run, but I'm not quite sure where to start here. Are there some standard/straightforward diagnostics that I could implement here (perhaps something showing local variability along the timeseries)? Knowing this will also help design tests. Separate vs. Single Smoothing FunctionI've now implemented the different smoothing methods as separate functions, but alternatively, we could consider wrapping these all in a single, generic |
Thanks for the work and the comprehensive writeup @b-peri 🤗 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Brandon, thanks for doing the research on these!
Here are my opinions on the points you raised.
- I agree with choosing
scipy.ndimage.median_filter()
, it's the most flexible approach. But FYI, I don't think the xarray'srolling()
operations fix the origin at the end. Based on their docs, passingrolling(center=True)
should put the origin at the window center. - Regarding which parameter values to choose for
mode
,cval
,origin
, I would opt for keeping their default arguments. We don't have to explicitly expose those. Instead, we can add **kwargs to the function definition, which will allow users to override these arguments (i.e. pass them to the underlying scipy function). That said we might want to stop users for overriding theaxis
argument, because that means they would be smoothing in non-time dimensions, which doesn't make sense for these data. See if you can catch this case and throw an error. - I don't have any experience or thoughts on the choice for
polyorder
andwindow_length
for the SG filter. I'd encourage you to play with our sample data and do some visual checks by plotting the tracks before and after. You can also research what parameters people have reported in animal behaviour papers (if you are able to find ones with thorough method sections). - Regarding diagnostics, you are right that it's unclear what to report. Ultimately there is no substitute for visual checks in these cases, so we'd have to implement some diagnostic plots (e.g. overlaying tracks before/after filter). I have some ideas on this which we can discuss tomorrow, but ultimately this should be left for a future PR.
- Regarding having a single
smooth()
function to rule-them-all, I think we need something like that, but I hope that we can come up with a more elegant solution than dumping it all in one function with many arguments. In real life, people might want to chain multiple filters, to build a pipeline that could look sth like: filter_by_confidence -> median_filter -> SG_filter -> interpolate_over_time. We could think about defining aPipeline
class, to which we can add a number of steps, and then execute them all viaPipeline.run
. This approach is also nice for batching analyses across many datasets and good for reproducibility. In any case, let's discuss further and leave this for a future PR as well.
On the topic of tests. I don't think we need anything fancy, since we trust scipy to test the functions we are relying on, we should just ensure that the inputs are correctly passed to those functions. One thing I'm worried about is NaNs. If there are already NaNs in the tracks, some smoothing functions may cause them to increase (by setting the entire window to NaN, if 1 is present). In any case, we should find out the exact behaviour through tests. For example, you can define some toy arrays wih NaNs at the edges or in the middle (e.g [nan, 0, 1, 3, 1]
, [3, 0, 1, 3, 1, nan]
, [0, 2, nan, nan, nan, 2, 0]
etc.) and check how they are affected by the filters. The sort of thing I did in the filtering example (in the docs), but with the new filters.
Another issue that came to mind is what unit to use for window/gap lengths, time (seconds) or frames. We should at least be consistent and clearly state what we are using.
Some additional thoughts on this, based on today's meeting:
|
9c62eba
to
a70c414
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #163 +/- ##
==========================================
+ Coverage 99.66% 99.68% +0.02%
==========================================
Files 10 11 +1
Lines 591 637 +46
==========================================
+ Hits 589 635 +46
Misses 2 2 ☔ View full report in Codecov by Sentry. |
Thanks very much for the extensive feedback @niksirbi! I've now implemented most of your points. Following on from our discussions here and on Zulip, I just had the following notes:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @b-peri, nice work.
I took my time to extensively review this, and I basically agree with most of the choices you've made. My major quibbles are:
- I'm bothered by the fact that the median filter introduces nans at the edges, especially because this also affects keypoint time-series that had no nans to begin with. I found a way around that, such that the median filter will now behave identically to the savgol filter (propagate nans within windows, but not at the edges). See below for my proposed code.
- If we wish to further reduce the propagation of nans by the median filter, we can use the
min_periods
argument of therolling
method. The default isNone
, meaning that if at least 1 value within a window is NaN, the result is also nan. If we setmin_periods
to e.g. 2, two non-nan values within a window will be sufficient to produce a valid output (i.e. the median of these 2). I propose exposing that argument, with the default beingNone
(to achieve consistent behaviour with savgol), but allow users to override to ameliorate nan propagation. - On the topic of nan propagation, we make claims in the notes, but we have to back them up with rigorous tests. I propose introducing a new integration test that will verify how nans are propagated through filtering functions (see below).
- The unit tests you wrote were a bit too minimal I think. I've left specific comments on how to easily parametrise them to expand their scope and how to add some additional unit tests.
My proposed modified filtering functions
Median filter
This relies on xarray.DataArray.pad
to extend the edges of the array (with reflection), to avoid the introduction of NaNs at the edges. It also exposes the min_periods
and print_report
arguments.
@log_to_attrs
def median_filter(
ds: xr.Dataset,
window_length: int,
min_periods: Optional[int] = None,
print_report: bool = True
) -> xr.Dataset:
"""Smooths pose tracks by applying a median filter over time.
Parameters
----------
ds : xarray.Dataset
Dataset containing position, confidence scores, and metadata.
window_length : int
The size of the filter window. Window length is interpreted
as being in the input dataset's time unit, which can be inspected
with ``ds.time_unit``.
min_periods : int
Minimum number of observations in window required to have a value
(otherwise result is NaN). The default, None, is equivalent to
setting ``min_periods`` equal to the size of the window.
This argument is directly passed to the ``min_periods`` parameter of
``xarray.DataArray.rolling``.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.
Returns
-------
ds_smoothed : xarray.Dataset
The provided dataset (ds), where pose tracks have been smoothed
using a median filter with the provided parameters.
Notes
-----
By default, whenever one or more NaNs are present in the filter window,
a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input dataset will be propagated
proportionally to the size of the window in frames (specifically, by
``floor(window_length/2)``). To control this behaviour, the
``min_periods`` option can be used to specify the minimum number of
non-NaN values required in the window to compute a result. For example,
setting ``min_periods=1`` will result in the filter returning NaNs
only when all values in the window are NaN, since 1 non-NaN value
is sufficient to compute the median.
"""
ds_smoothed = ds.copy()
# Express window length (and its half) in frames
if ds.time_unit == "seconds":
window_length = int(window_length * ds.fps)
half_window = window_length // 2
else:
half_window = window_length // 2
ds_smoothed.update(
{
"position": ds.position.pad( # Pad the edges to avoid NaNs
time = half_window, mode="reflect"
).rolling( # Take rolling windows across time
time=window_length, center=True, min_periods=min_periods
).median( # Compute the median of each window
skipna=True
).isel( # Remove the padded edges
time=slice(half_window, -half_window)
)
}
)
if print_report:
report_nan_values(ds, "input dataset")
report_nan_values(ds_smoothed, "filtered dataset")
return ds_smoothed
Savgol filter
It exposes the print_report
argument.
@log_to_attrs
def savgol_filter(
ds: xr.Dataset,
window_length: int,
polyorder: int = 2,
print_report: bool = True,
**kwargs,
) -> xr.Dataset:
"""Smooths pose tracks by applying a Savitzky-Golay filter over time.
Parameters
----------
ds : xarray.Dataset
Dataset containing position, confidence scores, and metadata.
window_length : int
The size of the filter window. Window length is interpreted
as being in the input dataset's time unit, which can be inspected
with ``ds.time_unit``.
polyorder : int
The order of the polynomial used to fit the samples. Must be
less than ``window_length``. By default, a ``polyorder`` of
2 is used.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.
**kwargs : dict
Additional keyword arguments are passed to scipy.signal.savgol_filter.
Note that the ``axis`` keyword argument may not be overridden.
Returns
-------
ds_smoothed : xarray.Dataset
The provided dataset (ds), where pose tracks have been smoothed
using a Savitzky-Golay filter with the provided parameters.
Notes
-----
Uses the ``scipy.signal.savgol_filter`` function to apply a Savitzky-Golay
filter to the input dataset's ``position`` variable.
See the scipy documentation for more information on that function.
Whenever one or more NaNs are present in a filter window of the
input dataset, a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input dataset will be propagated
proportionally to the size of the window in frames (specifically, by
``floor(window_length/2)``). Note that, unlike
``movement.filtering.median_filter()``, there is no ``min_periods``
option to control this behaviour.
"""
if "axis" in kwargs:
raise log_error(
ValueError, "The 'axis' argument may not be overridden."
)
ds_smoothed = ds.copy()
if ds.time_unit == "seconds":
window_length = int(window_length * ds.fps)
position_smoothed = signal.savgol_filter(
ds.position,
window_length,
polyorder,
axis=0,
**kwargs,
)
position_smoothed_da = ds.position.copy(data=position_smoothed)
ds_smoothed.update({"position": position_smoothed_da})
if print_report:
report_nan_values(ds, "input dataset")
report_nan_values(ds_smoothed, "filtered dataset")
return ds_smoothed
My proposed additional tests
Integration test for the filtering module
This can go in test_integration/test_filtering.py
.
import numpy as np
import pytest
from movement.filtering import (
filter_by_confidence,
interpolate_over_time,
median_filter,
savgol_filter,
)
from movement.io import load_poses
from movement.sample_data import fetch_dataset_paths
@pytest.fixture(scope="module")
def sample_dataset():
"""Return a single-animal sample dataset, with time unit in frames.
This allows us to better control the expected number of NaNs in the tests.
"""
ds_path = fetch_dataset_paths(
"DLC_single-mouse_EPM.predictions.h5"
)["poses"]
return load_poses.from_dlc_file(ds_path, fps=None)
def _count_nans(ds):
"""Count NaNs in the x coordinate timeseries of the first keypoint of the
first individual in the dataset.
"""
n_nans = np.count_nonzero(
np.isnan(
ds.position.isel(
individuals=0, keypoints=0, space=0
).values
)
)
return n_nans
@pytest.mark.parametrize("window_length", [3, 5, 6, 13])
def test_nan_propagation_through_filters(sample_dataset, window_length):
"""Tests how NaNs are propagated when passing a dataset through multiple
filters sequanetially. For the ``median_filter`` and ``savgol_filter``, we
expect the number of NaNs to increase at most by half the filter's
window length (``floor(window_length/2)``) multiplied by the number of
existing NaNs in the input dataset.
"""
half_window = window_length // 2
# Introduce nans via filter_by_confidence
ds_with_nans = filter_by_confidence(sample_dataset, threshold=0.6)
nans_after_confilt = _count_nans(ds_with_nans)
assert nans_after_confilt == 2555, (
f"Unexpected number of NaNs in filtered dataset: "
f"expected: 2555, got: {nans_after_confilt}"
)
# Apply median filter and check that
# it doesn't introduce too many or too few NaNs
ds_medfilt = median_filter(ds_with_nans, window_length)
nans_after_medfilt = _count_nans(ds_medfilt)
max_nans_increase = half_window * nans_after_confilt
assert nans_after_medfilt <= nans_after_confilt + max_nans_increase, (
"Median filter introduced more NaNs than expected."
)
assert nans_after_medfilt >= nans_after_confilt, (
"Median filter mysteriously removed NaNs."
)
# Apply savgol filter and check that
# it doesn't introduce too many or too few NaNs
ds_savgol = savgol_filter(
ds_medfilt, window_length, polyorder=2, print_report=True
)
nans_after_savgol = _count_nans(ds_savgol)
max_nans_increase = half_window * nans_after_medfilt
assert nans_after_savgol <= nans_after_medfilt + max_nans_increase, (
"Savgol filter introduced more NaNs than expected."
)
assert nans_after_savgol >= nans_after_medfilt, (
"Savgol filter mysteriously removed NaNs."
)
# Apply interpolate_over_time (without max_gap) to eliminate all NaNs
ds_interpolated = interpolate_over_time(ds_savgol, print_report=True)
assert _count_nans(ds_interpolated) == 0
New unit tests
In test_unit/test_filtering.py
. These are in addition to the modifications I proposed to the existing unit tests.
def test_median_filter_with_nans(valid_poses_dataset_with_nan):
"""Test nan behavior of the ``median_filter()`` function. The
``valid_poses_dataset_with_nan`` dataset (fixture defined in conftest.py)
contains NaN values in all keypoints of the first individual at times
3, 7, and 8 (0-indexed, 10 total timepoints).
The median filter should propagate NaNs within the windows of the filter,
but it should not introduce any NaNs for the second individual.
"""
ds_smoothed = median_filter(valid_poses_dataset_with_nan, 3)
# There should be NaNs at 7 timepoints for the first individual
# all except for timepoints 0, 1 and 5
assert count_nans(ds_smoothed) == 7
assert ~ds_smoothed.position.isel(
individuals=0, time=[0, 1, 5]
).isnull().any()
# The second individual should not contain any NaNs
assert ~ds_smoothed.position.sel(individuals="ind2").isnull().any()
@pytest.mark.parametrize("override_kwargs", [
{"mode": "nearest"},
{"axis": 1},
{"mode": "nearest", "axis": 1},
]
)
def test_savgol_filter_kwargs_override(sample_dataset, override_kwargs):
"""Further tests for the ``savgol_filter()`` function.
Checks that the function raises a ValueError when the ``axis`` keyword
argument is overridden, as this is not allowed. Overriding other keyword
arguments (e.g. ``mode``) should not raise an error.
"""
if "axis" in override_kwargs:
with pytest.raises(ValueError):
savgol_filter(sample_dataset, 5, **override_kwargs)
else:
ds_smoothed = savgol_filter(sample_dataset, 5, **override_kwargs)
assert isinstance(ds_smoothed, xr.Dataset)
I realised that my proposed changes would essentially duplicate the count_nans
function across integration and unit tests. I propose using this more general form of the function:
def count_nans(ds):
"""Count NaNs in the x coordinate timeseries of the first keypoint of the
first individual in the dataset.
"""
n_nans = np.count_nonzero(
np.isnan(
ds.position.isel(
individuals=0, keypoints=0, space=0
).values
)
)
return n_nans
and try to define it only once (either, in conftest.py
or somehow import it from one test module to the other. If not possible, don't sweat it, leave as is.
I also noticed that the existing test_filter_by_confidence
function doesn't make use of count_nans
, although it very well could.
Hey @niksirbi, thanks again for the very in-depth feedback! I've implemented pretty much all of your suggestions as-is. Some points of note are the following: Count NaNsI've solved the duplication issue with class Helpers:
"""Generic helper methods for ``movement`` testing modules."""
@staticmethod
def count_nans(ds):
"""Count NaNs in the x coordinate timeseries of the first keypoint
of the first individual in the dataset.
"""
n_nans = np.count_nonzero(
np.isnan(
ds.position.isel(individuals=0, keypoints=0, space=0).values
)
)
return n_nans
@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
return Helpers Integration TestsWhenever NaNs occur in continuous stretches (as they often do in periods of occlusion of poor visibility of the keypoint), that whole block of NaNs is essentially only propagated once, by c. To deal with this, I've therefore tweaked I've also added a new, generic def count_nan_repeats(ds):
"""Count the number of NaN repeats in the x coordinate timeseries
of the first keypoint of the first individual in the dataset.
"""
x = ds.position.isel(individuals=0, keypoints=0, space=0).values
repeats = []
running_count = 1
for i in range(len(x)):
if i != len(x) - 1:
if np.isnan(x[i]) and np.isnan(x[i + 1]):
running_count += 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
else:
running_count = 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
return len(repeats) This should be all for now! Thanks again! |
Quality Gate passedIssues Measures |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting this over the finish line @b-peri! I like the "Helpers" and the way you made the integration test stricter, it makes sense.
I just did a minor tweak of the count_nan_repeats
function to add the @staticmethod
decorator and reword the docstring (for the sake of my future self's understanding).
@staticmethod
def count_nan_repeats(ds):
"""Count the number of continuous stretches of NaNs in the
x coordinate timeseries of the first keypoint of the first individual
in the dataset.
"""
I'll merge this today.
ae702ad
Edited 08/05/2024:
This PR introduces two new smoothing functions to the
filtering
module:median_filter(ds, window_length)
: Smooths pose tracks in the input dataset by applying a median filter along the time dimension. Window length must be specified by the user, and is interpreted as being in the input dataset's time unit (usually seconds). The filter window is centered over the filter origin.savgol_filter(ds, window_length, polyorder, **kwargs)
: Smooths pose tracks over time using a Savitzky-Golay filter. Again, window length must be specified by the user, and is interpreted as being in the input dataset's time unit. The order of the polynomial used to fit the samples can optionally be specified by the user. If omitted, a default value of2
is used. Additional keyword arguments (**kwargs
) are passed toscipy.signal.savgol_filter()
directly, but note that theaxis
kwarg may not be overwritten.What is this PR
References
Closes #55, closes #139
Checklist: