Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SomeOf transform composer #6143

Merged
merged 9 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ Generic Interfaces
.. autoclass:: RandomOrder
:members:

`SomeOf`
^^^^^^^^^^^^^
.. autoclass:: SomeOf
:members:

Functionals
-----------

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs
from .compose import Compose, OneOf, RandomOrder
from .compose import Compose, OneOf, RandomOrder, SomeOf
from .croppad.array import (
BorderPad,
BoundingRect,
Expand Down
146 changes: 143 additions & 3 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
Transform,
apply_transform,
)
from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed
from monai.utils.misc import to_tuple_of_dictionaries
from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed, to_tuple_of_dictionaries

logger = get_logger(__name__)

__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides"]
__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides", "SomeOf"]


def evaluate_with_overrides(
Expand Down Expand Up @@ -521,3 +520,144 @@ def inverse(self, data):
self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats
)
return data


class SomeOf(Compose):
tuanchien marked this conversation as resolved.
Show resolved Hide resolved
"""
``SomeOf`` samples a different sequence of transforms to apply each time it is called.

It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn
uniformly, or from user supplied transform weights. When varying the number of transforms sampled per call,
the number of transforms to sample that call is sampled uniformly from a range supplied by the user.

Args:
transforms: list of callables.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
Defaults to `True`.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
Defaults to `False`.
log_stats: whether to log the detailed information of data and applied transform when error happened,
for NumPy array and PyTorch Tensor, log the data shape and value range,
for other metadata, log the values directly. Default to `False`.
num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of
transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal.
None sets it to `len(transforms)`. Default to `None`.
replace: whether to sample with replacement. Defaults to `False`.
weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform).
"""

def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
*,
num_transforms: int | tuple[int, int] | None = None,
replace: bool = False,
weights: list[int] | None = None,
) -> None:
super().__init__(transforms, map_items, unpack_items, log_stats)
self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms)
self.replace = replace
self.weights = self._normalize_probabilities(weights)

def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple:
if (
not isinstance(num_transforms, tuple)
and not isinstance(num_transforms, list)
and not isinstance(num_transforms, int)
and num_transforms is not None
):
raise ValueError(
f"Expected num_transforms to be of type int, list, tuple or None, but it's {type(num_transforms)}"
)

if num_transforms is None:
result = [len(self.transforms), len(self.transforms)]
elif isinstance(num_transforms, int):
n = min(len(self.transforms), num_transforms)
result = [n, n]
else:
if len(num_transforms) != 2:
raise ValueError(f"Expected len(num_transforms)=2, but it was {len(num_transforms)}")
if not isinstance(num_transforms[0], int) or not isinstance(num_transforms[1], int):
raise ValueError(
f"Expected (int,int), but received ({type(num_transforms[0])}, {type(num_transforms[1])})"
)

result = [num_transforms[0], num_transforms[1]]

if result[0] < 0 or result[1] > len(self.transforms):
raise ValueError(f"num_transforms={num_transforms} are out of the bounds [0, {len(self.transforms)}].")

return ensure_tuple(result)

# Modified from OneOf
def _normalize_probabilities(self, weights):
if weights is None or len(self.transforms) == 0:
return None

weights = np.array(weights)

n_weights = len(weights)
if n_weights != len(self.transforms):
raise ValueError(f"Expected len(weights)={len(self.transforms)}, got: {n_weights}.")

if np.any(weights < 0):
raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.")

if np.all(weights == 0):
raise ValueError(f"At least one probability must be greater than zero, got {weights}.")

weights = weights / weights.sum()

return ensure_tuple(list(weights))

def __call__(self, data):
if len(self.transforms) == 0:
return data

sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)
applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist()
for i in applied_order:
data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats)

if isinstance(data, monai.data.MetaTensor):
self.push_transform(data, extra_info={"applied_order": applied_order})
elif isinstance(data, Mapping):
for key in data: # dictionary not change size during iteration
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
self.push_transform(data, key, extra_info={"applied_order": applied_order})

return data

# From RandomOrder
def inverse(self, data):
if len(self.transforms) == 0:
return data

applied_order = None
if isinstance(data, monai.data.MetaTensor):
applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"]
elif isinstance(data, Mapping):
for key in data:
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"]
else:
raise RuntimeError(
f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}."
)
if applied_order is None:
# no invertible transforms have been applied
return data

# loop backwards over transforms
for o in reversed(applied_order):
transform = self.transforms[o]
if isinstance(transform, InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats
)

return data
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
star_zip_with,
str2bool,
str2list,
to_tuple_of_dictionaries,
zip_with,
)
from .module import (
Expand Down
1 change: 1 addition & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"ensure_tuple",
"ensure_tuple_size",
"ensure_tuple_rep",
"to_tuple_of_dictionaries",
"fall_back_tuple",
"is_scalar_tensor",
"is_scalar",
Expand Down
Loading