Skip to content

Commit

Permalink
SavitzkyGolay simple layer checks for PyTorch > 1.5.1 if padding mode…
Browse files Browse the repository at this point in the history
… not zeros, and version dependent tests added.

Signed-off-by: Christian Baker <[email protected]>
  • Loading branch information
crnbaker committed Jan 13, 2021
1 parent 204c02e commit cda44a4
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 209 deletions.
18 changes: 12 additions & 6 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"LLTM",
"Reshape",
"separable_filtering",
"SavitskyGolayFilter",
"SavitzkyGolayFilter",
"HilbertTransform",
"ChannelPad",
]
Expand Down Expand Up @@ -174,12 +174,18 @@ def separable_filtering(
x: the input image. must have shape (batch, channels, H[, W, ...]).
kernels: kernel along each spatial dimension.
could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels.
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See
torch.nn.Conv1d() for more information.
Raises:
TypeError: When ``x`` is not a ``torch.Tensor``.
"""

pytorch_version_pre_1_5_1 = tuple(int(x[0]) for x in torch.__version__.split(".")[:3]) < (1, 5, 1)
if (mode != "zeros") and pytorch_version_pre_1_5_1:
raise InvalidPyTorchVersionError("1.5.1", f"Padding mode '{mode}'")

if not torch.is_tensor(x):
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")

Expand Down Expand Up @@ -219,9 +225,9 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor:
return _conv(x, spatial_dims - 1)


class SavitskyGolayFilter(nn.Module):
class SavitzkyGolayFilter(nn.Module):
"""
Convolve a Tensor along a particular axis with a Savitsky-Golay kernel.
Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.
Args:
window_length: Length of the filter window, must be a positive odd integer.
Expand All @@ -247,7 +253,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
have a device type of ``'cpu'``.
Returns:
torch.Tensor: ``x`` filtered by Savitsky-Golay kernel with window length ``self.window_length`` using
torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
polynomials of order ``self.order``, along axis specified in ``self.axis``.
"""

Expand Down
298 changes: 162 additions & 136 deletions tests/test_savitzky_golay_filter.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,162 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.layers import SavitskyGolayFilter
from tests.utils import skip_if_no_cuda

# Zero-padding trivial tests

TEST_CASE_SINGLE_VALUE = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([2 / 3, 1.0, 2 / 3])
.unsqueeze(0)
.unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2 = [
{"window_length": 3, "order": 1}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3 = [
{"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Replicated-padding trivial tests

TEST_CASE_SINGLE_VALUE_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3_REP = [
{"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Sine smoothing

TEST_CASE_SINE_SMOOTH = [
{"window_length": 3, "order": 1},
# Sine wave with period equal to savgol window length (windowed to reduce edge effects).
torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0),
# Should be smoothed out to zeros
torch.zeros(100).unsqueeze(0).unsqueeze(0),
# tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input
2e-2, # absolute tolerance
]


class TestSavitskyGolayCPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINGLE_VALUE_REP,
TEST_CASE_1D_REP,
TEST_CASE_2D_AXIS_2_REP,
TEST_CASE_2D_AXIS_3_REP,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitskyGolayFilter(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)


@skip_if_no_cuda
class TestSavitskyGolayGPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINGLE_VALUE_REP,
TEST_CASE_1D_REP,
TEST_CASE_2D_AXIS_2_REP,
TEST_CASE_2D_AXIS_3_REP,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitskyGolayFilter(**arguments)(image.to(device="cuda"))
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol)
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.layers import SavitzkyGolayFilter
from monai.utils import InvalidPyTorchVersionError
from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_no_cuda

# Zero-padding trivial tests

TEST_CASE_SINGLE_VALUE = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([2 / 3, 1.0, 2 / 3])
.unsqueeze(0)
.unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2 = [
{"window_length": 3, "order": 1}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3 = [
{"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Replicated-padding trivial tests

TEST_CASE_SINGLE_VALUE_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3_REP = [
{"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Sine smoothing

TEST_CASE_SINE_SMOOTH = [
{"window_length": 3, "order": 1},
# Sine wave with period equal to savgol window length (windowed to reduce edge effects).
torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0),
# Should be smoothed out to zeros
torch.zeros(100).unsqueeze(0).unsqueeze(0),
# tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input
2e-2, # absolute tolerance
]


class TestSavitzkyGolayCPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)


@SkipIfBeforePyTorchVersion((1, 5, 1))
class TestSavitzkyGolayCPUREP(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)


@skip_if_no_cuda
class TestSavitzkyGolayGPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol)


@skip_if_no_cuda
@SkipIfBeforePyTorchVersion((1, 5, 1))
class TestSavitzkyGolayGPUREP(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE_REP,
TEST_CASE_1D_REP,
TEST_CASE_2D_AXIS_2_REP,
TEST_CASE_2D_AXIS_3_REP,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol)


@SkipIfAtLeastPyTorchVersion((1, 5, 1))
class TestSavitzkyGolayInvalidPyTorch(unittest.TestCase):
def test_invalid_pytorch_error(self):
with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"):
SavitzkyGolayFilter(3, 1, mode="replicate")(torch.ones((1, 1, 10, 10)))
Loading

0 comments on commit cda44a4

Please sign in to comment.