diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index dd6b5e31936..a0cd38e9d3b 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -297,15 +297,23 @@ class RandSpatialCrop(Randomizable): roi_size: if `random_size` is True, it specifies the minimum crop region. if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] If its components have non-positive values, the corresponding size of input image will be used. + max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size` + can specify the max crop region size. if None, defaults to the input image size. + if its components have non-positive values, the corresponding size of input image will be used. random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. """ def __init__( - self, roi_size: Union[Sequence[int], int], random_center: bool = True, random_size: bool = True + self, + roi_size: Union[Sequence[int], int], + max_roi_size: Optional[Union[Sequence[int], int]] = None, + random_center: bool = True, + random_size: bool = True, ) -> None: self.roi_size = roi_size + self.max_roi_size = max_roi_size self.random_center = random_center self.random_size = random_size self._size: Optional[Sequence[int]] = None @@ -314,7 +322,10 @@ def __init__( def randomize(self, img_size: Sequence[int]) -> None: self._size = fall_back_tuple(self.roi_size, img_size) if self.random_size: - self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size)))) + max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) + if any([i > j for i, j in zip(self._size, max_size)]): + raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") + self._size = tuple((self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size)))) if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) @@ -341,9 +352,13 @@ class RandSpatialCropSamples(Randomizable): It will return a list of cropped images. Args: - roi_size: if `random_size` is True, the spatial size of the minimum crop region. - if `random_size` is False, specify the expected ROI size to crop. e.g. [224, 224, 128] + roi_size: if `random_size` is True, it specifies the minimum crop region. + if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] + If its components have non-positive values, the corresponding size of input image will be used. num_samples: number of samples (crop regions) to take in the returned list. + max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size` + can specify the max crop region size. if None, defaults to the input image size. + if its components have non-positive values, the corresponding size of input image will be used. random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. @@ -357,13 +372,14 @@ def __init__( self, roi_size: Union[Sequence[int], int], num_samples: int, + max_roi_size: Optional[Union[Sequence[int], int]] = None, random_center: bool = True, random_size: bool = True, ) -> None: if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCrop(roi_size, random_center, random_size) + self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 90ba2d601ac..2fad9991d25 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -405,6 +405,9 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): roi_size: if `random_size` is True, it specifies the minimum crop region. if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] If its components have non-positive values, the corresponding size of input image will be used. + max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size` + can specify the max crop region size. if None, defaults to the input image size. + if its components have non-positive values, the corresponding size of input image will be used. random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. @@ -415,12 +418,14 @@ def __init__( self, keys: KeysCollection, roi_size: Union[Sequence[int], int], + max_roi_size: Optional[Union[Sequence[int], int]] = None, random_center: bool = True, random_size: bool = True, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) self.roi_size = roi_size + self.max_roi_size = max_roi_size self.random_center = random_center self.random_size = random_size self._slices: Optional[Tuple[slice, ...]] = None @@ -429,7 +434,10 @@ def __init__( def randomize(self, img_size: Sequence[int]) -> None: self._size = fall_back_tuple(self.roi_size, img_size) if self.random_size: - self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] + max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size) + if any([i > j for i, j in zip(self._size, max_size)]): + raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.") + self._size = [self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))] if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) @@ -494,9 +502,13 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: monai.transforms.MapTransform - roi_size: if `random_size` is True, the spatial size of the minimum crop region. - if `random_size` is False, specify the expected ROI size to crop. e.g. [224, 224, 128] + roi_size: if `random_size` is True, it specifies the minimum crop region. + if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128] + If its components have non-positive values, the corresponding size of input image will be used. num_samples: number of samples (crop regions) to take in the returned list. + max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size` + can specify the max crop region size. if None, defaults to the input image size. + if its components have non-positive values, the corresponding size of input image will be used. random_center: crop at random position as center or the image center. random_size: crop with random size or specific size ROI. The actual size is sampled from `randint(roi_size, img_size)`. @@ -515,6 +527,7 @@ def __init__( keys: KeysCollection, roi_size: Union[Sequence[int], int], num_samples: int, + max_roi_size: Optional[Union[Sequence[int], int]] = None, random_center: bool = True, random_size: bool = True, meta_key_postfix: str = "meta_dict", @@ -524,7 +537,7 @@ def __init__( if num_samples < 1: raise ValueError(f"num_samples must be positive, got {num_samples}.") self.num_samples = num_samples - self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys) + self.cropper = RandSpatialCropd(keys, roi_size, max_roi_size, random_center, random_size, allow_missing_keys) self.meta_key_postfix = meta_key_postfix def set_random_state( diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 8865cd7adad..69385a8b249 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -158,9 +158,9 @@ ) ) -TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], True, False))) +TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False))) -TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) +TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) TESTS.append( ( diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 3dadbe3d92e..4f6e9a25fda 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -34,7 +34,6 @@ def test_shape(self): output = [] n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): - print(item) output.append("".join(item)) expected = ["vwx", "yzh", "ell", "owo", "rld"] self.assertEqual(output, expected) diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 7ee3db11312..01e057e5891 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -35,6 +35,18 @@ np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), ] +TEST_CASE_4 = [ + {"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + np.random.randint(0, 2, size=[1, 4, 5, 6]), + (1, 4, 4, 3), +] + +TEST_CASE_5 = [ + {"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, + np.random.randint(0, 2, size=[1, 4, 5, 6]), + (1, 3, 4, 3), +] + class TestRandSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) @@ -49,6 +61,13 @@ def test_value(self, input_param, input_data): roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) + def test_random_shape(self, input_param, input_data, expected_shape): + cropper = RandSpatialCrop(**input_param) + cropper.set_random_state(seed=123) + result = cropper(input_data) + self.assertTupleEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_spatial_cropd.py b/tests/test_rand_spatial_cropd.py index 2e6a2747fbd..610c1974aa1 100644 --- a/tests/test_rand_spatial_cropd.py +++ b/tests/test_rand_spatial_cropd.py @@ -39,6 +39,18 @@ {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, ] +TEST_CASE_4 = [ + {"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True}, + {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, + (1, 4, 4, 3), +] + +TEST_CASE_5 = [ + {"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True}, + {"img": np.random.randint(0, 2, size=[1, 4, 5, 6])}, + (1, 3, 4, 3), +] + class TestRandSpatialCropd(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) @@ -53,6 +65,13 @@ def test_value(self, input_param, input_data): roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) + def test_random_shape(self, input_param, input_data, expected_shape): + cropper = RandSpatialCropd(**input_param) + cropper.set_random_state(seed=123) + result = cropper(input_data) + self.assertTupleEqual(result["img"].shape, expected_shape) + if __name__ == "__main__": unittest.main()