Skip to content

Commit

Permalink
Implement TorchIO transforms wrapper analogous to TorchVision transfo… (
Browse files Browse the repository at this point in the history
#7579)

…rms wrapper and test case

Fixes #7499  .

### Description
As discussed in the issue, this PR implements a wrapper class for
TorchIO transforms, analogous to the TorchVision transforms wrapper.

The test cases just check that transforms are callable and that after
applying a transform, the result is different from the inputs.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Fabian Klopfer <[email protected]>
Signed-off-by: Fabian Klopfer <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Fabian Klopfer <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
4 people authored Nov 28, 2024
1 parent 20372f0 commit 44e249d
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd
# clang format tool
.clang-format-bin/

# ctags
tags

# VSCode
.vscode/
*.zip
Expand Down
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,18 @@ Utility
:members:
:special-members: __call__

`TorchIO`
"""""""""
.. autoclass:: TorchIO
:members:
:special-members: __call__

`RandTorchIO`
"""""""""""""
.. autoclass:: RandTorchIO
:members:
:special-members: __call__

`MapLabelValue`
"""""""""""""""
.. autoclass:: MapLabelValue
Expand Down Expand Up @@ -2253,6 +2265,18 @@ Utility (Dict)
:members:
:special-members: __call__

`TorchIOd`
""""""""""
.. autoclass:: TorchIOd
:members:
:special-members: __call__

`RandTorchIOd`
""""""""""""""
.. autoclass:: RandTorchIOd
:members:
:special-members: __call__

`MapLabelValued`
""""""""""""""""
.. autoclass:: MapLabelValued
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ channels:
dependencies:
- numpy>=1.24,<2.0
- pytorch>=1.9
- torchio
- torchvision
- pytorch-cuda>=11.6
- pip
Expand Down
9 changes: 9 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@
RandIdentity,
RandImageFilter,
RandLambda,
RandTorchIO,
RandTorchVision,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand All @@ -540,6 +542,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -620,6 +623,9 @@
RandLambdad,
RandLambdaD,
RandLambdaDict,
RandTorchIOd,
RandTorchIOD,
RandTorchIODict,
RandTorchVisiond,
RandTorchVisionD,
RandTorchVisionDict,
Expand Down Expand Up @@ -653,6 +659,9 @@
ToPILd,
ToPILD,
ToPILDict,
TorchIOd,
TorchIOD,
TorchIODict,
TorchVisiond,
TorchVisionD,
TorchVisionDict,
Expand Down
109 changes: 103 additions & 6 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import sys
import time
import warnings
from collections.abc import Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from copy import deepcopy
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Union

import numpy as np
import torch
Expand Down Expand Up @@ -99,11 +99,14 @@
"ConvertToMultiChannelBasedOnBratsClasses",
"AddExtremePointsChannel",
"TorchVision",
"TorchIO",
"MapLabelValue",
"IntensityStats",
"ToDevice",
"CuCIM",
"RandCuCIM",
"RandTorchIO",
"RandTorchVision",
"ToCupy",
"ImageFilter",
"RandImageFilter",
Expand Down Expand Up @@ -1136,12 +1139,44 @@ def __call__(
return concatenate((img, points_image), axis=0)


class TorchVision:
class TorchVision(Transform):
"""
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchVision package.
args: parameters for the TorchVision transform.
kwargs: parameters for the TorchVision transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: PyTorch Tensor data for the TorchVision transform.
"""
img_t, *_ = convert_data_type(img, torch.Tensor)

out = self.trans(img_t)
out, *_ = convert_to_dst_type(src=out, dst=img)
return out


class RandTorchVision(Transform, RandomizableTrait):
"""
This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]
Expand Down Expand Up @@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor):
return out


class TorchIO(Transform):
"""
This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values
"""
return self.trans(img)


class RandTorchIO(Transform, RandomizableTrait):
"""
This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
Use this wrapper for all TorchIO transform inheriting from RandomTransform:
https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values
"""
return self.trans(img)


class MapLabelValue:
"""
Utility to map label values to another set of values.
Expand Down
67 changes: 67 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -136,6 +137,9 @@
"RandLambdaD",
"RandLambdaDict",
"RandLambdad",
"RandTorchIOd",
"RandTorchIOD",
"RandTorchIODict",
"RandTorchVisionD",
"RandTorchVisionDict",
"RandTorchVisiond",
Expand Down Expand Up @@ -172,6 +176,9 @@
"ToTensorD",
"ToTensorDict",
"ToTensord",
"TorchIOD",
"TorchIODict",
"TorchIOd",
"TorchVisionD",
"TorchVisionDict",
"TorchVisiond",
Expand Down Expand Up @@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class TorchIOd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
"""

backend = TorchIO.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__(keys, allow_missing_keys)
self.name = name
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
return dict(self.trans(data))


class RandTorchIOd(MapTransform, RandomizableTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
"""

backend = TorchIO.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__(keys, allow_missing_keys)
self.name = name
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
return dict(self.trans(data))


class MapLabelValued(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
Expand Down Expand Up @@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
ConvertToMultiChannelBasedOnBratsClassesd
)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchIOD = TorchIODict = TorchIOd
TorchVisionD = TorchVisionDict = TorchVisiond
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
RandTorchIOD = RandTorchIODict = RandTorchIOd
RandLambdaD = RandLambdaDict = RandLambdad
MapLabelValueD = MapLabelValueDict = MapLabelValued
IntensityStatsD = IntensityStatsDict = IntensityStatsd
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
ninja
torchio
torchvision
psutil
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ all =
tensorboard
gdown>=4.7.3
pytorch-ignite==0.4.11
torchio
torchvision
itk>=5.2
tqdm>=4.47.0
Expand Down Expand Up @@ -102,6 +103,8 @@ gdown =
gdown>=4.7.3
ignite =
pytorch-ignite==0.4.11
torchio =
torchio
torchvision =
torchvision
itk =
Expand Down
Loading

0 comments on commit 44e249d

Please sign in to comment.