Skip to content

Commit

Permalink
add device in HoVerNetNuclearTypePostProcessing and `HoVerNetInst…
Browse files Browse the repository at this point in the history
…anceMapPostProcessing` (#6333)

Fixes # .

### Description

Since some operations in post-processing of HoVerNet will convert data
to numpy. And most of the time we need to calculate the metric from the
model output and label which should both be on CUDA.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Apr 11, 2023
1 parent d14e8f0 commit fa7411a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
17 changes: 13 additions & 4 deletions monai/apps/pathology/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
from monai.utils import TransformBackends, convert_to_numpy, optional_import
from monai.utils.misc import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

label, _ = optional_import("scipy.ndimage.measurements", name="label")
disk, _ = optional_import("skimage.morphology", name="disk")
Expand Down Expand Up @@ -671,6 +671,7 @@ class HoVerNetInstanceMapPostProcessing(Transform):
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
If not provided, the level is set to `(max(image) + min(image)) / 2`.
device: target device to put the output Tensor data.
"""

def __init__(
Expand All @@ -686,9 +687,10 @@ def __init__(
watershed_connectivity: int | None = 1,
min_num_points: int = 3,
contour_level: float | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()

self.device = device
self.generate_watershed_mask = GenerateWatershedMask(
activation=activation, threshold=mask_threshold, min_object_size=min_object_size
)
Expand Down Expand Up @@ -742,7 +744,7 @@ def __call__( # type: ignore
"centroid": instance_centroid,
"contour": instance_contour,
}

instance_map = convert_to_tensor(instance_map, device=self.device)
return instance_info, instance_map


Expand All @@ -758,13 +760,19 @@ class HoVerNetNuclearTypePostProcessing(Transform):
threshold: an optional float value to threshold to binarize probability map.
If not provided, defaults to 0.5 when activation is not "softmax", otherwise None.
return_type_map: whether to calculate and return pixel-level type map.
device: target device to put the output Tensor data.
"""

def __init__(
self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True
self,
activation: str | Callable = "softmax",
threshold: float | None = None,
return_type_map: bool = True,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.device = device
self.return_type_map = return_type_map
self.generate_instance_type = GenerateInstanceType()

Expand Down Expand Up @@ -824,5 +832,6 @@ def __call__( # type: ignore
# update instance type map
if type_map is not None:
type_map[instance_map == inst_id] = instance_type
type_map = convert_to_tensor(type_map, device=self.device)

return instance_info, type_map
9 changes: 7 additions & 2 deletions monai/apps/pathology/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections.abc import Callable, Hashable, Mapping

import numpy as np
import torch

from monai.apps.pathology.transforms.post.array import (
GenerateDistanceMap,
Expand Down Expand Up @@ -488,6 +489,7 @@ class HoVerNetInstanceMapPostProcessingd(Transform):
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
If not provided, the level is set to `(max(image) + min(image)) / 2`.
device: target device to put the output Tensor data.
"""

def __init__(
Expand All @@ -507,6 +509,7 @@ def __init__(
watershed_connectivity: int | None = 1,
min_num_points: int = 3,
contour_level: float | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.instance_map_post_process = HoVerNetInstanceMapPostProcessing(
Expand All @@ -521,6 +524,7 @@ def __init__(
watershed_connectivity=watershed_connectivity,
min_num_points=min_num_points,
contour_level=contour_level,
device=device,
)
self.nuclear_prediction_key = nuclear_prediction_key
self.hover_map_key = hover_map_key
Expand Down Expand Up @@ -553,7 +557,7 @@ class HoVerNetNuclearTypePostProcessingd(Transform):
Defaults to `"instance_info"`.
instance_map_key: the key where instance map is stored. Defaults to `"instance_map"`.
type_map_key: the output key where type map is written. Defaults to `"type_map"`.
device: target device to put the output Tensor data.
"""

Expand All @@ -566,10 +570,11 @@ def __init__(
activation: str | Callable = "softmax",
threshold: float | None = None,
return_type_map: bool = True,
device: str | torch.device | None = None,
) -> None:
super().__init__()
self.type_post_process = HoVerNetNuclearTypePostProcessing(
activation=activation, threshold=threshold, return_type_map=return_type_map
activation=activation, threshold=threshold, return_type_map=return_type_map, device=device
)
self.type_prediction_key = type_prediction_key
self.instance_info_key = instance_info_key
Expand Down

0 comments on commit fa7411a

Please sign in to comment.