diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b931b88e97..b8ada53b39 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -420,6 +420,8 @@ def randomize(self, spatial_size) -> None: # type: ignore # see issue #495 def __call__(self, data): d = dict(data) spatial_size = self.rand_2d_elastic.spatial_size + if np.any([sz <= 1 for sz in spatial_size]): + spatial_size = data[self.keys[0]].shape[1:] self.randomize(spatial_size) if self.rand_2d_elastic.do_transform: @@ -508,6 +510,8 @@ def randomize(self, grid_size) -> None: # type: ignore # see issue #495 def __call__(self, data): d = dict(data) spatial_size = self.rand_3d_elastic.spatial_size + if np.any([sz <= 1 for sz in spatial_size]): + spatial_size = data[self.keys[0]].shape[1:] self.randomize(spatial_size) grid = create_grid(spatial_size) if self.rand_3d_elastic.do_transform: diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index 0dc7ce4aa4..f981e27e65 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -31,6 +31,19 @@ {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, np.ones((3, 2, 2)), ], + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "as_tensor_output": False, + "device": None, + "spatial_size": (-1,), + }, + {"img": torch.ones((1, 2, 2)), "seg": torch.ones((1, 2, 2))}, + np.array([[[0.25, 0.25], [0.25, 0.25]]]), + ], [ { "keys": ("img", "seg"), diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 43bc297277..3c2c676747 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -31,6 +31,19 @@ {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, np.ones((2, 2, 2, 2)), ], + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "as_tensor_output": False, + "device": None, + "spatial_size": (-1,), + }, + {"img": torch.ones((1, 2, 2, 2)), "seg": torch.ones((1, 2, 2, 2))}, + np.array([[[[0.125, 0.125], [0.125, 0.125]], [[0.125, 0.125], [0.125, 0.125]]]]), + ], [ { "keys": ("img", "seg"),