diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ab9adb6a99..ef1da2d855 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -671,6 +671,7 @@ in_bounds, is_empty, is_positive, + map_and_generate_sampling_centers, map_binary_to_indices, map_classes_to_indices, map_spatial_axes, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 560dbac346..d8461d927b 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -108,6 +108,7 @@ "in_bounds", "is_empty", "is_positive", + "map_and_generate_sampling_centers", "map_binary_to_indices", "map_classes_to_indices", "map_spatial_axes", @@ -368,6 +369,70 @@ def check_non_lazy_pending_ops( warnings.warn(msg) +def map_and_generate_sampling_centers( + label: NdarrayOrTensor, + spatial_size: Sequence[int] | int, + num_samples: int, + label_spatial_shape: Sequence[int] | None = None, + num_classes: int | None = None, + image: NdarrayOrTensor | None = None, + image_threshold: float = 0.0, + max_samples_per_class: int | None = None, + ratios: list[float | int] | None = None, + rand_state: np.random.RandomState | None = None, + allow_smaller: bool = False, + warn: bool = True, +) -> tuple[tuple]: + """ + Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates. + This calls `map_classes_to_indices` to get indices from `label`, gets the shape from `label_spatial_shape` + is given otherwise from the labels, calls `generate_label_classes_crop_centers`, and returns its results. + + Args: + label: use the label data to get the indices of every class. + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + max_samples_per_class: maximum length of indices in each class to reduce memory consumption. + Default is None, no subsampling. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. + Returns: + Tuple of crop centres + """ + if label is None: + raise ValueError("label must not be None.") + indices = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class) + + if label_spatial_shape is not None: + _shape = label_spatial_shape + elif isinstance(label, monai.data.MetaTensor): + _shape = label.peek_pending_shape() + else: + _shape = label.shape[1:] + + if _shape is None: + raise ValueError( + "label_spatial_shape or label with a known shape must be provided to infer the output spatial shape." + ) + centers = generate_label_classes_crop_centers( + spatial_size, num_samples, _shape, indices, ratios, rand_state, allow_smaller, warn + ) + + return ensure_tuple(centers) + + def map_binary_to_indices( label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0 ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: diff --git a/tests/test_map_and_generate_sampling_centers.py b/tests/test_map_and_generate_sampling_centers.py new file mode 100644 index 0000000000..ff74f974b9 --- /dev/null +++ b/tests/test_map_and_generate_sampling_centers.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from copy import deepcopy + +import numpy as np +from parameterized import parameterized + +from monai.transforms import map_and_generate_sampling_centers +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASE_1 = [ + # test Argmax data + { + "label": (np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "spatial_size": [2, 2, 2], + "num_samples": 2, + "label_spatial_shape": [3, 3, 3], + "num_classes": 3, + "image": None, + "ratios": [0, 1, 2], + "image_threshold": 0.0, + }, + tuple, + 2, + 3, +] + +TEST_CASE_2 = [ + { + "label": ( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "spatial_size": [2, 2, 2], + "num_samples": 1, + "ratios": None, + "label_spatial_shape": [3, 3, 3], + "image": None, + "image_threshold": 0.0, + }, + tuple, + 1, + 3, +] + + +class TestMapAndGenerateSamplingCenters(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_map_and_generate_sampling_centers(self, input_data, expected_type, expected_count, expected_shape): + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + input_data["label"] = p(input_data["label"]) + set_determinism(0) + result = map_and_generate_sampling_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + for x, y in zip(result[0], result[-1]): + assert_allclose(x, y, type_test=False) + + +if __name__ == "__main__": + unittest.main()