Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6222 adapative resampling mode based on backends #6429

Merged
merged 8 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@
rescale_instance_array,
reset_ops_id,
resize_center,
resolves_modes,
sync_meta_info,
weighted_patch_samples,
zero_margins,
Expand Down
73 changes: 23 additions & 50 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from __future__ import annotations

import functools
import warnings
from collections.abc import Callable
from copy import deepcopy
Expand Down Expand Up @@ -54,16 +53,15 @@
create_shear,
create_translate,
map_spatial_axes,
resolves_modes,
scale_affine,
)
from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
InterpolateMode,
NdimageMode,
NumpyPadMode,
SplineMode,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
Expand Down Expand Up @@ -695,7 +693,7 @@ def __init__(
) -> None:
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.mode = mode
self.align_corners = align_corners
self.anti_aliasing = anti_aliasing
self.anti_aliasing_sigma = anti_aliasing_sigma
Expand Down Expand Up @@ -759,7 +757,7 @@ def __call__(
scale = self.spatial_size / max(img_size)
sp_size = tuple(int(round(s * scale)) for s in img_size)

_mode = look_up_option(self.mode if mode is None else mode, InterpolateMode)
_mode = self.mode if mode is None else mode
_align_corners = self.align_corners if align_corners is None else align_corners
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
return resize( # type: ignore
Expand Down Expand Up @@ -831,8 +829,8 @@ def __init__(
) -> None:
self.angle = angle
self.keep_size = keep_size
self.mode: str = look_up_option(mode, GridSampleMode)
self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
self.mode: str = mode
self.padding_mode: str = padding_mode
self.align_corners = align_corners
self.dtype = dtype

Expand Down Expand Up @@ -867,8 +865,8 @@ def __call__(
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
_mode = look_up_option(mode or self.mode, GridSampleMode)
_padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode)
_mode = mode or self.mode
_padding_mode = padding_mode or self.padding_mode
_align_corners = self.align_corners if align_corners is None else align_corners
im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
output_shape = im_shape if self.keep_size else None
Expand All @@ -888,10 +886,11 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat))

_, _m, _p, _ = resolves_modes(mode, padding_mode)
xform = AffineTransform(
normalized=False,
mode=mode,
padding_mode=padding_mode,
mode=_m,
padding_mode=_p,
align_corners=False if align_corners == TraceKeys.NONE else align_corners,
reverse_indexing=True,
)
Expand Down Expand Up @@ -953,7 +952,7 @@ def __init__(
**kwargs,
) -> None:
self.zoom = zoom
self.mode: InterpolateMode = InterpolateMode(mode)
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = dtype
Expand Down Expand Up @@ -991,7 +990,7 @@ def __call__(
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
_zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim
_mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value
_mode = self.mode if mode is None else mode
_padding_mode = padding_mode or self.padding_mode
_align_corners = self.align_corners if align_corners is None else align_corners
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
Expand Down Expand Up @@ -1181,8 +1180,8 @@ def __init__(
self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]]))

self.keep_size = keep_size
self.mode: str = look_up_option(mode, GridSampleMode)
self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
self.mode: str = mode
self.padding_mode: str = padding_mode
self.align_corners = align_corners
self.dtype = dtype

Expand Down Expand Up @@ -1231,8 +1230,8 @@ def __call__(
rotator = Rotate(
angle=self.x if ndim == 2 else (self.x, self.y, self.z),
keep_size=self.keep_size,
mode=look_up_option(mode or self.mode, GridSampleMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode),
mode=mode or self.mode,
padding_mode=padding_mode or self.padding_mode,
align_corners=self.align_corners if align_corners is None else align_corners,
dtype=dtype or self.dtype or img.dtype,
)
Expand Down Expand Up @@ -1406,7 +1405,7 @@ def __init__(
raise ValueError(
f"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}."
)
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
self.dtype = dtype
Expand Down Expand Up @@ -1467,7 +1466,7 @@ def __call__(
xform = Zoom(
self._zoom,
keep_size=self.keep_size,
mode=look_up_option(mode or self.mode, InterpolateMode),
mode=mode or self.mode,
padding_mode=padding_mode or self.padding_mode,
align_corners=self.align_corners if align_corners is None else align_corners,
dtype=dtype or self.dtype,
Expand Down Expand Up @@ -1815,35 +1814,6 @@ def __init__(
self.align_corners = align_corners
self.dtype = dtype

@staticmethod
@functools.lru_cache(None)
def resolve_modes(interp_mode, padding_mode):
"""compute the backend and the corresponding mode for the given interpolation mode and padding mode."""
_interp_mode = None
_padding_mode = None
if look_up_option(str(interp_mode), SplineMode, default=None) is not None:
backend = TransformBackends.NUMPY
else:
backend = TransformBackends.TORCH

if (not USE_COMPILED) and (backend == TransformBackends.TORCH):
if str(interp_mode).lower().endswith("linear"):
_interp_mode = GridSampleMode("bilinear")
_interp_mode = GridSampleMode(interp_mode)
_padding_mode = GridSamplePadMode(padding_mode)
elif USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name
_padding_mode = 1 if padding_mode == "reflection" else padding_mode # type: ignore
if interp_mode == "bicubic":
_interp_mode = 3 # type: ignore
elif interp_mode == "bilinear":
_interp_mode = 1 # type: ignore
else:
_interp_mode = GridSampleMode(interp_mode)
else: # TransformBackends.NUMPY
_interp_mode = int(interp_mode) # type: ignore
_padding_mode = look_up_option(padding_mode, NdimageMode)
return backend, _interp_mode, _padding_mode

def __call__(
self,
img: torch.Tensor,
Expand Down Expand Up @@ -1894,8 +1864,11 @@ def __call__(
_align_corners = self.align_corners if align_corners is None else align_corners
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)
backend, _interp_mode, _padding_mode = Resample.resolve_modes(
self.mode if mode is None else mode, self.padding_mode if padding_mode is None else padding_mode
backend, _interp_mode, _padding_mode, _ = resolves_modes(
self.mode if mode is None else mode,
self.padding_mode if padding_mode is None else padding_mode,
backend=None,
use_compiled=USE_COMPILED,
)

if USE_COMPILED or backend == TransformBackends.NUMPY:
Expand Down
14 changes: 9 additions & 5 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, scale_affine
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
LazyAttr,
Expand Down Expand Up @@ -172,8 +172,9 @@ def spatial_resample(
with affine_xform.trace_transform(False):
img = affine_xform(img, mode=mode, padding_mode=padding_mode)
else:
_, _m, _p, _ = resolves_modes(mode, padding_mode)
affine_xform = AffineTransform( # type: ignore
normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0) # type: ignore
if additional_dims:
Expand Down Expand Up @@ -331,8 +332,9 @@ def resize(img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing,
anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)
anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)
img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1)
resized = torch.nn.functional.interpolate(
input=img_.unsqueeze(0), size=out_size, mode=mode, align_corners=align_corners
input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners
)
out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
Expand Down Expand Up @@ -396,8 +398,9 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, t
out = _maybe_new_metatensor(img)
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
_, _m, _p, _ = resolves_modes(mode, padding_mode)
xform = AffineTransform(
normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True
normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True
)
img_t = out.to(dtype)
transform_t, *_ = convert_to_dst_type(transform, img_t)
Expand Down Expand Up @@ -468,11 +471,12 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype,
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
img_t = out.to(dtype)
_, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1)
zoomed: NdarrayOrTensor = torch.nn.functional.interpolate(
recompute_scale_factor=True,
input=img_t.unsqueeze(0),
scale_factor=list(scale_factor),
mode=mode,
mode=_m,
align_corners=align_corners,
).squeeze(0)
out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32)
Expand Down
Loading