Skip to content

Commit

Permalink
Add utils for vista3d (#7999)
Browse files Browse the repository at this point in the history
This PR is a part of #7987 

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent 6be7b13 commit f848002
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 9 deletions.
8 changes: 0 additions & 8 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,3 @@ FastMRIReader

.. autoclass:: monai.apps.nnunet.nnUNetV2Runner
:members:

`Generative AI`
---------------

`MAISI Utilities`
~~~~~~~~~~~~~~~~~
.. automodule:: monai.apps.generation.maisi.utils.morphological_ops
:members:
3 changes: 3 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,9 @@ Utilities
.. automodule:: monai.transforms.utils_pytorch_numpy_unification
:members:

.. automodule:: monai.transforms.utils_morphological_ops
:members:

By Categories
-------------
.. toctree::
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@
weighted_patch_samples,
zero_margins,
)
from .utils_morphological_ops import dilate, erode
from .utils_pytorch_numpy_unification import (
allclose,
any_np_pt,
Expand Down
183 changes: 183 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import torch
from torch import Tensor

import monai
from monai.config import DtypeLike, IndexSelection
Expand All @@ -30,6 +31,7 @@
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform, Transform, apply_transform
from monai.transforms.utils_morphological_ops import erode
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
Expand Down Expand Up @@ -65,6 +67,8 @@
min_version,
optional_import,
pytorch_after,
unsqueeze_left,
unsqueeze_right,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import (
Expand Down Expand Up @@ -103,6 +107,8 @@
"generate_spatial_bounding_box",
"get_extreme_points",
"get_largest_connected_component_mask",
"get_largest_connected_component_mask_point",
"convert_points_to_disc",
"remove_small_objects",
"img_bounds",
"in_bounds",
Expand Down Expand Up @@ -1172,6 +1178,183 @@ def get_largest_connected_component_mask(
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]


def get_largest_connected_component_mask_point(
img_pos: NdarrayTensor,
img_neg: NdarrayTensor,
point_coords: NdarrayTensor,
point_labels: NdarrayTensor,
pos_val: Sequence[int] = (1, 3),
neg_val: Sequence[int] = (0, 2),
margins: int = 3,
) -> NdarrayTensor:
"""
Gets the connected component of img_pos and img_neg that include the positive points and
negative points separately. The function is used for combining automatic results with interactive
results in VISTA3D.
Args:
img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.
img_neg: same format as img_pos but corresponds to negative points.
pos_val: positive point label values.
neg_val: negative point label values.
point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
point_labels: the label of each point, shape [B, N].
"""

cucim_skimage, has_cucim = optional_import("cucim.skimage")

use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu")
if use_cp:
img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore
img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore
label = cucim_skimage.measure.label
lib = cp
else:
if not has_measure:
raise RuntimeError("skimage.measure required.")
img_pos_, *_ = convert_data_type(img_pos, np.ndarray)
img_neg_, *_ = convert_data_type(img_neg, np.ndarray)
# for skimage.measure.label, the input must be bool type
if img_pos_.dtype != bool or img_neg_.dtype != bool:
raise ValueError("img_pos and img_neg must be bool type.")
label = measure.label
lib = np

features_pos, _ = label(img_pos_, connectivity=3, return_num=True)
features_neg, _ = label(img_neg_, connectivity=3, return_num=True)

outs = np.zeros_like(img_pos_)
for bs in range(point_coords.shape[0]):
for i, p in enumerate(point_coords[bs]):
if point_labels[bs, i] in pos_val:
features = features_pos
elif point_labels[bs, i] in neg_val:
features = features_neg
else:
# if -1 padding point, skip
continue
for margin in range(margins):
if isinstance(p, np.ndarray):
x, y, z = np.round(p).astype(int).tolist()
else:
x, y, z = p.float().round().int().tolist()
l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])
t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])
f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])
if (features[bs, 0, l:r, t:d, f:b] > 0).any():
index = features[bs, 0, l:r, t:d, f:b].max()
outs[[bs]] += lib.isin(features[[bs]], index)
break
outs[outs > 1] = 1
return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]


def convert_points_to_disc(
image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
):
"""
Convert a 3D point coordinates into image mask. The returned mask has the same spatial
size as `image_size` while the batch dimension is the same as 'point' batch dimension.
The point is converted to a mask ball with radius defined by `radius`. The output
contains two channels each for negative (first channel) and positive points.
Args:
image_size: The output size of the converted mask. It should be a 3D tuple.
point: [B, N, 3], 3D point coordinates.
point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.
radius: disc ball radius size.
disc: If true, use regular disc, other use gaussian.
"""
masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)
_array = [
torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
]
coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0])
# [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
for b, n in np.ndindex(*point.shape[:2]):
point_bn = unsqueeze_right(point[b, n], 6)
if point_label[b, n] > -1:
channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1
pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2)
if disc:
masks[b, channel] += pow_diff.sum(0) < radius**2
else:
masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))
return masks


def sample_points_from_label(
labels: Tensor,
label_set: Sequence[int],
max_ppoint: int = 1,
max_npoint: int = 0,
device: torch.device | str | None = "cpu",
use_center: bool = False,
):
"""Sample points from labels.
Args:
labels: [1, 1, H, W, D]
label_set: local index, must match values in labels.
max_ppoint: maximum positive point samples.
max_npoint: maximum negative point samples.
device: returned tensor device.
use_center: whether to sample points from center.
Returns:
point: point coordinates of [B, N, 3]. B equals to the length of label_set.
point_label: [B, N], always 0 for negative, 1 for positive.
"""
if not labels.shape[0] == 1:
raise ValueError("labels must have batch size 1.")

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

labels = labels[0, 0]
unique_labels = labels.unique().cpu().numpy().tolist()
_point = []
_point_label = []
for id in label_set:
if id in unique_labels:
plabels = labels == int(id)
nlabels = ~plabels
_plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])
plabelpoints = torch.nonzero(_plabels).to(device)
if len(plabelpoints) == 0:
plabelpoints = torch.nonzero(plabels).to(device)
nlabelpoints = torch.nonzero(nlabels).to(device)
num_p = min(len(plabelpoints), max_ppoint)
num_n = min(len(nlabelpoints), max_npoint)
pad = max_ppoint + max_npoint - num_p - num_n
if use_center:
pmean = plabelpoints.float().mean(0)
pdis = ((plabelpoints - pmean) ** 2).sum(-1)
_, sorted_indices_tensor = torch.sort(pdis)
sorted_indices = sorted_indices_tensor.cpu().tolist()
else:
sorted_indices = list(range(len(plabelpoints)))
random.shuffle(sorted_indices)
_point.append(
torch.stack(
[plabelpoints[sorted_indices[i]] for i in range(num_p)]
+ random.choices(nlabelpoints, k=num_n)
+ [torch.tensor([0, 0, 0], device=device)] * pad
)
)
_point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))
else:
# pad the background labels
_point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))
_point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)
point = torch.stack(_point)
point_label = torch.stack(_point_label)

return point, point_label


def remove_small_objects(
img: NdarrayTensor,
min_size: int = 64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from monai.config import NdarrayOrTensor
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep

__all__ = ["erode", "dilate"]


def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def run_testsuit():
"test_zarr_avg_merger",
"test_perceptual_loss",
"test_ultrasound_confidence_map_transform",
"test_vista3d_utils",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_morphological_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t
from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS_SHAPE = []
Expand Down
Loading

0 comments on commit f848002

Please sign in to comment.