Skip to content

Commit

Permalink
Fixes #6704
Browse files Browse the repository at this point in the history
Signed-off-by: ytl0623 <[email protected]>
  • Loading branch information
ytl0623 committed Apr 25, 2024
1 parent 8c709de commit a78c5a6
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"in_bounds",
"is_empty",
"is_positive",
"map_and_generate_sampling_centers",
"map_binary_to_indices",
"map_classes_to_indices",
"map_spatial_axes",
Expand Down Expand Up @@ -368,6 +369,64 @@ def check_non_lazy_pending_ops(
warnings.warn(msg)


def map_and_generate_sampling_centers(
label: NdarrayOrTensor,
spatial_size: Sequence[int] | int,
num_samples: int,
label_spatial_shape: Sequence[int],
indices: Sequence[NdarrayOrTensor],
num_classes: int | None = None,
image: NdarrayOrTensor | None = None,
image_threshold: float = 0.0,
max_samples_per_class: int | None = None,
ratios: list[float | int] | None = None,
rand_state: np.random.RandomState | None = None,
allow_smaller: bool = False,
warn: bool = True,
) -> tuple[tuple]:
"""
Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates.
Args:
label: use the label data to get the indices of every class.
spatial_size: spatial size of the ROIs to be sampled.
num_samples: total sample centers to be generated.
label_spatial_shape: spatial shape of the original label data to unravel selected centers.
indices: sequence of pre-computed foreground indices of every class in 1 dimension.
num_classes: number of classes for argmax label, not necessary for One-Hot label.
image: if image is not None, only return the indices of every class that are within the valid
region of the image (``image > image_threshold``).
image_threshold: if enabled `image`, use ``image > image_threshold`` to
determine the valid image content area and select class indices only in this area.
max_samples_per_class: maximum length of indices in each class to reduce memory consumption.
Default is None, no subsampling.
ratios: ratios of every class in the label to generate crop centers, including background class.
if None, every class will have the same ratio to generate crop centers.
rand_state: numpy randomState object to align with other modules.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).
warn: if `True` prints a warning if a class is not present in the label.
"""
indices_ = indices if indices is None else indices
if indices_ is None:
if label is None:
raise ValueError("label must not be None.")
indices_ = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class)
_shape = None
if label is not None:
_shape = label.peek_pending_shape() if isinstance(label, monai.data.MetaTensor) else label.shape[1:]
elif image is not None:
_shape = image.peek_pending_shape() if isinstance(image, monai.data.MetaTensor) else image.shape[1:]
if _shape is None:
raise ValueError("label or image must be provided to infer the output spatial shape.")
centers = generate_label_classes_crop_centers(
spatial_size, num_samples, _shape, indices_, ratios, rand_state, allow_smaller, warn
)
return ensure_tuple(centers)


def map_binary_to_indices(
label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0
) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:
Expand Down

0 comments on commit a78c5a6

Please sign in to comment.