diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index e257a9d5218..4de4a9ea529 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -39,7 +39,7 @@ "LLTM", "Reshape", "separable_filtering", - "SavitskyGolayFilter", + "SavitzkyGolayFilter", "HilbertTransform", "ChannelPad", ] @@ -174,12 +174,18 @@ def separable_filtering( x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. - mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or - ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See + torch.nn.Conv1d() for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ + + pytorch_version_pre_1_5_1 = tuple(int(x[0]) for x in torch.__version__.split(".")[:3]) < (1, 5, 1) + if (mode != "zeros") and pytorch_version_pre_1_5_1: + raise InvalidPyTorchVersionError("1.5.1", f"Padding mode '{mode}'") + if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") @@ -219,9 +225,9 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: return _conv(x, spatial_dims - 1) -class SavitskyGolayFilter(nn.Module): +class SavitzkyGolayFilter(nn.Module): """ - Convolve a Tensor along a particular axis with a Savitsky-Golay kernel. + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. Args: window_length: Length of the filter window, must be a positive odd integer. @@ -247,7 +253,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and have a device type of ``'cpu'``. Returns: - torch.Tensor: ``x`` filtered by Savitsky-Golay kernel with window length ``self.window_length`` using + torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using polynomials of order ``self.order``, along axis specified in ``self.axis``. """ diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index 625adc33f7c..f546b823fb6 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -1,136 +1,162 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from monai.networks.layers import SavitskyGolayFilter -from tests.utils import skip_if_no_cuda - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([2 / 3, 1.0, 2 / 3]) - .unsqueeze(0) - .unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3 = [ - {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3_REP = [ - {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), - # Should be smoothed out to zeros - torch.zeros(100).unsqueeze(0).unsqueeze(0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitskyGolayCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -@skip_if_no_cuda -class TestSavitskyGolayGPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitzkyGolayFilter +from monai.utils import InvalidPyTorchVersionError +from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@SkipIfBeforePyTorchVersion((1, 5, 1)) +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +@SkipIfBeforePyTorchVersion((1, 5, 1)) +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@SkipIfAtLeastPyTorchVersion((1, 5, 1)) +class TestSavitzkyGolayInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"): + SavitzkyGolayFilter(3, 1, mode="replicate")(torch.ones((1, 1, 10, 10))) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 91c999aea95..307af681d25 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -1,65 +1,80 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import SavitskyGolaySmooth - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) - np.expand_dims(np.ones((2, 3)), 0), - np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), - # Should be smoothed out to zeros - np.expand_dims(np.zeros(100), 0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitskyGolaySmooth(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_SINE_SMOOTH] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmooth +from monai.utils import InvalidPyTorchVersionError +from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@SkipIfBeforePyTorchVersion((1, 5, 1)) +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@SkipIfAtLeastPyTorchVersion((1, 5, 1)) +class TestSavitzkyGolayInvalidPyTorch(unittest.TestCase): + def test_invalid_pytorch_error(self): + with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"): + SavitzkyGolaySmooth(3, 1, mode="replicate")(np.ones((1, 1, 10, 10))) diff --git a/tests/utils.py b/tests/utils.py index 157eef8affe..372216425c0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -110,8 +110,12 @@ def __init__(self, pytorch_version_tuple): self.min_version = pytorch_version_tuple if has_pkg_res: self.version_too_old = ver(torch.__version__) < ver(".".join(map(str, self.min_version))) - else: + elif len(self.min_version) == 2: self.version_too_old = get_torch_version_tuple() < self.min_version + elif len(self.min_version) == 3: + self.version_too_old = tuple(int(x) for x in torch.__version__.split(".")[:3]) < self.min_version + else: + raise ValueError("Invalid PyTorch version tuple. Must be length 2 or 3 tuple.") def __call__(self, obj): return unittest.skipIf( @@ -127,8 +131,12 @@ def __init__(self, pytorch_version_tuple): self.max_version = pytorch_version_tuple if has_pkg_res: self.version_too_new = ver(torch.__version__) >= ver(".".join(map(str, self.max_version))) - else: + elif len(self.max_version) == 2: self.version_too_new = get_torch_version_tuple() >= self.max_version + elif len(self.max_version) == 3: + self.version_too_old = tuple(int(x[0]) for x in torch.__version__.split(".")[:3]) < self.max_version + else: + raise ValueError("Invalid PyTorch version tuple. Must be length 2 or 3 tuple.") def __call__(self, obj): return unittest.skipIf(