Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distance_transform_edt and the DistanceTransformEDT transform #6981

Merged
merged 26 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
48a2873
Implement distance_transform_edt and the DistanceTransformEDT transform
matt3o Sep 13, 2023
cbac352
Fix code style
matt3o Sep 13, 2023
150d59b
Add test_distance_transform_edt to min_tests.py
matt3o Sep 13, 2023
2349f89
Add DistanceTransformEDTd
matt3o Sep 13, 2023
e3b3846
Update docs
matt3o Sep 13, 2023
62cbdce
Fix test
matt3o Sep 14, 2023
f32841d
Add typing
matt3o Sep 14, 2023
9a88bae
Fix typing for sampling argument
matt3o Sep 14, 2023
c4ed6c5
Fix typing return value
matt3o Sep 14, 2023
f76957e
Merge branch 'dev' into gpu_edt_transform
wyli Sep 14, 2023
6eccc4a
Update docs to match the code
matt3o Sep 15, 2023
4e26689
CuPy now allows 4D channel-wise input
matt3o Sep 19, 2023
2e97953
Merge branch 'dev' into gpu_edt_transform
matt3o Sep 19, 2023
7943ffd
fixes format
wyli Sep 19, 2023
101cc62
Update monai/transforms/post/array.py
matt3o Sep 23, 2023
24814c4
Apply suggestions from code review
matt3o Sep 23, 2023
e98b10a
Add test for 4D input
matt3o Sep 20, 2023
6d19ae4
Remove force_scipy flag
matt3o Sep 24, 2023
1f60c62
Remove force_scipy flag
matt3o Sep 24, 2023
3cccb54
Rework distance_transform_edt to include more parameters
matt3o Sep 26, 2023
582efb4
Code styling
matt3o Sep 26, 2023
2bb17e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
aa31e31
DCO Remediation Commit for Matthias Hadlich <[email protected]>
matt3o Sep 26, 2023
94db8ea
Final fixes
matt3o Sep 27, 2023
d0075c5
Fix typing
matt3o Sep 27, 2023
d4da2af
Merge branch 'dev' into gpu_edt_transform
wyli Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ Post-processing
:members:
:special-members: __call__

`DistanceTransformEDT`
"""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDT
:members:
:special-members: __call__

`RemoveSmallObjects`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png
Expand Down Expand Up @@ -1622,6 +1628,12 @@ Post-processing (Dict)
:members:
:special-members: __call__

`DistanceTransformEDTd`
""""""""""""""""""""""""""""""""
.. autoclass:: DistanceTransformEDTd
:members:
:special-members: __call__

`RemoveSmallObjectsd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
from .post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
Invert,
KeepLargestConnectedComponent,
Expand All @@ -295,6 +296,9 @@
AsDiscreteD,
AsDiscreted,
AsDiscreteDict,
DistanceTransformEDTd,
DistanceTransformEDTD,
DistanceTransformEDTDict,
Ensembled,
EnsembleD,
EnsembleDict,
Expand Down
38 changes: 38 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import (
convert_applied_interp_mode,
distance_transform_edt,
fill_holes,
get_largest_connected_component_mask,
get_unique_labels,
Expand All @@ -53,6 +54,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"DistanceTransformEDT",
]


Expand Down Expand Up @@ -936,3 +938,39 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]

return grads


class DistanceTransformEDT(Transform):
"""
Applies the Euclidean distance transform on the input.
Either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.

Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.

.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
"""

backend = [TransformBackends.NUMPY, TransformBackends.CUPY]

def __init__(self, sampling: None | float | list[float] = None) -> None:
super().__init__()
self.sampling = sampling

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: Input image on which the distance transform shall be run.
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.

Returns:
An array with the same shape and data type as img
"""
return distance_transform_edt(img=img, sampling=self.sampling) # type: ignore
50 changes: 50 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from monai.transforms.post.array import (
Activations,
AsDiscrete,
DistanceTransformEDT,
FillHoles,
KeepLargestConnectedComponent,
LabelFilter,
Expand Down Expand Up @@ -91,6 +92,9 @@
"VoteEnsembleD",
"VoteEnsembleDict",
"VoteEnsembled",
"DistanceTransformEDTd",
"DistanceTransformEDTD",
"DistanceTransformEDTDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -855,6 +859,51 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class DistanceTransformEDTd(MapTransform):
"""
Applies the Euclidean distance transform on the input.
Either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.

Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.


Note on the input shape:
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.

Args:
keys: keys of the corresponding items to be transformed.
allow_missing_keys: don't raise exception if key is missing.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.

.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt


"""

backend = DistanceTransformEDT.backend

def __init__(
self, keys: KeysCollection, allow_missing_keys: bool = False, sampling: None | float | list[float] = None
) -> None:
super().__init__(keys, allow_missing_keys)
self.sampling = sampling
self.distance_transform = DistanceTransformEDT(sampling=self.sampling)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.distance_transform(img=d[key])

return d


ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
FillHolesD = FillHolesDict = FillHolesd
Expand All @@ -869,3 +918,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
EnsembleD = EnsembleDict = Ensembled
SobelGradientsD = SobelGradientsDict = SobelGradientsd
DistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd
148 changes: 146 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@
pytorch_after,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
from monai.utils.type_conversion import (
convert_data_type,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
convert_to_tensor,
)

measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
morphology, has_morphology = optional_import("skimage.morphology")
ndimage, _ = optional_import("scipy.ndimage")
ndimage, has_ndimage = optional_import("scipy.ndimage")
cp, has_cp = optional_import("cupy")
cp_ndarray, _ = optional_import("cupy", name="ndarray")
exposure, has_skimage = optional_import("skimage.exposure")
Expand Down Expand Up @@ -124,6 +130,7 @@
"reset_ops_id",
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
]


Expand Down Expand Up @@ -2051,5 +2058,142 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str =
return True, None


def distance_transform_edt(
img: NdarrayOrTensor,
sampling: None | float | list[float] = None,
return_distances: bool = True,
return_indices: bool = False,
distances: NdarrayOrTensor | None = None,
indices: NdarrayOrTensor | None = None,
*,
block_params: tuple[int, int, int] | None = None,
matt3o marked this conversation as resolved.
Show resolved Hide resolved
float64_distances: bool = False,
) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy.
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
Note that the results of the libraries can differ, so stick to one if possible.
For details, check out the `SciPy`_ and `cuCIM`_ documentation.
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
Args:
img: Input image on which the distance transform shall be run.
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
return_distances: Whether to calculate the distance transform.
return_indices: Whether to calculate the feature transform.
distances: An output array to store the calculated distance transform, instead of returning it.
`return_distances` must be True.
indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True.
block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_.
float64_distances: This parameter is specific to cuCIM and does not exist in SciPy.
If True, use double precision in the distance computation (to match SciPy behavior).
Otherwise, single precision will be used for efficiency.
Returns:
distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied.
It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True,
otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64.
indices: The calculated feature transform. It has an image-shaped array for each dimension of the image.
Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64.
"""
distance_transform_edt, has_cucim = optional_import(
"cucim.core.operations.morphology", name="distance_transform_edt"
)
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda"

if not return_distances and not return_indices:
raise RuntimeError("Neither return_distances nor return_indices True")

if not (img.ndim >= 3 and img.ndim <= 4):
raise RuntimeError("Wrong input dimensionality. Use (num_channels, H, W [,D])")

distances_original, indices_original = distances, indices
distances, indices = None, None
matt3o marked this conversation as resolved.
Show resolved Hide resolved
if use_cp:
distances_, indices_ = None, None
if return_distances:
dtype = torch.float64 if float64_distances else torch.float32
if distances is None:
distances = torch.zeros_like(img, dtype=dtype) # type: ignore
else:
if not isinstance(distances, torch.Tensor) and distances.device != img.device:
raise TypeError("distances must be a torch.Tensor on the same device as img")
if not distances.dtype == dtype:
raise TypeError("distances must be a torch.Tensor of dtype float32 or float64")
distances_ = convert_to_cupy(distances)
if return_indices:
dtype = torch.int32
if indices is None:
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
else:
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
raise TypeError("indices must be a torch.Tensor on the same device as img")
if not indices.dtype == dtype:
raise TypeError("indices must be a torch.Tensor of dtype int32")
indices_ = convert_to_cupy(indices)
img_ = convert_to_cupy(img)
for channel_idx in range(img_.shape[0]):
distance_transform_edt(
img_[channel_idx],
sampling=sampling,
return_distances=return_distances,
return_indices=return_indices,
distances=distances_[channel_idx] if distances_ is not None else None,
indices=indices_[channel_idx] if indices_ is not None else None,
block_params=block_params,
float64_distances=float64_distances,
)
else:
if not has_ndimage:
raise RuntimeError("scipy.ndimage required if cupy is not available")
img_ = convert_to_numpy(img)
if return_distances:
if distances is None:
distances = np.zeros_like(img_, dtype=np.float64)
else:
if not isinstance(distances, np.ndarray):
raise TypeError("distances must be a numpy.ndarray")
if not distances.dtype == np.float64:
raise TypeError("distances must be a numpy.ndarray of dtype float64")
if return_indices:
if indices is None:
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
else:
if not isinstance(indices, np.ndarray):
raise TypeError("indices must be a numpy.ndarray")
if not indices.dtype == np.int32:
raise TypeError("indices must be a numpy.ndarray of dtype int32")

for channel_idx in range(img_.shape[0]):
ndimage.distance_transform_edt(
img_[channel_idx],
sampling=sampling,
return_distances=return_distances,
return_indices=return_indices,
distances=distances[channel_idx] if distances is not None else None,
indices=indices[channel_idx] if indices is not None else None,
)

r_vals = []
if return_distances and distances_original is None:
r_vals.append(distances)
if return_indices and indices_original is None:
r_vals.append(indices)
if not r_vals:
return None
if len(r_vals) == 1:
return r_vals[0]
return tuple(r_vals) # type: ignore


if __name__ == "__main__":
print_transform_backends()
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def run_testsuit():
"test_deepgrow_transforms",
"test_detect_envelope",
"test_dints_network",
"test_distance_transform_edt",
"test_efficientnet",
"test_ensemble_evaluator",
"test_ensure_channel_first",
Expand Down
Loading