Skip to content

Commit

Permalink
6222 adapative resampling mode based on backends (#6429)
Browse files Browse the repository at this point in the history
Fixes #6222

### Description
nonbreaking change extending the current mode+padding mode combinations.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Apr 26, 2023
1 parent 7ea082e commit 5991b83
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 61 deletions.
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@
rescale_instance_array,
reset_ops_id,
resize_center,
resolves_modes,
sync_meta_info,
weighted_patch_samples,
zero_margins,
Expand Down
73 changes: 23 additions & 50 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

import functools
import warnings
from collections.abc import Callable
from copy import deepcopy
Expand Down Expand Up @@ -54,16 +53,15 @@
create_shear,
create_translate,
map_spatial_axes,
resolves_modes,
scale_affine,
)
from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
InterpolateMode,
NdimageMode,
NumpyPadMode,
SplineMode,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
Expand Down Expand Up @@ -695,7 +693,7 @@ def __init__(
) -> None:
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.mode = mode
self.align_corners = align_corners
self.anti_aliasing = anti_aliasing
self.anti_aliasing_sigma = anti_aliasing_sigma
Expand Down Expand Up @@ -759,7 +757,7 @@ def __call__(
scale = self.spatial_size / max(img_size)
sp_size = tuple(int(round(s * scale)) for s in img_size)

_mode = look_up_option(self.mode if mode is None else mode, InterpolateMode)
_mode = self.mode if mode is None else mode
_align_corners = self.align_corners if align_corners is None else align_corners
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
return resize( # type: ignore
Expand Down Expand Up @@ -831,8 +829,8 @@ def __init__(
) -> None:
self.angle = angle
self.keep_size = keep_size
self.mode: str = look_up_option(mode, GridSampleMode)
self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
self.mode: str = mode
self.padding_mode: str = padding_mode
self.align_corners = align_corners
self.dtype = dtype

Expand Down Expand Up @@ -867,8 +865,8 @@ def __call__(
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
_mode = look_up_option(mode or self.mode, GridSampleMode)
_padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode)
_mode = mode or self.mode
_padding_mode = padding_mode or self.padding_mode
_align_corners = self.align_corners if align_corners is None else align_corners
im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
output_shape = im_shape if self.keep_size else None
Expand All @@ -888,10 +886,11 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat))

_, _m, _p, _ = resolves_modes(mode, padding_mode)
xform = AffineTransform(
normalized=False,
mode=mode,
padding_mode=padding_mode,
mode=_m,
padding_mode=_p,
align_corners=False if align_corners == TraceKeys.NONE else align_corners,
reverse_indexing=True,
)
Expand Down Expand Up @@ -953,7 +952,7 @@ def __init__(
**kwargs,
) -> None:
self.zoom = zoom
self.mode: InterpolateMode = InterpolateMode(mode)
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = dtype
Expand Down Expand Up @@ -991,7 +990,7 @@ def __call__(
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
_zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim
_mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value
_mode = self.mode if mode is None else mode
_padding_mode = padding_mode or self.padding_mode
_align_corners = self.align_corners if align_corners is None else align_corners
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
Expand Down Expand Up @@ -1181,8 +1180,8 @@ def __init__(
self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]]))

self.keep_size = keep_size
self.mode: str = look_up_option(mode, GridSampleMode)
self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
self.mode: str = mode
self.padding_mode: str = padding_mode
self.align_corners = align_corners
self.dtype = dtype

Expand Down Expand Up @@ -1231,8 +1230,8 @@ def __call__(
rotator = Rotate(
angle=self.x if ndim == 2 else (self.x, self.y, self.z),
keep_size=self.keep_size,
mode=look_up_option(mode or self.mode, GridSampleMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode),
mode=mode or self.mode,
padding_mode=padding_mode or self.padding_mode,
align_corners=self.align_corners if align_corners is None else align_corners,
dtype=dtype or self.dtype or img.dtype,
)
Expand Down Expand Up @@ -1406,7 +1405,7 @@ def __init__(
raise ValueError(
f"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}."
)
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = dtype
Expand Down Expand Up @@ -1467,7 +1466,7 @@ def __call__(
xform = Zoom(
self._zoom,
keep_size=self.keep_size,
mode=look_up_option(mode or self.mode, InterpolateMode),
mode=mode or self.mode,
padding_mode=padding_mode or self.padding_mode,
align_corners=self.align_corners if align_corners is None else align_corners,
dtype=dtype or self.dtype,
Expand Down Expand Up @@ -1815,35 +1814,6 @@ def __init__(
self.align_corners = align_corners
self.dtype = dtype

@staticmethod
@functools.lru_cache(None)
def resolve_modes(interp_mode, padding_mode):
"""compute the backend and the corresponding mode for the given interpolation mode and padding mode."""
_interp_mode = None
_padding_mode = None
if look_up_option(str(interp_mode), SplineMode, default=None) is not None:
backend = TransformBackends.NUMPY
else:
backend = TransformBackends.TORCH

if (not USE_COMPILED) and (backend == TransformBackends.TORCH):
if str(interp_mode).lower().endswith("linear"):
_interp_mode = GridSampleMode("bilinear")
_interp_mode = GridSampleMode(interp_mode)
_padding_mode = GridSamplePadMode(padding_mode)
elif USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name
_padding_mode = 1 if padding_mode == "reflection" else padding_mode # type: ignore
if interp_mode == "bicubic":
_interp_mode = 3 # type: ignore
elif interp_mode == "bilinear":
_interp_mode = 1 # type: ignore
else:
_interp_mode = GridSampleMode(interp_mode)
else: # TransformBackends.NUMPY
_interp_mode = int(interp_mode) # type: ignore
_padding_mode = look_up_option(padding_mode, NdimageMode)
return backend, _interp_mode, _padding_mode

def __call__(
self,
img: torch.Tensor,
Expand Down Expand Up @@ -1894,8 +1864,11 @@ def __call__(
_align_corners = self.align_corners if align_corners is None else align_corners
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)
backend, _interp_mode, _padding_mode = Resample.resolve_modes(
self.mode if mode is None else mode, self.padding_mode if padding_mode is None else padding_mode
backend, _interp_mode, _padding_mode, _ = resolves_modes(
self.mode if mode is None else mode,
self.padding_mode if padding_mode is None else padding_mode,
backend=None,
use_compiled=USE_COMPILED,
)

if USE_COMPILED or backend == TransformBackends.NUMPY:
Expand Down
14 changes: 9 additions & 5 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, scale_affine
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
LazyAttr,
Expand Down Expand Up @@ -172,8 +172,9 @@ def spatial_resample(
with affine_xform.trace_transform(False):
img = affine_xform(img, mode=mode, padding_mode=padding_mode)
else:
_, _m, _p, _ = resolves_modes(mode, padding_mode)
affine_xform = AffineTransform( # type: ignore
normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore
if additional_dims:
Expand Down Expand Up @@ -331,8 +332,9 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing,
anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)
anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)
img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1)
resized = torch.nn.functional.interpolate(
input=img_.unsqueeze(0), size=out_size, mode=mode, align_corners=align_corners
input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners
)
out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
Expand Down Expand Up @@ -396,8 +398,9 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t
out = _maybe_new_metatensor(img)
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
_, _m, _p, _ = resolves_modes(mode, padding_mode)
xform = AffineTransform(
normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img_t = out.to(dtype)
transform_t, *_ = convert_to_dst_type(transform, img_t)
Expand Down Expand Up @@ -468,11 +471,12 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype,
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
img_t = out.to(dtype)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1)
zoomed: NdarrayOrTensor = torch.nn.functional.interpolate(
recompute_scale_factor=True,
input=img_t.unsqueeze(0),
scale_factor=list(scale_factor),
mode=mode,
mode=_m,
align_corners=align_corners,
).squeeze(0)
out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32)
Expand Down
Loading

0 comments on commit 5991b83

Please sign in to comment.