Skip to content

Commit

Permalink
Prototype for annotator plugins (#304)
Browse files Browse the repository at this point in the history
Refactor napari functionality into napari plugins.
  • Loading branch information
constantinpape authored Jan 18, 2024
1 parent 8b0ec02 commit 4aff57f
Show file tree
Hide file tree
Showing 27 changed files with 1,502 additions and 1,751 deletions.
38 changes: 23 additions & 15 deletions doc/development.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,56 @@
# For Developers

This software consists of four different python (sub-)modules:
- The top-level `micro_sam` module implements general purpose functionality for using Segment Anything for multi-dimension data.
- The top-level `micro_sam` module implements general purpose functionality for using Segment Anything for multi-dimensional data.
- `micro_sam.evaluation` provides functionality to evaluate Segment Anything models on (micrscopy) segmentation tasks.
- `micro_sam.traning` implements the training functionality to finetune Segment Anything on custom segmentation datasets.
- `micro_sam.traning` implements the training functionality to finetune Segment Anything for custom segmentation datasets.
- `micro_sam.sam_annotator` implements the interactive annotation tools.

## Annotation Tools

The annotation tools are currently implemented as stand-alone napari applications. We are in the process of implementing them as napari plugins instead (see https://github.com/computational-cell-analytics/micro-sam/issues/167 for details), and the descriptions here refer to the planned architecture for the plugins.
The annotation tools are implemented as napari plugins.

There are four annotation tools:
- `micro_sam.sam_annotator.annotator_2d`: for interactive segmentation of 2d images.
- `micro_sam.sam_annotator.annotator_3d`: for interactive segmentation of volumetric images.
- `micro_sam.sam_annotator.annotator_tracking`: for interactive tracking in timeseries of 2d images.
- `micro_sam.sam_annotator.image_series_annotator`: for applying the 2d annotation tool to a series of images.
- `micro_sam.sam_annotator.image_series_annotator`: for applying the 2d annotation tool to a series of images. This is not implemented as a separate plugin, but as a function that runns annotator 2d for multiple images.

An overview of the functionality of the different tools:

| Functionality | annotator_2d | annotator_3d | annotator_tracking |
| ------------- | ------------ | ------------ | ------------------ |
| Interactive segmentation | Yes | Yes | Yes |
| For multiple objects at a time | Yes | No | No |
| Interactive 3d segmnetation via projection | No | Yes | Yes |
| Interactive segmentation for multiple objects at a time | Yes | No | No |
| Interactive 3d segmentation via projection | No | Yes | Yes |
| Support for dividing objects | No | No | Yes |
| Automatic segmentation | Yes | Yes (on `dev`) | No |
| Automatic segmentation | Yes | Yes | No |

The functionality for the `image_series_annotator` is not listed because it is identical with the functionality of the `annotator_2d`.
The functionality for `image_series_annotator` is not listed because it is identical with the functionality of `annotator_2d`.

Each tool implements the follwing core logic:
1. The image embeddings (prediction from SAM image encoder) are pre-computed for the input data (2d image, image volume or timeseries). These embeddings can be cached to a zarr file.
2. Interactive (and automatic) segmentation functionality is implemented by a UI based on `napari` and `magicgui` functionality.

Each tool has two different entry points:
- From napari plugin menu, e.g. `plugin->micro_sam->annotator_2d` (This entry point is called *plugin* in the following).
- From the command line, e.g. `micro_sam.annotator_2d -i /path/to/image` (This entry point is called *CLI* in the following).
Each tool has three different entry points:
- From napari plugin menu, e.g. `plugin->micro_sam->Annotator 2d`. (Called *plugin* in the following).
- From a python function, e.g. `micro_sam.sam_annotator.annotator_2d:annotator_2d`. (Called *function* in the following.)
- From the command line, e.g. `micro_sam.annotator_2d`. (Called *CLI* in the following).

The tools are implemented their own submodules, e.g. `micro_sam.sam_annotator.annotator_2d` with shared functionality implemented in `micro_sam.sam_annotator.util`. The function `micro_sam.sam_annotator.annotator_2d.annotator_2d_plugin` implements the *plugin* entry point, using the `magicgui.magic_factory` decorator. `micro_sam.sam_annotator.annotator_2d.annotator_2d` implements the *CLI* entry point; it calls the `annotator_2d_plugin` function internally.
The image embeddings are computed by the `embedding widget` (@GenevieveBuckley: will need to be implemented in your PR), which takes the image data from an image layer.
In case of the *plugin* entry point this image layer is created by the user (by loading an image into napari), and the user can then select in the `embedding widget` which layer to use for embedding computation.
In case of *CLI* the image data is specified via the `-i` parameter, the layer is created for that image and the embeddings are computed for it automatically.
Each tool is implemented in its own submodule, e.g. `micro_sam.sam_annotator.annotator_2d`.
The napari plugin is implemented by a class, e.g. `micro_sam.sam_annotator.annotator_2d:Annotator2d`, inheriting from `micro_sam.sam_annotator._annotator._AnnotatorBase`. This class implements the core logic for the plugins.
The concrete annotation tools are instantiated by passing widgets from `micro_sam.sam_annotator._widgets` to it,
which implement the interactive segmentation in 2d, 3d etc.
These plugins are designed so that image embeddings can be computed for user-specified image layers in napari.

The *function* and *CLI* entry points are implemented by `micro_sam.sam_annotator.annotator_2d:annotator_2d` (and corresponding functions for the other annotation tools). They are called with image data, precompute the embeddings for that image and start a napari viewer with this image and the annotation plugin.

<!--
TODO update the flow chart so that it matches the new design.
The same overall design holds true for the other plugins. The flow chart below shows a flow chart with a simplified overview of the design of the 2d annotation tool. Rounded squares represent functions or the corresponding widget and squares napari layers or other data, orange represents the *plugin* enty point, cyan *CLI*. Arrows that do not have a label correspond to a simple input/output relation.
![annotator 2d flow diagram](./images/2d-annotator-flow.png)
-->

<!---
Source for the diagram is here:
Expand Down
7 changes: 5 additions & 2 deletions examples/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def livecell_annotator(use_finetuned_model):
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr")
model_type = "vit_h"

annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type)
annotator_2d(image, embedding_path, model_type=model_type)


def hela_2d_annotator(use_finetuned_model):
Expand All @@ -41,7 +41,7 @@ def hela_2d_annotator(use_finetuned_model):
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr")
model_type = "vit_h"

annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True)
annotator_2d(image, embedding_path, model_type=model_type)


def wholeslide_annotator(use_finetuned_model):
Expand Down Expand Up @@ -77,5 +77,8 @@ def main():
# wholeslide_annotator(use_finetuned_model)


# The corresponding CLI call for hela_2d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_2d -i /home/pape/.cache/micro_sam/sample_data/hela-2d-image.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-hela2d.zarr
if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions examples/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ def main():
em_3d_annotator(finetuned_model)


# The corresponding CLI call for em_3d_annotator:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_3d -i /home/pape/.cache/micro_sam/sample_data/lucchi_pp.zip.unzip/Lucchi++/Test_In -k *.png -e /home/pape/.cache/micro_sam/embeddings/embeddings-lucchi.zarr
if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions examples/annotator_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@ def main():
track_ctc_data(use_finetuned_model)


# The corresponding CLI call for track_ctc_data:
# (replace with cache directory on your machine)
# $ micro_sam.annotator_tracking -i /home/pape/.cache/micro_sam/sample_data/DIC-C2DH-HeLa.zip.unzip/DIC-C2DH-HeLa/01 -k *.tif -e /home/pape/.cache/micro_sam/embeddings/embeddings-ctc.zarr
if __name__ == "__main__":
main()
11 changes: 8 additions & 3 deletions examples/image_series_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def series_annotation(use_finetuned_model):

example_data = fetch_image_series_example_data(DATA_CACHE)
image_folder_annotator(
example_data, "./series-segmentation-result", embedding_path=embedding_path,
pattern="*.tif", model_type=model_type,
precompute_amg_state=True,
example_data, "./series-segmentation-result",
pattern="*.tif",
embedding_path=embedding_path,
model_type=model_type,
precompute_amg_state=False,
)


Expand All @@ -34,5 +36,8 @@ def main():
series_annotation(use_finetuned_model)


# The corresponding CLI call for track_ctc_data:
# (replace with cache directory on your machine)
# $ micro_sam.image_series_annotator -i /home/pape/.cache/micro_sam/sample_data/image-series.zip.unzip/series/ -e /home/pape/.cache/micro_sam/embeddings/series-embeddings/ -o segmentation_results
if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions micro_sam/_test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np


def check_layer_initialization(viewer, expected_shape):
"""Utility function to check the initial layer setup is correct."""

assert len(viewer.layers) == 6
expected_layer_names = [
"image", "auto_segmentation", "committed_objects", "current_object", "point_prompts", "prompts"
]

for layer_name in expected_layer_names:
assert layer_name in viewer.layers

# Check prompt layers
assert viewer.layers["prompts"].data == [] # shape data is list, not numpy array
np.testing.assert_equal(viewer.layers["point_prompts"].data, 0)

# Check segmentation layers.
for layer_name in ["auto_segmentation", "committed_objects", "current_object"]:
assert viewer.layers[layer_name].data.shape == expected_shape
np.testing.assert_equal(viewer.layers[layer_name].data, 0)
34 changes: 25 additions & 9 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ name: micro-sam
display_name: SegmentAnything for Microscopy
contributions:
commands:

# Commands for sample data.
- id: micro-sam.sample_data_image_series
python_name: micro_sam.sample_data:sample_data_image_series
title: Load image series sample data from micro-sam plugin
Expand All @@ -23,12 +25,21 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.embedding_widget
python_name: micro_sam.sam_annotator._widgets:embedding_widget
title: Embedding widget
- id: micro-sam.cachedir_widget
python_name: micro_sam.sam_annotator._widgets:cachedir_widget

# Commands for plugins.
- id: micro-sam.annotator_2d
python_name: micro_sam.sam_annotator.annotator_2d:Annotator2d
title: Start the 2d annotator
- id: micro-sam.annotator_3d
python_name: micro_sam.sam_annotator.annotator_3d:Annotator3d
title: Start the 3d annotator
- id: micro-sam.annotator_tracking
python_name: micro_sam.sam_annotator.annotator_tracking:AnnotatorTracking
title: Start the tracking annotator
- id: micro-sam.settings
python_name: micro_sam.sam_annotator._widgets:settings_widget
title: Set cache directory

sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand All @@ -51,8 +62,13 @@ contributions:
- command: micro-sam.sample_data_segmentation
display_name: Segmentation sample dataset
key: micro-sam-segmentation

widgets:
- command: micro-sam.embedding_widget
display_name: Embedding widget
- command: micro-sam.cachedir_widget
display_name: Set cache directory
- command: micro-sam.annotator_2d
display_name: Annotator 2d
- command: micro-sam.annotator_3d
display_name: Annotator 3d
- command: micro-sam.annotator_tracking
display_name: Annotator Tracking
- command: micro-sam.settings
display_name: Settings
4 changes: 1 addition & 3 deletions micro_sam/sam_annotator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
The interactive annotation tools.
"""The interactive annotation tools.
"""

from .annotator import annotator
from .annotator_2d import annotator_2d
from .annotator_3d import annotator_3d
from .annotator_tracking import annotator_tracking
Expand Down
160 changes: 160 additions & 0 deletions micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import numpy as np

from magicgui.widgets import Container, Widget

from . import _widgets as widgets
from . import util as vutil
from ._state import AnnotatorState

from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import napari


class _AnnotatorBase(Container):
"""Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator.
The annotators differ in their data dimensionality and the widgets.
"""

def _create_layers(self, segmentation_result):
# Add the point layer for point prompts.
self._point_labels = ["positive", "negative"]
self._point_prompt_layer = self._viewer.add_points(
name="point_prompts",
property_choices={"label": self._point_labels},
edge_color="label",
edge_color_cycle=vutil.LABEL_COLOR_CYCLE,
symbol="o",
face_color="transparent",
edge_width=0.5,
size=12,
ndim=self._ndim,
)
self._point_prompt_layer.edge_color_mode = "cycle"

# Add the shape layer for box and other shape prompts.
self._viewer.add_shapes(
face_color="transparent", edge_color="green", edge_width=4, name="prompts", ndim=self._ndim,
)

# Add the label layers for the current object, the automatic segmentation and the committed segmentation.
dummy_data = np.zeros(self._shape, dtype="uint32")
self._viewer.add_labels(data=dummy_data, name="current_object")
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
self._viewer.add_labels(
data=dummy_data if segmentation_result is None else segmentation_result, name="committed_objects"
)
# Randomize colors so it is easy to see when object committed.
self._viewer.layers["committed_objects"].new_colormap()

def _create_widgets(self, segment_widget, segment_nd_widget, autosegment_widget, commit_widget, clear_widget):
self._embedding_widget = widgets.embedding_widget()
# Connect the call button of the embedding widget with a function
# that updates all relevant layers when the image changes.
self._embedding_widget.call_button.changed.connect(self._update_image)

self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels)
self._segment_widget = segment_widget()
widget_list = [self._embedding_widget, self._prompt_widget, self._segment_widget]

if segment_nd_widget is not None:
self._segment_nd_widget = segment_nd_widget()
widget_list.append(self._segment_nd_widget)

if autosegment_widget is not None:
self._autosegment_widget = autosegment_widget()
widget_list.append(self._autosegment_widget)

self._commit_widget = commit_widget()
self._clear_widget = clear_widget()
widget_list.extend([self._commit_widget, self._clear_widget])

# Add the widgets to the container.
self.extend(widget_list)

def _create_keybindings(self):
@self._viewer.bind_key("s")
def _segment(viewer):
self._segment_widget(viewer)

@self._viewer.bind_key("c")
def _commit(viewer):
self._commit_widget(viewer)

@self._viewer.bind_key("t")
def _toggle_label(event=None):
vutil.toggle_label(self._point_prompt_layer)

@self._viewer.bind_key("Shift-C")
def _clear_annotations(viewer):
self._clear_widget(viewer)

if hasattr(self, "_segment_nd_widget"):
@self._viewer.bind_key("Shift-S")
def _seg_nd(viewer):
self._segment_nd_widget(viewer)

# TODO
# We could implement a better way of initializing the segmentation result,
# so that instead of just passing a numpy array an existing layer from the napari
# viewer can be chosen.
# See https://github.com/computational-cell-analytics/micro-sam/issues/335
def __init__(
self,
viewer: "napari.viewer.Viewer",
ndim: int,
segment_widget: Widget,
segment_nd_widget: Optional[Widget] = None,
autosegment_widget: Optional[Widget] = None,
commit_widget: Widget = widgets.commit_segmentation_widget,
clear_widget: Widget = widgets.clear_widget,
segmentation_result: Optional[np.ndarray] = None,
) -> None:
"""
Args:
viewer:
ndim:
segment_widget:
segment_nd_widget:
autosegment_widget:
commit_widget:
clear_widget:
segmentation_result:
"""
super().__init__()
self._viewer = viewer

# Add the layers for prompts and segmented obejcts.
# We initialize these with a dummy shape, which is reset to the
# correct shape once an image is set.
self._ndim = ndim
self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
self._create_layers(segmentation_result)

# Add the widgets in common between all annotators.
self._create_widgets(
segment_widget, segment_nd_widget, autosegment_widget, commit_widget, clear_widget,
)

# Add the key bindings in common between all annotators.
self._create_keybindings()

def _update_image(self):
state = AnnotatorState()

# Update the image shape if it has changed.
if state.image_shape != self._shape:
if len(state.image_shape) != self._ndim:
raise RuntimeError(
f"The dim of the annotator {self._ndim} does not match the image data of shape {state.image_shape}."
)
self._shape = state.image_shape

# Reset all layers.
self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")

vutil.clear_annotations(self._viewer, clear_segmentations=False)
Loading

0 comments on commit 4aff57f

Please sign in to comment.