From 0ee18a727957f938b07d045e946df03e93f66ff3 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Thu, 25 Apr 2024 17:40:28 +0800 Subject: [PATCH] Fixes #6704 Signed-off-by: ytl0623 --- monai/transforms/utils.py | 59 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 560dbac3466..8a45d8fdebd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -108,6 +108,7 @@ "in_bounds", "is_empty", "is_positive", + "map_and_generate_sampling_centers", "map_binary_to_indices", "map_classes_to_indices", "map_spatial_axes", @@ -368,6 +369,64 @@ def check_non_lazy_pending_ops( warnings.warn(msg) +def map_and_generate_sampling_centers( + label: NdarrayOrTensor, + spatial_size: Sequence[int] | int, + num_samples: int, + label_spatial_shape: Sequence[int], + indices: Sequence[NdarrayOrTensor], + num_classes: int | None = None, + image: NdarrayOrTensor | None = None, + image_threshold: float = 0.0, + max_samples_per_class: int | None = None, + ratios: list[float | int] | None = None, + rand_state: np.random.RandomState | None = None, + allow_smaller: bool = False, + warn: bool = True, +) -> list[NdarrayOrTensor]: + """ + Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates. + + Args: + label: use the label data to get the indices of every class. + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + max_samples_per_class: maximum length of indices in each class to reduce memory consumption. + Default is None, no subsampling. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. + + """ + indices_ = indices if indices is None else indices + if indices_ is None: + if label is None: + raise ValueError("label must not be None.") + indices_ = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class) + _shape = None + if label is not None: + _shape = label.peek_pending_shape() if isinstance(label, monai.data.MetaTensor) else label.shape[1:] + elif image is not None: + _shape = image.peek_pending_shape() if isinstance(image, monai.data.MetaTensor) else image.shape[1:] + if _shape is None: + raise ValueError("label or image must be provided to infer the output spatial shape.") + centers = generate_label_classes_crop_centers( + spatial_size, num_samples, _shape, indices_, ratios, rand_state, allow_smaller, warn + ) + return list(centers) + + def map_binary_to_indices( label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0 ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: