Skip to content

Commit

Permalink
Added functionality to get node features from segmentation hypotheses (
Browse files Browse the repository at this point in the history
…#150)

* added functionality to get node features from segmentation hypotheses

* adding example

* add missing import
  • Loading branch information
JoOkuma authored Oct 21, 2024
1 parent 2b73f41 commit 0ecaf1b
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 30 deletions.
101 changes: 101 additions & 0 deletions examples/node_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
This example how request for each segmentation hypotheses features and use them
to compute a custom edge weight between nodes, in this case, the cosine distance.
For this we consider a 3D image as a 2D video, it's more a didactic example than a real use case.
"""
import napari
import numpy as np
from scipy.spatial.distance import cdist
from skimage import morphology as morph
from skimage.data import cells3d

from ultrack import MainConfig, Tracker


def main() -> None:

config = MainConfig()
config.data_config.working_dir = "/tmp/ultrack/."

# removing small segments
config.segmentation_config.min_area = 1_000
# disable division
config.tracking_config.division_weight = -1_000_000

tracker = Tracker(config)

# mocking a 3D image as 2D video
image = cells3d()[:, 1] # nuclei

# simple foreground extraction
foreground = image > image.mean()
foreground = morph.opening(foreground, morph.disk(3)[None, :])
foreground = morph.closing(foreground, morph.disk(3)[None, :])

# contour as inverse of the image
contour = 1 - image / image.max()

tracker.segment(
foreground=foreground,
contours=contour,
image=image,
properties=["equivalent_diameter_area", "intensity_mean", "inertia_tensor"],
overwrite=True,
)

df = tracker.get_nodes_features()

# extending properties, some include -0-0, -0-1, -1-0, -1-1
cols = [
"y",
"x",
"area",
"intensity_mean",
"inertia_tensor-0-0",
"inertia_tensor-0-1",
"inertia_tensor-1-0",
"inertia_tensor-1-1",
"equivalent_diameter_area",
]

# normalizing features
df[cols] -= df[cols].mean()
df[cols] /= df[cols].std()

df_by_t = df.groupby("t")
t_max = df["t"].max()

# iterating over time and querying pair of frames
for t in range(t_max + 1):
try:
# some frames might be without nodes
source_df = df_by_t.get_group(t)
target_df = df_by_t.get_group(t + 1)
except KeyError:
continue

# the higher the weights the more likely the link
weights = 1 - cdist(source_df[cols], target_df[cols], "cosine").ravel()

source_ids = np.repeat(source_df.index.to_numpy(), len(target_df))
target_ids = np.tile(target_df.index.to_numpy(), len(source_df))

# for very dense graph this not recommended because the ILP problem will be huge
tracker.add_links(sources=source_ids, targets=target_ids, weights=weights)

tracker.solve()

tracks, graph = tracker.to_tracks_layer()
segments = tracker.to_zarr()

viewer = napari.Viewer()
viewer.add_image(image, name="cells")
viewer.add_tracks(tracks[["track_id", "t", "y", "x"]], name="tracks", graph=graph)
viewer.add_labels(segments, name="segments")

napari.run()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion ultrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ignoring small float32/64 zero flush warning
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for")

__version__ = "0.6.0"
__version__ = "0.6.1"

from ultrack.config.config import MainConfig, load_config
from ultrack.core.export.ctc import to_ctc
Expand Down
2 changes: 2 additions & 0 deletions ultrack/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ class NodeDB(Base):
y_shift = Column(Float, default=0.0)
x_shift = Column(Float, default=0.0)
area = Column(Integer)
frontier = Column(Float, default=-1.0)
selected = Column(Boolean, default=False)
pickle = Column(MaybePickleType)
features = Column(MaybePickleType, default=None)
node_prob = Column(Float, default=-1.0)
segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN)
node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
Expand Down
1 change: 1 addition & 0 deletions ultrack/core/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ultrack.core.segmentation.processing import get_nodes_features
19 changes: 18 additions & 1 deletion ultrack/core/segmentation/_test/test_segment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ultrack import segment
from ultrack.config.config import MainConfig
from ultrack.core.database import NodeDB, OverlapDB
from ultrack.core.database import NodeDB, OverlapDB, get_node_values
from ultrack.core.segmentation import get_nodes_features


@pytest.mark.parametrize(
Expand Down Expand Up @@ -43,6 +44,7 @@ def test_multiprocessing_segment(
foreground,
contours,
config_instance,
properties=["centroid"],
)

assert config_instance.data_config.metadata["shape"] == list(contours.shape)
Expand Down Expand Up @@ -83,3 +85,18 @@ def test_multiprocessing_segment(
continue
node_j = nodes[j]
assert node_i.IoU(node_j) == 0.0

feats = get_node_values(config_instance.data_config, i, NodeDB.features)
feats_name = config_instance.data_config.metadata["properties"]

assert len(feats) == len(feats_name)
assert isinstance(feats, np.ndarray)

df = get_nodes_features(config_instance)

centroids_cols = [f"centroid-{i}" for i in range(contours.ndim - 1)]

assert df.shape[0] == len(nodes)
np.testing.assert_array_equal(
df.columns.to_numpy(dtype=str), ["t", "z", "y", "x", "area"] + centroids_cols
)
7 changes: 6 additions & 1 deletion ultrack/core/segmentation/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def from_mask(
time: int,
mask: ArrayLike,
bbox: Optional[ArrayLike] = None,
node_id: int = -1,
) -> "Node":
"""
Create a new node from a mask.
Expand All @@ -293,6 +294,8 @@ def from_mask(
bbox : Optional[ArrayLike], optional
Bounding box of the node, (min_0, min_1, ..., max_0, max_1, ...).
When provided it assumes the mask is a crop of the original image, by default None
node_id : int, optional
Node ID, by default -1
Returns
-------
Expand All @@ -303,12 +306,14 @@ def from_mask(
if mask.dtype != bool:
raise ValueError(f"Mask should be a boolean array. Found {mask.dtype}")

node = Node(h_node_index=-1, time=time, parent=None)
node = Node(h_node_index=-1, id=node_id, time=time, parent=None)

if bbox is None:
bbox = ndi.find_objects(mask)[0]
mask = mask[bbox]

bbox = np.asarray(bbox)

if mask.ndim * 2 != len(bbox):
raise ValueError(
f"Bounding box {bbox} does not match 2x mask ndim {mask.ndim}"
Expand Down
Loading

0 comments on commit 0ecaf1b

Please sign in to comment.