Skip to content

Commit

Permalink
added functionality to get node features from segmentation hypotheses
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma committed Oct 18, 2024
1 parent 23e41db commit 7d9092f
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 38 deletions.
2 changes: 1 addition & 1 deletion ultrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# ignoring small float32/64 zero flush warning
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for")

__version__ = "0.6.0"
__version__ = "0.6.1"

from ultrack.config.config import MainConfig, load_config
from ultrack.core.export.ctc import to_ctc
Expand Down
2 changes: 2 additions & 0 deletions ultrack/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ class NodeDB(Base):
y_shift = Column(Float, default=0.0)
x_shift = Column(Float, default=0.0)
area = Column(Integer)
frontier = Column(Float, default=-1.0)
selected = Column(Boolean, default=False)
pickle = Column(MaybePickleType)
features = Column(MaybePickleType, default=None)
node_prob = Column(Float, default=-1.0)
segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN)
node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
Expand Down
19 changes: 18 additions & 1 deletion ultrack/core/segmentation/_test/test_segment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ultrack import segment
from ultrack.config.config import MainConfig
from ultrack.core.database import NodeDB, OverlapDB
from ultrack.core.database import NodeDB, OverlapDB, get_node_values
from ultrack.core.segmentation import get_nodes_features


@pytest.mark.parametrize(
Expand Down Expand Up @@ -43,6 +44,7 @@ def test_multiprocessing_segment(
foreground,
contours,
config_instance,
properties=["centroid"],
)

assert config_instance.data_config.metadata["shape"] == list(contours.shape)
Expand Down Expand Up @@ -83,3 +85,18 @@ def test_multiprocessing_segment(
continue
node_j = nodes[j]
assert node_i.IoU(node_j) == 0.0

feats = get_node_values(config_instance.data_config, i, NodeDB.features)
feats_name = config_instance.data_config.metadata["properties"]

assert len(feats) == len(feats_name)
assert isinstance(feats, np.ndarray)

df = get_nodes_features(config_instance)

centroids_cols = [f"centroid-{i}" for i in range(contours.ndim - 1)]

assert df.shape[0] == len(nodes)
np.testing.assert_array_equal(
df.columns.to_numpy(dtype=str), ["t", "z", "y", "x", "area"] + centroids_cols
)
7 changes: 6 additions & 1 deletion ultrack/core/segmentation/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def from_mask(
time: int,
mask: ArrayLike,
bbox: Optional[ArrayLike] = None,
node_id: int = -1,
) -> "Node":
"""
Create a new node from a mask.
Expand All @@ -293,6 +294,8 @@ def from_mask(
bbox : Optional[ArrayLike], optional
Bounding box of the node, (min_0, min_1, ..., max_0, max_1, ...).
When provided it assumes the mask is a crop of the original image, by default None
node_id : int, optional
Node ID, by default -1
Returns
-------
Expand All @@ -303,12 +306,14 @@ def from_mask(
if mask.dtype != bool:
raise ValueError(f"Mask should be a boolean array. Found {mask.dtype}")

node = Node(h_node_index=-1, time=time, parent=None)
node = Node(h_node_index=-1, id=node_id, time=time, parent=None)

if bbox is None:
bbox = ndi.find_objects(mask)[0]
mask = mask[bbox]

bbox = np.asarray(bbox)

if mask.ndim * 2 != len(bbox):
raise ValueError(
f"Bounding box {bbox} does not match 2x mask ndim {mask.ndim}"
Expand Down
185 changes: 181 additions & 4 deletions ultrack/core/segmentation/processing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import logging
import pickle
from contextlib import nullcontext
from typing import List, Optional
from typing import Callable, List, Optional

import fasteners
import numpy as np
import pandas as pd
import sqlalchemy as sqla
import zarr
from numpy.typing import ArrayLike
from skimage.measure._regionprops import RegionProperties, regionprops_table
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from toolz import curry

from ultrack.config.config import MainConfig, SegmentationConfig
from ultrack.core.database import Base, NodeDB, OverlapDB, clear_all_data
from ultrack.core.database import (
Base,
NodeDB,
OverlapDB,
clear_all_data,
get_node_values,
)
from ultrack.core.segmentation.hierarchy import create_hierarchies
from ultrack.core.segmentation.node import Node
from ultrack.utils.array import check_array_chunk
from ultrack.utils.deprecation import rename_argument
from ultrack.utils.multiprocessing import (
Expand Down Expand Up @@ -77,6 +86,76 @@ def _insert_db(
overlaps.clear()


class _ImageCachedLazyLoader:
"""
Wrapper class to cache dask/zarr data loading for feature computation.
"""

def __init__(self, image: ArrayLike):
self._image = image
self._current_t = -1
self._frame = None

def __getitem__(self, index: int) -> np.ndarray:
if index != self._current_t:
self._frame = np.asarray(self._image[index])
self._current_t = index
return self._frame


def create_feats_callback(
shape: ArrayLike, image: Optional[ArrayLike], properties: List[str]
) -> Callable[[Node], np.ndarray]:
"""
Create a callback function to compute features for each node.
Parameters
----------
shape : ArrayLike
Volume (plane) shape.
image : Optional[ArrayLike]
Image array for segments properties, could have channel dimension on last axis.
properties : List[str]
List of properties to compute for each segment, see skimage.measure.regionprops documentation.
Returns
-------
Callable[[Node], np.ndarray]
Callback function to compute features for each node returning a numpy array.
"""
mask = np.zeros(shape, dtype=bool)

if image is not None:
image = _ImageCachedLazyLoader(image)

def _feats_callback(node: Node) -> np.ndarray:

node.paint_buffer(mask, True, include_time=False)

if image is None:
frame = None
else:
frame = image[node.time]

obj = RegionProperties(
node.slice,
label=True,
label_image=mask,
intensity_image=frame,
cache_active=True,
)

feats = np.concatenate(
[np.ravel(getattr(obj, p)) for p in properties], dtype=np.float32
)

node.paint_buffer(mask, False, include_time=False)

return feats

return _feats_callback


@curry
def _process(
time: int,
Expand All @@ -88,6 +167,8 @@ def _process(
write_lock: Optional[fasteners.InterProcessLock] = None,
catch_duplicates_expection: bool = False,
insertion_throttle_rate: int = 50,
image: Optional[ArrayLike] = None,
properties: Optional[List[str]] = None,
) -> None:
"""Process `foreground` and `edge` of current time and add data to database.
Expand All @@ -111,6 +192,10 @@ def _process(
If True, catches duplicates exception, by default False.
insertion_throttle_rate : int
Throttling rate for insertion, by default 50.
image : Optional[ArrayLike], optional
Image array for segments properties, channel dimension is optional on last axis, by default None.
properties : Optional[List[str]], optional
List of properties to compute for each segment, by default None.
"""
np.random.seed(time)

Expand Down Expand Up @@ -138,6 +223,12 @@ def _process(
nodes = []
overlaps = []

node_feats = None
feats_callback = None

if properties is not None:
feats_callback = create_feats_callback(foreground.shape[1:], image, properties)

for h, hierarchy in enumerate(hiers):
hierarchy.cache = True

Expand All @@ -154,6 +245,9 @@ def _process(
else:
z, y, x = centroid

if feats_callback is not None:
node_feats = feats_callback(hier_node)

node = NodeDB(
id=hier_node.id,
t_node_id=index,
Expand All @@ -163,7 +257,9 @@ def _process(
y=int(y),
x=int(x),
area=int(hier_node.area),
frontier=hier_node.frontier,
pickle=pickle.dumps(hier_node), # pickling to reduce memory usage
features=node_feats,
)

hier_index_map[hier_node._h_node_index] = node
Expand Down Expand Up @@ -229,7 +325,9 @@ def _process(
)
return
else:
raise e
raise ValueError(
"Duplicated nodes found. Set `overwrite=True` to overwrite existing data."
) from e

# pushes any remaning data
with write_lock if write_lock is not None else nullcontext():
Expand All @@ -246,6 +344,40 @@ def _check_zarr_memory_store(arr: ArrayLike) -> None:
)


def _get_properties_names(
shape: ArrayLike,
image: Optional[ArrayLike],
properties: Optional[List[str]],
) -> Optional[List[str]]:
"""
Get properties names from provided properties list.
Parameters
----------
shape : ArrayLike
Volume (plane) shape.
image : Optional[ArrayLike]
Image array for segments properties, could have channel dimension on last axis.
properties : Optional[List[str]]
List of properties to compute for each segment, see skimage.measure.regionprops documentation.
"""

if properties is None:
return None

if image is None:
dummy_image = None
else:
dummy_image = np.ones((4,) * (image.ndim - 1), dtype=np.float32)

dummy_labels = np.zeros((4,) * len(shape), dtype=np.uint32)
dummy_labels[:2, :2] = 1

data_dict = regionprops_table(dummy_labels, dummy_image, properties=properties)

return list(data_dict.keys())


@rename_argument("detection", "foreground")
@rename_argument("edge", "contours")
def segment(
Expand All @@ -256,6 +388,8 @@ def segment(
batch_index: Optional[int] = None,
overwrite: bool = False,
insertion_throttle_rate: int = 50,
image: Optional[ArrayLike] = None,
properties: Optional[List[str]] = None,
) -> None:
"""Add candidate segmentation (nodes) from `foreground` and `edge` to database.
Expand All @@ -276,6 +410,11 @@ def segment(
insertion_throttle_rate : int
Throttling rate for insertion, by default 50.
Only used with non-sqlite databases.
image : Optional[ArrayLike], optional
Image array of shape (T, (Z), Y, X, (C)) for segments properties, by default None.
Channel and Z dimensions are optional.
properties : Optional[List[str]], optional
List of properties to compute for each segment, see skimage.measure.regionprops documentation.
"""
LOG.info(f"Adding nodes with SegmentationConfig:\n{config.segmentation_config}")

Expand Down Expand Up @@ -306,7 +445,14 @@ def segment(
clear_all_data(config.data_config.database_path)

Base.metadata.create_all(engine)
config.data_config.metadata_add({"shape": foreground.shape})
config.data_config.metadata_add(
{
"shape": foreground.shape,
"properties": _get_properties_names(
foreground.shape[1:], image, properties=properties
),
}
)

with multiprocessing_sqlite_lock(config.data_config) as lock:
process = _process(
Expand All @@ -318,6 +464,8 @@ def segment(
max_segments_per_time=max_segments_per_time,
catch_duplicates_expection=batch_index is not None,
insertion_throttle_rate=insertion_throttle_rate,
image=image,
properties=properties,
)

multiprocessing_apply(
Expand All @@ -326,3 +474,32 @@ def segment(
config.segmentation_config.n_workers,
desc="Adding nodes to database",
)


def get_nodes_features(
config: MainConfig,
indices: Optional[ArrayLike] = None,
) -> pd.DataFrame:
"""
Creates a pandas dataframe from nodes features defined during segmentation
plus area and coordinates.
Parameters
----------
config : MainConfig
Configuration parameters.
indices : Optional[ArrayLike], optional
List of node indices, by default
Returns
-------
pd.DataFrame
Dataframe with nodes features
"""
feats_cols = [NodeDB.t, NodeDB.z, NodeDB.y, NodeDB.x, NodeDB.area, NodeDB.features]
df = get_node_values(config.data_config, indices=indices, values=feats_cols)
feat_columns = config.data_config.metadata["properties"]
df.loc[:, feat_columns] = np.vstack(df["features"].to_numpy())
df.drop(columns=["features"], inplace=True)

return df
Loading

0 comments on commit 7d9092f

Please sign in to comment.