Skip to content

Commit

Permalink
Implement distance_transform_edt and the DistanceTransformEDT transfo…
Browse files Browse the repository at this point in the history
…rm (#6981)

Related to #6845, this
commits adds an EDT distance transform to MONAI.
Most importantly this enables GPU based distance transforms which lead
to a huge speedup.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] In-line docstrings updated.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Matthias Hadlich <[email protected]>
  • Loading branch information
matt3o authored Sep 27, 2023
1 parent a29ab04 commit 84566d1
Show file tree
Hide file tree
Showing 7 changed files with 453 additions and 2 deletions.
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ Post-processing
:members:
:special-members: __call__

`DistanceTransformEDT`
"""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDT
:members:
:special-members: __call__

`RemoveSmallObjects`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png
Expand Down Expand Up @@ -1622,6 +1628,12 @@ Post-processing (Dict)
:members:
:special-members: __call__

`DistanceTransformEDTd`
""""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDTd
:members:
:special-members: __call__

`RemoveSmallObjectsd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
from .post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
Invert,
KeepLargestConnectedComponent,
Expand All @@ -295,6 +296,9 @@
AsDiscreteD,
AsDiscreted,
AsDiscreteDict,
DistanceTransformEDTd,
DistanceTransformEDTD,
DistanceTransformEDTDict,
Ensembled,
EnsembleD,
EnsembleDict,
Expand Down
38 changes: 38 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import (
convert_applied_interp_mode,
distance_transform_edt,
fill_holes,
get_largest_connected_component_mask,
get_unique_labels,
Expand All @@ -53,6 +54,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"DistanceTransformEDT",
]


Expand Down Expand Up @@ -936,3 +938,39 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]

return grads


class DistanceTransformEDT(Transform):
"""
Applies the Euclidean distance transform on the input.
Either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
"""

backend = [TransformBackends.NUMPY, TransformBackends.CUPY]

def __init__(self, sampling: None | float | list[float] = None) -> None:
super().__init__()
self.sampling = sampling

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: Input image on which the distance transform shall be run.
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
Returns:
An array with the same shape and data type as img
"""
return distance_transform_edt(img=img, sampling=self.sampling) # type: ignore
50 changes: 50 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from monai.transforms.post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
KeepLargestConnectedComponent,
LabelFilter,
Expand Down Expand Up @@ -91,6 +92,9 @@
"VoteEnsembleD",
"VoteEnsembleDict",
"VoteEnsembled",
"DistanceTransformEDTd",
"DistanceTransformEDTD",
"DistanceTransformEDTDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -855,6 +859,51 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class DistanceTransformEDTd(MapTransform):
"""
Applies the Euclidean distance transform on the input.
Either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.
Note on the input shape:
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
Args:
keys: keys of the corresponding items to be transformed.
allow_missing_keys: don't raise exception if key is missing.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
"""

backend = DistanceTransformEDT.backend

def __init__(
self, keys: KeysCollection, allow_missing_keys: bool = False, sampling: None | float | list[float] = None
) -> None:
super().__init__(keys, allow_missing_keys)
self.sampling = sampling
self.distance_transform = DistanceTransformEDT(sampling=self.sampling)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.distance_transform(img=d[key])

return d


ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
FillHolesD = FillHolesDict = FillHolesd
Expand All @@ -869,3 +918,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
EnsembleD = EnsembleDict = Ensembled
SobelGradientsD = SobelGradientsDict = SobelGradientsd
DistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd
148 changes: 146 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@
pytorch_after,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
from monai.utils.type_conversion import (
convert_data_type,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
)

measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
morphology, has_morphology = optional_import("skimage.morphology")
ndimage, _ = optional_import("scipy.ndimage")
ndimage, has_ndimage = optional_import("scipy.ndimage")
cp, has_cp = optional_import("cupy")
cp_ndarray, _ = optional_import("cupy", name="ndarray")
exposure, has_skimage = optional_import("skimage.exposure")
Expand Down Expand Up @@ -124,6 +130,7 @@
"reset_ops_id",
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
]


Expand Down Expand Up @@ -2051,5 +2058,142 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str =
return True, None


def distance_transform_edt(
img: NdarrayOrTensor,
sampling: None | float | list[float] = None,
return_distances: bool = True,
return_indices: bool = False,
distances: NdarrayOrTensor | None = None,
indices: NdarrayOrTensor | None = None,
*,
block_params: tuple[int, int, int] | None = None,
float64_distances: bool = False,
) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation.
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
Args:
img: Input image on which the distance transform shall be run.
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
return_distances: Whether to calculate the distance transform.
return_indices: Whether to calculate the feature transform.
distances: An output array to store the calculated distance transform, instead of returning it.
`return_distances` must be True.
indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True.
block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_.
float64_distances: This parameter is specific to cuCIM and does not exist in SciPy.
If True, use double precision in the distance computation (to match SciPy behavior).
Otherwise, single precision will be used for efficiency.
Returns:
distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied.
It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True,
otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64.
indices: The calculated feature transform. It has an image-shaped array for each dimension of the image.
Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64.
"""
distance_transform_edt, has_cucim = optional_import(
"cucim.core.operations.morphology", name="distance_transform_edt"
)
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda"

if not return_distances and not return_indices:
raise RuntimeError("Neither return_distances nor return_indices True")

if not (img.ndim >= 3 and img.ndim <= 4):
raise RuntimeError("Wrong input dimensionality. Use (num_channels, H, W [,D])")

distances_original, indices_original = distances, indices
distances, indices = None, None
if use_cp:
distances_, indices_ = None, None
if return_distances:
dtype = torch.float64 if float64_distances else torch.float32
if distances is None:
distances = torch.zeros_like(img, dtype=dtype) # type: ignore
else:
if not isinstance(distances, torch.Tensor) and distances.device != img.device:
raise TypeError("distances must be a torch.Tensor on the same device as img")
if not distances.dtype == dtype:
raise TypeError("distances must be a torch.Tensor of dtype float32 or float64")
distances_ = convert_to_cupy(distances)
if return_indices:
dtype = torch.int32
if indices is None:
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
else:
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
raise TypeError("indices must be a torch.Tensor on the same device as img")
if not indices.dtype == dtype:
raise TypeError("indices must be a torch.Tensor of dtype int32")
indices_ = convert_to_cupy(indices)
img_ = convert_to_cupy(img)
for channel_idx in range(img_.shape[0]):
distance_transform_edt(
img_[channel_idx],
sampling=sampling,
return_distances=return_distances,
return_indices=return_indices,
distances=distances_[channel_idx] if distances_ is not None else None,
indices=indices_[channel_idx] if indices_ is not None else None,
block_params=block_params,
float64_distances=float64_distances,
)
else:
if not has_ndimage:
raise RuntimeError("scipy.ndimage required if cupy is not available")
img_ = convert_to_numpy(img)
if return_distances:
if distances is None:
distances = np.zeros_like(img_, dtype=np.float64)
else:
if not isinstance(distances, np.ndarray):
raise TypeError("distances must be a numpy.ndarray")
if not distances.dtype == np.float64:
raise TypeError("distances must be a numpy.ndarray of dtype float64")
if return_indices:
if indices is None:
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
else:
if not isinstance(indices, np.ndarray):
raise TypeError("indices must be a numpy.ndarray")
if not indices.dtype == np.int32:
raise TypeError("indices must be a numpy.ndarray of dtype int32")

for channel_idx in range(img_.shape[0]):
ndimage.distance_transform_edt(
img_[channel_idx],
sampling=sampling,
return_distances=return_distances,
return_indices=return_indices,
distances=distances[channel_idx] if distances is not None else None,
indices=indices[channel_idx] if indices is not None else None,
)

r_vals = []
if return_distances and distances_original is None:
r_vals.append(distances)
if return_indices and indices_original is None:
r_vals.append(indices)
if not r_vals:
return None
if len(r_vals) == 1:
return r_vals[0]
return tuple(r_vals) # type: ignore


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def run_testsuit():
"test_deepgrow_transforms",
"test_detect_envelope",
"test_dints_network",
"test_distance_transform_edt",
"test_efficientnet",
"test_ensemble_evaluator",
"test_ensure_channel_first",
Expand Down
Loading

0 comments on commit 84566d1

Please sign in to comment.