From 8ab4891f01507721fd6c3f05b6dc9bfcfb1800a1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 20 Jun 2023 12:15:20 +0200 Subject: [PATCH 1/4] Fix issue in segment_from_prompts that would always use point prompts --- development/annotator_2d_tiled.py | 2 +- micro_sam/segment_from_prompts.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/development/annotator_2d_tiled.py b/development/annotator_2d_tiled.py index c27dc846..9ddfa060 100644 --- a/development/annotator_2d_tiled.py +++ b/development/annotator_2d_tiled.py @@ -16,7 +16,7 @@ def annotator_with_tiling(): # napari.run() embedding_path = "./embeddings/embeddings-tiled.zarr" - annotator_2d(im, embedding_path, tile_shape=(512, 512), halo=(64, 64)) + annotator_2d(im, embedding_path, tile_shape=(1024, 1024), halo=(256, 256)) def debug(): diff --git a/micro_sam/segment_from_prompts.py b/micro_sam/segment_from_prompts.py index 44e1d79a..388d7f06 100644 --- a/micro_sam/segment_from_prompts.py +++ b/micro_sam/segment_from_prompts.py @@ -65,7 +65,7 @@ def _compute_points_from_mask(mask, original_size): np.zeros(len(outer_maxima), dtype="uint8"), ] ) - return point_coords, point_labels + return point_coords[:, ::-1], point_labels def _compute_logits_from_mask(mask, eps=1e-3): @@ -278,11 +278,14 @@ def segment_from_mask( predictor, image_embeddings, i, mask, _mask_to_tile ) - point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size) + if use_points: + point_coords, point_labels = _compute_points_from_mask(mask, original_size=original_size) + else: + point_coords, point_labels = None, None box = _compute_box_from_mask(mask, original_size=original_size, box_extension=box_extension) if use_box else None logits = _compute_logits_from_mask(mask) if use_mask else None mask, scores, logits = predictor.predict( - point_coords=point_coords[:, ::-1], point_labels=point_labels, + point_coords=point_coords, point_labels=point_labels, mask_input=logits, box=box, multimask_output=multimask_output, return_logits=return_logits ) From 307d39521b13ea04396ec36cd54851c29423bd5c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 20 Jun 2023 12:36:49 +0200 Subject: [PATCH 2/4] Fix issues in _compute_points_from_mask --- micro_sam/segment_from_prompts.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/micro_sam/segment_from_prompts.py b/micro_sam/segment_from_prompts.py index 388d7f06..cd643e95 100644 --- a/micro_sam/segment_from_prompts.py +++ b/micro_sam/segment_from_prompts.py @@ -33,7 +33,7 @@ def _compute_box_from_mask(mask, original_size=None, box_extension=0): # sample points from a mask. SAM expects the following point inputs: def _compute_points_from_mask(mask, original_size): - box = _compute_box_from_mask(mask, box_extension=5) + box = _compute_box_from_mask(mask, box_extension=7) # get slice and offset in python coordinate convention bb = (slice(box[1], box[3]), slice(box[0], box[2])) @@ -45,8 +45,8 @@ def _compute_points_from_mask(mask, original_size): outer_distances = gaussian(distance_transform_edt(cropped_mask == 0)) # sample positives and negatives from the distance maxima - inner_maxima = peak_local_max(inner_distances, exclude_border=False) - outer_maxima = peak_local_max(outer_distances, exclude_border=False) + inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3) + outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5) # derive the positive (=inner maxima) and negative (=outer maxima) points point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64") @@ -54,7 +54,7 @@ def _compute_points_from_mask(mask, original_size): if original_size is not None: scale_factor = np.array([ - float(mask.shape[0]) / original_size[0], float(mask.shape[1]) / original_size[1] + original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1]) ])[None] point_coords *= scale_factor @@ -284,6 +284,7 @@ def segment_from_mask( point_coords, point_labels = None, None box = _compute_box_from_mask(mask, original_size=original_size, box_extension=box_extension) if use_box else None logits = _compute_logits_from_mask(mask) if use_mask else None + mask, scores, logits = predictor.predict( point_coords=point_coords, point_labels=point_labels, mask_input=logits, box=box, From f6ae43cbc24a5bd6092495f8317b135a618f4d6f Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 20 Jun 2023 19:17:49 +0200 Subject: [PATCH 3/4] Implement 3d instance segmentation --- micro_sam/segment_instances.py | 180 +++++++++++++++++++++++++++------ 1 file changed, 148 insertions(+), 32 deletions(-) diff --git a/micro_sam/segment_instances.py b/micro_sam/segment_instances.py index aae7726a..bb478984 100644 --- a/micro_sam/segment_instances.py +++ b/micro_sam/segment_instances.py @@ -2,8 +2,11 @@ import torch import vigra +from elf.evaluation.matching import label_overlap, intersection_over_union from elf.segmentation import embeddings as embed from elf.segmentation.stitching import stitch_segmentation +from nifty.tools import takeDict +from scipy.optimize import linear_sum_assignment from segment_anything import SamAutomaticMaskGenerator from segment_anything.utils.amg import ( @@ -36,8 +39,8 @@ def _amg_to_seg(masks, shape, with_background): """Convert the output of the automatic mask generation to an instance segmentation.""" masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) - segmentation = np.zeros(shape[:2], dtype="uint32") + for seg_id, mask in enumerate(masks, 1): segmentation[mask["segmentation"]] = seg_id @@ -46,7 +49,7 @@ def _amg_to_seg(masks, shape, with_background): bg_id = seg_ids[np.argmax(sizes)] if bg_id != 0: segmentation[segmentation == bg_id] = 0 - vigra.analysis.relabelConsecutive(segmentation, out=segmentation) + vigra.analysis.relabelConsecutive(segmentation, out=segmentation) return segmentation @@ -68,25 +71,34 @@ def segment_instances_sam(sam, image, with_background=False, **kwargs): # https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py#L266 def _refine_mask( predictor, mask, original_size, - pred_iou_thresh, stability_score_offset, stability_score_thresh, + pred_iou_thresh, stability_score_offset, + stability_score_thresh, seg_id, verbose, ): # Predict masks and store them as mask data masks, iou_preds, _ = segment_from_mask( predictor, mask, original_size=original_size, multimask_output=True, return_logits=True, return_all=True, + use_box=True, use_mask=True, use_points=False, box_extension=4 ) data = MaskData( masks=torch.from_numpy(masks), iou_preds=torch.from_numpy(iou_preds), + seg_id=torch.from_numpy(np.full(len(masks), seg_id, dtype="int64")), ) del masks + n_masks = len(data["masks"]) # Filter by predicted IoU if pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > pred_iou_thresh data.filter(keep_mask) + n_masks_filtered = len(data["masks"]) + if verbose > 2: + print("Masks after IoU filter:", n_masks_filtered, "/", n_masks) + print("IoU Threshold is:", pred_iou_thresh) + # Calculate stability score data["stability_score"] = calculate_stability_score( data["masks"], predictor.model.mask_threshold, stability_score_offset @@ -95,6 +107,11 @@ def _refine_mask( keep_mask = data["stability_score"] >= stability_score_thresh data.filter(keep_mask) + n_masks_filtered_stability = len(data["masks"]) + if verbose > 2: + print("Masks after stability filter:", n_masks_filtered_stability, "/", n_masks_filtered) + print("Stability Threshold is:", stability_score_thresh) + # Threshold masks and calculate boxes data["masks"] = data["masks"] > predictor.model.mask_threshold data["boxes"] = batched_mask_to_box(data["masks"]) @@ -103,31 +120,37 @@ def _refine_mask( data["rles"] = mask_to_rle_pytorch(data["masks"]) del data["masks"] - return data + return data, n_masks - n_masks_filtered, n_masks_filtered - n_masks_filtered_stability def _refine_initial_segmentation( predictor, initial_seg, original_size, box_nms_thresh, with_background, - pred_iou_thresh, stability_score_offset, stability_score_thresh, verbose + pred_iou_thresh, stability_score_offset, + stability_score_thresh, verbose ): masks = MaskData() seg_ids = np.unique(initial_seg) - for seg_id in tqdm(seg_ids, disable=not verbose, desc="Refine masks for automatic instance segmentation"): + n_filtered_total, n_filtered_stability_total = 0, 0 + for seg_id in tqdm(seg_ids, disable=not bool(verbose), desc="Refine masks for automatic instance segmentation"): # refine the segmentations via sam for this prediction mask = (initial_seg == seg_id) assert mask.shape == (256, 256) - mask_data = _refine_mask( + mask_data, n_filtered, n_filtered_stability = _refine_mask( predictor, mask, original_size, - pred_iou_thresh, stability_score_offset, stability_score_thresh + pred_iou_thresh, stability_score_offset, + stability_score_thresh, seg_id, verbose, ) + n_filtered_total += n_filtered + n_filtered_stability_total += n_filtered_stability # append to the mask data masks.cat(mask_data) # apply non-max-suppression to only keep the likely objects + n_masks = len(masks["boxes"]) keep_by_nms = batched_nms( masks["boxes"].float(), masks["iou_preds"], @@ -135,9 +158,18 @@ def _refine_initial_segmentation( iou_threshold=box_nms_thresh, ) masks.filter(keep_by_nms) + n_masks_filtered = len(masks["boxes"]) + + if verbose > 1: + print(n_filtered_total, "masks were filtered out due to the IOU threshold", pred_iou_thresh) + print( + n_filtered_stability_total, "masks were filtered out due to the stability threshold", stability_score_thresh + ) + print(n_masks - n_masks_filtered, "masks were filtered out by nms with threshold", box_nms_thresh) # get the mask output (binary masks and area - masks = [{"segmentation": rle_to_mask(rle), "area": len(rle)} for rle in masks["rles"]] + masks = [{"segmentation": rle_to_mask(rle), "area": len(rle), "seg_id": seg_id} + for rle, seg_id in zip(masks["rles"], masks["seg_id"])] # convert to instance segmentation segmentation = _amg_to_seg(masks, original_size, with_background) @@ -153,7 +185,7 @@ def segment_instances_from_embeddings( stability_score_thresh=0.95, stability_score_offset=1.0, # general settings min_initial_size=10, min_size=0, with_background=False, - verbose=True, return_initial_seg=False, + verbose=1, return_initial_segmentation=False, ): """ """ @@ -161,43 +193,43 @@ def segment_instances_from_embeddings( embeddings = predictor.get_image_embedding().squeeze().cpu().numpy() assert embeddings.shape == (256, 64, 64), f"{embeddings.shape}" - initial_seg = embed.segment_embeddings_mws( - embeddings, distance_type=distance_type, offsets=offsets, bias=bias + + initial_segmentation = embed.segment_embeddings_mws( + embeddings, distance_type=distance_type, offsets=offsets, bias=bias, ).astype("uint32") - assert initial_seg.shape == (64, 64), f"{initial_seg.shape}" + assert initial_segmentation.shape == (64, 64), f"{initial_segmentation.shape}" # filter out small initial objects if min_initial_size > 0: - seg_ids, sizes = np.unique(initial_seg, return_counts=True) - initial_seg[np.isin(initial_seg, seg_ids[sizes < min_initial_size])] = 0 - vigra.analysis.relabelConsecutive(initial_seg, out=initial_seg) + seg_ids, sizes = np.unique(initial_segmentation, return_counts=True) + initial_segmentation[np.isin(initial_segmentation, seg_ids[sizes < min_initial_size])] = 0 # resize to 256 x 256, which is the mask input expected by SAM - initial_seg = resize( - initial_seg, (256, 256), order=0, preserve_range=True, anti_aliasing=False - ).astype(initial_seg.dtype) + initial_segmentation = resize( + initial_segmentation, (256, 256), order=0, preserve_range=True, anti_aliasing=False + ).astype(initial_segmentation.dtype) original_size = image_embeddings["original_size"] - seg = _refine_initial_segmentation( - predictor, initial_seg, original_size, + segmentation = _refine_initial_segmentation( + predictor, initial_segmentation, original_size, box_nms_thresh=box_nms_thresh, with_background=with_background, pred_iou_thresh=pred_iou_thresh, stability_score_offset=stability_score_offset, stability_score_thresh=stability_score_thresh, verbose=verbose, ) if min_size > 0: - seg_ids, counts = np.unique(seg, return_counts=True) - filter_ids = seg_ids[counts < min_size] - seg[np.isin(seg, filter_ids)] = 0 - vigra.analysis.relabelConsecutive(seg, out=seg) - - if return_initial_seg: - initial_seg = resize( - initial_seg, seg.shape, order=0, preserve_range=True, anti_aliasing=False - ).astype(seg.dtype) - return seg, initial_seg + segmentation_ids, counts = np.unique(segmentation, return_counts=True) + filter_ids = segmentation_ids[counts < min_size] + segmentation[np.isin(segmentation, filter_ids)] = 0 + vigra.analysis.relabelConsecutive(segmentation, out=segmentation) + + if return_initial_segmentation: + initial_segmentation = resize( + initial_segmentation, segmentation.shape, order=0, preserve_range=True, anti_aliasing=False + ).astype(segmentation.dtype) + return segmentation, initial_segmentation else: - return seg + return segmentation class FakeInput: @@ -238,3 +270,87 @@ def segment_tile(_, tile_id): input_, segment_tile, tile_shape, halo, with_background=with_background, verbose=verbose ) return segmentation + + +# this is still experimental and not yet ready to be integrated within the annotator_3d +# (will need to see how well it works with retrained models) +def segment_instances_from_embeddings_3d(predictor, image_embeddings, verbose=1, iou_threshold=0.50, **kwargs): + """ + """ + if image_embeddings["original_size"] is None: # tiled embeddings + is_tiled = True + image_shape = tuple(image_embeddings["features"].attrs["shape"]) + n_slices = len(image_embeddings["features"][0]) + + else: # normal embeddings (not tiled) + is_tiled = False + image_shape = tuple(image_embeddings["original_size"]) + n_slices = image_embeddings["features"].shape[0] + + shape = (n_slices,) + image_shape + segmentation_function = segment_instances_from_embeddings_with_tiling if is_tiled else\ + segment_instances_from_embeddings + + segmentation = np.zeros(shape, dtype="uint32") + + def match_segments(seg, prev_seg): + overlap, ignore_idx = label_overlap(seg, prev_seg, ignore_label=0) + scores = intersection_over_union(overlap) + # remove ignore_label (remapped to continuous object_ids) + if ignore_idx[0] is not None: + scores = np.delete(scores, ignore_idx[0], axis=0) + if ignore_idx[1] is not None: + scores = np.delete(scores, ignore_idx[1], axis=1) + + n_matched = min(scores.shape) + no_match = n_matched == 0 or (not np.any(scores >= iou_threshold)) + + max_id = segmentation.max() + if no_match: + seg[seg != 0] += max_id + + else: + # compute optimal matching with scores as tie-breaker + costs = -(scores >= iou_threshold).astype(float) - scores / (2*n_matched) + seg_ind, prev_ind = linear_sum_assignment(costs) + + seg_ids, prev_ids = np.unique(seg)[1:], np.unique(prev_seg)[1:] + match_ok = scores[seg_ind, prev_ind] >= iou_threshold + + id_updates = {0: 0} + matched_ids, matched_prev = seg_ids[seg_ind[match_ok]], prev_ids[prev_ind[match_ok]] + id_updates.update( + {seg_id: prev_id for seg_id, prev_id in zip(matched_ids, matched_prev) if seg_id != 0} + ) + + unmatched_ids = np.setdiff1d(seg_ids, np.concatenate([np.zeros(1, dtype=matched_ids.dtype), matched_ids])) + id_updates.update({seg_id: max_id + i for i, seg_id in enumerate(unmatched_ids, 1)}) + + seg = takeDict(id_updates, seg) + + return seg + + ids_to_slices = {} + # segment the objects starting from slice 0 + for z in tqdm( + range(0, n_slices), total=n_slices, desc="Run instance segmentation in 3d", disable=not bool(verbose) + ): + # TODO set to non verbose once the fix is in new napari version + seg = segmentation_function(predictor, image_embeddings, i=z, verbose=True, **kwargs) + if z > 0: + prev_seg = segmentation[z - 1] + seg = match_segments(seg, prev_seg) + + # keep track of the slices per object id to get rid of unconnected objects in the post-processing + this_ids = np.unique(seg)[1:] + for id_ in this_ids: + ids_to_slices[id_] = ids_to_slices.get(id_, []) + [z] + + segmentation[z] = seg + + # get rid of objects that are just in a single slice + filter_objects = [seg_id for seg_id, slice_list in ids_to_slices.items() if len(slice_list) == 1] + segmentation[np.isin(segmentation, filter_objects)] = 0 + vigra.analysis.relabelConsecutive(segmentation, out=segmentation) + + return segmentation From d9f69f946708428a38135777186287b4db93713e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 20 Jun 2023 19:19:21 +0200 Subject: [PATCH 4/4] Add test scripts for 3d instance segmentation --- development/annotator_3d_tiled.py | 19 +++++- development/instance_segmentation.py | 99 ++++++++++++++++++---------- 2 files changed, 82 insertions(+), 36 deletions(-) diff --git a/development/annotator_3d_tiled.py b/development/annotator_3d_tiled.py index 9430fc8d..12f06859 100644 --- a/development/annotator_3d_tiled.py +++ b/development/annotator_3d_tiled.py @@ -5,13 +5,28 @@ def annotator_with_tiling(): with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f: raw = f["volumes/raw/s0"][:25] - embedding_path = "./embeddings/embeddings-tiled_3d.zarr" annotator_3d(raw, embedding_path, tile_shape=(512, 512), halo=(64, 64)) +def segment_tiled(): + import micro_sam.util as util + from micro_sam.segment_instances import segment_instances_from_embeddings_3d + + with z5py.File("/home/pape/Work/data/cremi/sampleA.n5", "r") as f: + raw = f["volumes/raw/s0"][:25] + embedding_path = "./embeddings/embeddings-tiled_3d.zarr" + + predictor = util.get_sam_model() + image_embeddings = util.precompute_image_embeddings( + predictor, raw, embedding_path, tile_shape=(512, 512), halo=(64, 64) + ) + segment_instances_from_embeddings_3d(predictor, image_embeddings) + + def main(): - annotator_with_tiling() + # annotator_with_tiling() + segment_tiled() main() diff --git a/development/instance_segmentation.py b/development/instance_segmentation.py index 978e732e..42f072ee 100644 --- a/development/instance_segmentation.py +++ b/development/instance_segmentation.py @@ -5,44 +5,71 @@ from micro_sam.segment_instances import ( segment_instances_from_embeddings, segment_instances_sam, - # segment_instances_from_embeddings_with_tiling + segment_instances_from_embeddings_3d, ) from micro_sam.visualization import compute_pca -INPUT_PATH = "../examples/data/Lucchi++/Test_In" -EMBEDDINGS_PATH = "../examples/embeddings/embeddings-mito2d.zarr" +INPUT_PATH = "/home/pape/Work/data/mouse-embryo/Nuclei/for_sam.h5" +EMBEDDINGS_PATH = "./embeddings/embedding-nuclei-3d.zarr" TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01" EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr" -def mito_segmentation() -> None: - """Performs mito segmentation on the input image.""" +def nucleus_segmentation(use_sam=False, use_mws=False) -> None: + """Segment nuclei in 3d lightsheet data (one slice).""" with open_file(INPUT_PATH) as f: - raw = f["*.png"][-1, :768, :768] + raw = f["raw"][:] - predictor, sam = util.get_sam_model(return_sam=True) + z = 32 - print("Run SAM prediction ...") - seg_sam = segment_instances_sam(sam, raw) + predictor, sam = util.get_sam_model(return_sam=True, model_type="vit_b") + if use_sam: + print("Run SAM prediction ...") + seg_sam = segment_instances_sam(sam, raw[z]) + else: + seg_sam = None image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH) - embedding_pca = compute_pca(image_embeddings["features"]) + embedding_pca = compute_pca(image_embeddings["features"])[z] - print("Run prediction from embeddings ...") - seg, initial_seg = segment_instances_from_embeddings( - predictor, image_embeddings=image_embeddings, return_initial_seg=True - ) + if use_mws: + print("Run prediction from embeddings ...") + seg, initial_seg = segment_instances_from_embeddings( + predictor, image_embeddings=image_embeddings, return_initial_segmentation=True, + pred_iou_thresh=0.8, verbose=1, stability_score_thresh=0.9, + i=z, + ) + else: + seg, initial_seg = None, None + + v = napari.Viewer() + v.add_image(raw[z]) + v.add_image(embedding_pca, scale=(12, 12), visible=False) + if seg_sam is not None: + v.add_labels(seg_sam) + if seg is not None: + v.add_labels(seg) + if initial_seg is not None: + v.add_labels(initial_seg, visible=False) + napari.run() + + +def nucleus_segmentation_3d() -> None: + """Segment nuclei in 3d lightsheet data (3d segmentation).""" + with open_file(INPUT_PATH) as f: + raw = f["raw"][:] + + predictor = util.get_sam_model(model_type="vit_b") + image_embeddings = util.precompute_image_embeddings(predictor, raw, EMBEDDINGS_PATH) + seg = segment_instances_from_embeddings_3d(predictor, image_embeddings) v = napari.Viewer() v.add_image(raw) - v.add_image(embedding_pca, scale=(12, 12)) - v.add_labels(seg_sam) v.add_labels(seg) - v.add_labels(initial_seg) napari.run() -def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None: +def cell_segmentation(use_sam=False, use_mws=False) -> None: """Performs cell segmentation on the input timeseries.""" with open_file(TIMESERIES_PATH, mode="r") as f: timeseries = f["*.tif"][:50] @@ -50,16 +77,15 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None: frame = 11 predictor, sam = util.get_sam_model(return_sam=True) - image_embeddings = util.precompute_image_embeddings( - predictor, timeseries, EMBEDDINGS_TRACKING_PATH) + image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH) embedding_pca = compute_pca(image_embeddings["features"][frame]) if use_mws: print("Run embedding segmentation ...") seg_mws, initial_seg = segment_instances_from_embeddings( - predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True, - bias=0.0, distance_type="l2", + predictor, image_embeddings=image_embeddings, i=frame, return_initial_segmentation=True, + bias=0.0, distance_type="l2", verbose=2, ) else: seg_mws = None @@ -71,14 +97,6 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None: else: seg_sam = None - if use_tiling: - print("Run embedding segmentation with tiling ...") - seg_tiled = segment_instances_from_embeddings_with_tiling( - predictor, timeseries[frame], image_embeddings - ) - else: - seg_tiled = None - v = napari.Viewer() v.add_image(timeseries[frame]) v.add_image(embedding_pca, scale=(8, 8), visible=False) @@ -92,19 +110,32 @@ def cell_segmentation(use_sam=False, use_mws=False, use_tiling=False) -> None: if seg_sam is not None: v.add_labels(seg_sam) - if seg_tiled is not None: - v.add_labels(seg_tiled) + napari.run() + +def cell_segmentation_3d() -> None: + with open_file(TIMESERIES_PATH, mode="r") as f: + timeseries = f["*.tif"][:50] + + predictor = util.get_sam_model() + image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH) + + seg = segment_instances_from_embeddings_3d(predictor, image_embeddings) + + v = napari.Viewer() + v.add_image(timeseries) + v.add_labels(seg) napari.run() def main(): # automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py') - # mito_segmentation() + # nucleus_segmentation(use_mws=True) + nucleus_segmentation_3d() # automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py') # cell_segmentation(use_mws=True) - cell_segmentation(use_mws=True) + # cell_segmentation_3d() if __name__ == "__main__":