Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding widget for napari plugin #235

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21829d5
Start on embedding widget
GenevieveBuckley Oct 17, 2023
ac0412e
Add missing typing import in visualization.py file
GenevieveBuckley Oct 17, 2023
d8eb08b
Fix npe2 napari manifest validation
GenevieveBuckley Oct 17, 2023
4a2ed76
Graceful error handling for torch device backend selection
GenevieveBuckley Oct 19, 2023
38734f0
Embedding widget for napari plugin
GenevieveBuckley Oct 19, 2023
9bbf661
Merge branch 'dev' into embedding-widget
GenevieveBuckley Oct 19, 2023
f0b3f49
Fix _available_devices utility function
GenevieveBuckley Oct 19, 2023
c18af45
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Oct 19, 2023
45c1a57
More clear variable names
GenevieveBuckley Oct 19, 2023
b9c37f0
Add GUI test for embedding widget
GenevieveBuckley Oct 19, 2023
3405039
Thread worker function doesn't actually return object the same way si…
GenevieveBuckley Oct 19, 2023
0314ef2
Fix rgb ndim calculation
GenevieveBuckley Oct 19, 2023
522b27b
Move location where global IMAGE_EMBEDDINGS is defined
GenevieveBuckley Oct 19, 2023
b39d5c9
Order of output files does not matter for test
GenevieveBuckley Oct 23, 2023
174a9c6
Fix ndim for rgb images in embedding_widget
GenevieveBuckley Oct 23, 2023
147797c
Let's be careful since now the prgress bar is an unnamed argument to …
GenevieveBuckley Oct 23, 2023
9b1730f
Match progress bar with thread worker example from napari/examples
GenevieveBuckley Oct 23, 2023
5c573b9
Sanitize user string input to _get_device()
GenevieveBuckley Oct 24, 2023
464d7c3
Merge branch 'dev' into embedding-widget-singleton
GenevieveBuckley Nov 3, 2023
ae9b611
Workaround for issue 246
GenevieveBuckley Nov 3, 2023
c1fe954
Image embedding widget now uses singleton AnnotatorState
GenevieveBuckley Nov 3, 2023
8e75871
Embedding widget, ensure save directory exists and is empty
GenevieveBuckley Nov 3, 2023
1de747e
Don't set a device for custom checkpoint export
constantinpape Nov 4, 2023
1c68c03
More consise code with os.makedirs exist_ok
GenevieveBuckley Nov 5, 2023
856da60
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Nov 5, 2023
bfa8c57
Upgrade invalid embeddings path from user warning to runtime error (s…
GenevieveBuckley Nov 7, 2023
51e96bf
Move all computation into thread worker, allow previously computed em…
GenevieveBuckley Nov 7, 2023
c0c83a4
Add reset_state method to clear all attributes held in state
GenevieveBuckley Nov 7, 2023
a4ad1dc
Remove data_signature attribute from AnnotatorState attributes
GenevieveBuckley Nov 7, 2023
ca3c77e
Embedding widget, image_shape in annotator state
GenevieveBuckley Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.embedding_widget
python_name: micro_sam.sam_annotator._widgets:embedding_widget
title: Embedding widget
sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand All @@ -45,3 +48,6 @@ contributions:
- command: micro-sam.sample_data_segmentation
display_name: Segmentation sample dataset
key: micro-sam-segmentation
widgets:
- command: micro-sam.embedding_widget
display_name: Embedding widget
10 changes: 10 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
init_sum = sum((have_image_embeddings, have_predictor, have_image_shape))
if init_sum == 3:
return True
elif init_sum == 0:
return False

Check warning on line 50 in micro_sam/sam_annotator/_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_state.py#L49-L50

Added lines #L49 - L50 were not covered by tests
else:
raise RuntimeError(

Check warning on line 52 in micro_sam/sam_annotator/_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_state.py#L52

Added line #L52 was not covered by tests
f"Invalid AnnotatorState: {init_sum} / 3 parts of the state "
"needed for interactive segmentation are initialized."
)
Expand All @@ -60,10 +60,20 @@
init_sum = sum((have_current_track_id, have_lineage))
if init_sum == 2:
return True
elif init_sum == 0:
return False

Check warning on line 64 in micro_sam/sam_annotator/_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_state.py#L63-L64

Added lines #L63 - L64 were not covered by tests
else:
raise RuntimeError(

Check warning on line 66 in micro_sam/sam_annotator/_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_state.py#L66

Added line #L66 was not covered by tests
f"Invalid AnnotatorState: {init_sum} / 2 parts of the state "
"needed for tracking are initialized."
)

def reset_state(self):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
"""Reset state, clear all attributes."""
self.image_embeddings = None
self.predictor = None
self.image_shape = None
self.amg = None
self.amg_state = None
self.current_track_id = None
self.lineage = None
84 changes: 84 additions & 0 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from enum import Enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from magicgui import magic_factory, widgets
from napari.qt.threading import thread_worker
import zarr
from zarr.errors import PathNotFoundError

from micro_sam.sam_annotator._state import AnnotatorState
from micro_sam.util import (
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
_DEFAULT_MODEL,
_available_devices,
)

if TYPE_CHECKING:
import napari

Check warning on line 22 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L22

Added line #L22 was not covered by tests

Model = Enum("Model", _MODEL_URLS)
available_devices_list = ["auto"] + _available_devices()


@magic_factory(
pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'},
call_button="Compute image embeddings",
device = {"choices": available_devices_list},
save_path={"mode": "d"}, # choose a directory
)
def embedding_widget(
pbar: widgets.ProgressBar,
image: "napari.layers.Image",
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
"""Image embedding widget."""
state = AnnotatorState()
state.reset_state()
# Get image dimensions
if image.rgb:
ndim = image.data.ndim - 1
state.image_shape = image.data.shape[:-1]

Check warning on line 48 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L47-L48

Added lines #L47 - L48 were not covered by tests
else:
ndim = image.data.ndim
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
state.image_shape = image.data.shape

@thread_worker(connect={'started': pbar.show, 'finished': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None,
device="auto", model=Model.__getitem__(_DEFAULT_MODEL),
optional_custom_weights=None):
# Make sure save directory exists and is an empty directory
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
if not save_path.is_dir():
raise NotADirectoryError(

Check warning on line 61 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L58-L61

Added lines #L58 - L61 were not covered by tests
f"The user selected 'save_path' is not a direcotry: {save_path}"
)
if len(os.listdir(save_path)) > 0:
try:
zarr.open(save_path, "r")
except PathNotFoundError:
raise RuntimeError(

Check warning on line 68 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L64-L68

Added lines #L64 - L68 were not covered by tests
"The user selected 'save_path' is not a zarr array "
f"or empty directory: {save_path}"
)
# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,

Check warning on line 73 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L73

Added line #L73 was not covered by tests
checkpoint_path=optional_custom_weights)
# Compute the image embeddings
state.image_embeddings = precompute_image_embeddings(

Check warning on line 76 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L76

Added line #L76 was not covered by tests
predictor = state.predictor,
input_ = image_data,
save_path = str(save_path),
ndim=ndim,
)
return state # returns napari._qt.qthreading.FunctionWorker

Check warning on line 82 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L82

Added line #L82 was not covered by tests

return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)
39 changes: 33 additions & 6 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@
return checkpoint_path


def _get_device(device):
if device is not None:
return device

def _get_default_device():
# Use cuda enabled gpu if it's available.
if torch.cuda.is_available():
device = "cuda"
Expand All @@ -145,6 +142,36 @@
return device


def _get_device(device=None):
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
if device is None or device == "auto":
device = _get_default_device()
else:
if device.lower() == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("PyTorch CUDA backend is not available.")
elif device.lower() == "mps":
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
elif device.lower() == "cpu":
pass # cpu is always available
else:
raise RuntimeError(f"Unsupported device: {device}\n"

Check warning on line 158 in micro_sam/util.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/util.py#L158

Added line #L158 was not covered by tests
"Please choose from 'cpu', 'cuda', or 'mps'.")
return device


def _available_devices():
available_devices = []
for i in ["cuda", "mps", "cpu"]:
try:
device = _get_device(i)
except RuntimeError:
pass
else:
available_devices.append(device)
return available_devices


def get_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[str] = None,
Expand Down Expand Up @@ -269,7 +296,7 @@
save_path: Where to save the exported model.
"""
_, state = get_custom_sam_model(
checkpoint_path, model_type=model_type, return_state=True, device=torch.device("cpu"),
checkpoint_path, model_type=model_type, return_state=True, device="cpu",
)
model_state = state["model_state"]
prefix = "sam."
Expand Down Expand Up @@ -552,7 +579,7 @@
continue
# check whether the key signature does not match or is not in the file
if key not in f.attrs or f.attrs[key] != val:
warnings.warn(
raise RuntimeError(

Check warning on line 582 in micro_sam/util.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/util.py#L582

Added line #L582 was not covered by tests
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upgrading this from a user warning to an actual error allows users to see a little pop up happen in the bottom of the napari viewer if an error happens. I think this is sufficient for what we need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good! I think it's good to update this to a real error in any case, since having different embeddings will lead to non-sensical results. I think we had a warning there in the beginning because we were still testing this functionality out, and then never got around to updating it to an error.

f"Embeddings file {save_path} is invalid due to unmatching {key}: "
f"{f.attrs.get(key)} != {val}.Please recompute embeddings in a new file."
)
Expand Down
2 changes: 0 additions & 2 deletions micro_sam/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"""
from typing import Tuple

from typing import Tuple
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from elf.segmentation.embeddings import embedding_pca
Expand Down
46 changes: 46 additions & 0 deletions test/test_sam_annotator/test_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import os

from mobile_sam.predictor import SamPredictor as MobileSamPredictor
from segment_anything.predictor import SamPredictor
import torch
import zarr

from micro_sam.sam_annotator._state import AnnotatorState
from micro_sam.sam_annotator._widgets import embedding_widget, Model
from micro_sam.util import _compute_data_signature


# make_napari_viewer is a pytest fixture that returns a napari viewer object
# you don't need to import it, as long as napari is installed
# in your testing environment.
# tmp_path is a regular pytest fixture.
def test_embedding_widget(make_napari_viewer, tmp_path):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
"""Test embedding widget for micro-sam napari plugin."""
# setup
viewer = make_napari_viewer()
layer = viewer.open_sample('napari', 'camera')[0]
my_widget = embedding_widget()
# run image embedding widget
worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path)
worker.await_workers() # blocks until thread worker is finished the embedding
# Check in-memory state - predictor
assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor))
# Check in-memory state - image embeddings
assert AnnotatorState().image_embeddings is not None
assert 'features' in AnnotatorState().image_embeddings.keys()
assert 'input_size' in AnnotatorState().image_embeddings.keys()
assert 'original_size' in AnnotatorState().image_embeddings.keys()
assert isinstance(AnnotatorState().image_embeddings["features"], torch.Tensor)
assert AnnotatorState().image_embeddings["original_size"] == layer.data.shape
# Check saved embedding results are what we expect to have
temp_path_files = os.listdir(tmp_path)
temp_path_files.sort()
assert temp_path_files == ['.zattrs', '.zgroup', 'features']
with open(os.path.join(tmp_path, ".zattrs")) as f:
content = f.read()
zarr_dict = json.loads(content)
assert zarr_dict.get("original_size") == list(layer.data.shape)
assert zarr_dict.get("data_signature") == _compute_data_signature(layer.data)
assert zarr.open(os.path.join(tmp_path, "features")).shape == (1, 256, 64, 64)
viewer.close() # must close the viewer at the end of tests
Loading