Skip to content

Commit

Permalink
2177 Add max_roi_size to RandSpatialCrop (#2178)
Browse files Browse the repository at this point in the history
* [DLMED] add max_roi_size

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add max_roi_size

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] optimize logic

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>
  • Loading branch information
Nic-Ma authored and wyli committed May 27, 2021
1 parent 7997e68 commit aefa54d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 12 deletions.
26 changes: 21 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,23 @@ class RandSpatialCrop(Randomizable):
roi_size: if `random_size` is True, it specifies the minimum crop region.
if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]
If its components have non-positive values, the corresponding size of input image will be used.
max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`
can specify the max crop region size. if None, defaults to the input image size.
if its components have non-positive values, the corresponding size of input image will be used.
random_center: crop at random position as center or the image center.
random_size: crop with random size or specific size ROI.
The actual size is sampled from `randint(roi_size, img_size)`.
"""

def __init__(
self, roi_size: Union[Sequence[int], int], random_center: bool = True, random_size: bool = True
self,
roi_size: Union[Sequence[int], int],
max_roi_size: Optional[Union[Sequence[int], int]] = None,
random_center: bool = True,
random_size: bool = True,
) -> None:
self.roi_size = roi_size
self.max_roi_size = max_roi_size
self.random_center = random_center
self.random_size = random_size
self._size: Optional[Sequence[int]] = None
Expand All @@ -314,7 +322,10 @@ def __init__(
def randomize(self, img_size: Sequence[int]) -> None:
self._size = fall_back_tuple(self.roi_size, img_size)
if self.random_size:
self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))))
max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size)
if any([i > j for i, j in zip(self._size, max_size)]):
raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.")
self._size = tuple((self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))))
if self.random_center:
valid_size = get_valid_patch_size(img_size, self._size)
self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R)
Expand All @@ -341,9 +352,13 @@ class RandSpatialCropSamples(Randomizable):
It will return a list of cropped images.
Args:
roi_size: if `random_size` is True, the spatial size of the minimum crop region.
if `random_size` is False, specify the expected ROI size to crop. e.g. [224, 224, 128]
roi_size: if `random_size` is True, it specifies the minimum crop region.
if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]
If its components have non-positive values, the corresponding size of input image will be used.
num_samples: number of samples (crop regions) to take in the returned list.
max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`
can specify the max crop region size. if None, defaults to the input image size.
if its components have non-positive values, the corresponding size of input image will be used.
random_center: crop at random position as center or the image center.
random_size: crop with random size or specific size ROI.
The actual size is sampled from `randint(roi_size, img_size)`.
Expand All @@ -357,13 +372,14 @@ def __init__(
self,
roi_size: Union[Sequence[int], int],
num_samples: int,
max_roi_size: Optional[Union[Sequence[int], int]] = None,
random_center: bool = True,
random_size: bool = True,
) -> None:
if num_samples < 1:
raise ValueError(f"num_samples must be positive, got {num_samples}.")
self.num_samples = num_samples
self.cropper = RandSpatialCrop(roi_size, random_center, random_size)
self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
Expand Down
21 changes: 17 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform):
roi_size: if `random_size` is True, it specifies the minimum crop region.
if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]
If its components have non-positive values, the corresponding size of input image will be used.
max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`
can specify the max crop region size. if None, defaults to the input image size.
if its components have non-positive values, the corresponding size of input image will be used.
random_center: crop at random position as center or the image center.
random_size: crop with random size or specific size ROI.
The actual size is sampled from `randint(roi_size, img_size)`.
Expand All @@ -415,12 +418,14 @@ def __init__(
self,
keys: KeysCollection,
roi_size: Union[Sequence[int], int],
max_roi_size: Optional[Union[Sequence[int], int]] = None,
random_center: bool = True,
random_size: bool = True,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
self.roi_size = roi_size
self.max_roi_size = max_roi_size
self.random_center = random_center
self.random_size = random_size
self._slices: Optional[Tuple[slice, ...]] = None
Expand All @@ -429,7 +434,10 @@ def __init__(
def randomize(self, img_size: Sequence[int]) -> None:
self._size = fall_back_tuple(self.roi_size, img_size)
if self.random_size:
self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))]
max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size)
if any([i > j for i, j in zip(self._size, max_size)]):
raise ValueError(f"min ROI size: {self._size} is bigger than max ROI size: {max_size}.")
self._size = [self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size))]
if self.random_center:
valid_size = get_valid_patch_size(img_size, self._size)
self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R)
Expand Down Expand Up @@ -494,9 +502,13 @@ class RandSpatialCropSamplesd(Randomizable, MapTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
roi_size: if `random_size` is True, the spatial size of the minimum crop region.
if `random_size` is False, specify the expected ROI size to crop. e.g. [224, 224, 128]
roi_size: if `random_size` is True, it specifies the minimum crop region.
if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]
If its components have non-positive values, the corresponding size of input image will be used.
num_samples: number of samples (crop regions) to take in the returned list.
max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`
can specify the max crop region size. if None, defaults to the input image size.
if its components have non-positive values, the corresponding size of input image will be used.
random_center: crop at random position as center or the image center.
random_size: crop with random size or specific size ROI.
The actual size is sampled from `randint(roi_size, img_size)`.
Expand All @@ -515,6 +527,7 @@ def __init__(
keys: KeysCollection,
roi_size: Union[Sequence[int], int],
num_samples: int,
max_roi_size: Optional[Union[Sequence[int], int]] = None,
random_center: bool = True,
random_size: bool = True,
meta_key_postfix: str = "meta_dict",
Expand All @@ -524,7 +537,7 @@ def __init__(
if num_samples < 1:
raise ValueError(f"num_samples must be positive, got {num_samples}.")
self.num_samples = num_samples
self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size, allow_missing_keys)
self.cropper = RandSpatialCropd(keys, roi_size, max_roi_size, random_center, random_size, allow_missing_keys)
self.meta_key_postfix = meta_key_postfix

def set_random_state(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@
)
)

TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], True, False)))
TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False)))

TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], False, False)))
TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False)))

TESTS.append(
(
Expand Down
1 change: 0 additions & 1 deletion tests/test_patch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_shape(self):
output = []
n_workers = 0 if sys.platform == "win32" else 2
for item in DataLoader(result, batch_size=3, num_workers=n_workers):
print(item)
output.append("".join(item))
expected = ["vwx", "yzh", "ell", "owo", "rld"]
self.assertEqual(output, expected)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_rand_spatial_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),
]

TEST_CASE_4 = [
{"roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True},
np.random.randint(0, 2, size=[1, 4, 5, 6]),
(1, 4, 4, 3),
]

TEST_CASE_5 = [
{"roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True},
np.random.randint(0, 2, size=[1, 4, 5, 6]),
(1, 3, 4, 3),
]


class TestRandSpatialCrop(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
Expand All @@ -49,6 +61,13 @@ def test_value(self, input_param, input_data):
roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size]
np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]])

@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_random_shape(self, input_param, input_data, expected_shape):
cropper = RandSpatialCrop(**input_param)
cropper.set_random_state(seed=123)
result = cropper(input_data)
self.assertTupleEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
19 changes: 19 additions & 0 deletions tests/test_rand_spatial_cropd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@
{"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])},
]

TEST_CASE_4 = [
{"keys": "img", "roi_size": [3, 3, 3], "max_roi_size": [5, -1, 4], "random_center": True, "random_size": True},
{"img": np.random.randint(0, 2, size=[1, 4, 5, 6])},
(1, 4, 4, 3),
]

TEST_CASE_5 = [
{"keys": "img", "roi_size": 3, "max_roi_size": 4, "random_center": True, "random_size": True},
{"img": np.random.randint(0, 2, size=[1, 4, 5, 6])},
(1, 3, 4, 3),
]


class TestRandSpatialCropd(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
Expand All @@ -53,6 +65,13 @@ def test_value(self, input_param, input_data):
roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size]
np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]])

@parameterized.expand([TEST_CASE_4, TEST_CASE_5])
def test_random_shape(self, input_param, input_data, expected_shape):
cropper = RandSpatialCropd(**input_param)
cropper.set_random_state(seed=123)
result = cropper(input_data)
self.assertTupleEqual(result["img"].shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit aefa54d

Please sign in to comment.