diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 560dbac346..699ef87745 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -108,6 +108,7 @@ "in_bounds", "is_empty", "is_positive", + "map_and_generate_sampling_centers", "map_binary_to_indices", "map_classes_to_indices", "map_spatial_axes", @@ -368,6 +369,64 @@ def check_non_lazy_pending_ops( warnings.warn(msg) +def map_and_generate_sampling_centers( + label: NdarrayOrTensor, + spatial_size: Sequence[int] | int, + num_samples: int, + label_spatial_shape: Sequence[int], + indices: Sequence[NdarrayOrTensor], + num_classes: int | None = None, + image: NdarrayOrTensor | None = None, + image_threshold: float = 0.0, + max_samples_per_class: int | None = None, + ratios: list[float | int] | None = None, + rand_state: np.random.RandomState | None = None, + allow_smaller: bool = False, + warn: bool = True, +) -> tuple[tuple]: + """ + Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates. + + Args: + label: use the label data to get the indices of every class. + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + max_samples_per_class: maximum length of indices in each class to reduce memory consumption. + Default is None, no subsampling. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). + warn: if `True` prints a warning if a class is not present in the label. + + """ + indices_ = indices if indices is None else indices + if indices_ is None: + if label is None: + raise ValueError("label must not be None.") + indices_ = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class) + _shape = None + if label is not None: + _shape = label.peek_pending_shape() if isinstance(label, monai.data.MetaTensor) else label.shape[1:] + elif image is not None: + _shape = image.peek_pending_shape() if isinstance(image, monai.data.MetaTensor) else image.shape[1:] + if _shape is None: + raise ValueError("label or image must be provided to infer the output spatial shape.") + centers = generate_label_classes_crop_centers( + spatial_size, num_samples, _shape, indices_, ratios, rand_state, allow_smaller, warn + ) + return ensure_tuple(centers) + + def map_binary_to_indices( label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0 ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]: