Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sampling kwarg for distance_transform_edt (take pixel/voxel sizes into account) #407

Merged
merged 11 commits into from
Nov 28, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from ._pba_2d import _pba_2d
from ._pba_3d import _pba_3d

# TODO: support sampling distances
# support the distances and indices output arguments
# support chamfer, chessboard and l1/manhattan distances too?
# TODO: support the distances and indices output arguments
# support chamfer/chessboard and taxicab/manhattan distances too?


def distance_transform_edt(image, sampling=None, return_distances=True,
Expand Down Expand Up @@ -34,14 +33,17 @@ def distance_transform_edt(image, sampling=None, return_distances=True,
Whether to calculate the distance transform.
return_indices : bool, optional
Whether to calculate the feature transform.
distances : float32 cupy.ndarray, optional
distances : cupy.ndarray, optional
An output array to store the calculated distance transform, instead of
returning it. `return_distances` must be True. It must be the same
shape as `image`.
indices : int32 cupy.ndarray, optional
returning it. `return_distances` must be ``True``. It must be the same
shape as `image`. Should have dtype ``cp.float32`` if
`float64_distances` is ``False``, otherwise it should be
``cp.float64``.
indices : cupy.ndarray, optional
An output array to store the calculated feature transform, instead of
returning it. `return_indicies` must be True. Its shape must be
`(image.ndim,) + image.shape`.
returning it. `return_indicies` must be ``True``. Its shape must be
``(image.ndim,) + image.shape``. Its dtype must be a signed or unsigned
integer type of at least 16-bits in 2D or 32-bits in 3D.

Other Parameters
----------------
Expand All @@ -57,14 +59,16 @@ def distance_transform_edt(image, sampling=None, return_distances=True,

Returns
-------
distances : float64 ndarray, optional
distances : cupy.ndarray, optional
The calculated distance transform. Returned only when
`return_distances` is True and `distances` is not supplied. It will
have the same shape as `image`.
indices : int32 ndarray, optional
`return_distances` is ``True`` and `distances` is not supplied. It will
have the same shape as `image`. Will have dtype `cp.float64` if
`float64_distances` is ``True``, otherwise it will have dtype
``cp.float32``.
indices : ndarray, optional
The calculated feature transform. It has an image-shaped array for each
dimension of the image. See example below. Returned only when
`return_indices` is True and `indices` is not supplied.
`return_indices` is ``True`` and `indices` is not supplied.

Notes
-----
Expand Down Expand Up @@ -139,24 +143,14 @@ def distance_transform_edt(image, sampling=None, return_distances=True,
[0, 0, 3, 3, 4]]])

"""
if distances is not None:
raise NotImplementedError(
"preallocated distances image is not supported"
)
if indices is not None:
raise NotImplementedError(
"preallocated indices image is not supported"
)
scalar_sampling = None
if sampling is not None:
sampling = np.unique(np.atleast_1d(sampling))
if len(sampling) == 1:
scalar_sampling = float(sampling)
unique_sampling = np.unique(np.atleast_1d(sampling))
if len(unique_sampling) == 1:
# In the isotropic case, can use the kernels without sample scaling
# and just adjust the final distance accordingly.
scalar_sampling = float(unique_sampling)
sampling = None
else:
raise NotImplementedError(
"non-uniform values in sampling is not currently supported"
)

if image.ndim == 3:
pba_func = _pba_3d
Expand All @@ -171,11 +165,16 @@ def distance_transform_edt(image, sampling=None, return_distances=True,
sampling=sampling,
return_distances=return_distances,
return_indices=return_indices,
block_params=block_params
block_params=block_params,
distances=distances,
indices=indices,
)

if return_distances and scalar_sampling is not None:
vals = (vals[0] * scalar_sampling,) + vals[1:]
# inplace multiply in case distance != None
vals = list(vals)
vals[0] *= scalar_sampling
vals = tuple(vals)

if len(vals) == 1:
vals = vals[0]
Expand Down
Loading