Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support GPU tensor for GridPatch and GridPatchDataset #6246

Merged
merged 10 commits into from
Mar 29, 2023
43 changes: 28 additions & 15 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

from __future__ import annotations

from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from collections.abc import Generator, Callable, Hashable, Iterable, Mapping, Sequence
from copy import deepcopy

import numpy as np

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple, first, look_up_option
from monai.utils import NumpyPadMode, ensure_tuple, first

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]

Expand All @@ -34,17 +35,25 @@ class PatchIter:
"""

def __init__(
self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: dict
self,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: str | None = NumpyPadMode.WRAP,
**pad_opts: dict,
):
"""

Args:
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Expand All @@ -58,10 +67,10 @@ def __init__(
"""
self.patch_size = (None,) + tuple(patch_size) # expand to have the channel dim
self.start_pos = ensure_tuple(start_pos)
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.mode = mode
self.pad_opts = pad_opts

def __call__(self, array: np.ndarray):
def __call__(self, array: NdarrayTensor) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]:
"""
Args:
array: the image to generate patches from.
Expand Down Expand Up @@ -89,10 +98,14 @@ class PatchIterd:
keys: keys of the corresponding items to iterate patches.
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Expand All @@ -107,13 +120,13 @@ def __init__(
keys: KeysCollection,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: str = NumpyPadMode.WRAP,
mode: str | None = NumpyPadMode.WRAP,
**pad_opts,
):
self.keys = ensure_tuple(keys)
self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts)

def __call__(self, data: Mapping[Hashable, np.ndarray]):
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]:
d = dict(data)
original_spatial_shape = d[first(self.keys)].shape[1:]

Expand Down
22 changes: 16 additions & 6 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ def iter_patch_position(


def iter_patch(
arr: np.ndarray,
arr: NdarrayTensor,
patch_size: Sequence[int] | int = 0,
start_pos: Sequence[int] = (),
overlap: Sequence[float] | float = 0.0,
copy_back: bool = True,
mode: str | None = NumpyPadMode.WRAP,
**pad_opts: dict,
):
) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]:
"""
Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`
but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative
Expand All @@ -268,9 +268,16 @@ def iter_patch(
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes
mode: One of the listed string values in ``monai.utils.NumpyPadMode`` or ``monai.utils.PytorchPadMode``,
or a user supplied function. If None, no wrapping is performed. Defaults to ``"wrap"``.
pad_opts: padding options, see `numpy.pad`
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Yields:
Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is
Expand All @@ -285,6 +292,9 @@ def iter_patch(
Nth_dim_start, Nth_dim_end]]

"""

from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import

# ensure patchSize and startPos are the right length
patch_size_ = get_valid_patch_size(arr.shape, patch_size)
start_pos = ensure_tuple_size(start_pos, arr.ndim)
Expand All @@ -296,7 +306,7 @@ def iter_patch(
_overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0
# pad image by maximum values needed to ensure patches are taken from inside an image
if padded:
arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts)
arrpad = pad_nd(arr, to_pad=tuple((p, p) for p in _pad_size), mode=mode, **pad_opts)
# choose a start position in the padded image
start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size))

Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __call__( # type: ignore[override]
kwargs_.update(kwargs)

img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_)
return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_)

def inverse(self, data: MetaTensor) -> MetaTensor:
transform = self.pop_transform(data)
Expand Down
46 changes: 25 additions & 21 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch.nn.functional import pad as pad_pt

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
Expand Down Expand Up @@ -49,7 +50,7 @@ def _convert_pt_pad_mode(padding_mode):
return PytorchPadMode.REPLICATE # "nearest", "border", and others


def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:
if isinstance(img, torch.Tensor):
if img.is_cuda:
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
Expand All @@ -59,28 +60,30 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw
mode = convert_pad_mode(dst=img_np, mode=mode).value
if mode == "constant" and "value" in kwargs:
kwargs["constant_values"] = kwargs.pop("value")
out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore
if isinstance(img, MetaTensor):
out = convert_to_dst_type(out, dst=img)[0]
return out
img_np = np.pad(img_np, pad_width, mode=mode, **kwargs)
return convert_to_dst_type(img_np, dst=img)[0]


def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
mode = convert_pad_mode(dst=img, mode=mode).value
def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:
img_pt = torch.as_tensor(img)
mode = convert_pad_mode(dst=img_pt, mode=mode).value
if mode == "constant" and "constant_values" in kwargs:
_kwargs = kwargs.copy()
_kwargs["value"] = _kwargs.pop("constant_values")
else:
_kwargs = kwargs
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
# torch.pad expects `[B, C, H, W, [D]]` shape
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)
img_pt = pad_pt(img_pt.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)
return convert_to_dst_type(img_pt, dst=img)[0]


def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs):
def pad_nd(img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str=PytorchPadMode.CONSTANT, **kwargs) -> NdarrayTensor:
"""
PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype,
a suitable backend will be used automatically.
Pad `img` for a given an amount of padding in each dimension.

`torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
in which case `np.pad` will be used.

Args:
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
Expand All @@ -90,20 +93,18 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
mode = convert_pad_mode(dst=img, mode=mode).value
try:
_pad = (
_np_pad
if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8}
else _pt_pad
)
_pad = _np_pad
if (mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and
img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}):
_pad = _pt_pad
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
except (ValueError, TypeError, RuntimeError) as err:
if isinstance(err, NotImplementedError) or any(
Expand Down Expand Up @@ -148,23 +149,26 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,


def pad_func(
img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs
img: torch.Tensor, to_pad: tuple[tuple[int, int]], transform_info: dict, mode: str=PytorchPadMode.CONSTANT, **kwargs
wyli marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
"""
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).

`torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
in which case `np.pad` will be used.

Args:
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
note that it including channel dimension.
transform_info: a dictionary with the relevant information pertaining to an applied transform.
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
transform_info: a dictionary with the relevant information pertaining to an applied transform.
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
Expand Down
Loading