From cae979a426cc003a282ce756afc788757d6d08e7 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Thu, 15 Aug 2024 11:19:25 -0400 Subject: [PATCH 01/32] Add vista3d inferers Signed-off-by: heyufan1995 --- monai/apps/vista3d/inferer.py | 199 +++++++++++++++++++++++++++++++++ monai/inferers/utils.py | 2 + monai/networks/nets/vista3d.py | 20 ++++ 3 files changed, 221 insertions(+) create mode 100644 monai/apps/vista3d/inferer.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py new file mode 100644 index 0000000000..d7ffd6e0e8 --- /dev/null +++ b/monai/apps/vista3d/inferer.py @@ -0,0 +1,199 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import Any + +import copy +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.utils import ( + optional_import +) + +tqdm, _ = optional_import("tqdm", name="tqdm") + +__all__ = ["point_based_window_inferer"] + +def point_based_window_inferer( + inputs: torch.Tensor | MetaTensor, + roi_size: Sequence[int] | int, + predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], + point_coords: torch.Tensor | None = None, + point_labels: torch.Tensor | None = None, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + prev_mask: torch.Tensor | None = None, + point_start: int = 0, + **kwargs: Any, +): + """Point based window inferer, crop a patch centered at the point, and perform inference. Different patches are combined with gaussian weighted weights. + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output. The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N] + class_vector: [B] + prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! + Returns: + stitched_output: [1, B, H, W, D]. The value is before sigmoid. + Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. + """ + assert point_coords.shape[0] == 1, "Only supports single object point click" + image, pad = pad_previous_mask(copy.deepcopy(inputs), roi_size) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to( + point_coords.device + ) + prev_mask = ( + pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] + if prev_mask is not None + else None + ) + stitched_output = None + center_only = True + for p in point_coords[0][point_start:]: + lx_, rx_ = get_window_idx( + p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5 + ) + ly_, ry_ = get_window_idx( + p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5 + ) + lz_, rz_ = get_window_idx( + p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5 + ) + for i in range(len(lx_)): + for j in range(len(ly_)): + for k in range(len(lz_)): + lx, rx, ly, ry, lz, rz = ( + lx_[i], + rx_[i], + ly_[j], + ry_[j], + lz_[k], + rz_[k], + ) + unravel_slice = [ + slice(None), + slice(None), + slice(int(lx), int(rx)), + slice(int(ly), int(ry)), + slice(int(lz), int(rz)), + ] + batch_image = image[unravel_slice] + # ball = get_gaussian_ball(batch_image.shape[-3:]) + output = predictor( + batch_image, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + patch_coords=unravel_slice, + prev_mask=prev_mask, + **kwargs, + ) + if stitched_output is None: + stitched_output = torch.zeros( + [ + 1, + output.shape[1], + image.shape[-3], + image.shape[-2], + image.shape[-1], + ], + device="cpu", + ) + stitched_mask = torch.zeros( + [ + 1, + output.shape[1], + image.shape[-3], + image.shape[-2], + image.shape[-1], + ], + device="cpu", + ) + stitched_output[unravel_slice] += output.to("cpu") + stitched_mask[unravel_slice] = 1 + # if stitched_mask is 0, then NaN value + stitched_output = stitched_output / stitched_mask + # revert padding + stitched_output = stitched_output[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + stitched_mask = stitched_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + if prev_mask is not None: + prev_mask = prev_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + prev_mask = prev_mask.to("cpu") + # for un-calculated place, use previous mask + stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + if not hasattr(stitched_output, "meta"): + stitched_output = MetaTensor( + stitched_output, affine=inputs.meta["affine"], meta=inputs.meta + ) + return stitched_output + +def get_window_idx_c(p, roi, s): + if p - roi // 2 < 0: + left, right = 0, roi + elif p + roi // 2 > s: + left, right = s - roi, s + else: + left, right = int(p) - roi // 2, int(p) + roi // 2 + return left, right + + +def get_window_idx(p, roi, s, center_only=True, margin=5): + left, right = get_window_idx_c(p, roi, s) + if center_only: + return [left], [right] + left_most = max(0, p - roi + margin) + right_most = min(s, p + roi - margin) + left = [left_most, right_most - roi, left] + right = [left_most + roi, right_most, right] + return left, right + + +def pad_previous_mask(inputs, roi_size, padvalue=0): + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = torch.nn.functional.pad( + inputs, pad=pad_size, mode="constant", value=padvalue + ) + return inputs, pad_size \ No newline at end of file diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index a080284e7c..5b02e62758 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -16,6 +16,7 @@ from typing import Any, Iterable import numpy as np +import copy import torch import torch.nn.functional as F @@ -300,6 +301,7 @@ def sliding_window_inference( # remove padding if image_size smaller than roi_size if any(pad_size): + kwargs.update({'pad_size': pad_size}) for ss, output_i in enumerate(output_image_list): zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index fe7f93d493..d7b7cd15c4 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -77,6 +77,25 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.point_freeze = False self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 + + def update_slidingwindow_padding(self, pad_size: list | None, labels: torch.Tensor | None, prev_mask: torch.Tensor | None, point_coords: torch.Tensor | None): + """ Image has been padded by sliding window inferer. The related padding need to be performed outside of slidingwindow inferer. + Args: + pad_size: padding size passed from sliding window inferer. + labels: image label ground truth. + prev_mask: previous segmentation mask. + point_coords: point click coordinates. + """ + if pad_size is None: + return labels, prev_mask, point_coords + if labels is not None: + labels = F.pad(labels, pad=pad_size, mode='constant', val=0) + if prev_mask is not None: + prev_mask = F.pad(prev_mask, pad=pad_size, mode='constant', val=0) + if point_coords is not None: + point_coords = point_coords + torch.tensor([pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device) + return labels, prev_mask, point_coords + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: """Get number of foreground classes based on class and point prompt.""" @@ -348,6 +367,7 @@ def forward( val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. """ + labels, prev_mask, point_coords = self.update_slidingwindow_padding(kwargs.get('pad_size', None), labels, prev_mask, point_coords) image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: From daf7b459e3f83d9a61de17a22d69f796ea87cfec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 15:21:18 +0000 Subject: [PATCH 02/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/inferer.py | 2 +- monai/inferers/utils.py | 1 - monai/networks/nets/vista3d.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index d7ffd6e0e8..6416830850 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -196,4 +196,4 @@ def pad_previous_mask(inputs, roi_size, padvalue=0): inputs = torch.nn.functional.pad( inputs, pad=pad_size, mode="constant", value=padvalue ) - return inputs, pad_size \ No newline at end of file + return inputs, pad_size diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 5b02e62758..8fe8129f0d 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -16,7 +16,6 @@ from typing import Any, Iterable import numpy as np -import copy import torch import torch.nn.functional as F diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index d7b7cd15c4..8ea75c0b8a 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -77,7 +77,7 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.point_freeze = False self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 - + def update_slidingwindow_padding(self, pad_size: list | None, labels: torch.Tensor | None, prev_mask: torch.Tensor | None, point_coords: torch.Tensor | None): """ Image has been padded by sliding window inferer. The related padding need to be performed outside of slidingwindow inferer. Args: From ee3ceaad311b9f2ff7d6915e1fb4a4c21cea98d8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Aug 2024 16:58:22 +0800 Subject: [PATCH 03/32] fix format issues Signed-off-by: Yiheng Wang --- docs/source/apps.rst | 4 + .../maisi/utils => vista3d}/__init__.py | 0 monai/apps/vista3d/inferer.py | 102 ++++++------------ monai/inferers/utils.py | 2 +- monai/networks/nets/vista3d.py | 28 +++-- 5 files changed, 55 insertions(+), 81 deletions(-) rename monai/apps/{generation/maisi/utils => vista3d}/__init__.py (100%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 7fa7b9e9ff..da4c70177b 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -248,6 +248,10 @@ FastMRIReader ~~~~~~~~~~~~~ .. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj +`Vista3d` +--------- +.. autofunction:: monai.apps.vista3d.inferer.point_based_window_inferer + `Auto3DSeg` ----------- .. automodule:: monai.apps.auto3dseg diff --git a/monai/apps/generation/maisi/utils/__init__.py b/monai/apps/vista3d/__init__.py similarity index 100% rename from monai/apps/generation/maisi/utils/__init__.py rename to monai/apps/vista3d/__init__.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 6416830850..538fd64889 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -11,34 +11,36 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +import copy +from collections.abc import Sequence from typing import Any -import copy import torch from monai.data.meta_tensor import MetaTensor -from monai.utils import ( - optional_import -) +from monai.utils import optional_import tqdm, _ = optional_import("tqdm", name="tqdm") __all__ = ["point_based_window_inferer"] + def point_based_window_inferer( inputs: torch.Tensor | MetaTensor, - roi_size: Sequence[int] | int, - predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], - point_coords: torch.Tensor | None = None, - point_labels: torch.Tensor | None = None, + roi_size: Sequence[int], + predictor: torch.nn.Module, + point_coords: torch.Tensor, + point_labels: torch.Tensor, class_vector: torch.Tensor | None = None, prompt_class: torch.Tensor | None = None, prev_mask: torch.Tensor | None = None, point_start: int = 0, **kwargs: Any, -): - """Point based window inferer, crop a patch centered at the point, and perform inference. Different patches are combined with gaussian weighted weights. +) -> torch.Tensor: + """ + Point based window inferer, crop a patch centered at the point, and perform inference. + Different patches are combined with gaussian weighted weights. + Args: inputs: input image to be processed (assuming NCHW[D]) roi_size: the spatial window size for inferences. @@ -47,48 +49,30 @@ def point_based_window_inferer( corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. - predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output. The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D] + predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output. + The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. point_coords: [B, N, 3] point_labels: [B, N] class_vector: [B] - prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! + prev_mask: [1, B, H, W, D]. The value is before sigmoid. Returns: stitched_output: [1, B, H, W, D]. The value is before sigmoid. Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. """ assert point_coords.shape[0] == 1, "Only supports single object point click" image, pad = pad_previous_mask(copy.deepcopy(inputs), roi_size) - point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to( - point_coords.device - ) - prev_mask = ( - pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] - if prev_mask is not None - else None - ) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) + prev_mask = pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None stitched_output = None center_only = True for p in point_coords[0][point_start:]: - lx_, rx_ = get_window_idx( - p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5 - ) - ly_, ry_ = get_window_idx( - p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5 - ) - lz_, rz_ = get_window_idx( - p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5 - ) + lx_, rx_ = get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5) + ly_, ry_ = get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5) + lz_, rz_ = get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5) for i in range(len(lx_)): for j in range(len(ly_)): for k in range(len(lz_)): - lx, rx, ly, ry, lz, rz = ( - lx_[i], - rx_[i], - ly_[j], - ry_[j], - lz_[k], - rz_[k], - ) + lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k]) unravel_slice = [ slice(None), slice(None), @@ -97,7 +81,6 @@ def point_based_window_inferer( slice(int(lz), int(rz)), ] batch_image = image[unravel_slice] - # ball = get_gaussian_ball(batch_image.shape[-3:]) output = predictor( batch_image, point_coords=point_coords, @@ -110,24 +93,10 @@ def point_based_window_inferer( ) if stitched_output is None: stitched_output = torch.zeros( - [ - 1, - output.shape[1], - image.shape[-3], - image.shape[-2], - image.shape[-1], - ], - device="cpu", + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" ) stitched_mask = torch.zeros( - [ - 1, - output.shape[1], - image.shape[-3], - image.shape[-2], - image.shape[-1], - ], - device="cpu", + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" ) stitched_output[unravel_slice] += output.to("cpu") stitched_mask[unravel_slice] = 1 @@ -135,18 +104,10 @@ def point_based_window_inferer( stitched_output = stitched_output / stitched_mask # revert padding stitched_output = stitched_output[ - :, - :, - pad[4] : image.shape[-3] - pad[5], - pad[2] : image.shape[-2] - pad[3], - pad[0] : image.shape[-1] - pad[1], + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] ] stitched_mask = stitched_mask[ - :, - :, - pad[4] : image.shape[-3] - pad[5], - pad[2] : image.shape[-2] - pad[3], - pad[0] : image.shape[-1] - pad[1], + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] ] if prev_mask is not None: prev_mask = prev_mask[ @@ -156,15 +117,14 @@ def point_based_window_inferer( pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1], ] - prev_mask = prev_mask.to("cpu") + prev_mask = prev_mask.to("cpu") # type: ignore # for un-calculated place, use previous mask stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] if not hasattr(stitched_output, "meta"): - stitched_output = MetaTensor( - stitched_output, affine=inputs.meta["affine"], meta=inputs.meta - ) + stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) # type: ignore return stitched_output + def get_window_idx_c(p, roi, s): if p - roi // 2 < 0: left, right = 0, roi @@ -193,7 +153,5 @@ def pad_previous_mask(inputs, roi_size, padvalue=0): half = diff // 2 pad_size.extend([half, diff - half]) if any(pad_size): - inputs = torch.nn.functional.pad( - inputs, pad=pad_size, mode="constant", value=padvalue - ) + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) return inputs, pad_size diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 8fe8129f0d..bd99765348 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -300,7 +300,7 @@ def sliding_window_inference( # remove padding if image_size smaller than roi_size if any(pad_size): - kwargs.update({'pad_size': pad_size}) + kwargs.update({"pad_size": pad_size}) for ss, output_i in enumerate(output_image_list): zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 8ea75c0b8a..3b9b2a10fa 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -78,8 +78,17 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 - def update_slidingwindow_padding(self, pad_size: list | None, labels: torch.Tensor | None, prev_mask: torch.Tensor | None, point_coords: torch.Tensor | None): - """ Image has been padded by sliding window inferer. The related padding need to be performed outside of slidingwindow inferer. + def update_slidingwindow_padding( + self, + pad_size: list | None, + labels: torch.Tensor | None, + prev_mask: torch.Tensor | None, + point_coords: torch.Tensor | None, + ): + """ + Image has been padded by sliding window inferer. + The related padding need to be performed outside of slidingwindow inferer. + Args: pad_size: padding size passed from sliding window inferer. labels: image label ground truth. @@ -89,14 +98,15 @@ def update_slidingwindow_padding(self, pad_size: list | None, labels: torch.Tens if pad_size is None: return labels, prev_mask, point_coords if labels is not None: - labels = F.pad(labels, pad=pad_size, mode='constant', val=0) + labels = F.pad(labels, pad=pad_size, mode="constant", value=0) if prev_mask is not None: - prev_mask = F.pad(prev_mask, pad=pad_size, mode='constant', val=0) + prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0) if point_coords is not None: - point_coords = point_coords + torch.tensor([pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device) + point_coords = point_coords + torch.tensor( + [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device + ) return labels, prev_mask, point_coords - def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: """Get number of foreground classes based on class and point prompt.""" if class_vector is None: @@ -348,7 +358,7 @@ def forward( point_coords: [B, N, 3] point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. 2/3 means negative/postive ponits for special supported class like tumor. - class_vector: [B, 1], the global class index + class_vector: [B, 1], the global class index. prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] @@ -367,7 +377,9 @@ def forward( val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. """ - labels, prev_mask, point_coords = self.update_slidingwindow_padding(kwargs.get('pad_size', None), labels, prev_mask, point_coords) + labels, prev_mask, point_coords = self.update_slidingwindow_padding( + kwargs.get("pad_size", None), labels, prev_mask, point_coords + ) image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: From 4a069d77f75bcef0080db0aa47adb1a890d5fb3f Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Aug 2024 17:42:29 +0800 Subject: [PATCH 04/32] add tests Signed-off-by: Yiheng Wang --- monai/apps/vista3d/inferer.py | 40 ++++++++++------- tests/test_point_based_window_inferer.py | 56 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 16 deletions(-) create mode 100644 tests/test_point_based_window_inferer.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 538fd64889..7702b167f7 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -33,7 +33,7 @@ def point_based_window_inferer( point_labels: torch.Tensor, class_vector: torch.Tensor | None = None, prompt_class: torch.Tensor | None = None, - prev_mask: torch.Tensor | None = None, + prev_mask: torch.Tensor | MetaTensor | None = None, point_start: int = 0, **kwargs: Any, ) -> torch.Tensor: @@ -59,16 +59,17 @@ def point_based_window_inferer( stitched_output: [1, B, H, W, D]. The value is before sigmoid. Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. """ - assert point_coords.shape[0] == 1, "Only supports single object point click" - image, pad = pad_previous_mask(copy.deepcopy(inputs), roi_size) + if not point_coords.shape[0] == 1: + raise ValueError("Only supports single object point click.") + image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size) point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) - prev_mask = pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None + prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None stitched_output = None center_only = True for p in point_coords[0][point_start:]: - lx_, rx_ = get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5) - ly_, ry_ = get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5) - lz_, rz_ = get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5) + lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5) + ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5) + lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5) for i in range(len(lx_)): for j in range(len(ly_)): for k in range(len(lz_)): @@ -120,12 +121,15 @@ def point_based_window_inferer( prev_mask = prev_mask.to("cpu") # type: ignore # for un-calculated place, use previous mask stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + if isinstance(inputs, torch.Tensor): + inputs = MetaTensor(inputs) if not hasattr(stitched_output, "meta"): - stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) # type: ignore + stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) return stitched_output -def get_window_idx_c(p, roi, s): +def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]: + """Helper function to get the window index.""" if p - roi // 2 < 0: left, right = 0, roi elif p + roi // 2 > s: @@ -135,23 +139,27 @@ def get_window_idx_c(p, roi, s): return left, right -def get_window_idx(p, roi, s, center_only=True, margin=5): - left, right = get_window_idx_c(p, roi, s) +def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]: + """Get the window index.""" + left, right = _get_window_idx_c(p, roi, s) if center_only: return [left], [right] left_most = max(0, p - roi + margin) right_most = min(s, p + roi - margin) - left = [left_most, right_most - roi, left] - right = [left_most + roi, right_most, right] - return left, right + left_list = [left_most, right_most - roi, left] + right_list = [left_most + roi, right_most, right] + return left_list, right_list -def pad_previous_mask(inputs, roi_size, padvalue=0): +def _pad_previous_mask( + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0 +) -> tuple[torch.Tensor | MetaTensor, list[int]]: + """Helper function to pad inputs.""" pad_size = [] for k in range(len(inputs.shape) - 1, 1, -1): diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) if any(pad_size): - inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore return inputs, pad_size diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py new file mode 100644 index 0000000000..5ffef37319 --- /dev/null +++ b/tests/test_point_based_window_inferer.py @@ -0,0 +1,56 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.inferer import point_based_window_inferer +from monai.networks import eval_mode +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +_, has_tqdm = optional_import("tqdm") + +TEST_CASES = [ + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + ] +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestPointBasedWindowInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vista3d(self, vista3d_params, inputs_shape, inferer_params): + vista3d = vista3d132(**vista3d_params).to(device) + with eval_mode(vista3d): + inferer_params["predictor"] = vista3d + inferer_params["inputs"] = torch.randn(*inputs_shape).to(device) + stitched_output = point_based_window_inferer(**inferer_params) + self.assertEqual(stitched_output.shape, inputs_shape) + + +if __name__ == "__main__": + unittest.main() From 519154e2d7de803528491066d1e5a1970da60cb6 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Aug 2024 20:03:06 +0800 Subject: [PATCH 05/32] add vista3d transforms Signed-off-by: Yiheng Wang --- monai/apps/vista3d/transforms.py | 156 +++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 monai/apps/vista3d/transforms.py diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py new file mode 100644 index 0000000000..92b4e5756c --- /dev/null +++ b/monai/apps/vista3d/transforms.py @@ -0,0 +1,156 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence + +import numpy as np + +from monai.config import DtypeLike, KeysCollection +from monai.transforms import MapLabelValue +from monai.transforms.transform import MapTransform +from monai.utils import look_up_option + + +def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: + """get the label name to index mapping""" + name_to_index_mapping = {} + if labels_dict is not None: + name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()} + return name_to_index_mapping + + +def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None: + """convert the label name to index""" + if label_prompt is not None and isinstance(label_prompt, list): + converted_label_prompt = [] + # for new class, add to the mapping + for l in label_prompt: + if isinstance(l, str) and not l.isdigit(): + if l.lower() not in name_to_index_mapping: + name_to_index_mapping[l.lower()] = len(name_to_index_mapping) + for l in label_prompt: + if isinstance(l, (int, str)): + converted_label_prompt.append( + name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l) + ) + else: + converted_label_prompt.append(l) + return converted_label_prompt + return label_prompt + + +class VistaPreTransform(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + special_index: Sequence[int] = (25, 26, 27, 28, 29, 117), + labels_dict: dict | None = None, + subclass: dict | None = None, + ) -> None: + """ + Pre-transform for Vista3d. + + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + special_index: the class index that need to be handled differently. + """ + super().__init__(keys, allow_missing_keys) + self.special_index = special_index + self.subclass = subclass + self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict) + + def __call__(self, data): + label_prompt = data.get("label_prompt", None) + point_labels = data.get("point_labels", None) + # convert the label name to index if needed + label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt) + try: + # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator. + if self.subclass is not None and label_prompt is not None: + _label_prompt = [] + subclass_keys = list(map(int, self.subclass.keys())) + for i in range(len(label_prompt)): + if label_prompt[i] in subclass_keys: + _label_prompt.extend(self.subclass[str(label_prompt[i])]) + else: + _label_prompt.append(label_prompt[i]) + data["label_prompt"] = _label_prompt + + if label_prompt is not None and point_labels is not None: + if label_prompt[0] in self.special_index: + point_labels = np.array(point_labels) + point_labels[point_labels == 0] = 2 + point_labels[point_labels == 1] = 3 + point_labels = point_labels.tolist() + data["point_labels"] = point_labels + except Exception: + pass + + return data + + +class RelabelD(MapTransform): + def __init__( + self, + keys: KeysCollection, + label_mappings: dict[str, list[tuple[int, int]]], + dtype: DtypeLike = np.int16, + dataset_key: str = "dataset_name", + allow_missing_keys: bool = False, + ) -> None: + """ + Remap the voxel labels in the input data dictionary based on the specified mapping. + + This list of local -> global label mappings will be applied to each input `data[keys]`. + if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used. + if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed. + + Args: + keys: keys of the corresponding items to be transformed. + label_mappings: a dictionary specifies how local dataset class indices are mapped to the + global class indices, format: + key: dataset name. + value: list of (local label, global label) pairs. This list of local -> global label mappings + will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, + label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, + no relabeling will be performed. + set `label_mappings={}` to completely skip this transform. + dtype: convert the output data to dtype, default to float32. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mappers = {} + self.dataset_key = dataset_key + for name, mapping in label_mappings.items(): + self.mappers[name] = MapLabelValue( + orig_labels=[int(pair[0]) for pair in mapping], + target_labels=[int(pair[1]) for pair in mapping], + dtype=dtype, + ) + + def __call__(self, data): + d = dict(data) + dataset_name = d.get(self.dataset_key, "default") + _m = look_up_option(dataset_name, self.mappers, default=None) + if _m is None: + return d + for key in self.key_iterator(d): + d[key] = _m(d[key]) + return d From 4b579e13943288c3e1e0bc00cf87b4137e4f5ba7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 16 Aug 2024 20:47:32 +0800 Subject: [PATCH 06/32] update inputs doc string Signed-off-by: Yiheng Wang --- monai/apps/vista3d/inferer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 7702b167f7..cbd3777b80 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -42,7 +42,7 @@ def point_based_window_inferer( Different patches are combined with gaussian weighted weights. Args: - inputs: input image to be processed (assuming NCHW[D]) + inputs: [1CHWD], input image to be processed. roi_size: the spatial window size for inferences. When its components have None or non-positives, the corresponding inputs dimension will be used. if the components of the `roi_size` are non-positive values, the transform will use the From 42d29bedbc5a2ced443c361a6fa318263f2a01c6 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 16 Aug 2024 19:51:51 -0400 Subject: [PATCH 07/32] Add transforms Signed-off-by: heyufan1995 --- monai/apps/vista3d/transforms.py | 50 +++++++++++ monai/networks/nets/vista3d.py | 9 +- monai/transforms/utils.py | 52 +++++++++++- tests/test_vista3d_transforms.py | 138 +++++++++++++++++++++++++++++++ tests/test_vista3d_utils.py | 6 +- 5 files changed, 245 insertions(+), 10 deletions(-) create mode 100644 tests/test_vista3d_transforms.py diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 92b4e5756c..0b7eed03e5 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -14,10 +14,15 @@ from typing import Sequence import numpy as np +import torch + +from collections.abc import Hashable, Mapping from monai.config import DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.transforms import MapLabelValue from monai.transforms.transform import MapTransform +from monai.transforms.utils import keep_components_with_positive_points from monai.utils import look_up_option @@ -103,6 +108,51 @@ def __call__(self, data): return data +class VistaPostTransform(MapTransform): + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + for keys in self.keys: + if keys in data: + pred = data[keys] + object_num = pred.shape[0] + device = pred.device + if data.get("label_prompt", None) is None and data.get("points", None) is not None: + pred = keep_components_with_positive_points( + pred.unsqueeze(0), + point_coords=data.get("points").to(device), + point_labels=data.get("point_labels").to(device), + )[0] + pred[pred < 0] = 0.0 + # if it's multichannel, perform argmax + if object_num > 1: + # concate background channel. Make sure user did not provide 0 as prompt. + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) + pred = pred.argmax(0).unsqueeze(0).float() + 1.0 + pred[is_bk] = 0.0 + else: + # AsDiscrete will remove NaN + # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) + pred[pred > 0] = 1.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 + data[keys] = pred + return data + class RelabelD(MapTransform): def __init__( diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 3b9b2a10fa..e6ba506806 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -23,7 +23,7 @@ from monai.networks.blocks import MLPBlock, UnetrBasicBlock from monai.networks.nets import SegResNetDS2 from monai.transforms.utils import convert_points_to_disc -from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import keep_merge_components_with_points as lcc from monai.transforms.utils import sample_points_from_label from monai.utils import optional_import, unsqueeze_left, unsqueeze_right @@ -346,6 +346,7 @@ def forward( prev_mask: torch.Tensor | None = None, radius: int | None = None, val_point_sampler: Callable | None = None, + transpose: bool = False, **kwargs, ): """ @@ -375,7 +376,8 @@ def forward( radius: single float value controling the gaussian blur when combining point and auto results. The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. - + transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from + sliding window inferer/point inferer. """ labels, prev_mask, point_coords = self.update_slidingwindow_padding( kwargs.get("pad_size", None), labels, prev_mask, point_coords @@ -456,9 +458,10 @@ def forward( point_labels, # type: ignore mapping_index, ) - if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() + if transpose: + logits = logits.transpose(1, 0) return logits diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 363fce91be..f235055e30 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -107,7 +107,8 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", - "get_largest_connected_component_mask_point", + "keep_merge_components_with_points", + "keep_components_with_positive_points", "convert_points_to_disc", "remove_small_objects", "img_bounds", @@ -1178,7 +1179,7 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] -def get_largest_connected_component_mask_point( +def keep_merge_components_with_points( img_pos: NdarrayTensor, img_neg: NdarrayTensor, point_coords: NdarrayTensor, @@ -1188,8 +1189,8 @@ def get_largest_connected_component_mask_point( margins: int = 3, ) -> NdarrayTensor: """ - Gets the connected component of img_pos and img_neg that include the positive points and - negative points separately. The function is used for combining automatic results with interactive + Keep connected regions of img_pos and img_neg that include the positive points and + negative points separately. The function is used for merging automatic results with interactive results in VISTA3D. Args: @@ -1199,6 +1200,7 @@ def get_largest_connected_component_mask_point( neg_val: negative point label values. point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. point_labels: the label of each point, shape [B, N]. + margins: include points outside of the region but within the margin. """ cucim_skimage, has_cucim = optional_import("cucim.skimage") @@ -1248,6 +1250,48 @@ def get_largest_connected_component_mask_point( outs[outs > 1] = 1 return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] +def keep_components_with_positive_points( + img: NdarrayTensor, + point_coords: NdarrayTensor, + point_labels: NdarrayTensor) -> NdarrayTensor: + """ + Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove + regions without positive points. + Args: + img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value. + point_coords: [B, N, 3]. Point click coordinates + point_labels: [B, N]. Point click labels. + """ + outs = torch.zeros_like(img) + for c in range(len(point_coords)): + if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): + # skip if no positive points. + continue + coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() + not_nan_mask = ~torch.isnan(img[0, c]) + img_ = torch.nan_to_num(img[0, c] > 0, 0) + img_, *_ = convert_data_type(img_, np.ndarray) + label = measure.label + features = label(img_, connectivity=3) + pos_mask = torch.from_numpy(img_).to(img.device) > 0 + # if num features less than max desired, nothing to do. + features = torch.from_numpy(features).to(img.device) + # generate a map with all pos points + idx = [] + for p in coords: + idx.append(features[round(p[0]), round(p[1]), round(p[2])].item()) + idx = list(set(idx)) + for i in idx: + if i == 0: + continue + outs[0, c] += features == i + outs = outs > 0 + # find negative mean value + fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean() + img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in + return img + + def convert_points_to_disc( image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py new file mode 100644 index 0000000000..0b42d0c128 --- /dev/null +++ b/tests/test_vista3d_transforms.py @@ -0,0 +1,138 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.utils import ( + convert_points_to_disc, + keep_merge_components_with_points, + sample_points_from_label, +) +from monai.apps.vista3d import ( + VistaPreTransform, + VistaPostTransform, + RelabelD +) +from monai.utils import min_version +from monai.utils.module import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick + +cp, has_cp = optional_import("cupy") +cucim_skimage, has_cucim = optional_import("cucim.skimage") +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TESTS_SAMPLE_POINTS_FROM_LABEL = [] +for use_center in [True, False]: + labels = torch.zeros(1, 1, 32, 32, 32) + labels[0, 0, 5:10, 5:10, 5:10] = 1 + labels[0, 0, 10:15, 10:15, 10:15] = 3 + labels[0, 0, 20:25, 20:25, 20:25] = 5 + TESTS_SAMPLE_POINTS_FROM_LABEL.append( + [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)] + ) + +TEST_CONVERT_POINTS_TO_DISC = [] +for radius in [1, 2]: + for disc in [True, False]: + image_size = (32, 32, 32) + point = torch.randn(3, 1, 3) + point_label = torch.randint(0, 4, (3, 1)) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_LCC_MASK_POINT_TORCH = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 128, 32, 32) + TEST_LCC_MASK_POINT_TORCH.append( + [ + { + "img_pos": torch.randint(0, 2, shape, dtype=torch.bool), + "img_neg": torch.randint(0, 2, shape, dtype=torch.bool), + "point_coords": torch.randint(0, 10, (bs, num_points, 3)), + "point_labels": torch.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + +TEST_LCC_MASK_POINT_NP = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 32, 32, 64) + TEST_LCC_MASK_POINT_NP.append( + [ + { + "img_pos": np.random.randint(0, 2, shape, dtype=bool), + "img_neg": np.random.randint(0, 2, shape, dtype=bool), + "point_coords": np.random.randint(0, 5, (bs, num_points, 3)), + "point_labels": np.random.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestSamplePointsFromLabel(unittest.TestCase): + + @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) + def test_shape(self, input_data, expected_point_shape, expected_point_label_shape): + point, point_label = sample_points_from_label(**input_data) + self.assertEqual(point.shape, expected_point_shape) + self.assertEqual(point_label.shape, expected_point_label_shape) + + +class TestConvertPointsToDisc(unittest.TestCase): + + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC) + def test_shape(self, input_data, expected_shape): + result = convert_points_to_disc(**input_data) + self.assertEqual(result.shape, expected_shape) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): + + @skip_if_quick + @skip_if_no_cuda + @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) + def test_cp_shape(self, input_data, shape): + for key in input_data: + input_data[key] = input_data[key].to(device) + mask = keep_merge_components_with_points(**input_data) + self.assertEqual(mask.shape, shape) + + @skipUnless(has_measure, "skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_NP) + def test_np_shape(self, input_data, shape): + mask = keep_merge_components_with_points(**input_data) + self.assertEqual(mask.shape, shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index a940854d88..93d5757682 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -20,7 +20,7 @@ from monai.transforms.utils import ( convert_points_to_disc, - get_largest_connected_component_mask_point, + keep_merge_components_with_points, sample_points_from_label, ) from monai.utils import min_version @@ -119,13 +119,13 @@ class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): def test_cp_shape(self, input_data, shape): for key in input_data: input_data[key] = input_data[key].to(device) - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape) @skipUnless(has_measure, "skimage required") @parameterized.expand(TEST_LCC_MASK_POINT_NP) def test_np_shape(self, input_data, shape): - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape) From 4144b75621f8299016b33a6c1d1483167e69c319 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 23:52:18 +0000 Subject: [PATCH 08/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/transforms.py | 4 ++-- monai/networks/nets/vista3d.py | 2 +- monai/transforms/utils.py | 6 +++--- tests/test_vista3d_transforms.py | 5 ----- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 0b7eed03e5..d9f51cd6c2 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -19,7 +19,7 @@ from collections.abc import Hashable, Mapping from monai.config import DtypeLike, KeysCollection -from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.transforms import MapLabelValue from monai.transforms.transform import MapTransform from monai.transforms.utils import keep_components_with_positive_points @@ -152,7 +152,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N pred[pred == 0.5] = 0.0 data[keys] = pred return data - + class RelabelD(MapTransform): def __init__( diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index e6ba506806..9148e36542 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -377,7 +377,7 @@ def forward( The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from - sliding window inferer/point inferer. + sliding window inferer/point inferer. """ labels, prev_mask, point_coords = self.update_slidingwindow_padding( kwargs.get("pad_size", None), labels, prev_mask, point_coords diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index f235055e30..1f9988310f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1251,8 +1251,8 @@ def keep_merge_components_with_points( return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] def keep_components_with_positive_points( - img: NdarrayTensor, - point_coords: NdarrayTensor, + img: NdarrayTensor, + point_coords: NdarrayTensor, point_labels: NdarrayTensor) -> NdarrayTensor: """ Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove @@ -1265,7 +1265,7 @@ def keep_components_with_positive_points( outs = torch.zeros_like(img) for c in range(len(point_coords)): if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): - # skip if no positive points. + # skip if no positive points. continue coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() not_nan_mask = ~torch.isnan(img[0, c]) diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 0b42d0c128..93d5757682 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -23,11 +23,6 @@ keep_merge_components_with_points, sample_points_from_label, ) -from monai.apps.vista3d import ( - VistaPreTransform, - VistaPostTransform, - RelabelD -) from monai.utils import min_version from monai.utils.module import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick From 656f212fe9f3bb8b17fb1dd85c26f6ee5f03db8f Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Sun, 18 Aug 2024 13:01:01 -0400 Subject: [PATCH 09/32] Add test Signed-off-by: heyufan1995 --- monai/apps/vista3d/transforms.py | 17 ++-- tests/test_vista3d_transforms.py | 161 +++++++++++++------------------ 2 files changed, 74 insertions(+), 104 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index d9f51cd6c2..91ad7d83da 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -64,15 +64,17 @@ def __init__( subclass: dict | None = None, ) -> None: """ - Pre-transform for Vista3d. - + Pre-transform for Vista3d. It performs two functionalities: + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0,1 to 2,3. + 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. + e.g. "lung" label is converted to ["left lung", "right lung"] Args: - keys: keys of the corresponding items to be transformed. - dataset_transforms: a dictionary specifies the transform for corresponding dataset: - key: dataset name, value: list of data transforms. - dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. - special_index: the class index that need to be handled differently. + special_index: the class index that need to be handled differently. If label_prompt is within special index, + the point label will be converted from 0,1 to 2, 3 for negative/positive points. + subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. """ super().__init__(keys, allow_missing_keys) self.special_index = special_index @@ -95,7 +97,6 @@ def __call__(self, data): else: _label_prompt.append(label_prompt[i]) data["label_prompt"] = _label_prompt - if label_prompt is not None and point_labels is not None: if label_prompt[0] in self.special_index: point_labels = np.array(point_labels) diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 93d5757682..e30a07b062 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -18,10 +18,10 @@ import torch from parameterized import parameterized -from monai.transforms.utils import ( - convert_points_to_disc, - keep_merge_components_with_points, - sample_points_from_label, +from monai.apps.vista3d.transforms import ( + VistaPostTransform, + VistaPreTransform, + RelabelD ) from monai.utils import min_version from monai.utils.module import optional_import @@ -34,99 +34,68 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -TESTS_SAMPLE_POINTS_FROM_LABEL = [] -for use_center in [True, False]: - labels = torch.zeros(1, 1, 32, 32, 32) - labels[0, 0, 5:10, 5:10, 5:10] = 1 - labels[0, 0, 10:15, 10:15, 10:15] = 3 - labels[0, 0, 20:25, 20:25, 20:25] = 5 - TESTS_SAMPLE_POINTS_FROM_LABEL.append( - [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)] - ) - -TEST_CONVERT_POINTS_TO_DISC = [] -for radius in [1, 2]: - for disc in [True, False]: - image_size = (32, 32, 32) - point = torch.randn(3, 1, 3) - point_label = torch.randint(0, 4, (3, 1)) - expected_shape = (point.shape[0], 2, *image_size) - TEST_CONVERT_POINTS_TO_DISC.append( - [ - {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, - expected_shape, - ] +TEST_VISTA_PRETRANSFORM = [ + [ + {"label_prompt":[1], "points": [[0,0,0]], "point_labels": [1]}, + {"label_prompt":[1], "points": [[0,0,0]], "point_labels": [3]}, + ], + [ + {"label_prompt":[2], "points": [[0,0,0]], "point_labels": [0]}, + {"label_prompt":[2], "points": [[0,0,0]], "point_labels": [2]}, + ], + [ + {"label_prompt":[3], "points": [[0,0,0]], "point_labels": [0]}, + {"label_prompt":[4,5], "points": [[0,0,0]], "point_labels": [0]}, + ], + [ + {"label_prompt":[6], "points": [[0,0,0]], "point_labels": [0]}, + {"label_prompt":[7,8], "points": [[0,0,0]], "point_labels": [0]}, + ] +] + + +pred1 = torch.zeros([2,64,64,64]) +pred1[0,:10,:10,:10] = 1 +pred1[1,20:30,20:30,20:30] = 1 +output1 = torch.zeros([1,64,64,64]) +output1[:,:10,:10,:10] = 2 +output1[:,20:30,20:30,20:30] = 3 + +pred2 = torch.zeros([1,64,64,64]) +pred2[:,:10,:10,:10] = 1 +pred2[:,20:30,20:30,20:30] = 1 +output2 = torch.zeros([1,64,64,64]) +output2[:,20:30,20:30,20:30] = 1 + +TEST_VISTA_POSTTRANSFORM = [ + [ + {"pred":pred1, "label_prompt":torch.tensor([2,3])}, + output1 + ], + [ + {"pred":pred2, "points": torch.tensor([[25,25,25]]), "point_labels": torch.tensor([1])}, + output2 + ] +] + + +class TestVistaPreTransform(unittest.TestCase): + @parameterized.expand(TEST_VISTA_PRETRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPreTransform( + keys="image", + subclass = {"3": [4, 5], "6": [7,8]}, + special_index = [1, 2] ) - -TEST_LCC_MASK_POINT_TORCH = [] -for bs in [1, 2]: - for num_points in [1, 3]: - shape = (bs, 1, 128, 32, 32) - TEST_LCC_MASK_POINT_TORCH.append( - [ - { - "img_pos": torch.randint(0, 2, shape, dtype=torch.bool), - "img_neg": torch.randint(0, 2, shape, dtype=torch.bool), - "point_coords": torch.randint(0, 10, (bs, num_points, 3)), - "point_labels": torch.randint(0, 4, (bs, num_points)), - }, - shape, - ] - ) - -TEST_LCC_MASK_POINT_NP = [] -for bs in [1, 2]: - for num_points in [1, 3]: - shape = (bs, 1, 32, 32, 64) - TEST_LCC_MASK_POINT_NP.append( - [ - { - "img_pos": np.random.randint(0, 2, shape, dtype=bool), - "img_neg": np.random.randint(0, 2, shape, dtype=bool), - "point_coords": np.random.randint(0, 5, (bs, num_points, 3)), - "point_labels": np.random.randint(0, 4, (bs, num_points)), - }, - shape, - ] - ) - - -@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") -class TestSamplePointsFromLabel(unittest.TestCase): - - @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) - def test_shape(self, input_data, expected_point_shape, expected_point_label_shape): - point, point_label = sample_points_from_label(**input_data) - self.assertEqual(point.shape, expected_point_shape) - self.assertEqual(point_label.shape, expected_point_label_shape) - - -class TestConvertPointsToDisc(unittest.TestCase): - - @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC) - def test_shape(self, input_data, expected_shape): - result = convert_points_to_disc(**input_data) - self.assertEqual(result.shape, expected_shape) - - -@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") -class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): - - @skip_if_quick - @skip_if_no_cuda - @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") - @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) - def test_cp_shape(self, input_data, shape): - for key in input_data: - input_data[key] = input_data[key].to(device) - mask = keep_merge_components_with_points(**input_data) - self.assertEqual(mask.shape, shape) - - @skipUnless(has_measure, "skimage required") - @parameterized.expand(TEST_LCC_MASK_POINT_NP) - def test_np_shape(self, input_data, shape): - mask = keep_merge_components_with_points(**input_data) - self.assertEqual(mask.shape, shape) + result = transform(input_data) + self.assertEqual(result, expected) + +class TestVistaPostTransform(unittest.TestCase): + @parameterized.expand(TEST_VISTA_POSTTRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPostTransform(keys="pred") + result = transform(input_data) + self.assertEqual((result['pred'] == expected).all(), True) if __name__ == "__main__": From 8b647fde5befad3dc03365847983e3cc52693d2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 17:01:29 +0000 Subject: [PATCH 10/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/transforms.py | 12 ++++++------ tests/test_vista3d_transforms.py | 10 +++------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 91ad7d83da..5522b9c300 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -64,15 +64,15 @@ def __init__( subclass: dict | None = None, ) -> None: """ - Pre-transform for Vista3d. It performs two functionalities: - 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), - convert point labels from 0,1 to 2,3. + Pre-transform for Vista3d. It performs two functionalities: + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0,1 to 2,3. 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. - e.g. "lung" label is converted to ["left lung", "right lung"] + e.g. "lung" label is converted to ["left lung", "right lung"] Args: - keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. + keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. - special_index: the class index that need to be handled differently. If label_prompt is within special index, + special_index: the class index that need to be handled differently. If label_prompt is within special index, the point label will be converted from 0,1 to 2, 3 for negative/positive points. subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. """ diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index e30a07b062..b4fef1f406 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -12,20 +12,16 @@ from __future__ import annotations import unittest -from unittest.case import skipUnless -import numpy as np import torch from parameterized import parameterized from monai.apps.vista3d.transforms import ( VistaPostTransform, - VistaPreTransform, - RelabelD + VistaPreTransform ) from monai.utils import min_version from monai.utils.module import optional_import -from tests.utils import skip_if_no_cuda, skip_if_quick cp, has_cp = optional_import("cupy") cucim_skimage, has_cucim = optional_import("cucim.skimage") @@ -74,8 +70,8 @@ ], [ {"pred":pred2, "points": torch.tensor([[25,25,25]]), "point_labels": torch.tensor([1])}, - output2 - ] + output2 + ] ] From 757ea61a2bcf6a52560212991e388a16724c3fbf Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Sun, 18 Aug 2024 14:12:56 -0400 Subject: [PATCH 11/32] Add more test Signed-off-by: heyufan1995 --- monai/apps/vista3d/inferer.py | 5 ++++- tests/test_point_based_window_inferer.py | 21 +++++++++++++++++++++ tests/test_vista3d_transforms.py | 3 ++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index cbd3777b80..78eeafa951 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -53,7 +53,10 @@ def point_based_window_inferer( The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. point_coords: [B, N, 3] point_labels: [B, N] - class_vector: [B] + class_vector: [B]. Used for class-head automatic segmentation. Can be None value. + prompt_class: [B]. The same as class_vector representing the point class and inform point head about + supported class or zeroshot, not used for automatic segmentation. If None, point head is default + to supported class segmentation. prev_mask: [1, B, H, W, D]. The value is before sigmoid. Returns: stitched_output: [1, B, H, W, D]. The value is before sigmoid. diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py index 5ffef37319..a730337052 100644 --- a/tests/test_point_based_window_inferer.py +++ b/tests/test_point_based_window_inferer.py @@ -35,6 +35,27 @@ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), "point_labels": torch.tensor([[1, 0]], device=device), }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device) + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + "point_start": 1 + }, ] ] diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index b4fef1f406..ed9bf3e3dc 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -57,7 +57,8 @@ output1[:,:10,:10,:10] = 2 output1[:,20:30,20:30,20:30] = 3 -pred2 = torch.zeros([1,64,64,64]) +# -1 is needed since pred should be before sigmoid. +pred2 = torch.zeros([1,64,64,64]) - 1 pred2[:,:10,:10,:10] = 1 pred2[:,20:30,20:30,20:30] = 1 output2 = torch.zeros([1,64,64,64]) From f44e9df69509093ce4ce3ea967b6ba230b24fa5d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 19 Aug 2024 15:26:57 +0800 Subject: [PATCH 12/32] fix issues Signed-off-by: Yiheng Wang --- docs/source/apps.rst | 9 ++- monai/apps/vista3d/transforms.py | 19 ++++--- monai/transforms/utils.py | 11 ++-- tests/min_tests.py | 1 + tests/test_point_based_window_inferer.py | 6 +- tests/test_vista3d_transforms.py | 71 +++++++++++------------- tests/test_vista3d_utils.py | 8 +-- 7 files changed, 64 insertions(+), 61 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index da4c70177b..734dc5517e 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -250,7 +250,14 @@ FastMRIReader `Vista3d` --------- -.. autofunction:: monai.apps.vista3d.inferer.point_based_window_inferer +.. automodule:: monai.apps.vista3d.inferer +.. autofunction:: point_based_window_inferer + +.. automodule:: monai.apps.vista3d.transforms +.. autoclass:: VistaPreTransform + :members: +.. autoclass:: VistaPostTransform + :members: `Auto3DSeg` ----------- diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 5522b9c300..13306b3bce 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -16,10 +16,7 @@ import numpy as np import torch -from collections.abc import Hashable, Mapping - from monai.config import DtypeLike, KeysCollection -from monai.config.type_definitions import NdarrayOrTensor from monai.transforms import MapLabelValue from monai.transforms.transform import MapTransform from monai.transforms.utils import keep_components_with_positive_points @@ -64,11 +61,14 @@ def __init__( subclass: dict | None = None, ) -> None: """ - Pre-transform for Vista3d. It performs two functionalities: - 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), - convert point labels from 0,1 to 2,3. + Pre-transform for Vista3d. + + It performs two functionalities: + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0,1 to 2,3. 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. - e.g. "lung" label is converted to ["left lung", "right lung"] + e.g. "lung" label is converted to ["left lung", "right lung"] + Args: keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. @@ -109,9 +109,12 @@ def __call__(self, data): return data + class VistaPostTransform(MapTransform): def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ + Post-transform for Vista3d. + Args: keys: keys of the corresponding items to be transformed. dataset_transforms: a dictionary specifies the transform for corresponding dataset: @@ -122,7 +125,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No """ super().__init__(keys, allow_missing_keys) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + def __call__(self, data): for keys in self.keys: if keys in data: pred = data[keys] diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1f9988310f..305460ec9e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1250,10 +1250,10 @@ def keep_merge_components_with_points( outs[outs > 1] = 1 return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + def keep_components_with_positive_points( - img: NdarrayTensor, - point_coords: NdarrayTensor, - point_labels: NdarrayTensor) -> NdarrayTensor: + img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor +) -> torch.Tensor: """ Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove regions without positive points. @@ -1262,6 +1262,8 @@ def keep_components_with_positive_points( point_coords: [B, N, 3]. Point click coordinates point_labels: [B, N]. Point click labels. """ + if not has_measure: + raise RuntimeError("skimage.measure required.") outs = torch.zeros_like(img) for c in range(len(point_coords)): if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): @@ -1270,7 +1272,7 @@ def keep_components_with_positive_points( coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() not_nan_mask = ~torch.isnan(img[0, c]) img_ = torch.nan_to_num(img[0, c] > 0, 0) - img_, *_ = convert_data_type(img_, np.ndarray) + img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore label = measure.label features = label(img_, connectivity=3) pos_mask = torch.from_numpy(img_).to(img.device) > 0 @@ -1292,7 +1294,6 @@ def keep_components_with_positive_points( return img - def convert_points_to_disc( image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False ): diff --git a/tests/min_tests.py b/tests/min_tests.py index 479c4c8dc2..f80d06f5d3 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -210,6 +210,7 @@ def run_testsuit(): "test_perceptual_loss", "test_ultrasound_confidence_map_transform", "test_vista3d_utils", + "test_vista3d_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py index a730337052..1b293288c4 100644 --- a/tests/test_point_based_window_inferer.py +++ b/tests/test_point_based_window_inferer.py @@ -43,7 +43,7 @@ "roi_size": [32, 32, 32], "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), "point_labels": torch.tensor([[1, 0]], device=device), - "class_vector": torch.tensor([1], device=device) + "class_vector": torch.tensor([1], device=device), }, ], [ @@ -54,9 +54,9 @@ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), "point_labels": torch.tensor([[1, 0]], device=device), "class_vector": torch.tensor([1], device=device), - "point_start": 1 + "point_start": 1, }, - ] + ], ] diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index ed9bf3e3dc..38ae64c341 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -12,19 +12,15 @@ from __future__ import annotations import unittest +from unittest.case import skipUnless import torch from parameterized import parameterized -from monai.apps.vista3d.transforms import ( - VistaPostTransform, - VistaPreTransform -) +from monai.apps.vista3d.transforms import VistaPostTransform, VistaPreTransform from monai.utils import min_version from monai.utils.module import optional_import -cp, has_cp = optional_import("cupy") -cucim_skimage, has_cucim = optional_import("cucim.skimage") measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -32,67 +28,66 @@ TEST_VISTA_PRETRANSFORM = [ [ - {"label_prompt":[1], "points": [[0,0,0]], "point_labels": [1]}, - {"label_prompt":[1], "points": [[0,0,0]], "point_labels": [3]}, + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]}, + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]}, ], [ - {"label_prompt":[2], "points": [[0,0,0]], "point_labels": [0]}, - {"label_prompt":[2], "points": [[0,0,0]], "point_labels": [2]}, + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]}, ], [ - {"label_prompt":[3], "points": [[0,0,0]], "point_labels": [0]}, - {"label_prompt":[4,5], "points": [[0,0,0]], "point_labels": [0]}, + {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]}, ], [ - {"label_prompt":[6], "points": [[0,0,0]], "point_labels": [0]}, - {"label_prompt":[7,8], "points": [[0,0,0]], "point_labels": [0]}, - ] + {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]}, + ], ] -pred1 = torch.zeros([2,64,64,64]) -pred1[0,:10,:10,:10] = 1 -pred1[1,20:30,20:30,20:30] = 1 -output1 = torch.zeros([1,64,64,64]) -output1[:,:10,:10,:10] = 2 -output1[:,20:30,20:30,20:30] = 3 +pred1 = torch.zeros([2, 64, 64, 64]) +pred1[0, :10, :10, :10] = 1 +pred1[1, 20:30, 20:30, 20:30] = 1 +output1 = torch.zeros([1, 64, 64, 64]) +output1[:, :10, :10, :10] = 2 +output1[:, 20:30, 20:30, 20:30] = 3 # -1 is needed since pred should be before sigmoid. -pred2 = torch.zeros([1,64,64,64]) - 1 -pred2[:,:10,:10,:10] = 1 -pred2[:,20:30,20:30,20:30] = 1 -output2 = torch.zeros([1,64,64,64]) -output2[:,20:30,20:30,20:30] = 1 +pred2 = torch.zeros([1, 64, 64, 64]) - 1 +pred2[:, :10, :10, :10] = 1 +pred2[:, 20:30, 20:30, 20:30] = 1 +output2 = torch.zeros([1, 64, 64, 64]) +output2[:, 20:30, 20:30, 20:30] = 1 TEST_VISTA_POSTTRANSFORM = [ + [{"pred": pred1, "label_prompt": torch.tensor([2, 3]).to(device)}, output1], [ - {"pred":pred1, "label_prompt":torch.tensor([2,3])}, - output1 + { + "pred": pred2, + "points": torch.tensor([[25, 25, 25]]).to(device), + "point_labels": torch.tensor([1]).to(device), + }, + output2, ], - [ - {"pred":pred2, "points": torch.tensor([[25,25,25]]), "point_labels": torch.tensor([1])}, - output2 - ] ] class TestVistaPreTransform(unittest.TestCase): @parameterized.expand(TEST_VISTA_PRETRANSFORM) def test_result(self, input_data, expected): - transform = VistaPreTransform( - keys="image", - subclass = {"3": [4, 5], "6": [7,8]}, - special_index = [1, 2] - ) + transform = VistaPreTransform(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) result = transform(input_data) self.assertEqual(result, expected) + +@skipUnless(has_measure, "skimage.measure required") class TestVistaPostTransform(unittest.TestCase): @parameterized.expand(TEST_VISTA_POSTTRANSFORM) def test_result(self, input_data, expected): transform = VistaPostTransform(keys="pred") result = transform(input_data) - self.assertEqual((result['pred'] == expected).all(), True) + self.assertEqual((result["pred"] == expected).all(), True) if __name__ == "__main__": diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 93d5757682..6fe61682fb 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -18,11 +18,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils import ( - convert_points_to_disc, - keep_merge_components_with_points, - sample_points_from_label, -) +from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label from monai.utils import min_version from monai.utils.module import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick @@ -110,7 +106,7 @@ def test_shape(self, input_data, expected_shape): @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") -class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): +class TestKeepMergeComponentsWithPoints(unittest.TestCase): @skip_if_quick @skip_if_no_cuda From 81a09841daf6359d974948f0d87c8d9e6d12847c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 07:27:22 +0000 Subject: [PATCH 13/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 13306b3bce..d890997b2a 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -62,9 +62,9 @@ def __init__( ) -> None: """ Pre-transform for Vista3d. - + It performs two functionalities: - 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), convert point labels from 0,1 to 2,3. 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"] From 255a96e9ee02b4b6e1ed752e9ead2ae038f79168 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 19 Aug 2024 10:04:11 -0400 Subject: [PATCH 14/32] Change docstring Signed-off-by: heyufan1995 --- monai/apps/vista3d/inferer.py | 27 +++++++++++++++++---------- monai/apps/vista3d/transforms.py | 10 ++++++---- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 78eeafa951..6dd47a2cea 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -35,11 +35,14 @@ def point_based_window_inferer( prompt_class: torch.Tensor | None = None, prev_mask: torch.Tensor | MetaTensor | None = None, point_start: int = 0, + center_only: bool = True, + margin: int = 5, **kwargs: Any, ) -> torch.Tensor: """ - Point based window inferer, crop a patch centered at the point, and perform inference. - Different patches are combined with gaussian weighted weights. + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + patch inference and average output stitching, and finally returns the segmented mask. Args: inputs: [1CHWD], input image to be processed. @@ -49,15 +52,20 @@ def point_based_window_inferer( corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. - predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output. - The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. - point_coords: [B, N, 3] - point_labels: [B, N] + predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. + Add transpose=True in kwargs for vista3d. + point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. + point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. + 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel). class_vector: [B]. Used for class-head automatic segmentation. Can be None value. prompt_class: [B]. The same as class_vector representing the point class and inform point head about supported class or zeroshot, not used for automatic segmentation. If None, point head is default to supported class segmentation. prev_mask: [1, B, H, W, D]. The value is before sigmoid. + point_start: only use points starting from this number. All points before this number is used to generate + prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. + center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. + margin: if center_only is false, this value is the distance between point to the patch boundary. Returns: stitched_output: [1, B, H, W, D]. The value is before sigmoid. Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. @@ -68,11 +76,10 @@ def point_based_window_inferer( point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None stitched_output = None - center_only = True for p in point_coords[0][point_start:]: - lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5) - ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5) - lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5) + lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin) + ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin) + lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin) for i in range(len(lx_)): for j in range(len(ly_)): for k in range(len(lz_)): diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index d890997b2a..3b89bc9872 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -15,6 +15,7 @@ import numpy as np import torch +import warnings from monai.config import DtypeLike, KeysCollection from monai.transforms import MapLabelValue @@ -22,6 +23,7 @@ from monai.transforms.utils import keep_components_with_positive_points from monai.utils import look_up_option +__all__ = ["VistaPreTransform", "VistaPostTransformd", "RelabelD"] def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: """get the label name to index mapping""" @@ -65,7 +67,7 @@ def __init__( It performs two functionalities: 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), - convert point labels from 0,1 to 2,3. + convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"] @@ -73,7 +75,7 @@ def __init__( keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. special_index: the class index that need to be handled differently. If label_prompt is within special index, - the point label will be converted from 0,1 to 2, 3 for negative/positive points. + the point label will be converted from 0, 1 to 2, 3 for negative/positive points. subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. """ super().__init__(keys, allow_missing_keys) @@ -105,12 +107,12 @@ def __call__(self, data): point_labels = point_labels.tolist() data["point_labels"] = point_labels except Exception: - pass + warnings.warn("VistaPreTransform failed to transform label prompt or point labels.") return data -class VistaPostTransform(MapTransform): +class VistaPostTransformd(MapTransform): def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ Post-transform for Vista3d. From c9979d7391ee5cc57343ecaf89c64408fd2e565d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 14:04:57 +0000 Subject: [PATCH 15/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/inferer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 6dd47a2cea..8f4aa8bd53 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -40,8 +40,8 @@ def point_based_window_inferer( **kwargs: Any, ) -> torch.Tensor: """ - Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. - The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by patch inference and average output stitching, and finally returns the segmented mask. Args: @@ -53,7 +53,7 @@ def point_based_window_inferer( to `(32, 64)` if the second spatial dimension size of img is `64`. sw_batch_size: the batch size to run window slices. predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. - Add transpose=True in kwargs for vista3d. + Add transpose=True in kwargs for vista3d. point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel). From 72aa299ad386e8a46da9a81b3f6413b92b7c02de Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 20 Aug 2024 09:56:06 +0800 Subject: [PATCH 16/32] resolve comments Signed-off-by: Yiheng Wang --- docs/source/apps.rst | 6 ++++-- monai/apps/vista3d/transforms.py | 11 ++++++----- tests/test_vista3d_transforms.py | 10 +++++----- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 734dc5517e..9a6bf09c4b 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -254,9 +254,11 @@ FastMRIReader .. autofunction:: point_based_window_inferer .. automodule:: monai.apps.vista3d.transforms -.. autoclass:: VistaPreTransform +.. autoclass:: VistaPreTransformd :members: -.. autoclass:: VistaPostTransform +.. autoclass:: VistaPostTransformd + :members: +.. autoclass:: Relabeld :members: `Auto3DSeg` diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 3b89bc9872..3987e4ccf2 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -11,11 +11,11 @@ from __future__ import annotations +import warnings from typing import Sequence import numpy as np import torch -import warnings from monai.config import DtypeLike, KeysCollection from monai.transforms import MapLabelValue @@ -23,7 +23,8 @@ from monai.transforms.utils import keep_components_with_positive_points from monai.utils import look_up_option -__all__ = ["VistaPreTransform", "VistaPostTransformd", "RelabelD"] +__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"] + def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: """get the label name to index mapping""" @@ -53,7 +54,7 @@ def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | Non return label_prompt -class VistaPreTransform(MapTransform): +class VistaPreTransformd(MapTransform): def __init__( self, keys: KeysCollection, @@ -107,7 +108,7 @@ def __call__(self, data): point_labels = point_labels.tolist() data["point_labels"] = point_labels except Exception: - warnings.warn("VistaPreTransform failed to transform label prompt or point labels.") + warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") return data @@ -160,7 +161,7 @@ def __call__(self, data): return data -class RelabelD(MapTransform): +class Relabeld(MapTransform): def __init__( self, keys: KeysCollection, diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 38ae64c341..4a647a8d1e 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.apps.vista3d.transforms import VistaPostTransform, VistaPreTransform +from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd from monai.utils import min_version from monai.utils.module import optional_import @@ -73,19 +73,19 @@ ] -class TestVistaPreTransform(unittest.TestCase): +class TestVistaPreTransformd(unittest.TestCase): @parameterized.expand(TEST_VISTA_PRETRANSFORM) def test_result(self, input_data, expected): - transform = VistaPreTransform(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) + transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) result = transform(input_data) self.assertEqual(result, expected) @skipUnless(has_measure, "skimage.measure required") -class TestVistaPostTransform(unittest.TestCase): +class TestVistaPostTransformd(unittest.TestCase): @parameterized.expand(TEST_VISTA_POSTTRANSFORM) def test_result(self, input_data, expected): - transform = VistaPostTransform(keys="pred") + transform = VistaPostTransformd(keys="pred") result = transform(input_data) self.assertEqual((result["pred"] == expected).all(), True) From fe29e863ae26e3886470bbc710fdf2f5c9e17ed8 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 20 Aug 2024 13:52:35 +0800 Subject: [PATCH 17/32] fix doc issue Signed-off-by: Yiheng Wang --- monai/apps/vista3d/transforms.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 3987e4ccf2..846c4bc35d 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -180,13 +180,11 @@ def __init__( Args: keys: keys of the corresponding items to be transformed. label_mappings: a dictionary specifies how local dataset class indices are mapped to the - global class indices, format: - key: dataset name. - value: list of (local label, global label) pairs. This list of local -> global label mappings - will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, - label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, - no relabeling will be performed. - set `label_mappings={}` to completely skip this transform. + global class indices. The dictionary keys are dataset names and the values are lists of + list of (local label, global label) pairs. This list of local -> global label mappings + will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, + label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, + no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform. dtype: convert the output data to dtype, default to float32. dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". allow_missing_keys: don't raise exception if key is missing. From b1e1822c71cd6aad0a5b78cfb7e6c8bd1eaff7fe Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 20 Aug 2024 10:57:58 -0400 Subject: [PATCH 18/32] Address docstring issue --- monai/apps/vista3d/inferer.py | 6 +++--- monai/apps/vista3d/transforms.py | 14 +++++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 8f4aa8bd53..b9db0818bc 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -40,8 +40,8 @@ def point_based_window_inferer( **kwargs: Any, ) -> torch.Tensor: """ - Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. - The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by patch inference and average output stitching, and finally returns the segmented mask. Args: @@ -61,7 +61,7 @@ def point_based_window_inferer( prompt_class: [B]. The same as class_vector representing the point class and inform point head about supported class or zeroshot, not used for automatic segmentation. If None, point head is default to supported class segmentation. - prev_mask: [1, B, H, W, D]. The value is before sigmoid. + prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks. point_start: only use points starting from this number. All points before this number is used to generate prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 846c4bc35d..f9157033ad 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -70,8 +70,9 @@ def __init__( 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. - e.g. "lung" label is converted to ["left lung", "right lung"] - + e.g. "lung" label is converted to ["left lung", "right lung"]. + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, where each element is a int values of length [B, N]. + Args: keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. @@ -108,6 +109,9 @@ def __call__(self, data): point_labels = point_labels.tolist() data["point_labels"] = point_labels except Exception: + # There is specific requirements for `label_prompt` and `point_labels`. + # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. + # Those formatting errors should be captured later. warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") return data @@ -116,7 +120,11 @@ def __call__(self, data): class VistaPostTransformd(MapTransform): def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: """ - Post-transform for Vista3d. + Post-transform for Vista3d. It converts the model output logits into final segmentation masks. + If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...], + else the indexes will be [0, label_prompt[0], label_prompt[1], ...]. + If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove + regions that does not contain positive points. Args: keys: keys of the corresponding items to be transformed. From 879f1f82482491b10bdf2a45c92953c3fae6fb64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:59:06 +0000 Subject: [PATCH 19/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/inferer.py | 4 ++-- monai/apps/vista3d/transforms.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index b9db0818bc..6c3372550c 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -40,8 +40,8 @@ def point_based_window_inferer( **kwargs: Any, ) -> torch.Tensor: """ - Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. - The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by patch inference and average output stitching, and finally returns the segmented mask. Args: diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index f9157033ad..77202f4718 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -72,7 +72,7 @@ def __init__( 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"]. The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, where each element is a int values of length [B, N]. - + Args: keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. @@ -109,9 +109,9 @@ def __call__(self, data): point_labels = point_labels.tolist() data["point_labels"] = point_labels except Exception: - # There is specific requirements for `label_prompt` and `point_labels`. - # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. - # Those formatting errors should be captured later. + # There is specific requirements for `label_prompt` and `point_labels`. + # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. + # Those formatting errors should be captured later. warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") return data @@ -124,7 +124,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...], else the indexes will be [0, label_prompt[0], label_prompt[1], ...]. If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove - regions that does not contain positive points. + regions that does not contain positive points. Args: keys: keys of the corresponding items to be transformed. From cb7446b5db21dee259cd7c5b522ef2d9bf521965 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 20 Aug 2024 11:09:59 -0400 Subject: [PATCH 20/32] Update docstring Signed-off-by: heyufan1995 --- monai/apps/vista3d/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index f9157033ad..c5d471ee70 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -65,20 +65,20 @@ def __init__( ) -> None: """ Pre-transform for Vista3d. - It performs two functionalities: 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"]. The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, where each element is a int values of length [B, N]. - + Args: keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. allow_missing_keys: don't raise exception if key is missing. special_index: the class index that need to be handled differently. If label_prompt is within special index, the point label will be converted from 0, 1 to 2, 3 for negative/positive points. subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. + """ super().__init__(keys, allow_missing_keys) self.special_index = special_index From 144b753adaeb2e1efec115b3a93e92c8c9b6c26e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:12:36 +0000 Subject: [PATCH 21/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 1c85a1a6fa..23155157ec 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -78,7 +78,7 @@ def __init__( special_index: the class index that need to be handled differently. If label_prompt is within special index, the point label will be converted from 0, 1 to 2, 3 for negative/positive points. subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. - + """ super().__init__(keys, allow_missing_keys) self.special_index = special_index From 7127c109720dee8d55180233ee9995534931fed2 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 20 Aug 2024 11:56:22 -0400 Subject: [PATCH 22/32] Update docstring Signed-off-by: heyufan1995 --- monai/apps/vista3d/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 1c85a1a6fa..fb147f3ebc 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -70,7 +70,8 @@ def __init__( convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"]. - The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, where each element is a int values of length [B, N]. + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + where each element is a int values of length [B, N]. Args: keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. From 2e21d89091e79a78547f8f0f5eab5ed867dc5f5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:58:18 +0000 Subject: [PATCH 23/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index adae3d94c9..169d9efa59 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -70,7 +70,7 @@ def __init__( convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. e.g. "lung" label is converted to ["left lung", "right lung"]. - The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, where each element is a int values of length [B, N]. Args: From c0e20b57f036334bcb7869f323d383768ff83141 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Wed, 21 Aug 2024 11:30:58 +0800 Subject: [PATCH 24/32] fix doc issue Signed-off-by: Yiheng Wang --- monai/apps/vista3d/transforms.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 169d9efa59..972c94d39d 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -65,21 +65,23 @@ def __init__( ) -> None: """ Pre-transform for Vista3d. + It performs two functionalities: - 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), - convert point labels from 0 (negative), 1 (positive) to special 2 (negative),3 (positive). - 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. - e.g. "lung" label is converted to ["left lung", "right lung"]. - The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, - where each element is a int values of length [B, N]. + + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive). + + 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. + e.g. "lung" label is converted to ["left lung", "right lung"]. + + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + where each element is an int value of length [B, N]. Args: - keys: keys of the corresponding items to be transformed. Not used by the transform but kept here for formatting. + keys: keys of the corresponding items to be transformed. + special_index: the index that defines the special class. + subclass: a dictionary that maps a label prompt to its subclasses. allow_missing_keys: don't raise exception if key is missing. - special_index: the class index that need to be handled differently. If label_prompt is within special index, - the point label will be converted from 0, 1 to 2, 3 for negative/positive points. - subclass: if label_prompt is in subclass keys, the label_prompt will be converted to the subclasses defined in the dict. - """ super().__init__(keys, allow_missing_keys) self.special_index = special_index From c8e1e4450ff2a6ce47c0e8b9dc26810aca63a0ee Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Wed, 21 Aug 2024 17:00:01 -0400 Subject: [PATCH 25/32] Add generate_prompt_pairs Signed-off-by: heyufan1995 --- monai/apps/vista3d/workflow.py | 164 +++++++++++++++++++++++++++++++ tests/test_vista3d_transforms.py | 32 ++++++ 2 files changed, 196 insertions(+) create mode 100644 monai/apps/vista3d/workflow.py diff --git a/monai/apps/vista3d/workflow.py b/monai/apps/vista3d/workflow.py new file mode 100644 index 0000000000..688e79b7a4 --- /dev/null +++ b/monai/apps/vista3d/workflow.py @@ -0,0 +1,164 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random + +import monai +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from collections.abc import Callable, Sequence + +__all__ = ["generate_prompt_pairs"] + +ENABLE_SPECIAL = True +SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) +MERGE_LIST = { + 1: [25, 26], # hepatic tumor and vessel merge into liver + 4: [24], # pancreatic tumor merge into pancreas + 132: [57], # overlap with trachea merge into airway +} + + +def get_point_label(id): + # [B, N] + if id in SPECIAL_INDEX and ENABLE_SPECIAL: + return 2, 3 + else: + return 0, 1 + +def generate_prompt_pairs( + labels: Tensor, + label_set: Sequence[int] | None = None, + max_prompt: int | None = None, + max_foreprompt: int | None = None, + max_backprompt: int = 1, + max_point: int = 20, + include_background: bool = False, + drop_label_prob: float = 0.2, + drop_point_prob: float = 0.2, + point_sampler: Callable | None = None +): + """ Sample training pairs for VISTA3D training. + Args: + labels: [1, 1, H, W, D], ground truth labels. + label_set: the label list for the specific dataset. + max_prompt: int, max number of total prompt, including foreground and background. + max_foreprompt: int, max number of prompt from foreground. + max_backprompt: int, max number of prompt from background. + max_point: maximum number of points for each object. + include_background: if include label=0 into training prompt. May casue issue in partial label + trainig. + drop_label_prob: probablity to drop label prompt. + drop_point_prob: probablity to drop point prompt. + point_sampler: sampler to augment masks with supervoxel. + Returns: + label_prompt: [B, 1]. The classes used for training automatic segmentation + point: [B, N, 3]. The corresponding points for each class. Note that background label prompt + requires matching point as well ([0,0,0] is used). + point_label: [B, N]. The corresponding point labels for each point (negative or positive). + -1 is used for padding the background label prompt and will be ignored. + prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. + label_prompt can be None, and prompt_class is used to identify point classess. + """ + # class label number + assert labels.shape[0] == 1, "only support batch size 1" + labels = labels[0, 0] + device = labels.device + unique_labels = labels.unique().cpu().numpy().tolist() + if include_background: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) + else: + unique_labels = list( + set(unique_labels) - (set(unique_labels) - set(label_set)) - set([0]) + ) + background_labels = list(set(label_set) - set(unique_labels)) + # during training, balance background and foreground prompts + if max_backprompt is not None: + if len(background_labels) > max_backprompt: + random.shuffle(background_labels) + background_labels = background_labels[:max_backprompt] + + if max_foreprompt is not None: + if len(unique_labels) > max_foreprompt: + random.shuffle(unique_labels) + unique_labels = unique_labels[:max_foreprompt] + + if max_prompt is not None: + if len(unique_labels) + len(background_labels) > max_prompt: + if len(unique_labels) > max_prompt: + unique_labels = random.sample(unique_labels, max_prompt) + background_labels = [] + else: + background_labels = random.sample( + background_labels, max_prompt - len(unique_labels) + ) + _point = [] + _point_label = [] + # if use regular sampling + if point_sampler is None: + num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) + num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) + for id in unique_labels: + neg_id, pos_id = get_point_label(id) + plabels = labels == int(id) + nlabels = ~plabels + plabelpoints = torch.nonzero(plabels) + nlabelpoints = torch.nonzero(nlabels) + # final sampled positive points + num_pa = min(len(plabelpoints), num_p) + # final sampled negative points + num_na = min(len(nlabelpoints), num_n) + _point.append( + torch.stack( + random.choices(plabelpoints, k=num_pa) + random.choices(nlabelpoints, k=num_na) + + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na) + ) + ) + _point_label.append( + torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na) + ).to(device) + ) + for id in background_labels: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + else: + _point, _point_label = point_sampler(unique_labels, Np=max_point, Nn=0) + for id in background_labels: + # pad the background labels + _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 + _point_label.append( + torch.zeros(len(_point_label[0])).to(device) - 1 + ) # -1 not a point + if len(unique_labels) == 0 and len(background_labels) == 0: + # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must + # be skipped. Handle this in trainer. + label_prompt, point, point_label, prompt_class = None, None, None, None + else: + label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # If label prompt is dropped, there is no need to pad with points with label -1. + pad = len(background_labels) + point = point[: len(point) - pad] + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None + return label_prompt, point, point_label, prompt_class diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 4a647a8d1e..1763b76654 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -18,6 +18,8 @@ from parameterized import parameterized from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd +from monai.apps.vista3d.workflow import generate_prompt_pairs + from monai.utils import min_version from monai.utils.module import optional_import @@ -72,6 +74,28 @@ ], ] +label = torch.zeros([1,1,64,64,64]) +label[:, :, :10, :10, :10] = 1 +label[:, :, 20:30, 20:30, 20:30] = 2 +label[:, :, 30:40, 30:40, 30:40] = 3 +label1 = torch.zeros([1,1,64,64,64]) +TEST_VISTA_GENERATEPROMPT = [ + [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, + "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":0}, [4, 4, 4, 4] + ], + [{"labels": label, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [2, None, None, 2] + ], + [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, + "max_backprompt":1, "drop_label_prob":1, "drop_point_prob":0}, [None, 3, 3, 3] + ], + [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [1, None, None, 1] + ], + [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + "max_backprompt":0, "drop_label_prob":0, "drop_point_prob":1}, [None, None, None, None] + ] +] class TestVistaPreTransformd(unittest.TestCase): @parameterized.expand(TEST_VISTA_PRETRANSFORM) @@ -90,5 +114,13 @@ def test_result(self, input_data, expected): self.assertEqual((result["pred"] == expected).all(), True) +class TestGeneratePrompt(unittest.TestCase): + @parameterized.expand(TEST_VISTA_GENERATEPROMPT) + def test_result(self, input_data, expected): + output = generate_prompt_pairs(**input_data) + result = [i.shape[0] if i is not None else None for i in output] + self.assertEqual(result, expected) + + if __name__ == "__main__": unittest.main() From 729d235a6281f0b9eb7379092422786446f512d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 21:00:29 +0000 Subject: [PATCH 26/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/workflow.py | 16 +++++++--------- tests/test_vista3d_transforms.py | 10 +++++----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/monai/apps/vista3d/workflow.py b/monai/apps/vista3d/workflow.py index 688e79b7a4..3552a592e9 100644 --- a/monai/apps/vista3d/workflow.py +++ b/monai/apps/vista3d/workflow.py @@ -12,10 +12,8 @@ import copy import random -import monai import numpy as np import torch -import torch.nn.functional as F from torch import Tensor from collections.abc import Callable, Sequence @@ -36,7 +34,7 @@ def get_point_label(id): return 2, 3 else: return 0, 1 - + def generate_prompt_pairs( labels: Tensor, label_set: Sequence[int] | None = None, @@ -49,7 +47,7 @@ def generate_prompt_pairs( drop_point_prob: float = 0.2, point_sampler: Callable | None = None ): - """ Sample training pairs for VISTA3D training. + """ Sample training pairs for VISTA3D training. Args: labels: [1, 1, H, W, D], ground truth labels. label_set: the label list for the specific dataset. @@ -66,9 +64,9 @@ def generate_prompt_pairs( label_prompt: [B, 1]. The classes used for training automatic segmentation point: [B, N, 3]. The corresponding points for each class. Note that background label prompt requires matching point as well ([0,0,0] is used). - point_label: [B, N]. The corresponding point labels for each point (negative or positive). - -1 is used for padding the background label prompt and will be ignored. - prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. + point_label: [B, N]. The corresponding point labels for each point (negative or positive). + -1 is used for padding the background label prompt and will be ignored. + prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. label_prompt can be None, and prompt_class is used to identify point classess. """ # class label number @@ -80,7 +78,7 @@ def generate_prompt_pairs( unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) else: unique_labels = list( - set(unique_labels) - (set(unique_labels) - set(label_set)) - set([0]) + set(unique_labels) - (set(unique_labels) - set(label_set)) - {0} ) background_labels = list(set(label_set) - set(unique_labels)) # during training, balance background and foreground prompts @@ -143,7 +141,7 @@ def generate_prompt_pairs( ) # -1 not a point if len(unique_labels) == 0 and len(background_labels) == 0: # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must - # be skipped. Handle this in trainer. + # be skipped. Handle this in trainer. label_prompt, point, point_label, prompt_class = None, None, None, None else: label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 1763b76654..1972220c63 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -80,19 +80,19 @@ label[:, :, 30:40, 30:40, 30:40] = 3 label1 = torch.zeros([1,1,64,64,64]) TEST_VISTA_GENERATEPROMPT = [ - [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, + [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":0}, [4, 4, 4, 4] ], - [{"labels": label, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + [{"labels": label, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [2, None, None, 2] ], - [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, + [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, "max_backprompt":1, "drop_label_prob":1, "drop_point_prob":0}, [None, 3, 3, 3] ], - [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [1, None, None, 1] ], - [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, + [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, "max_backprompt":0, "drop_label_prob":0, "drop_point_prob":1}, [None, None, None, None] ] ] From f1686bcc387c056f97c7dd5b94a0d0e0eefccc5b Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 22 Aug 2024 15:48:05 +0800 Subject: [PATCH 27/32] fix issues Signed-off-by: Yiheng Wang --- docs/source/apps.rst | 3 + .../apps/vista3d/{workflow.py => sampler.py} | 91 ++++++++-------- tests/test_vista3d_sampler.py | 100 ++++++++++++++++++ tests/test_vista3d_transforms.py | 32 ------ 4 files changed, 149 insertions(+), 77 deletions(-) rename monai/apps/vista3d/{workflow.py => sampler.py} (71%) create mode 100644 tests/test_vista3d_sampler.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 9a6bf09c4b..cc4cea8c1e 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -261,6 +261,9 @@ FastMRIReader .. autoclass:: Relabeld :members: +.. automodule:: monai.apps.vista3d.sampler +.. autofunction:: sample_prompt_pairs + `Auto3DSeg` ----------- .. automodule:: monai.apps.auto3dseg diff --git a/monai/apps/vista3d/workflow.py b/monai/apps/vista3d/sampler.py similarity index 71% rename from monai/apps/vista3d/workflow.py rename to monai/apps/vista3d/sampler.py index 3552a592e9..2a78e6ea1b 100644 --- a/monai/apps/vista3d/workflow.py +++ b/monai/apps/vista3d/sampler.py @@ -9,15 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import random +from collections.abc import Callable, Sequence import numpy as np import torch from torch import Tensor -from collections.abc import Callable, Sequence -__all__ = ["generate_prompt_pairs"] +__all__ = ["sample_prompt_pairs"] ENABLE_SPECIAL = True SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) @@ -28,16 +30,16 @@ } -def get_point_label(id): - # [B, N] +def _get_point_label(id: int) -> tuple[int, int]: if id in SPECIAL_INDEX and ENABLE_SPECIAL: return 2, 3 else: return 0, 1 -def generate_prompt_pairs( + +def sample_prompt_pairs( labels: Tensor, - label_set: Sequence[int] | None = None, + label_set: Sequence[int], max_prompt: int | None = None, max_foreprompt: int | None = None, max_backprompt: int = 1, @@ -45,9 +47,11 @@ def generate_prompt_pairs( include_background: bool = False, drop_label_prob: float = 0.2, drop_point_prob: float = 0.2, - point_sampler: Callable | None = None -): - """ Sample training pairs for VISTA3D training. + point_sampler: Callable | None = None, +) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + """ + Sample training pairs for VISTA3D training. + Args: labels: [1, 1, H, W, D], ground truth labels. label_set: the label list for the specific dataset. @@ -55,31 +59,30 @@ def generate_prompt_pairs( max_foreprompt: int, max number of prompt from foreground. max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. - include_background: if include label=0 into training prompt. May casue issue in partial label - trainig. - drop_label_prob: probablity to drop label prompt. - drop_point_prob: probablity to drop point prompt. + include_background: if include label=0 into training prompt. May cause issue in partial label training. + drop_label_prob: probability to drop label prompt. + drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. + Returns: - label_prompt: [B, 1]. The classes used for training automatic segmentation - point: [B, N, 3]. The corresponding points for each class. Note that background label prompt - requires matching point as well ([0,0,0] is used). + label_prompt: [B, 1]. The classes used for training automatic segmentation. + point: [B, N, 3]. The corresponding points for each class. + Note that background label prompt requires matching point as well ([0,0,0] is used). point_label: [B, N]. The corresponding point labels for each point (negative or positive). - -1 is used for padding the background label prompt and will be ignored. + -1 is used for padding the background label prompt and will be ignored. prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. - label_prompt can be None, and prompt_class is used to identify point classess. + label_prompt can be None, and prompt_class is used to identify point classes. """ # class label number - assert labels.shape[0] == 1, "only support batch size 1" + if not labels.shape[0] == 1: + raise ValueError("only support batch size 1") labels = labels[0, 0] device = labels.device unique_labels = labels.unique().cpu().numpy().tolist() if include_background: unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) else: - unique_labels = list( - set(unique_labels) - (set(unique_labels) - set(label_set)) - {0} - ) + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0}) background_labels = list(set(label_set) - set(unique_labels)) # during training, balance background and foreground prompts if max_backprompt is not None: @@ -98,9 +101,7 @@ def generate_prompt_pairs( unique_labels = random.sample(unique_labels, max_prompt) background_labels = [] else: - background_labels = random.sample( - background_labels, max_prompt - len(unique_labels) - ) + background_labels = random.sample(background_labels, max_prompt - len(unique_labels)) _point = [] _point_label = [] # if use regular sampling @@ -108,7 +109,7 @@ def generate_prompt_pairs( num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) for id in unique_labels: - neg_id, pos_id = get_point_label(id) + neg_id, pos_id = _get_point_label(id) plabels = labels == int(id) nlabels = ~plabels plabelpoints = torch.nonzero(plabels) @@ -119,26 +120,26 @@ def generate_prompt_pairs( num_na = min(len(nlabelpoints), num_n) _point.append( torch.stack( - random.choices(plabelpoints, k=num_pa) + random.choices(nlabelpoints, k=num_na) + random.choices(plabelpoints, k=num_pa) + + random.choices(nlabelpoints, k=num_na) + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na) ) ) _point_label.append( - torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na) - ).to(device) + torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to( + device + ) ) - for id in background_labels: + for _ in background_labels: # pad the background labels _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point else: _point, _point_label = point_sampler(unique_labels, Np=max_point, Nn=0) - for id in background_labels: + for _ in background_labels: # pad the background labels _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 - _point_label.append( - torch.zeros(len(_point_label[0])).to(device) - 1 - ) # -1 not a point + _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point if len(unique_labels) == 0 and len(background_labels) == 0: # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must # be skipped. Handle this in trainer. @@ -148,15 +149,15 @@ def generate_prompt_pairs( point = torch.stack(_point) point_label = torch.stack(_point_label) prompt_class = copy.deepcopy(label_prompt) - if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: - label_prompt = None - # If label prompt is dropped, there is no need to pad with points with label -1. - pad = len(background_labels) - point = point[: len(point) - pad] - point_label = point_label[: len(point_label) - pad] - prompt_class = prompt_class[: len(prompt_class) - pad] - else: - if random.uniform(0, 1) < drop_point_prob: - point = None - point_label = None + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # If label prompt is dropped, there is no need to pad with points with label -1. + pad = len(background_labels) + point = point[: len(point) - pad] # type: ignore + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None return label_prompt, point, point_label, prompt_class diff --git a/tests/test_vista3d_sampler.py b/tests/test_vista3d_sampler.py new file mode 100644 index 0000000000..6945d250d2 --- /dev/null +++ b/tests/test_vista3d_sampler.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.sampler import sample_prompt_pairs + +label = torch.zeros([1, 1, 64, 64, 64]) +label[:, :, :10, :10, :10] = 1 +label[:, :, 20:30, 20:30, 20:30] = 2 +label[:, :, 30:40, 30:40, 30:40] = 3 +label1 = torch.zeros([1, 1, 64, 64, 64]) + +TEST_VISTA_SAMPLE_PROMPT = [ + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 0, + }, + [4, 4, 4, 4], + ], + [ + { + "labels": label, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [2, None, None, 2], + ], + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 1, + "drop_point_prob": 0, + }, + [None, 3, 3, 3], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [1, None, None, 1], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 0, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [None, None, None, None], + ], +] + + +class TestGeneratePrompt(unittest.TestCase): + @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT) + def test_result(self, input_data, expected): + output = sample_prompt_pairs(**input_data) + result = [i.shape[0] if i is not None else None for i in output] + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 1972220c63..4a647a8d1e 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -18,8 +18,6 @@ from parameterized import parameterized from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd -from monai.apps.vista3d.workflow import generate_prompt_pairs - from monai.utils import min_version from monai.utils.module import optional_import @@ -74,28 +72,6 @@ ], ] -label = torch.zeros([1,1,64,64,64]) -label[:, :, :10, :10, :10] = 1 -label[:, :, 20:30, 20:30, 20:30] = 2 -label[:, :, 30:40, 30:40, 30:40] = 3 -label1 = torch.zeros([1,1,64,64,64]) -TEST_VISTA_GENERATEPROMPT = [ - [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, - "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":0}, [4, 4, 4, 4] - ], - [{"labels": label, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, - "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [2, None, None, 2] - ], - [{"labels": label, "label_set": [0,1,2,3,4], "max_prompt":5, "max_foreprompt": 4, - "max_backprompt":1, "drop_label_prob":1, "drop_point_prob":0}, [None, 3, 3, 3] - ], - [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, - "max_backprompt":1, "drop_label_prob":0, "drop_point_prob":1}, [1, None, None, 1] - ], - [{"labels": label1, "label_set": [0,1], "max_prompt":5, "max_foreprompt": 4, - "max_backprompt":0, "drop_label_prob":0, "drop_point_prob":1}, [None, None, None, None] - ] -] class TestVistaPreTransformd(unittest.TestCase): @parameterized.expand(TEST_VISTA_PRETRANSFORM) @@ -114,13 +90,5 @@ def test_result(self, input_data, expected): self.assertEqual((result["pred"] == expected).all(), True) -class TestGeneratePrompt(unittest.TestCase): - @parameterized.expand(TEST_VISTA_GENERATEPROMPT) - def test_result(self, input_data, expected): - output = generate_prompt_pairs(**input_data) - result = [i.shape[0] if i is not None else None for i in output] - self.assertEqual(result, expected) - - if __name__ == "__main__": unittest.main() From 8c1d7b6de90c99bf8bc46805b986c5d9fdd9b12b Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 22 Aug 2024 16:40:08 +0800 Subject: [PATCH 28/32] update kwargs Signed-off-by: Yiheng Wang --- monai/apps/vista3d/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index 2a78e6ea1b..0648772793 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -14,6 +14,7 @@ import copy import random from collections.abc import Callable, Sequence +from typing import Any import numpy as np import torch @@ -48,6 +49,7 @@ def sample_prompt_pairs( drop_label_prob: float = 0.2, drop_point_prob: float = 0.2, point_sampler: Callable | None = None, + **point_sampler_kwargs: Any, ) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: """ Sample training pairs for VISTA3D training. @@ -63,6 +65,7 @@ def sample_prompt_pairs( drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. + point_sampler_kwargs: arguments for point_sampler. Returns: label_prompt: [B, 1]. The classes used for training automatic segmentation. @@ -135,7 +138,7 @@ def sample_prompt_pairs( _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point else: - _point, _point_label = point_sampler(unique_labels, Np=max_point, Nn=0) + _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs) for _ in background_labels: # pad the background labels _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 From 9363fa3da81bbd97562f3aff0b59b127affc5824 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Thu, 22 Aug 2024 16:13:23 -0400 Subject: [PATCH 29/32] Fix bug in convert point to disc and add more doc Signed-off-by: heyufan1995 --- monai/apps/vista3d/sampler.py | 10 ++++++++-- monai/apps/vista3d/transforms.py | 2 ++ monai/transforms/utils.py | 2 +- tests/test_vista3d_utils.py | 33 ++++++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index 0648772793..60ab3145ce 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -56,12 +56,18 @@ def sample_prompt_pairs( Args: labels: [1, 1, H, W, D], ground truth labels. - label_set: the label list for the specific dataset. + label_set: the label list for the specific dataset. Note if 0 is included in label_set, + it will be added into automatic branch training. Recommend removing 0 from label_set + for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. + The reason is region with 0 in one partially labeled dataset may contain foregrounds in + another dataset. max_prompt: int, max number of total prompt, including foreground and background. max_foreprompt: int, max number of prompt from foreground. max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. - include_background: if include label=0 into training prompt. May cause issue in partial label training. + include_background: if include 0 into training prompt. If included, background 0 is treated + the same as foreground. Always be False for multi-partial-dataset training. If needed, + can be true for finetuning specific dataset, . drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 972c94d39d..98f23a7274 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -140,6 +140,8 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) def __call__(self, data): + """ data["label_prompt"] should not contain 0 + """ for keys in self.keys: if keys in data: pred = data[keys] diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 305460ec9e..7027c07d67 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1314,7 +1314,7 @@ def convert_points_to_disc( _array = [ torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) ] - coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2]) # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 6fe61682fb..07cc37bf7f 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -53,6 +53,32 @@ expected_shape, ] ) + image_size = (16, 32, 64) + point = torch.tensor([[[8,16,42], [2,8,21]]]) + point_label = torch.tensor([[1, 0]]) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_CONVERT_POINTS_TO_DISC_VALUE = [] +image_size = (16, 32, 64) +point = torch.tensor([[[8,16,42], [2,8,21]]]) +point_label = torch.tensor([[1, 0]]) +expected_shape = (point.shape[0], 2, *image_size) +for radius in [5, 10]: + for disc in [True, False]: + TEST_CONVERT_POINTS_TO_DISC_VALUE.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + [point, point_label] + ] + ) + + TEST_LCC_MASK_POINT_TORCH = [] for bs in [1, 2]: @@ -104,6 +130,13 @@ def test_shape(self, input_data, expected_shape): result = convert_points_to_disc(**input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE) + def test_value(self, input_data, points): + result = convert_points_to_disc(**input_data) + point, point_label = points + for i in range(point.shape[0]): + for j in range(point.shape[1]): + self.assertEqual(result[i, point_label[i,j], point[i,j][0], point[i,j][1], point[i,j][2]], True) @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") class TestKeepMergeComponentsWithPoints(unittest.TestCase): From e5c8baae7e599d5b13c14399acd40031e8434ba3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:13:53 +0000 Subject: [PATCH 30/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/vista3d/sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index 60ab3145ce..b7aeb89a2e 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -58,16 +58,16 @@ def sample_prompt_pairs( labels: [1, 1, H, W, D], ground truth labels. label_set: the label list for the specific dataset. Note if 0 is included in label_set, it will be added into automatic branch training. Recommend removing 0 from label_set - for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. - The reason is region with 0 in one partially labeled dataset may contain foregrounds in + for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. + The reason is region with 0 in one partially labeled dataset may contain foregrounds in another dataset. max_prompt: int, max number of total prompt, including foreground and background. max_foreprompt: int, max number of prompt from foreground. max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. include_background: if include 0 into training prompt. If included, background 0 is treated - the same as foreground. Always be False for multi-partial-dataset training. If needed, - can be true for finetuning specific dataset, . + the same as foreground. Always be False for multi-partial-dataset training. If needed, + can be true for finetuning specific dataset, . drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. From 7b5f5f6449fb6d81afd628face9bceb00cc309e2 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 23 Aug 2024 10:42:06 +0800 Subject: [PATCH 31/32] autofix Signed-off-by: Yiheng Wang --- monai/apps/vista3d/transforms.py | 3 +-- tests/test_vista3d_utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 98f23a7274..9a46a319db 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -140,8 +140,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) def __call__(self, data): - """ data["label_prompt"] should not contain 0 - """ + """data["label_prompt"] should not contain 0""" for keys in self.keys: if keys in data: pred = data[keys] diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index 07cc37bf7f..5a0caedd61 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -54,7 +54,7 @@ ] ) image_size = (16, 32, 64) - point = torch.tensor([[[8,16,42], [2,8,21]]]) + point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) point_label = torch.tensor([[1, 0]]) expected_shape = (point.shape[0], 2, *image_size) TEST_CONVERT_POINTS_TO_DISC.append( @@ -66,7 +66,7 @@ TEST_CONVERT_POINTS_TO_DISC_VALUE = [] image_size = (16, 32, 64) -point = torch.tensor([[[8,16,42], [2,8,21]]]) +point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) point_label = torch.tensor([[1, 0]]) expected_shape = (point.shape[0], 2, *image_size) for radius in [5, 10]: @@ -74,12 +74,11 @@ TEST_CONVERT_POINTS_TO_DISC_VALUE.append( [ {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, - [point, point_label] + [point, point_label], ] ) - TEST_LCC_MASK_POINT_TORCH = [] for bs in [1, 2]: for num_points in [1, 3]: @@ -136,7 +135,8 @@ def test_value(self, input_data, points): point, point_label = points for i in range(point.shape[0]): for j in range(point.shape[1]): - self.assertEqual(result[i, point_label[i,j], point[i,j][0], point[i,j][1], point[i,j][2]], True) + self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True) + @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") class TestKeepMergeComponentsWithPoints(unittest.TestCase): From 194bd77a6b547bbd1c38da952d15d2ec9d6e6b34 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 26 Aug 2024 11:29:23 +0800 Subject: [PATCH 32/32] resolve issues Signed-off-by: Yiheng Wang --- monai/apps/vista3d/inferer.py | 2 ++ monai/apps/vista3d/transforms.py | 3 ++- tests/test_vista3d_transforms.py | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 6c3372550c..709f81f624 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -72,6 +72,8 @@ def point_based_window_inferer( """ if not point_coords.shape[0] == 1: raise ValueError("Only supports single object point click.") + if not len(inputs.shape) == 5: + raise ValueError("Input image should be 5D.") image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size) point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py index 9a46a319db..3e8145cd80 100644 --- a/monai/apps/vista3d/transforms.py +++ b/monai/apps/vista3d/transforms.py @@ -165,9 +165,10 @@ def __call__(self, data): pred[pred > 0] = 1.0 if "label_prompt" in data and data["label_prompt"] is not None: pred += 0.5 # inplace mapping to avoid cloning pred + label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device for i in range(1, object_num + 1): frac = i + 0.5 - pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == frac] = label_prompt[i - 1].to(pred.dtype) pred[pred == 0.5] = 0.0 data[keys] = pred return data diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py index 4a647a8d1e..9d61fe2fc2 100644 --- a/tests/test_vista3d_transforms.py +++ b/tests/test_vista3d_transforms.py @@ -61,14 +61,14 @@ output2[:, 20:30, 20:30, 20:30] = 1 TEST_VISTA_POSTTRANSFORM = [ - [{"pred": pred1, "label_prompt": torch.tensor([2, 3]).to(device)}, output1], + [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)], [ { - "pred": pred2, + "pred": pred2.to(device), "points": torch.tensor([[25, 25, 25]]).to(device), "point_labels": torch.tensor([1]).to(device), }, - output2, + output2.to(device), ], ]