Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distance_transform_edt and the DistanceTransformEDT transform #6981

Merged
merged 26 commits into from
Sep 27, 2023

Conversation

matt3o
Copy link
Contributor

@matt3o matt3o commented Sep 13, 2023

Related to #6845, this commits adds an EDT distance transform to MONAI.
Most importantly this enables GPU based distance transforms which lead to a huge speedup.

Description

A few sentences describing the changes proposed in this pull request.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • In-line docstrings updated.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • Documentation updated, tested make html command in the docs/ folder.

@matt3o matt3o force-pushed the gpu_edt_transform branch 2 times, most recently from 880189e to 3d93616 Compare September 13, 2023 15:31
Signed-off-by: Matthias Hadlich <[email protected]>
@matt3o
Copy link
Contributor Author

matt3o commented Sep 13, 2023

If that PR is accepted I would change the code in following functions to use the new code:
compute_multi_instance_mask() in monai/apps/pathology
get_surface_distance() in monai/metrics. It also offers cdt, which I can't change since there is no GPU implementation available

That should speed up at least the SurfaceDice a lot, currently it is rather slow on big volumes.

Please give feedback if that would be fine as well, thanks!

@wyli
Copy link
Contributor

wyli commented Sep 13, 2023

looks great, please add the unit test file names to this list to skip them in the minimal testing environments

exclude_cases = [ # these cases use external dependencies

Signed-off-by: Matthias Hadlich <[email protected]>
@matt3o
Copy link
Contributor Author

matt3o commented Sep 13, 2023

I just added the addtional DistanceTransformEDTd transform. Also @wyli shall I implement the two changes above that I mentioned? Or shall we keep it separate and I create a second PR?

@wyli
Copy link
Contributor

wyli commented Sep 13, 2023

Please keep them separate, modifying existing algorithm will require more checks to prevent regression...

@matt3o
Copy link
Contributor Author

matt3o commented Sep 14, 2023

@wyli How can I fix the optional import error?

monai.utils.module.OptionalImportError: import cupy (No module named 'cupy').

@wyli
Copy link
Contributor

wyli commented Sep 14, 2023

yes, I tried running it locally, the parameter expand decorator should be put before the skipping decorator:

    @parameterized.expand(SAMPLING_TEST_CASES)
    @skip_if_no_cuda
    @unittest.skipUnless(HAS_CUPY, "CuPy is required.")
    @unittest.skipUnless(momorphology, "cuCIM transforms are required.")

also test_sampling might need some decorators to skip as well.

Signed-off-by: Matthias Hadlich <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
@wyli
Copy link
Contributor

wyli commented Sep 14, 2023

/build

@john-zielke-snkeos
Copy link
Contributor

Hey, I just noticed this PR and we seem to be working on similar functionality regarding the euclidean distance transform. So before any merging, we should make sure that this PR and #7008 don't duplicate functionality and make the changes in a consistent implementation.

@wyli
Copy link
Contributor

wyli commented Sep 19, 2023

it seems #7008 could be refactored to use the utility function defined in this PR... I'd suggest @john-zielke-snkeos please also help review this PR and we merge this first, what do you think?

@wyli wyli requested a review from ericspod September 19, 2023 15:37
matt3o and others added 3 commits September 23, 2023 15:11
Co-authored-by: YunLiu <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
@matt3o
Copy link
Contributor Author

matt3o commented Sep 23, 2023

Hopefully final questions:
Both scipy and cupy have the common parameters return_distances=True, return_indices=False, distances=None, indices=None, as @john-zielke-snkeos mentioned. Do we want to expose them too?
At least return_distances=True, return_indices=False might make sense for the utils.py function, distances and indices might speed up code - I have not used them so far. The last two I would probably leave away for now, they can be added later on if that is necessary. return_distances=True, return_indices=False however do change the API, so we should decide on that now.
The transforms I would leave as they are until that the indices are actually required.

Those two flags do complicate the code a lot and I'm not sure how to solve it without duplication then:

    distance_transform_edt, has_cucim = optional_import(
        "cucim.core.operations.morphology", name="distance_transform_edt"
    )
    use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu")

    if not return_distances and not return_indices:
        raise RuntimeError("Neither return_distances nor return_indices True")

    distances, indices = [], []
    if use_cp:
        img_ = convert_to_cupy(img)
        for channel in img_:
            outputs = distance_transform_edt(channel, sampling=sampling, return_distances=return_distances, return_indices=return_indices)

            if return_distances and return_indices:
                distances.append(outputs[0])
                indices.append(outputs[1])
            elif return_distances:
                distances.append(outputs)
            elif return_indices:
                indices.append(outputs)
    else:
        if not has_ndimage:
            raise RuntimeError("scipy.ndimage required if cupy is not available")
        img_ = convert_to_numpy(img)
        for channel in img_:
            outputs =  ndimage.distance_transform_edt(channel, sampling=sampling, return_distances=return_distances, return_indices=return_indices)

            if return_distances and return_indices:
                distances.append(outputs[0])
                indices.append(outputs[1])
            elif return_distances:
                distances.append(outputs)
            elif return_indices:
                indices.append(outputs)

    d_out, i_out = None, None
    if use_cp:
        if return_distances and return_indices:
            d_out = cp.stack(distances)
        if return_indices:
            i_out = cp.stack(indices)
    else:
        if return_distances and return_indices:
            d_out = np.stack(distances)
        if return_indices:
            i_out = np.stack(indices)

    return d_out, i_out

Signed-off-by: Matthias Hadlich <[email protected]>
Signed-off-by: Matthias Hadlich <[email protected]>
@KumoLiu
Copy link
Contributor

KumoLiu commented Sep 25, 2023

Hi @matt3o, both are ok for me.
You can simplify the code like this:

lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack

Thanks!

@john-zielke-snkeos
Copy link
Contributor

How about sth like this:

def distance_transform_edt(
    img: NdarrayOrTensor,
    sampling: None | float | list[float] = None,
    return_distances: bool = True,
    return_indices: bool = False,
    distances=None,
    indices=None,
    *,
    block_params=None,
    float64_distances=False,
) -> NdarrayOrTensor:
    """
    Euclidean distance transform, either GPU based with CuPy / cuCIM
    or CPU based with scipy.ndimage.
    Choice depends on cuCIM being available or scipy can be forced with the ``force_scipy`` flag.

    Note that the runtime running on the CPU may be really depending on the inputs size.

    Args:
        ...
    """
    distance_transform_edt, has_cucim = optional_import(
        "cucim.core.operations.morphology", name="distance_transform_edt"
    )
    use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu")

    if not return_distances and not return_indices:
        raise RuntimeError("Neither return_distances nor return_indices True")
    distances_original, indices_original = distances, indices
    distances, indices = None, None
    if use_cp:
        distances_, indices_ = None, None
        if return_distances:
            dtype = torch.float64 if float64_distances else torch.float32
            if distances is None:
                distances = torch.zeros_like(img, dtype=dtype)
            else:
                if not isinstance(distances, torch.Tensor) and distances.device != img.device:
                    raise TypeError("distances must be a torch.Tensor on the same device as img")
                if not distances.dtype == dtype:
                    raise TypeError("distances must be a torch.Tensor of dtype float32 or float64")
            distances_ = convert_to_cupy(distances)
        if return_indices:
            dtype = torch.int32
            if indices is None:
                indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype)
            else:
                if not isinstance(indices, torch.Tensor) and indices.device != img.device:
                    raise TypeError("indices must be a torch.Tensor on the same device as img")
                if not indices.dtype == dtype:
                    raise TypeError("indices must be a torch.Tensor of dtype int32")
            indices_ = convert_to_cupy(indices)
        img_ = convert_to_cupy(img)
        for channel_idx in range(img_.shape[0]):
            distance_transform_edt(
                img_[channel_idx],
                sampling=sampling,
                return_distances=return_distances,
                return_indices=return_indices,
                distances=distances_[channel_idx] if distances_ is not None else None,
                indices=indices_[channel_idx] if indices_ is not None else None,
                block_params=block_params,
                float64_distances=float64_distances,
            )
    else:
        if not has_ndimage:
            raise RuntimeError("scipy.ndimage required if cupy is not available")
        img_ = convert_to_numpy(img)
        if return_distances:
            if distances is None:
                distances = np.zeros_like(img_, dtype=np.float64)
            else:
                if not isinstance(distances, np.ndarray):
                    raise TypeError("distances must be a numpy.ndarray")
                if not distances.dtype == np.float64:
                    raise TypeError("distances must be a numpy.ndarray of dtype float64")
        if return_indices:
            if indices is None:
                indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
            else:
                if not isinstance(indices, np.ndarray):
                    raise TypeError("indices must be a numpy.ndarray")
                if not indices.dtype == np.int32:
                    raise TypeError("indices must be a numpy.ndarray of dtype int32")

        for channel_idx in range(img_.shape[0]):
            ndimage.distance_transform_edt(
                img_[channel_idx],
                sampling=sampling,
                return_distances=return_distances,
                return_indices=return_indices,
                distances=distances[channel_idx] if distances is not None else None,
                indices=indices[channel_idx] if indices is not None else None,
            )

    r_vals = []
    if return_distances and distances_original is None:
        r_vals.append(distances)
    if return_indices and indices_original is None:
        r_vals.append(indices)
    if not r_vals:
        return None
    if len(r_vals) == 1:
        return r_vals[0]
    return tuple(r_vals)

@matt3o
Copy link
Contributor Author

matt3o commented Sep 26, 2023

Looking good @john-zielke-snkeos, thanks! My only concern is: Does MONAI want to expose the parameters of just cupy? Those flags don't do anything for scipy. Is that important to your use case to have block_params and float64_distances?

@john-zielke-snkeos
Copy link
Contributor

I think that's fine. We should mention that these are cucim specific just like they do in the cucim docs. And since they are simply ignored in the other case, I think that is fine. With the proper docs and them being keyword-only arguments, normal users will not really stumble upon these and users that need the options (especially the precision one) still have the possibility to use them.

@john-zielke-snkeos
Copy link
Contributor

@matt3o @wyli What I just wondered is if it's not better to check for use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" instead of use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu"). Otherwise amd GPUs and other accelerators like the AWS one would fail in this case. And if you agree, should we also change this in https://github.com/Project-MONAI/MONAI/blob/e5f933781b1a5f4eb7d278e06389d531bff38dca/monai/transforms/utils.py#L1031C22-L1031C22 ?

@wyli
Copy link
Contributor

wyli commented Sep 26, 2023

@matt3o @wyli What I just wondered is if it's not better to check for use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" instead of use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu"). Otherwise amd GPUs and other accelerators like the AWS one would fail in this case. And if you agree, should we also change this in https://github.com/Project-MONAI/MONAI/blob/e5f933781b1a5f4eb7d278e06389d531bff38dca/monai/transforms/utils.py#L1031C22-L1031C22 ?

thanks, I agree, we should be more specific in the if condition.

@matt3o
Copy link
Contributor Author

matt3o commented Sep 26, 2023

Sure, I'll match that right now. I will have to change the check to
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == torch.device("cuda").type
however, otherwise it won't work with "cuda:0" etc devices

pre-commit-ci bot and others added 2 commits September 26, 2023 18:37
I, Matthias Hadlich <[email protected]>, hereby add my Signed-off-by to this commit: 101cc62
I, Matthias Hadlich <[email protected]>, hereby add my Signed-off-by to this commit: 24814c4

Signed-off-by: Matthias Hadlich <[email protected]>
monai/transforms/utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you, it looks good to me, please help address the type annotation issues https://github.com/Project-MONAI/MONAI/actions/runs/6316993047/job/17169932102?pr=6981 and the docstring issues https://github.com/Project-MONAI/MONAI/actions/runs/6316993047/job/17169930462?pr=6981 then I'll trigger more tests and merged the PR.

monai/transforms/utils.py Outdated Show resolved Hide resolved
monai/transforms/utils.py Outdated Show resolved Hide resolved
monai/transforms/utils.py Show resolved Hide resolved
monai/transforms/utils.py Outdated Show resolved Hide resolved
monai/transforms/utils.py Show resolved Hide resolved
Signed-off-by: Matthias Hadlich <[email protected]>
@wyli
Copy link
Contributor

wyli commented Sep 27, 2023

/build

@wyli wyli marked this pull request as ready for review September 27, 2023 15:27
@wyli
Copy link
Contributor

wyli commented Sep 27, 2023

many thanks @matt3o for the PR, and @john-zielke-snkeos @KumoLiu for the detailed comments, all premerge tests passed, and I'm merging this PR to include it in the upcoming monai v1.3 , please feel free to create follow-ups..

@wyli wyli merged commit 84566d1 into Project-MONAI:dev Sep 27, 2023
26 checks passed
@john-zielke-snkeos
Copy link
Contributor

That's great news, thank you very much for the support! I'll adjust my Pull Request ASAP, would be nice if we could include the faster Metrics in 1.3 as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants