Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RandSimulateLowResolution(d) array and dictionary transforms and corresponding unit tests #6806

Merged
Merged
13 changes: 13 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,12 @@ Spatial
:members:
:special-members: __call__

`RandSimulateLowResolution`
"""""""""""""""""""""""""""
.. autoclass:: RandSimulateLowResolution
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^
Expand Down Expand Up @@ -1886,6 +1892,13 @@ Spatial (Dict)
:members:
:special-members: __call__

`RandSimulateLowResolutiond`
""""""""""""""""""""""""""""
.. autoclass:: RandSimulateLowResolutiond
:members:
:special-members: __call__


Smooth Field (Dict)
^^^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@
RandGridPatch,
RandRotate,
RandRotate90,
RandSimulateLowResolution,
RandZoom,
Resample,
ResampleToMatch,
Expand Down Expand Up @@ -437,6 +438,9 @@
RandRotated,
RandRotateD,
RandRotateDict,
RandSimulateLowResolutiond,
RandSimulateLowResolutionD,
RandSimulateLowResolutionDict,
RandZoomd,
RandZoomD,
RandZoomDict,
Expand Down
95 changes: 94 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
Expand Down Expand Up @@ -111,6 +111,7 @@
"RandAffine",
"Rand2DElastic",
"Rand3DElastic",
"RandSimulateLowResolution",
]

RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]
Expand Down Expand Up @@ -3456,3 +3457,95 @@ def __call__(self, array: NdarrayOrTensor, randomize: bool = True):
if randomize:
self.randomize(array)
return super().__call__(array)


class RandSimulateLowResolution(RandomizableTransform):
"""
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
"""

backend = Affine.backend

def __init__(
self,
prob: float = 0.1,
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
zoom_range: Sequence[float] = (0.5, 1.0),
align_corners=False,
device: torch.device | None = None,
) -> None:
"""
Args:
prob: probability of performing this augmentation
downsample_mode: interpolation mode for downsampling operation
upsample_mode: interpolation mode for upsampling operation
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
sampled. It determines the shape of the downsampled tensor.
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
'bicubic' or 'trilinear'. Default: False
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
device: device on which the tensor will be allocated.

"""
RandomizableTransform.__init__(self, prob)

self.downsample_mode = downsample_mode
self.upsample_mode = upsample_mode
self.zoom_range = zoom_range
self.align_corners = align_corners
self.device = device
self.zoom_factor = 1.0

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
self.zoom_factor = self.R.uniform(self.zoom_range[0], self.zoom_range[1])
if not self._do_transform:
return None

def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
randomize: whether to execute `randomize()` function first, defaults to True.
"""
if randomize:
self.randomize()

if self._do_transform:
input_shape = img.shape[1:]
target_shape = np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_)

resize_tfm_downsample = Resize(
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
)

resize_tfm_upsample = Resize(
spatial_size=input_shape,
size_mode="all",
mode=self.upsample_mode,
anti_aliasing=False,
align_corners=self.align_corners,
)
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
# post-processing
original_tack_meta_value = get_track_meta()
set_track_meta(False)

img_downsampled = resize_tfm_downsample(img)
img_upsampled = resize_tfm_upsample(img_downsampled)

# reset metadata tracking to original value
set_track_meta(original_tack_meta_value)

# copy metadata from original image to down-and-upsampled image
img_upsampled = MetaTensor(img_upsampled)
img_upsampled.copy_meta_from(img)

return img_upsampled

else:
return img
93 changes: 93 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RandGridDistortion,
RandGridPatch,
RandRotate,
RandSimulateLowResolution,
RandZoom,
ResampleToMatch,
Resize,
Expand Down Expand Up @@ -140,6 +141,9 @@
"RandGridPatchd",
"RandGridPatchD",
"RandGridPatchDict",
"RandSimulateLowResolutiond",
"RandSimulateLowResolutionD",
"RandSimulateLowResolutionDict",
]


Expand Down Expand Up @@ -2518,6 +2522,94 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RandSimulateLowResolution`.
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
"""

backend = RandAffine.backend

def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
zoom_range=(0.5, 1.0),
align_corners=False,
allow_missing_keys: bool = False,
device: torch.device | None = None,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
prob: probability of performing this augmentation
downsample_mode: interpolation mode for downsampling operation
upsample_mode: interpolation mode for upsampling operation
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
sampled. It determines the shape of the downsampled tensor.
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
'bicubic' or 'trilinear'. Default: False
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
allow_missing_keys: don't raise exception if key is missing.
device: device on which the tensor will be allocated.

See also:
- :py:class:`monai.transforms.compose.MapTransform`

"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)

self.downsample_mode = downsample_mode
self.upsample_mode = upsample_mode
self.zoom_range = zoom_range
self.align_corners = align_corners
self.device = device

self.sim_lowres_tfm = RandSimulateLowResolution(
prob=1.0, # probability is handled by dictionary class
downsample_mode=self.downsample_mode,
upsample_mode=self.upsample_mode,
zoom_range=self.zoom_range,
align_corners=self.align_corners,
device=self.device,
)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandSimulateLowResolutiond:
super().set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
"""
Args:
data: a dictionary containing the tensor-like data to be transformed. The ``keys`` specified
in this dictionary must be tensor like arrays that are channel first and have at most
three spatial dimensions
"""
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out

self.randomize(None)

for key in self.key_iterator(d):
# do the transform
if self._do_transform:
d[key] = self.sim_lowres_tfm(d[key]) # type: ignore
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
return d


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2541,3 +2633,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
GridSplitD = GridSplitDict = GridSplitd
GridPatchD = GridPatchDict = GridPatchd
RandGridPatchD = RandGridPatchDict = RandGridPatchd
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
83 changes: 83 additions & 0 deletions tests/test_rand_simulate_low_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import RandSimulateLowResolution
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
[
dict(prob=1.0, zoom_range=(0.8, 0.81)),
p(
np.array(
[
[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],
[[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]],
[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]],
]
]
)
),
np.array(
[
[
[
[0.0000, 0.6250, 1.3750, 2.0000],
[2.5000, 3.1250, 3.8750, 4.5000],
[5.5000, 6.1250, 6.8750, 7.5000],
[8.0000, 8.6250, 9.3750, 10.0000],
],
[
[10.0000, 10.6250, 11.3750, 12.0000],
[12.5000, 13.1250, 13.8750, 14.5000],
[15.5000, 16.1250, 16.8750, 17.5000],
[18.0000, 18.6250, 19.3750, 20.0000],
],
[
[22.0000, 22.6250, 23.3750, 24.0000],
[24.5000, 25.1250, 25.8750, 26.5000],
[27.5000, 28.1250, 28.8750, 29.5000],
[30.0000, 30.6250, 31.3750, 32.0000],
],
[
[32.0000, 32.6250, 33.3750, 34.0000],
[34.5000, 35.1250, 35.8750, 36.5000],
[37.5000, 38.1250, 38.8750, 39.5000],
[40.0000, 40.6250, 41.3750, 42.0000],
],
]
]
),
]
)


class TestRandGaussianSmooth(unittest.TestCase):
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
randsimlowres = RandSimulateLowResolution(**arguments)
randsimlowres.set_random_state(seed=0)
result = randsimlowres(image)
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")


if __name__ == "__main__":
unittest.main()
Loading