Skip to content

Commit

Permalink
Add vista3d inferers (#8021)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
5 people authored Aug 26, 2024
1 parent a5fbe71 commit 872585d
Show file tree
Hide file tree
Showing 13 changed files with 988 additions and 17 deletions.
16 changes: 16 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ FastMRIReader
~~~~~~~~~~~~~
.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj

`Vista3d`
---------
.. automodule:: monai.apps.vista3d.inferer
.. autofunction:: point_based_window_inferer

.. automodule:: monai.apps.vista3d.transforms
.. autoclass:: VistaPreTransformd
:members:
.. autoclass:: VistaPostTransformd
:members:
.. autoclass:: Relabeld
:members:

.. automodule:: monai.apps.vista3d.sampler
.. autofunction:: sample_prompt_pairs

`Auto3DSeg`
-----------
.. automodule:: monai.apps.auto3dseg
Expand Down
File renamed without changes.
177 changes: 177 additions & 0 deletions monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
from collections.abc import Sequence
from typing import Any

import torch

from monai.data.meta_tensor import MetaTensor
from monai.utils import optional_import

tqdm, _ = optional_import("tqdm", name="tqdm")

__all__ = ["point_based_window_inferer"]


def point_based_window_inferer(
inputs: torch.Tensor | MetaTensor,
roi_size: Sequence[int],
predictor: torch.nn.Module,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
class_vector: torch.Tensor | None = None,
prompt_class: torch.Tensor | None = None,
prev_mask: torch.Tensor | MetaTensor | None = None,
point_start: int = 0,
center_only: bool = True,
margin: int = 5,
**kwargs: Any,
) -> torch.Tensor:
"""
Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
patch inference and average output stitching, and finally returns the segmented mask.
Args:
inputs: [1CHWD], input image to be processed.
roi_size: the spatial window size for inferences.
When its components have None or non-positives, the corresponding inputs dimension will be used.
if the components of the `roi_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
sw_batch_size: the batch size to run window slices.
predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
Add transpose=True in kwargs for vista3d.
point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
prompt_class: [B]. The same as class_vector representing the point class and inform point head about
supported class or zeroshot, not used for automatic segmentation. If None, point head is default
to supported class segmentation.
prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
point_start: only use points starting from this number. All points before this number is used to generate
prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
margin: if center_only is false, this value is the distance between point to the patch boundary.
Returns:
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
"""
if not point_coords.shape[0] == 1:
raise ValueError("Only supports single object point click.")
if not len(inputs.shape) == 5:
raise ValueError("Input image should be 5D.")
image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
stitched_output = None
for p in point_coords[0][point_start:]:
lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
for i in range(len(lx_)):
for j in range(len(ly_)):
for k in range(len(lz_)):
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
unravel_slice = [
slice(None),
slice(None),
slice(int(lx), int(rx)),
slice(int(ly), int(ry)),
slice(int(lz), int(rz)),
]
batch_image = image[unravel_slice]
output = predictor(
batch_image,
point_coords=point_coords,
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
patch_coords=unravel_slice,
prev_mask=prev_mask,
**kwargs,
)
if stitched_output is None:
stitched_output = torch.zeros(
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
)
stitched_mask = torch.zeros(
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
)
stitched_output[unravel_slice] += output.to("cpu")
stitched_mask[unravel_slice] = 1
# if stitched_mask is 0, then NaN value
stitched_output = stitched_output / stitched_mask
# revert padding
stitched_output = stitched_output[
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
]
stitched_mask = stitched_mask[
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
]
if prev_mask is not None:
prev_mask = prev_mask[
:,
:,
pad[4] : image.shape[-3] - pad[5],
pad[2] : image.shape[-2] - pad[3],
pad[0] : image.shape[-1] - pad[1],
]
prev_mask = prev_mask.to("cpu") # type: ignore
# for un-calculated place, use previous mask
stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
if isinstance(inputs, torch.Tensor):
inputs = MetaTensor(inputs)
if not hasattr(stitched_output, "meta"):
stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
return stitched_output


def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
"""Helper function to get the window index."""
if p - roi // 2 < 0:
left, right = 0, roi
elif p + roi // 2 > s:
left, right = s - roi, s
else:
left, right = int(p) - roi // 2, int(p) + roi // 2
return left, right


def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
"""Get the window index."""
left, right = _get_window_idx_c(p, roi, s)
if center_only:
return [left], [right]
left_most = max(0, p - roi + margin)
right_most = min(s, p + roi - margin)
left_list = [left_most, right_most - roi, left]
right_list = [left_most + roi, right_most, right]
return left_list, right_list


def _pad_previous_mask(
inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
) -> tuple[torch.Tensor | MetaTensor, list[int]]:
"""Helper function to pad inputs."""
pad_size = []
for k in range(len(inputs.shape) - 1, 1, -1):
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
return inputs, pad_size
172 changes: 172 additions & 0 deletions monai/apps/vista3d/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
import random
from collections.abc import Callable, Sequence
from typing import Any

import numpy as np
import torch
from torch import Tensor

__all__ = ["sample_prompt_pairs"]

ENABLE_SPECIAL = True
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
MERGE_LIST = {
1: [25, 26], # hepatic tumor and vessel merge into liver
4: [24], # pancreatic tumor merge into pancreas
132: [57], # overlap with trachea merge into airway
}


def _get_point_label(id: int) -> tuple[int, int]:
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
return 2, 3
else:
return 0, 1


def sample_prompt_pairs(
labels: Tensor,
label_set: Sequence[int],
max_prompt: int | None = None,
max_foreprompt: int | None = None,
max_backprompt: int = 1,
max_point: int = 20,
include_background: bool = False,
drop_label_prob: float = 0.2,
drop_point_prob: float = 0.2,
point_sampler: Callable | None = None,
**point_sampler_kwargs: Any,
) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
"""
Sample training pairs for VISTA3D training.
Args:
labels: [1, 1, H, W, D], ground truth labels.
label_set: the label list for the specific dataset. Note if 0 is included in label_set,
it will be added into automatic branch training. Recommend removing 0 from label_set
for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
The reason is region with 0 in one partially labeled dataset may contain foregrounds in
another dataset.
max_prompt: int, max number of total prompt, including foreground and background.
max_foreprompt: int, max number of prompt from foreground.
max_backprompt: int, max number of prompt from background.
max_point: maximum number of points for each object.
include_background: if include 0 into training prompt. If included, background 0 is treated
the same as foreground. Always be False for multi-partial-dataset training. If needed,
can be true for finetuning specific dataset, .
drop_label_prob: probability to drop label prompt.
drop_point_prob: probability to drop point prompt.
point_sampler: sampler to augment masks with supervoxel.
point_sampler_kwargs: arguments for point_sampler.
Returns:
label_prompt: [B, 1]. The classes used for training automatic segmentation.
point: [B, N, 3]. The corresponding points for each class.
Note that background label prompt requires matching point as well ([0,0,0] is used).
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
-1 is used for padding the background label prompt and will be ignored.
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
label_prompt can be None, and prompt_class is used to identify point classes.
"""
# class label number
if not labels.shape[0] == 1:
raise ValueError("only support batch size 1")
labels = labels[0, 0]
device = labels.device
unique_labels = labels.unique().cpu().numpy().tolist()
if include_background:
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
else:
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
background_labels = list(set(label_set) - set(unique_labels))
# during training, balance background and foreground prompts
if max_backprompt is not None:
if len(background_labels) > max_backprompt:
random.shuffle(background_labels)
background_labels = background_labels[:max_backprompt]

if max_foreprompt is not None:
if len(unique_labels) > max_foreprompt:
random.shuffle(unique_labels)
unique_labels = unique_labels[:max_foreprompt]

if max_prompt is not None:
if len(unique_labels) + len(background_labels) > max_prompt:
if len(unique_labels) > max_prompt:
unique_labels = random.sample(unique_labels, max_prompt)
background_labels = []
else:
background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
_point = []
_point_label = []
# if use regular sampling
if point_sampler is None:
num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
for id in unique_labels:
neg_id, pos_id = _get_point_label(id)
plabels = labels == int(id)
nlabels = ~plabels
plabelpoints = torch.nonzero(plabels)
nlabelpoints = torch.nonzero(nlabels)
# final sampled positive points
num_pa = min(len(plabelpoints), num_p)
# final sampled negative points
num_na = min(len(nlabelpoints), num_n)
_point.append(
torch.stack(
random.choices(plabelpoints, k=num_pa)
+ random.choices(nlabelpoints, k=num_na)
+ [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
)
)
_point_label.append(
torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
device
)
)
for _ in background_labels:
# pad the background labels
_point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
_point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
else:
_point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
for _ in background_labels:
# pad the background labels
_point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
_point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
if len(unique_labels) == 0 and len(background_labels) == 0:
# if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
# be skipped. Handle this in trainer.
label_prompt, point, point_label, prompt_class = None, None, None, None
else:
label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
point = torch.stack(_point)
point_label = torch.stack(_point_label)
prompt_class = copy.deepcopy(label_prompt)
if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
label_prompt = None
# If label prompt is dropped, there is no need to pad with points with label -1.
pad = len(background_labels)
point = point[: len(point) - pad] # type: ignore
point_label = point_label[: len(point_label) - pad]
prompt_class = prompt_class[: len(prompt_class) - pad]
else:
if random.uniform(0, 1) < drop_point_prob:
point = None
point_label = None
return label_prompt, point, point_label, prompt_class
Loading

0 comments on commit 872585d

Please sign in to comment.