Skip to content

Commit

Permalink
Reducing memory usage (#17)
Browse files Browse the repository at this point in the history
* fixed deep learning modules utilities imports

* optimized hierarchy segmentation memory usage

* deleting viewer after usage

* updated README doc section

* fixing plantseg/microsam test skip

* updated examples documentation and dependencies

* updated examples
  • Loading branch information
JoOkuma authored Sep 29, 2023
1 parent aa47dae commit 883c7d3
Show file tree
Hide file tree
Showing 22 changed files with 1,680 additions and 506 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ Usage examples can be found [here](examples), including their environment files

## Documentation

The `ultrack` library relies on a configuration schema, its description is [here](ultrack/config/README.md).
The official documentation is available [here](https://royerlab.github.io/ultrack/).

The segmentation and tracking data are stored in an SQL database, described [here](ultrack/core/README.md).
These additional developer documentation are available:

Helper functions to export to the cell tracking challenge and napari formats are available [here](ultrack/core/export).
- Parameter [configuration schema](ultrack/config/README.md).
- Intermediate segmentation and tracking SQL database are [here](ultrack/core/README.md).

## Citing

Expand Down
3 changes: 2 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ conda activate <your new env>

The existing examples are:

- [multi_color_ensemble](./multi_color_ensemble): Multi-colored cytoplasm cell tracking using Cellpose and Watershed segmentation ensemble. Data provided by [The Lammerding Lab ](https://lammerding.wicmb.cornell.edu/).
- [flow_field_3d](./flow_field_3d): Tracking demo on a cartographic projection of Tribolium Castaneum embryo from the [cell-tracking challenge](http://celltrackingchallenge.net/3d-datasets/), using a flow field estimation to assist tracking of motile cells.
- [stardist_2d](./stardist_2d): Tracking demo on HeLa GPF nuclei from the [cell-tracking challenge](http://celltrackingchallenge.net/2d-datasets/) using Stardist 2D fluorescence images pre-trained model.
- [zebrahub](./zebrahub/): Tracking demo on zebrafish tail data from the [zebrahub](https://zebrahub.ds.czbiohub.org/) acquired with [DaXi](https://www.nature.com/articles/s41592-022-01417-2) using Ultrack's image processing helper functions.
- [neuromast_plantseg](./neuromast_plantseg/): Tracking demo membrane-labeled zebrafish neuromast from the [TODO](TODO) using [PlantSeg's](https://github.com/hci-unihd/plant-seg) membrane detection model.
- [neuromast_plantseg](./neuromast_plantseg/): Tracking demo membrane-labeled zebrafish neuromast from [Jacobo Group of CZ Biohub](https://www.czbiohub.org/jacobo/) using [PlantSeg's](https://github.com/hci-unihd/plant-seg) membrane detection model.
- [micro_sam](./micro_sam/): Tracking demo with [MicroSAM](https://github.com/computational-cell-analytics/micro-sam) instance segmentation package.


Expand Down
765 changes: 435 additions & 330 deletions examples/flow_field_3d/tribolium_cartograph.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/multi_color_ensemble/environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- napari==0.4.18
- traccuracy
- napari-arboretum
- pyift
1 change: 1 addition & 0 deletions examples/multi_color_ensemble/environment_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- napari==0.4.18
- traccuracy
- napari-arboretum
- pyift

Large diffs are not rendered by default.

122 changes: 86 additions & 36 deletions examples/neuromast_plantseg/neuromast_plantseg.ipynb

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions examples/refresh_examples.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#! /bin/bash

# Terminates if any command fails
set -e

# subject to change depending on your conda setup
CONDA_SETUP_PATH=$HOME/miniconda3/etc/profile.d/conda.sh
UPDATE_JUPYTER="jupyter nbconvert --execute --to notebook --inplace"
Expand All @@ -12,10 +15,6 @@ function install () {
pip install -e ..
}

# stardist
install ultrack-stardist stardist_2d
$UPDATE_JUPYTER stardist_2d/2d_tracking.ipynb

# multi color
install ultrack-multi-color multi_color_ensemble
$UPDATE_JUPYTER multi_color_ensemble/multi_color_ensemble.ipynb
Expand All @@ -35,3 +34,7 @@ $UPDATE_JUPYTER neuromast_plantseg/neuromast_plantseg.ipynb
# micro-sam
install ultrack-micro-sam micro_sam
$UPDATE_JUPYTER micro_sam/micro_sam_tracking.ipynb

# stardist
install ultrack-stardist stardist_2d
$UPDATE_JUPYTER stardist_2d/2d_tracking.ipynb
1 change: 1 addition & 0 deletions examples/stardist_2d/environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- gurobi
- jupyter
- pip
- pytorch
- zarr
- pip:
- pyqt5
Expand Down
1 change: 1 addition & 0 deletions examples/stardist_2d/environment_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- coin-or-cbc
- gurobi
- jupyter
- pytorch
- pip
- zarr
- pip:
Expand Down
158 changes: 103 additions & 55 deletions examples/zebrahub/zebrahub.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ultrack/cli/estimate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def estimate_params_cli(
except (KeyError, ValueError):
labels = viewer.layers[layer_key].data

del viewer

df = estimate_parameters_from_labels(labels, is_timelapse=timelapse)

covariables = {"area", "distance"}
Expand Down
1 change: 1 addition & 0 deletions ultrack/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def zarr_napari_cli(
layer.data[0] if layer.multiscale else layer.data
for layer in viewer.open(image_path, plugin=reader_plugin)
]
del viewer

tracks_w_measures = tracks_properties(
segments=segments,
Expand Down
1 change: 1 addition & 0 deletions ultrack/cli/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def add_flow_cli(
layer.data[0] if layer.multiscale else layer.data
for layer in viewer.open(paths, channel_axis=channel_axis, plugin=reader_plugin)
]
del viewer

add_flow(
config,
Expand Down
5 changes: 4 additions & 1 deletion ultrack/cli/labels_to_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def labels_to_edges_cli(
viewer = ViewerModel()
viewer.open(path=paths, plugin=reader_plugin)

labels = [layer.data for layer in viewer.layers]
del viewer

labels_to_edges(
[layer.data for layer in viewer.layers],
labels,
sigma=sigma,
detection_store_or_path=detection_path,
edges_store_or_path=edges_path,
Expand Down
1 change: 1 addition & 0 deletions ultrack/cli/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def link_cli(
layer.data[0] if layer.multiscale else layer.data
for layer in viewer.open(paths, **kwargs, plugin=reader_plugin)
]
del viewer

link(
config,
Expand Down
2 changes: 2 additions & 0 deletions ultrack/cli/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def segmentation_cli(
{"scale": viewer.layers[edge_layer].scale.tolist()}
)

del viewer

segment(
detection,
edge,
Expand Down
23 changes: 11 additions & 12 deletions ultrack/core/segmentation/hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import List
from typing import Iterator

import numpy as np
import scipy.ndimage as ndi
from numpy.typing import ArrayLike
from skimage import measure, morphology
from skimage.measure._regionprops import RegionProperties
Expand Down Expand Up @@ -33,7 +34,7 @@ def create_hierarchies(
binary_detection: ArrayLike,
edge: ArrayLike,
**kwargs,
) -> List[Hierarchy]:
) -> Iterator[Hierarchy]:
"""Computes a collection of hierarchical watersheds inside `binary_detection` mask.
Parameters
Expand All @@ -46,26 +47,25 @@ def create_hierarchies(
Returns
-------
List[Hierarchy]
Iterator[Hierarchy]
List of hierarchical watersheds.
"""
binary_detection = np.asarray(binary_detection)
edge = np.asarray(edge)

assert (
issubclass(binary_detection.dtype.type, np.integer)
or binary_detection.dtype == bool
)

LOG.info("Labeling connected components.")
labels, num_labels = measure.label(
binary_detection, return_num=True, connectivity=1
)
labels = labels.astype(np.int32)
labels, num_labels = ndi.label(binary_detection, output=np.int32)
del binary_detection

if "min_area" in kwargs and num_labels > 1:
LOG.info("Filtering small connected components.")
labels = morphology.remove_small_objects(labels, min_size=kwargs["min_area"])
morphology.remove_small_objects(labels, min_size=kwargs["min_area"], out=labels)

edge = np.asarray(edge)

if "max_area" in kwargs:
LOG.info("Oversegmenting connected components.")
Expand All @@ -77,6 +77,5 @@ def create_hierarchies(
)

LOG.info("Creating hierarchies (lazy).")
return [
Hierarchy(c, **kwargs) for c in measure.regionprops(labels, edge, cache=True)
]
for c in measure.regionprops(labels, edge, cache=True):
yield Hierarchy(c, **kwargs)
16 changes: 10 additions & 6 deletions ultrack/imgproc/_test/test_plantseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

plantseg = pytest.importorskip("ultrack.imgproc.plantseg")
from ultrack.imgproc.plantseg import PlantSeg


@pytest.mark.parametrize(
Expand All @@ -19,9 +19,13 @@ def test_plantseg(

image = np.random.rand(100, 100, 100)

seg_model = plantseg.PlantSeg(
model_name="generic_light_sheet_3D_unet",
batch_size=1,
patch=image.shape,
)
try:
seg_model = PlantSeg(
model_name="generic_light_sheet_3D_unet",
batch_size=1,
patch=image.shape,
)
except ModuleNotFoundError:
pytest.skip("PlantSeg not installed")

seg_model(image, transpose=transpose)
9 changes: 7 additions & 2 deletions ultrack/imgproc/_test/test_sam.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import numpy as np
import pytest

sam = pytest.importorskip("ultrack.imgproc.sam")
from ultrack.imgproc import sam


def test_sam() -> None:
image = np.random.rand(100, 100)
seg_model = sam.MicroSAM()

try:
seg_model = sam.MicroSAM()
except ModuleNotFoundError:
pytest.skip("MicroSAM not installed")

seg_model(image)

# with maxima prompt
Expand Down
15 changes: 8 additions & 7 deletions ultrack/imgproc/plantseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
import numpy as np
import torch as th
from numpy.typing import ArrayLike
from plantseg.dataprocessing import fix_input_shape
from plantseg.predictions.functional.array_predictor import ArrayPredictor
from plantseg.predictions.functional.utils import (
get_array_dataset,
get_model_config,
get_patch_halo,
)

from ultrack.utils.cuda import import_module, to_cpu, torch_default_device

Expand Down Expand Up @@ -67,6 +60,12 @@ def __init__(
"""
Initialized Plant-Seg model.
"""
from plantseg.predictions.functional.array_predictor import ArrayPredictor
from plantseg.predictions.functional.utils import (
get_model_config,
get_patch_halo,
)

if device is None:
device = torch_default_device()

Expand Down Expand Up @@ -121,6 +120,8 @@ def __call__(
np.ndarray
Segmentation boundary probability map as a numpy array.
"""
from plantseg.dataprocessing import fix_input_shape
from plantseg.predictions.functional.utils import get_array_dataset

if isinstance(image, da.Array):
# avoiding building a large compute graph
Expand Down
7 changes: 4 additions & 3 deletions ultrack/imgproc/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import numpy as np
import scipy.ndimage as ndi
import torch as th
from micro_sam import instance_segmentation, util
from numpy.typing import ArrayLike
from segment_anything.utils.amg import area_from_rle, rle_to_mask
from skimage.feature import peak_local_max
from skimage.segmentation import find_boundaries

Expand Down Expand Up @@ -92,6 +90,7 @@ def __init__(
tile_shape: Tuple[int, int] = (512, 512),
halo_shape: Tuple[int, int] = (128, 128),
) -> None:
from micro_sam import instance_segmentation, util

if device is None:
device = torch_default_device()
Expand Down Expand Up @@ -148,6 +147,8 @@ def __call__(self, image: ArrayLike) -> np.ndarray:
np.ndarray
The processed image with contours derived from the identified masks.
"""
from micro_sam.util import precompute_image_embeddings
from segment_anything.utils.amg import area_from_rle, rle_to_mask

image = np.asarray(image)

Expand All @@ -159,7 +160,7 @@ def __call__(self, image: ArrayLike) -> np.ndarray:
if embedding_path.exists():
shutil.rmtree(embedding_path)

embeddings = util.precompute_image_embeddings(
embeddings = precompute_image_embeddings(
self._predictor,
image,
save_path=str(embedding_path),
Expand Down

0 comments on commit 883c7d3

Please sign in to comment.