Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 3d auto segmentation #48

Merged
merged 4 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion development/annotator_2d_tiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def annotator_with_tiling():
# napari.run()

embedding_path = "./embeddings/embeddings-tiled.zarr"
annotator_2d(im, embedding_path, tile_shape=(512, 512), halo=(64, 64))
annotator_2d(im, embedding_path, tile_shape=(1024, 1024), halo=(256, 256))


def debug():
Expand Down
19 changes: 17 additions & 2 deletions development/annotator_3d_tiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@
def annotator_with_tiling():
with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f:
raw = f["volumes/raw/s0"][:25]

embedding_path = "./embeddings/embeddings-tiled_3d.zarr"
annotator_3d(raw, embedding_path, tile_shape=(512, 512), halo=(64, 64))


def segment_tiled():
import micro_sam.util as util
from micro_sam.segment_instances import segment_instances_from_embeddings_3d

with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f:
raw = f["volumes/raw/s0"][:25]
embedding_path = "./embeddings/embeddings-tiled_3d.zarr"

predictor = util.get_sam_model()
image_embeddings = util.precompute_image_embeddings(
predictor, raw, embedding_path, tile_shape=(512, 512), halo=(64, 64)
)
segment_instances_from_embeddings_3d(predictor, image_embeddings)


def main():
annotator_with_tiling()
# annotator_with_tiling()
segment_tiled()


main()
99 changes: 65 additions & 34 deletions development/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,87 @@
from micro_sam.segment_instances import (
segment_instances_from_embeddings,
segment_instances_sam,
# segment_instances_from_embeddings_with_tiling
segment_instances_from_embeddings_3d,
)
from micro_sam.visualization import compute_pca

INPUT_PATH = "../examples/data/Lucchi++/Test_In"
EMBEDDINGS_PATH = "../examples/embeddings/embeddings-mito2d.zarr"
INPUT_PATH = "/home/pape/Work/data/mouse-embryo/Nuclei/for_sam.h5"
EMBEDDINGS_PATH = "./embeddings/embedding-nuclei-3d.zarr"
TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01"
EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr"


def mito_segmentation() -> None:
"""Performs mito segmentation on the input image."""
def nucleus_segmentation(use_sam=False, use_mws=False) -> None:
"""Segment nuclei in 3d lightsheet data (one slice)."""
with open_file(INPUT_PATH) as f:
raw = f["*.png"][-1, :768, :768]
raw = f["raw"][:]

predictor, sam = util.get_sam_model(return_sam=True)
z = 32

print("Run SAM prediction ...")
seg_sam = segment_instances_sam(sam, raw)
predictor, sam = util.get_sam_model(return_sam=True, model_type="vit_b")
if use_sam:
print("Run SAM prediction ...")
seg_sam = segment_instances_sam(sam, raw[z])
else:
seg_sam = None

image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH)
embedding_pca = compute_pca(image_embeddings["features"])
embedding_pca = compute_pca(image_embeddings["features"])[z]

print("Run prediction from embeddings ...")
seg, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, return_initial_seg=True
)
if use_mws:
print("Run prediction from embeddings ...")
seg, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, return_initial_segmentation=True,
pred_iou_thresh=0.8, verbose=1, stability_score_thresh=0.9,
i=z,
)
else:
seg, initial_seg = None, None

v = napari.Viewer()
v.add_image(raw[z])
v.add_image(embedding_pca, scale=(12, 12), visible=False)
if seg_sam is not None:
v.add_labels(seg_sam)
if seg is not None:
v.add_labels(seg)
if initial_seg is not None:
v.add_labels(initial_seg, visible=False)
napari.run()


def nucleus_segmentation_3d() -> None:
"""Segment nuclei in 3d lightsheet data (3d segmentation)."""
with open_file(INPUT_PATH) as f:
raw = f["raw"][:]

predictor = util.get_sam_model(model_type="vit_b")
image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH)
seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)

v = napari.Viewer()
v.add_image(raw)
v.add_image(embedding_pca, scale=(12, 12))
v.add_labels(seg_sam)
v.add_labels(seg)
v.add_labels(initial_seg)
napari.run()


def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
def cell_segmentation(use_sam=False, use_mws=False) -> None:
"""Performs cell segmentation on the input timeseries."""
with open_file(TIMESERIES_PATH, mode="r") as f:
timeseries = f["*.tif"][:50]

frame = 11
predictor, sam = util.get_sam_model(return_sam=True)

image_embeddings = util.precompute_image_embeddings(
predictor, timeseries, EMBEDDINGS_TRACKING_PATH)
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)

embedding_pca = compute_pca(image_embeddings["features"][frame])

if use_mws:
print("Run embedding segmentation ...")
seg_mws, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True,
bias=0.0, distance_type="l2",
predictor, image_embeddings=image_embeddings, i=frame, return_initial_segmentation=True,
bias=0.0, distance_type="l2", verbose=2,
)
else:
seg_mws = None
Expand All @@ -71,14 +97,6 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
else:
seg_sam = None

if use_tiling:
print("Run embedding segmentation with tiling ...")
seg_tiled = segment_instances_from_embeddings_with_tiling(
predictor, timeseries[frame], image_embeddings
)
else:
seg_tiled = None

v = napari.Viewer()
v.add_image(timeseries[frame])
v.add_image(embedding_pca, scale=(8, 8), visible=False)
Expand All @@ -92,19 +110,32 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
if seg_sam is not None:
v.add_labels(seg_sam)

if seg_tiled is not None:
v.add_labels(seg_tiled)
napari.run()


def cell_segmentation_3d() -> None:
with open_file(TIMESERIES_PATH, mode="r") as f:
timeseries = f["*.tif"][:50]

predictor = util.get_sam_model()
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)

seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)

v = napari.Viewer()
v.add_image(timeseries)
v.add_labels(seg)
napari.run()


def main():
# automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py')
# mito_segmentation()
# nucleus_segmentation(use_mws=True)
nucleus_segmentation_3d()

# automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py')
# cell_segmentation(use_mws=True)
cell_segmentation(use_mws=True)
# cell_segmentation_3d()


if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions micro_sam/segment_from_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _compute_box_from_mask(mask, original_size=None, box_extension=0):

# sample points from a mask. SAM expects the following point inputs:
def _compute_points_from_mask(mask, original_size):
box = _compute_box_from_mask(mask, box_extension=5)
box = _compute_box_from_mask(mask, box_extension=7)

# get slice and offset in python coordinate convention
bb = (slice(box[1], box[3]), slice(box[0], box[2]))
Expand All @@ -45,16 +45,16 @@ def _compute_points_from_mask(mask, original_size):
outer_distances = gaussian(distance_transform_edt(cropped_mask == 0))

# sample positives and negatives from the distance maxima
inner_maxima = peak_local_max(inner_distances, exclude_border=False)
outer_maxima = peak_local_max(outer_distances, exclude_border=False)
inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)

# derive the positive (=inner maxima) and negative (=outer maxima) points
point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
point_coords += offset

if original_size is not None:
scale_factor = np.array([
float(mask.shape[0]) / original_size[0], float(mask.shape[1]) / original_size[1]
original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
])[None]
point_coords *= scale_factor

Expand All @@ -65,7 +65,7 @@ def _compute_points_from_mask(mask, original_size):
np.zeros(len(outer_maxima), dtype="uint8"),
]
)
return point_coords, point_labels
return point_coords[:, ::-1], point_labels


def _compute_logits_from_mask(mask, eps=1e-3):
Expand Down Expand Up @@ -278,11 +278,15 @@ def segment_from_mask(
predictor, image_embeddings, i, mask, _mask_to_tile
)

point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size)
if use_points:
point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size)
else:
point_coords, point_labels = None, None
box = _compute_box_from_mask(mask, original_size=original_size, box_extension=box_extension) if use_box else None
logits = _compute_logits_from_mask(mask) if use_mask else None

mask, scores, logits = predictor.predict(
point_coords=point_coords[:, ::-1], point_labels=point_labels,
point_coords=point_coords, point_labels=point_labels,
mask_input=logits, box=box,
multimask_output=multimask_output, return_logits=return_logits
)
Expand Down
Loading