Skip to content

Commit

Permalink
Merge pull request #48 from computational-cell-analytics/auto-seg-3d
Browse files Browse the repository at this point in the history
Implement 3d auto segmentation
  • Loading branch information
constantinpape authored Jun 20, 2023
2 parents 26b8191 + d9f69f9 commit ca2f590
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 76 deletions.
2 changes: 1 addition & 1 deletion development/annotator_2d_tiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def annotator_with_tiling():
# napari.run()

embedding_path = "./embeddings/embeddings-tiled.zarr"
annotator_2d(im, embedding_path, tile_shape=(512, 512), halo=(64, 64))
annotator_2d(im, embedding_path, tile_shape=(1024, 1024), halo=(256, 256))


def debug():
Expand Down
19 changes: 17 additions & 2 deletions development/annotator_3d_tiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@
def annotator_with_tiling():
with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f:
raw = f["volumes/raw/s0"][:25]

embedding_path = "./embeddings/embeddings-tiled_3d.zarr"
annotator_3d(raw, embedding_path, tile_shape=(512, 512), halo=(64, 64))


def segment_tiled():
import micro_sam.util as util
from micro_sam.segment_instances import segment_instances_from_embeddings_3d

with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f:
raw = f["volumes/raw/s0"][:25]
embedding_path = "./embeddings/embeddings-tiled_3d.zarr"

predictor = util.get_sam_model()
image_embeddings = util.precompute_image_embeddings(
predictor, raw, embedding_path, tile_shape=(512, 512), halo=(64, 64)
)
segment_instances_from_embeddings_3d(predictor, image_embeddings)


def main():
annotator_with_tiling()
# annotator_with_tiling()
segment_tiled()


main()
99 changes: 65 additions & 34 deletions development/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,87 @@
from micro_sam.segment_instances import (
segment_instances_from_embeddings,
segment_instances_sam,
# segment_instances_from_embeddings_with_tiling
segment_instances_from_embeddings_3d,
)
from micro_sam.visualization import compute_pca

INPUT_PATH = "../examples/data/Lucchi++/Test_In"
EMBEDDINGS_PATH = "../examples/embeddings/embeddings-mito2d.zarr"
INPUT_PATH = "/home/pape/Work/data/mouse-embryo/Nuclei/for_sam.h5"
EMBEDDINGS_PATH = "./embeddings/embedding-nuclei-3d.zarr"
TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01"
EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr"


def mito_segmentation() -> None:
"""Performs mito segmentation on the input image."""
def nucleus_segmentation(use_sam=False, use_mws=False) -> None:
"""Segment nuclei in 3d lightsheet data (one slice)."""
with open_file(INPUT_PATH) as f:
raw = f["*.png"][-1, :768, :768]
raw = f["raw"][:]

predictor, sam = util.get_sam_model(return_sam=True)
z = 32

print("Run SAM prediction ...")
seg_sam = segment_instances_sam(sam, raw)
predictor, sam = util.get_sam_model(return_sam=True, model_type="vit_b")
if use_sam:
print("Run SAM prediction ...")
seg_sam = segment_instances_sam(sam, raw[z])
else:
seg_sam = None

image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH)
embedding_pca = compute_pca(image_embeddings["features"])
embedding_pca = compute_pca(image_embeddings["features"])[z]

print("Run prediction from embeddings ...")
seg, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, return_initial_seg=True
)
if use_mws:
print("Run prediction from embeddings ...")
seg, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, return_initial_segmentation=True,
pred_iou_thresh=0.8, verbose=1, stability_score_thresh=0.9,
i=z,
)
else:
seg, initial_seg = None, None

v = napari.Viewer()
v.add_image(raw[z])
v.add_image(embedding_pca, scale=(12, 12), visible=False)
if seg_sam is not None:
v.add_labels(seg_sam)
if seg is not None:
v.add_labels(seg)
if initial_seg is not None:
v.add_labels(initial_seg, visible=False)
napari.run()


def nucleus_segmentation_3d() -> None:
"""Segment nuclei in 3d lightsheet data (3d segmentation)."""
with open_file(INPUT_PATH) as f:
raw = f["raw"][:]

predictor = util.get_sam_model(model_type="vit_b")
image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH)
seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)

v = napari.Viewer()
v.add_image(raw)
v.add_image(embedding_pca, scale=(12, 12))
v.add_labels(seg_sam)
v.add_labels(seg)
v.add_labels(initial_seg)
napari.run()


def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
def cell_segmentation(use_sam=False, use_mws=False) -> None:
"""Performs cell segmentation on the input timeseries."""
with open_file(TIMESERIES_PATH, mode="r") as f:
timeseries = f["*.tif"][:50]

frame = 11
predictor, sam = util.get_sam_model(return_sam=True)

image_embeddings = util.precompute_image_embeddings(
predictor, timeseries, EMBEDDINGS_TRACKING_PATH)
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)

embedding_pca = compute_pca(image_embeddings["features"][frame])

if use_mws:
print("Run embedding segmentation ...")
seg_mws, initial_seg = segment_instances_from_embeddings(
predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True,
bias=0.0, distance_type="l2",
predictor, image_embeddings=image_embeddings, i=frame, return_initial_segmentation=True,
bias=0.0, distance_type="l2", verbose=2,
)
else:
seg_mws = None
Expand All @@ -71,14 +97,6 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
else:
seg_sam = None

if use_tiling:
print("Run embedding segmentation with tiling ...")
seg_tiled = segment_instances_from_embeddings_with_tiling(
predictor, timeseries[frame], image_embeddings
)
else:
seg_tiled = None

v = napari.Viewer()
v.add_image(timeseries[frame])
v.add_image(embedding_pca, scale=(8, 8), visible=False)
Expand All @@ -92,19 +110,32 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None:
if seg_sam is not None:
v.add_labels(seg_sam)

if seg_tiled is not None:
v.add_labels(seg_tiled)
napari.run()


def cell_segmentation_3d() -> None:
with open_file(TIMESERIES_PATH, mode="r") as f:
timeseries = f["*.tif"][:50]

predictor = util.get_sam_model()
image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH)

seg = segment_instances_from_embeddings_3d(predictor, image_embeddings)

v = napari.Viewer()
v.add_image(timeseries)
v.add_labels(seg)
napari.run()


def main():
# automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py')
# mito_segmentation()
# nucleus_segmentation(use_mws=True)
nucleus_segmentation_3d()

# automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py')
# cell_segmentation(use_mws=True)
cell_segmentation(use_mws=True)
# cell_segmentation_3d()


if __name__ == "__main__":
Expand Down
18 changes: 11 additions & 7 deletions micro_sam/segment_from_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _compute_box_from_mask(mask, original_size=None, box_extension=0):

# sample points from a mask. SAM expects the following point inputs:
def _compute_points_from_mask(mask, original_size):
box = _compute_box_from_mask(mask, box_extension=5)
box = _compute_box_from_mask(mask, box_extension=7)

# get slice and offset in python coordinate convention
bb = (slice(box[1], box[3]), slice(box[0], box[2]))
Expand All @@ -45,16 +45,16 @@ def _compute_points_from_mask(mask, original_size):
outer_distances = gaussian(distance_transform_edt(cropped_mask == 0))

# sample positives and negatives from the distance maxima
inner_maxima = peak_local_max(inner_distances, exclude_border=False)
outer_maxima = peak_local_max(outer_distances, exclude_border=False)
inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)

# derive the positive (=inner maxima) and negative (=outer maxima) points
point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
point_coords += offset

if original_size is not None:
scale_factor = np.array([
float(mask.shape[0]) / original_size[0], float(mask.shape[1]) / original_size[1]
original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
])[None]
point_coords *= scale_factor

Expand All @@ -65,7 +65,7 @@ def _compute_points_from_mask(mask, original_size):
np.zeros(len(outer_maxima), dtype="uint8"),
]
)
return point_coords, point_labels
return point_coords[:, ::-1], point_labels


def _compute_logits_from_mask(mask, eps=1e-3):
Expand Down Expand Up @@ -278,11 +278,15 @@ def segment_from_mask(
predictor, image_embeddings, i, mask, _mask_to_tile
)

point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size)
if use_points:
point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size)
else:
point_coords, point_labels = None, None
box = _compute_box_from_mask(mask, original_size=original_size, box_extension=box_extension) if use_box else None
logits = _compute_logits_from_mask(mask) if use_mask else None

mask, scores, logits = predictor.predict(
point_coords=point_coords[:, ::-1], point_labels=point_labels,
point_coords=point_coords, point_labels=point_labels,
mask_input=logits, box=box,
multimask_output=multimask_output, return_logits=return_logits
)
Expand Down
Loading

0 comments on commit ca2f590

Please sign in to comment.