Skip to content

Commit

Permalink
Add 3d segmentation functionality (#255)
Browse files Browse the repository at this point in the history
Implement automatic 3d segmentation for a given slice, add new example data
  • Loading branch information
constantinpape authored Nov 2, 2023
1 parent 88c0d24 commit cd5e818
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 55 deletions.
104 changes: 58 additions & 46 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import napari

from micro_sam import instance_segmentation, util
from micro_sam.multi_dimensional_segmentation import segment_3d_from_slice


def cell_segmentation():
Expand Down Expand Up @@ -32,36 +33,15 @@ def cell_segmentation():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation,
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
instances_mws = instance_segmentation.mask_data_to_segmentation(
instances_mws, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_labels(instances)
napari.run()


Expand Down Expand Up @@ -94,39 +74,71 @@ def cell_segmentation_with_tiling():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10)
# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances)
v.add_labels(instances)
napari.run()

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation.
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
def segmentation_in_3d():
"""Run instance segmentation in 3d, for segmenting all objects that intersect
with a given slice. If you use a fine-tuned model for this then you should
first find good parameters for 2d segmentation.
"""
import imageio.v3 as imageio
from micro_sam.sample_data import fetch_nucleus_3d_example_data

# Load the example image data: 3d nucleus segmentation.
path = fetch_nucleus_3d_example_data("./data")
data = imageio.imread(path)

# Load the SAM model for prediction.
model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc.
checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)

# Run 3d segmentation for a given slice. Will segment all objects found in that slice
# throughout the volume.

# The slice that is used for segmentation in 2d. If you don't specify a slice
# then the middle slice is used.
z_slice = data.shape[0] // 2

# The threshold for filtering objects in the 2d segmentation based on the model's
# predicted iou score. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
pred_iou_thresh = 0.88

# The threshold for filtering objects in the 2d segmentation based on the model's
# stability score for a given object. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
stability_score_thresh = 0.95

instances = segment_3d_from_slice(
predictor, data, z=z_slice,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
verbose=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_image(data)
v.add_labels(instances)
napari.run()


def main():
cell_segmentation()
# cell_segmentation()
# cell_segmentation_with_tiling()
segmentation_in_3d()


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def mask_data_to_segmentation(
shape: Tuple[int, ...],
with_background: bool,
min_object_size: int = 0,
max_object_size: Optional[int] = None,
) -> np.ndarray:
"""Convert the output of the automatic mask generation to an instance segmentation.
Expand All @@ -63,6 +64,7 @@ def mask_data_to_segmentation(
with_background: Whether the segmentation has background. If yes this function assures that the largest
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
max_object_size: The maximal size of an object in pixels.
Returns:
The instance segmentation.
"""
Expand All @@ -77,6 +79,8 @@ def require_numpy(mask):
for mask in masks:
if mask["area"] < min_object_size:
continue
if max_object_size is not None and mask["area"] > max_object_size:
continue

this_seg_id = mask.get("seg_id", seg_id)
segmentation[require_numpy(mask["segmentation"])] = this_seg_id
Expand Down
89 changes: 84 additions & 5 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
Multi-dimensional segmentation with segment anything.
"""Multi-dimensional segmentation with segment anything.
"""

from typing import Any, Optional
import os
from typing import Any, Optional, Union

import numpy as np
from segment_anything.predictor import SamPredictor
from tqdm import tqdm

from . import util
from .instance_segmentation import AutomaticMaskGenerator, mask_data_to_segmentation
from .precompute_state import cache_amg_state
from .prompt_based_segmentation import segment_from_mask


Expand All @@ -21,7 +24,7 @@ def segment_mask_in_volume(
iou_threshold: float,
projection: str,
progress_bar: Optional[Any] = None,
box_extension: int = 0,
box_extension: float = 0.0,
) -> np.ndarray:
"""Segment an object mask in in volumetric data.
Expand All @@ -35,7 +38,7 @@ def segment_mask_in_volume(
iou_threshold: The IOU threshold for continuing segmentation across 3d.
projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'.
progress_bar: Optional progress bar.
box_extension: Extension factor for increasing the box size after projection
box_extension: Extension factor for increasing the box size after projection.
Returns:
Array with the volumetric segmentation
Expand Down Expand Up @@ -132,3 +135,79 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
_update_progress()

return segmentation


def segment_3d_from_slice(
predictor: SamPredictor,
raw: np.ndarray,
z: Optional[int] = None,
embedding_path: Optional[Union[str, os.PathLike]] = None,
projection: str = "mask",
box_extension: float = 0.0,
verbose: bool = True,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_object_size_z: int = 50,
max_object_size_z: Optional[int] = None,
iou_threshold: float = 0.8,
):
"""Segment all objects in a volume intersecting with a specific slice.
This function first segments the objects in the specified slice using the
automatic instance segmentation functionality. Then it segments all objects that
were found in that slice in the volume.
Args:
predictor: The segment anything predictor.
raw: The volumetric image data.
z: The slice from which to start segmentation.
If none is given the central slice will be used.
embedding_path: The path were embeddings will be cached.
If none is given embeddings will not be cached.
projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'.
box_extension: Extension factor for increasing the box size after projection.
verbose: Whether to print progress bar and other status messages.
pred_iou_thresh: The predicted iou value to filter objects in `AutomaticMaskGenerator.generate`.
stability_score_thresh: The stability score to filter objects in `AutomaticMaskGenerator.generate`.
min_object_size_z: Minimal object size in the segmented frame.
max_object_size_z: Maximal object size in the segmented frame.
iou_threshold: The IOU threshold for linking objects across slices.
Returns:
Segmentation volume.
"""
# Compute the image embeddings.
image_embeddings = util.precompute_image_embeddings(predictor, raw, save_path=embedding_path, ndim=3)

# Select the middle slice if no slice is given.
if z is None:
z = raw.shape[0] // 2

# Perform automatic instance segmentation.
if embedding_path is not None:
amg = cache_amg_state(predictor, raw[z], image_embeddings, embedding_path, verbose=verbose, i=z)
else:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(raw[z], image_embeddings, i=z, verbose=verbose)

seg_z = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
seg_z = mask_data_to_segmentation(
seg_z, shape=raw.shape[1:], with_background=True,
min_object_size=min_object_size_z,
max_object_size=max_object_size_z,
)

# Segment all objects that were found in 3d.
seg_ids = np.unique(seg_z)[1:]
segmentation = np.zeros(raw.shape, dtype=seg_z.dtype)
for seg_id in tqdm(seg_ids, desc="Segment objects in 3d", disable=not verbose):
this_seg = np.zeros_like(segmentation)
this_seg[z][seg_z == seg_id] = 1
this_seg = segment_mask_in_volume(
this_seg, predictor, image_embeddings,
segmented_slices=np.array([z]), stop_lower=False, stop_upper=False,
iou_threshold=iou_threshold, projection=projection, box_extension=box_extension,
)
segmentation[this_seg > 0] = seg_id

return segmentation
13 changes: 11 additions & 2 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def cache_amg_state(
image_embeddings: util.ImageEmbeddings,
save_path: Union[str, os.PathLike],
verbose: bool = True,
i: Optional[int] = None,
**kwargs,
) -> instance_segmentation.AMGBase:
"""Compute and cache or load the state for the automatic mask generator.
Expand All @@ -32,6 +33,7 @@ def cache_amg_state(
image_embeddings: The image embeddings.
save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
verbose: Whether to run the computation verbose.
i: The index for which to cache the state.
kwargs: The keyword arguments for the amg class.
Returns:
Expand All @@ -40,7 +42,14 @@ def cache_amg_state(
is_tiled = image_embeddings["input_size"] is None
amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)

save_path_amg = os.path.join(save_path, "amg_state.pickle")
# If i is given we compute the state for a given slice/frame.
# And we have to save the state for slices/frames separately.
if i is None:
save_path_amg = os.path.join(save_path, "amg_state.pickle")
else:
os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")

if os.path.exists(save_path_amg):
if verbose:
print("Load the AMG state from", save_path_amg)
Expand All @@ -52,7 +61,7 @@ def cache_amg_state(
if verbose:
print("Precomputing the state for instance segmentation.")

amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
amg_state = amg.get_state()

# put all state onto the cpu so that the state can be deserialized without a gpu
Expand Down
30 changes: 28 additions & 2 deletions micro_sam/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,12 @@ def sample_data_segmentation():
return [(data, add_image_kwargs)]


def synthetic_data(shape):
def synthetic_data(shape, seed=None):
"""Create synthetic image data and segmentation for training."""
ndim = len(shape)
assert ndim in (2, 3)
image_shape = shape if ndim == 2 else shape[1:]
image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15)
image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15, seed=seed)

if image_shape[1] != image_shape[0]:
image = resize(image, image_shape, order=0, anti_aliasing=False, preserve_range=True).astype(image.dtype)
Expand All @@ -337,3 +337,29 @@ def synthetic_data(shape):

segmentation = label(image)
return image, segmentation


def fetch_nucleus_3d_example_data(save_directory: Union[str, os.PathLike]) -> str:
"""Download the sample data for 3d segmentation of nuclei.
This data contains a small crop from a volume from the publication
"Efficient automatic 3D segmentation of cell nuclei for high-content screening"
https://doi.org/10.1186/s12859-022-04737-4
Args:
save_directory: Root folder to save the downloaded data.
Returns:
The path of the downloaded image.
"""
save_directory = Path(save_directory)
os.makedirs(save_directory, exist_ok=True)
print("Example data directory is:", save_directory.resolve())
fname = "3d-nucleus-data.tif"
pooch.retrieve(
url="https://owncloud.gwdg.de/index.php/s/eW0uNCo8gedzWU4/download",
known_hash="4946896f747dc1c3fc82fb2e1320226d92f99d22be88ea5f9c37e3ba4e281205",
fname=fname,
path=save_directory,
progressbar=True,
)
return os.path.join(save_directory, fname)

0 comments on commit cd5e818

Please sign in to comment.