Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TiledAutomaticMaskGenerator #90

Merged
merged 7 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- cpuonly
- napari
- pooch
- python-elf
- python-elf >=0.4.8
- pytorch
- torchvision
- tqdm
Expand Down
2 changes: 1 addition & 1 deletion environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name:
dependencies:
- napari
- pooch
- python-elf
- python-elf >=0.4.8
- pytorch
- pytorch-cuda>=11.7 # you may need to update the cuda version to match your system
- torchvision
Expand Down
10 changes: 10 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# micro_sam examples

Examples for using the micro_sam annotation tools:
- `sam_annotator_2d.py`: run the interactive 2d annotation tool
- `sam_annotator_3d.py`: run the interactive 3d annotation tool
- `sam_annotator_tracking.py`: run the interactive tracking annotation tool
- `sam_image_series_annotator.py`: run the annotation tool for a series of images

The folder `use_as_library` contains example scripts that show how `micro_sam` can be used as a python
library to apply Segment Anything on mult-dimensional data.
131 changes: 131 additions & 0 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import imageio.v3 as imageio
import napari

from micro_sam import instance_segmentation, util


def cell_segmentation():
"""Run the instance segmentation functionality from micro_sam for segmentation of
HeLA cells. You need to run examples/sam_annotator_2d.py:hela_2d_annotator once before
running this script so that all required data is downloaded and pre-computed.
"""
image_path = "../data/hela-2d-image.png"
embedding_path = "../embeddings/embeddings-hela2d.zarr"

# Load the image, the SAM Model, and the pre-computed embeddings.
image = imageio.imread(image_path)
predictor = util.get_sam_model()
embeddings = util.precompute_image_embeddings(predictor, image, save_path=embedding_path)

# Use the instance segmentation logic of SegmentAnything.
# This works by covering the image with a grid of points, getting the masks for all the poitns
# and only keeping the plausible ones (according to the model predictions).
# While the functionality here does the same as the implementation from SegmentAnything,
# we enable changing the hyperparameters, e.g. 'pred_iou_thresh', without recomputing masks and embeddings,
# to support (interactive) evaluation of different hyperparameters.

# Create the automatic mask generator class.
amg = instance_segmentation.AutomaticMaskGenerator(predictor)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's considerably faster than the original implementation.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
instances_mws = instance_segmentation.mask_data_to_segmentation(
instances_mws, shape=image.shape, with_background=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
napari.run()


def segmentation_with_tiling():
"""Run the instance segmentation functionality from micro_sam for segmentation of
cells in a large image. You need to run examples/sam_annotator_2d.py:wholeslide_annotator once before
running this script so that all required data is downloaded and pre-computed.
"""
image_path = "../data/whole-slide-example-image.tif"
embedding_path = "../embeddings/whole-slide-embeddings.zarr"

# Load the image, the SAM Model, and the pre-computed embeddings.
image = imageio.imread(image_path)
predictor = util.get_sam_model()
embeddings = util.precompute_image_embeddings(
predictor, image, save_path=embedding_path, tile_shape=(1024, 1024), halo=(256, 256)
)

# Use the instance segmentation logic of SegmentAnything.
# This works by covering the image with a grid of points, getting the masks for all the poitns
# and only keeping the plausible ones (according to the model predictions).
# The functionality here is similar to the instance segmentation in Segment Anything,
# but uses the pre-computed tiled embeddings.

# Create the automatic mask generator class.
amg = instance_segmentation.TiledAutomaticMaskGenerator(predictor)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's considerably faster than the original implementation.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)

# Show the results.
v = napari.Viewer()
v.add_image(image)
# v.add_labels(instances_amg)
v.add_labels(instances_mws)
napari.run()


def main():
cell_segmentation()
# segmentation_with_tiling()


if __name__ == "__main__":
main()
174 changes: 140 additions & 34 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def mask_data_to_segmentation(
#


class _AMGBase(ABC):
"""
class AMGBase(ABC):
"""Base class for the automatic mask generators.
"""
def __init__(self):
# the state that has to be computed by the 'initialize' method of the child classes
Expand Down Expand Up @@ -277,7 +277,7 @@ def set_state(self, state: Dict[str, Any]) -> None:
self._is_initialized = True


class AutomaticMaskGenerator(_AMGBase):
class AutomaticMaskGenerator(AMGBase):
"""Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as
Expand Down Expand Up @@ -358,8 +358,11 @@ def _process_batch(self, points, im_size):

def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
# crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
if crop_box is None:
cropped_im = image
else:
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]

if not precomputed_embeddings:
Expand Down Expand Up @@ -477,7 +480,7 @@ def generate(
)
data.cat(crop_data)

if len(self.crop_boxes) > 1:
if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
Expand All @@ -494,7 +497,7 @@ def generate(
return masks


class EmbeddingMaskGenerator(_AMGBase):
class EmbeddingMaskGenerator(AMGBase):
"""Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.

Uses an intial segmentation derived from the image embeddings via the Mutex Watershed,
Expand Down Expand Up @@ -718,6 +721,133 @@ def set_state(self, state: Dict[str, Any]) -> None:
super().set_state(state)


def _compute_tiled_embeddings(predictor, image, image_embeddings, embedding_save_path, tile_shape, halo):
have_tiling_params = (tile_shape is not None) and (halo is not None)
if image_embeddings is None and have_tiling_params:
if embedding_save_path is None:
raise ValueError(
"You have passed neither pre-computed embeddings nor a path for saving embeddings."
"Embeddings with tiling can only be computed if a save path is given."
)
image_embeddings = util.precompute_image_embeddings(
predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
)
elif image_embeddings is None and not have_tiling_params:
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
else:
feats = image_embeddings["features"]
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
if have_tiling_params and (
(list(tile_shape) != list(tile_shape_)) or
(list(halo) != list(halo_))
):
warnings.warn(
"You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
"the values of the tiling parameters from the embeddings disagree with the ones that were passed."
"The tiling parameters you have passed wil be ignored."
)
tile_shape = tile_shape_
halo = halo_

return image_embeddings, tile_shape, halo


class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
"""Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.

Args:
predictor: The segment anything predictor.
points_per_side: The number of points to be sampled along one side of the image.
If None, `point_grids` must provide explicit point sampling.
points_per_batch: The number of points run simultaneously by the model.
Higher numbers may be faster but use more GPU memory.
point_grids: A lisst over explicit grids of points used for sampling masks.
Normalized to [0, 1] with respect to the image coordinate system.
"""

# We only expose the arguments that make sense for the tiled mask generator.
# Anything related to crops doesn't make sense, because we re-use that functionality
# for tiling, so these parameters wouldn't have any effect.
def __init__(
self,
predictor: SamPredictor,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
point_grids: Optional[List[np.ndarray]] = None,
) -> None:
super().__init__(
predictor=predictor,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
point_grids=point_grids,
)

@torch.no_grad()
def initialize(
self,
image: np.ndarray,
image_embeddings: Optional[util.ImageEmbeddings] = None,
i: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
verbose: bool = False,
embedding_save_path: Optional[str] = None,
) -> None:
"""Initialize image embeddings and masks for an image.

Args:
image: The input image, volume or timeseries.
image_embeddings: Optional precomputed image embeddings.
See `util.precompute_image_embeddings` for details.
i: Index for the image data. Required if `image` has three spatial dimensions
or a time dimension and two spatial dimensions.
tile_shape: The tile shape for embedding prediction.
halo: The overlap of between tiles.
verbose: Whether to print computation progress.
embedding_save_path: Where to save the image embeddings.
"""
original_size = image.shape[:2]
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(
self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo
)

tiling = blocking([0, 0], original_size, tile_shape)
n_tiles = tiling.numberOfBlocks

mask_data = []
for tile_id in tqdm(range(n_tiles), total=n_tiles, desc="Compute masks for tile", disable=not verbose):
# get the bounding box for this tile and crop the image data
tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
tile_bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
tile_data = image[tile_bb]

# set the pre-computed embeddings for this tile
features = image_embeddings["features"][tile_id]
tile_embeddings = {
"features": features,
"input_size": features.attrs["input_size"],
"original_size": features.attrs["original_size"],
}
util.set_precomputed(self._predictor, tile_embeddings, i)

# compute the mask data for this tile and append it
this_mask_data = self._process_crop(
tile_data, crop_box=None, crop_layer_idx=0, verbose=verbose, precomputed_embeddings=True
)
mask_data.append(this_mask_data)

# set the initialized data
self._is_initialized = True
self._crop_list = mask_data
self._original_size = original_size

# the crop box is always the full local tile
tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
self._crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]


class TiledEmbeddingMaskGenerator(EmbeddingMaskGenerator):
"""Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.

Expand Down Expand Up @@ -812,33 +942,9 @@ def initialize(
embedding_save_path: Where to save the image embeddings.
"""
original_size = image.shape[:2]

have_tiling_params = (tile_shape is not None) and (halo is not None)
if image_embeddings is None and have_tiling_params:
if embedding_save_path is None:
raise ValueError(
"You have passed neither pre-computed embeddings nor a path for saving embeddings."
"Embeddings with tiling can only be computed if a save path is given."
)
image_embeddings = util.precompute_image_embeddings(
self._predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
)
elif image_embeddings is None and not have_tiling_params:
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
else:
feats = image_embeddings["features"]
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
if have_tiling_params and (
(list(tile_shape) != list(tile_shape_)) or
(list(halo) != list(halo_))
):
warnings.warn(
"You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
"the values of the tiling parameters from the embeddings disagree with the ones that were passed."
"The tiling parameters you have passed wil be ignored."
)
tile_shape = tile_shape_
halo = halo_
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(
self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo
)

tiling = blocking([0, 0], original_size, tile_shape)
n_tiles = tiling.numberOfBlocks
Expand Down
Loading
Loading