From 44e249d7d492d858199acfca1c948faa5aa33763 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Thu, 28 Nov 2024 08:35:29 +0100 Subject: [PATCH] =?UTF-8?q?Implement=20TorchIO=20transforms=20wrapper=20an?= =?UTF-8?q?alogous=20to=20TorchVision=20transfo=E2=80=A6=20(#7579)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rms wrapper and test case Fixes #7499 . ### Description As discussed in the issue, this PR implements a wrapper class for TorchIO transforms, analogous to the TorchVision transforms wrapper. The test cases just check that transforms are callable and that after applying a transform, the result is different from the inputs. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Fabian Klopfer Signed-off-by: Fabian Klopfer Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Fabian Klopfer Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- .gitignore | 3 + docs/source/transforms.rst | 24 ++++++ environment-dev.yml | 1 + monai/transforms/__init__.py | 9 ++ monai/transforms/utility/array.py | 109 +++++++++++++++++++++++-- monai/transforms/utility/dictionary.py | 67 +++++++++++++++ requirements-dev.txt | 1 + setup.cfg | 3 + tests/test_rand_torchio.py | 54 ++++++++++++ tests/test_rand_torchiod.py | 44 ++++++++++ tests/test_torchio.py | 41 ++++++++++ tests/test_torchiod.py | 47 +++++++++++ 12 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 tests/test_rand_torchio.py create mode 100644 tests/test_rand_torchiod.py create mode 100644 tests/test_torchio.py create mode 100644 tests/test_torchiod.py diff --git a/.gitignore b/.gitignore index 437677d2bb..76c6ab0d12 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd # clang format tool .clang-format-bin/ +# ctags +tags + # VSCode .vscode/ *.zip diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 41bb4ae79a..d2585daf63 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1180,6 +1180,18 @@ Utility :members: :special-members: __call__ +`TorchIO` +""""""""" +.. autoclass:: TorchIO + :members: + :special-members: __call__ + +`RandTorchIO` +""""""""""""" +.. autoclass:: RandTorchIO + :members: + :special-members: __call__ + `MapLabelValue` """"""""""""""" .. autoclass:: MapLabelValue @@ -2253,6 +2265,18 @@ Utility (Dict) :members: :special-members: __call__ +`TorchIOd` +"""""""""" +.. autoclass:: TorchIOd + :members: + :special-members: __call__ + +`RandTorchIOd` +"""""""""""""" +.. autoclass:: RandTorchIOd + :members: + :special-members: __call__ + `MapLabelValued` """""""""""""""" .. autoclass:: MapLabelValued diff --git a/environment-dev.yml b/environment-dev.yml index a4651ec7e4..4a1723e8a5 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,6 +7,7 @@ channels: dependencies: - numpy>=1.24,<2.0 - pytorch>=1.9 + - torchio - torchvision - pytorch-cuda>=11.6 - pip diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2cdd965c91..d15042181b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -531,6 +531,8 @@ RandIdentity, RandImageFilter, RandLambda, + RandTorchIO, + RandTorchVision, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -540,6 +542,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -620,6 +623,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandTorchIOd, + RandTorchIOD, + RandTorchIODict, RandTorchVisiond, RandTorchVisionD, RandTorchVisionDict, @@ -653,6 +659,9 @@ ToPILd, ToPILD, ToPILDict, + TorchIOd, + TorchIOD, + TorchIODict, TorchVisiond, TorchVisionD, TorchVisionDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1b3c59afdb..84422a9ee5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -18,10 +18,10 @@ import sys import time import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Union import numpy as np import torch @@ -99,11 +99,14 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "TorchIO", "MapLabelValue", "IntensityStats", "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1136,12 +1139,44 @@ def __call__( return concatenate((img, points_image), axis=0) -class TorchVision: +class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class RandTorchVision(Transform, RandomizableTrait): + """ + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ backend = [TransformBackends.TORCH] @@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor): return out +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + class MapLabelValue: """ Utility to map label values to another set of values. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 65c721e48e..7dd2397a74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -60,6 +60,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -136,6 +137,9 @@ "RandLambdaD", "RandLambdaDict", "RandLambdad", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandTorchVisionD", "RandTorchVisionDict", "RandTorchVisiond", @@ -172,6 +176,9 @@ "ToTensorD", "ToTensorDict", "ToTensord", + "TorchIOD", + "TorchIODict", + "TorchIOd", "TorchVisionD", "TorchVisionDict", "TorchVisiond", @@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class TorchIOd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + class MapLabelValued(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. @@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesd ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld +TorchIOD = TorchIODict = TorchIOd TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond +RandTorchIOD = RandTorchIODict = RandTorchIOd RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd diff --git a/requirements-dev.txt b/requirements-dev.txt index 72654d3534..bffe304df4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows" types-setuptools mypy>=1.5.0, <1.12.0 ninja +torchio torchvision psutil cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10" diff --git a/setup.cfg b/setup.cfg index 694dc969d9..ecfd717aff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,7 @@ all = tensorboard gdown>=4.7.3 pytorch-ignite==0.4.11 + torchio torchvision itk>=5.2 tqdm>=4.47.0 @@ -102,6 +103,8 @@ gdown = gdown>=4.7.3 ignite = pytorch-ignite==0.4.11 +torchio = + torchio torchvision = torchvision itk = diff --git a/tests/test_rand_torchio.py b/tests/test_rand_torchio.py new file mode 100644 index 0000000000..ab212d4a11 --- /dev/null +++ b/tests/test_rand_torchio.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIO +from monai.utils import optional_import, set_determinism + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [ + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_torchiod.py b/tests/test_rand_torchiod.py new file mode 100644 index 0000000000..52bcf7c576 --- /dev/null +++ b/tests/test_rand_torchiod.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIOd +from monai.utils import optional_import, set_determinism +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIOd(**input_param)(input_data) + self.assertFalse(np.allclose(input_data["img1"], result["img1"], atol=1e-6, rtol=1e-6)) + assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchio.py b/tests/test_torchio.py new file mode 100644 index 0000000000..d2d598ca4c --- /dev/null +++ b/tests/test_torchio.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import TorchIO +from monai.utils import optional_import + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + result = TorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py new file mode 100644 index 0000000000..892287461c --- /dev/null +++ b/tests/test_torchiod.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.transforms import TorchIOd +from monai.utils import optional_import +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [ + [ + {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, + {"img": TEST_TENSOR}, + ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, + ] +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_value(self, input_param, input_data, expected_value): + result = TorchIOd(**input_param)(input_data) + assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main()