From 9233b96433aadc4b5a4c3f7225a593dbb7de2693 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 14 Jun 2024 11:56:12 -0700 Subject: [PATCH 01/45] added save_config function --- ultrack/config/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ultrack/config/config.py b/ultrack/config/config.py index a8bed2c..3b3123b 100644 --- a/ultrack/config/config.py +++ b/ultrack/config/config.py @@ -44,3 +44,8 @@ def load_config(path: Union[str, Path]) -> MainConfig: data = toml.load(f) LOG.info(data) return MainConfig.parse_obj(data) + +def save_config(config, path: Union[str, Path]): + """Saved MainConfig to TOML file.""" + with open(path,mode='w') as f: + toml.dump(config.dict(by_alias=True),f) \ No newline at end of file From 094b3e07f9c98ecd6ef3770a6214045a40dbd5de Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 14 Jun 2024 19:19:22 -0700 Subject: [PATCH 02/45] added volume attribute to node --- ultrack/core/segmentation/node.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ultrack/core/segmentation/node.py b/ultrack/core/segmentation/node.py index c7a1d8a..ca2c5d5 100644 --- a/ultrack/core/segmentation/node.py +++ b/ultrack/core/segmentation/node.py @@ -210,6 +210,10 @@ def _centroid(self) -> np.ndarray: ) return centroid.round().astype(int) + def _volume(self) -> int: + volume = self.mask.sum() + return volume.astype(int) + def __eq__(self, other: object) -> bool: if not isinstance(other, Node): return False From bad1d878c2b317ff5e7d4b85c8a104b848b4b5ff Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 21 Jun 2024 15:46:24 -0700 Subject: [PATCH 03/45] close_tracks_gaps now works for segments being a dask_array --- ultrack/utils/segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultrack/utils/segmentation.py b/ultrack/utils/segmentation.py index e34c3f1..26e895f 100644 --- a/ultrack/utils/segmentation.py +++ b/ultrack/utils/segmentation.py @@ -262,6 +262,6 @@ def copy_segments( ) # not sure why this is necessary in large datasets else: for t in tqdm(range(segments.shape[0]), "Copying segments"): - out_segments[t] = segments[t] + out_segments[t] = np.asarray(segments[t]) return out_segments From 777404a9c902ea2fbfb3ccb7e2cfec7b7c5485fe Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 21 Jun 2024 16:09:21 -0700 Subject: [PATCH 04/45] removed node._volume method, since precomputed area already exists --- ultrack/core/segmentation/node.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ultrack/core/segmentation/node.py b/ultrack/core/segmentation/node.py index ca2c5d5..c7a1d8a 100644 --- a/ultrack/core/segmentation/node.py +++ b/ultrack/core/segmentation/node.py @@ -210,10 +210,6 @@ def _centroid(self) -> np.ndarray: ) return centroid.round().astype(int) - def _volume(self) -> int: - volume = self.mask.sum() - return volume.astype(int) - def __eq__(self, other: object) -> bool: if not isinstance(other, Node): return False From d612e039fe6db081d5d1cfc8647e10eba34440ad Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 11 Jul 2024 11:57:33 -0700 Subject: [PATCH 05/45] added UltrackArray class to utils/array.py --- ultrack/utils/array.py | 122 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index c17884a..049f6e2 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -10,6 +10,9 @@ from numpy.typing import ArrayLike from tqdm import tqdm from zarr.storage import Store +import sqlalchemy as sqla +from sqlalchemy.orm import Session +from ultrack.core.database import NodeDB LOG = logging.getLogger(__name__) @@ -193,3 +196,122 @@ def create_zarr( chunks = large_chunk_size(shape, dtype=dtype) return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) + +class UltrackArray: + def __init__(self, + config, + dtype: np.dtype = np.int32, + ): + self.config = config + self.database_path = config.data_config.database_path + self.shape = tuple(config.data_config.metadata["shape"]) #(t,(z),y,x) + self.dtype = dtype + self.ndim = len(self.shape) + + self.array = np.zeros(self.shape[1:] ,dtype=self.dtype) + self.minmax = self.find_min_max_volume_entire_dataset() + self.volume = self.minmax.mean().astype(int) + self.export_func = self.array.__setitem__ + +#proper documentation!! + + def __getitem__(self, indexing): + if isinstance(indexing, tuple): + time, volume_slicing = indexing[0], indexing[1:] + else: + time = indexing + volume_slicing = ... + + try: + time = time.item() #convert from numpy.int to int + except: + time = time + + self.query_volume( + time = time, + buffer = self.array, + ) + + return self.array[volume_slicing] + + + def query_volume(self, + time: int, + buffer: np.array, + ) -> None: + engine = sqla.create_engine(self.database_path) + buffer.fill(0) + + with Session(engine) as session: + query = list( + session.query(NodeDB.id, NodeDB.pickle, NodeDB.hier_parent_id).where( + NodeDB.t == time + ) + ) + + idx_to_plot = [] + + for idx,q in enumerate(query): + if q[1].area <= self.volume: + idx_to_plot.append(idx) + + id_to_plot = [q[0] for idx,q in enumerate(query) if idx in idx_to_plot ] + label_list = np.arange(1,len(query)+1,dtype=int) + + to_remove = [] + for idx in idx_to_plot: + if query[idx][2] in id_to_plot: #if parent is also printed + to_remove.append(idx) + + for idx in to_remove: + idx_to_plot.remove(idx) + + if len(query) == 0: + print('query is empty!') + + for idx in idx_to_plot: + query[idx][1].paint_buffer(buffer, value=label_list[idx], include_time=False) + + return query + + + def find_minmax_volumes_1_timepoint(self, + time: int, + ) -> np.ndarray: + + ## + # returns an np.array: [minVolume, maxVolume] of all nodes in the hierarchy for a single time point + ## + + engine = sqla.create_engine(self.database_path) + min_vol = np.inf + max_vol = 0 + with Session(engine) as session: + query = list( + session.query(NodeDB.pickle).where( + NodeDB.t == time + ) + ) + for node in query: + vol = node[0].area + if vol < min_vol: + min_vol = vol + if vol > max_vol: + max_vol = vol + return np.array([min_vol, max_vol]).astype(int) + + def find_min_max_volume_entire_dataset(self): + ## + # loops over all time points in the stack and returns an + # np.array: [minVolume, maxVolume] of all nodes in the hierarchy over all times + ## + min_vol = np.inf + max_vol = 0 + for t in range(self.shape[0]): + minmax = self.find_minmax_volumes_1_timepoint(t) + if minmax[0] < min_vol: + min_vol = minmax[0] + if minmax[1] > max_vol: + max_vol = minmax[1] + + return np.array([min_vol, max_vol],dtype=int) \ No newline at end of file From cc0b1ed4473b2edd77d217dd302d233a875df9e3 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 11 Jul 2024 13:46:13 -0700 Subject: [PATCH 06/45] ran pre-commit --- ultrack/api/__init__.py | 3 +- ultrack/api/database.py | 4 +-- ultrack/api/main.py | 7 ++-- ultrack/config/config.py | 5 +-- ultrack/utils/array.py | 73 ++++++++++++++++++++-------------------- 5 files changed, 47 insertions(+), 45 deletions(-) diff --git a/ultrack/api/__init__.py b/ultrack/api/__init__.py index 29d72d1..3fb5a15 100644 --- a/ultrack/api/__init__.py +++ b/ultrack/api/__init__.py @@ -1,3 +1,2 @@ +from ultrack.api.database import Experiment, ExperimentStatus from ultrack.api.main import start_server -from ultrack.api.database import Experiment -from ultrack.api.database import ExperimentStatus \ No newline at end of file diff --git a/ultrack/api/database.py b/ultrack/api/database.py index 1959b34..8e48729 100644 --- a/ultrack/api/database.py +++ b/ultrack/api/database.py @@ -7,7 +7,7 @@ import sqlalchemy as sqla from pydantic import BaseModel, Json, validator -from sqlalchemy import JSON, Column, DateTime, Enum, Integer, String, Text +from sqlalchemy import Column, Enum, Integer, JSON, String, Text from sqlalchemy.orm import declarative_base, sessionmaker from ultrack import MainConfig @@ -262,8 +262,6 @@ def update_experiment(experiment: Experiment) -> None: session.close() - - def get_experiment(id: int) -> Experiment: """Get an experiment from the database. diff --git a/ultrack/api/main.py b/ultrack/api/main.py index 86560c3..45d231f 100644 --- a/ultrack/api/main.py +++ b/ultrack/api/main.py @@ -1,7 +1,6 @@ import os from multiprocessing import Process from pathlib import Path -from threading import Thread from typing import Union import uvicorn @@ -13,7 +12,8 @@ def _in_notebook(): try: from IPython import get_ipython - if 'IPKernelApp' not in get_ipython().config: # pragma: no cover + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover return False except ImportError: return False @@ -21,6 +21,7 @@ def _in_notebook(): return False return True + def start_server( api_results_path: Union[Path, str, None] = None, ultrack_data_config: Union[DataConfig, None] = None, @@ -41,8 +42,10 @@ def start_server( os.environ["ULTRACK_DATA_CONFIG"] = ultrack_data_config.json() if _in_notebook(): + def start_in_notebook(): uvicorn.run(app.app, host=host, port=port) + Process(target=start_in_notebook).start() else: uvicorn.run(app.app, host=host, port=port) diff --git a/ultrack/config/config.py b/ultrack/config/config.py index 3b3123b..f583428 100644 --- a/ultrack/config/config.py +++ b/ultrack/config/config.py @@ -45,7 +45,8 @@ def load_config(path: Union[str, Path]) -> MainConfig: LOG.info(data) return MainConfig.parse_obj(data) + def save_config(config, path: Union[str, Path]): """Saved MainConfig to TOML file.""" - with open(path,mode='w') as f: - toml.dump(config.dict(by_alias=True),f) \ No newline at end of file + with open(path, mode="w") as f: + toml.dump(config.dict(by_alias=True), f) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 049f6e2..b4dd744 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -6,12 +6,13 @@ from typing import Callable, Literal, Optional, Tuple, Union import numpy as np +import sqlalchemy as sqla import zarr from numpy.typing import ArrayLike +from sqlalchemy.orm import Session from tqdm import tqdm from zarr.storage import Store -import sqlalchemy as sqla -from sqlalchemy.orm import Session + from ultrack.core.database import NodeDB LOG = logging.getLogger(__name__) @@ -197,25 +198,27 @@ def create_zarr( return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) + class UltrackArray: - def __init__(self, + def __init__( + self, config, dtype: np.dtype = np.int32, ): - self.config = config - self.database_path = config.data_config.database_path - self.shape = tuple(config.data_config.metadata["shape"]) #(t,(z),y,x) - self.dtype = dtype - self.ndim = len(self.shape) - - self.array = np.zeros(self.shape[1:] ,dtype=self.dtype) + self.config = config + self.database_path = config.data_config.database_path + self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) + self.dtype = dtype + self.ndim = len(self.shape) + + self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) self.export_func = self.array.__setitem__ -#proper documentation!! + # proper documentation!! - def __getitem__(self, indexing): + def __getitem__(self, indexing): if isinstance(indexing, tuple): time, volume_slicing = indexing[0], indexing[1:] else: @@ -223,19 +226,19 @@ def __getitem__(self, indexing): volume_slicing = ... try: - time = time.item() #convert from numpy.int to int + time = time.item() # convert from numpy.int to int except: time = time self.query_volume( - time = time, - buffer = self.array, + time=time, + buffer=self.array, ) return self.array[volume_slicing] - - def query_volume(self, + def query_volume( + self, time: int, buffer: np.array, ) -> None: @@ -251,34 +254,36 @@ def query_volume(self, idx_to_plot = [] - for idx,q in enumerate(query): + for idx, q in enumerate(query): if q[1].area <= self.volume: idx_to_plot.append(idx) - id_to_plot = [q[0] for idx,q in enumerate(query) if idx in idx_to_plot ] - label_list = np.arange(1,len(query)+1,dtype=int) + id_to_plot = [q[0] for idx, q in enumerate(query) if idx in idx_to_plot] + label_list = np.arange(1, len(query) + 1, dtype=int) to_remove = [] for idx in idx_to_plot: - if query[idx][2] in id_to_plot: #if parent is also printed + if query[idx][2] in id_to_plot: # if parent is also printed to_remove.append(idx) for idx in to_remove: idx_to_plot.remove(idx) if len(query) == 0: - print('query is empty!') + print("query is empty!") for idx in idx_to_plot: - query[idx][1].paint_buffer(buffer, value=label_list[idx], include_time=False) + query[idx][1].paint_buffer( + buffer, value=label_list[idx], include_time=False + ) return query + def find_minmax_volumes_1_timepoint( + self, + time: int, + ) -> np.ndarray: - def find_minmax_volumes_1_timepoint(self, - time: int, - ) -> np.ndarray: - ## # returns an np.array: [minVolume, maxVolume] of all nodes in the hierarchy for a single time point ## @@ -287,11 +292,7 @@ def find_minmax_volumes_1_timepoint(self, min_vol = np.inf max_vol = 0 with Session(engine) as session: - query = list( - session.query(NodeDB.pickle).where( - NodeDB.t == time - ) - ) + query = list(session.query(NodeDB.pickle).where(NodeDB.t == time)) for node in query: vol = node[0].area if vol < min_vol: @@ -299,10 +300,10 @@ def find_minmax_volumes_1_timepoint(self, if vol > max_vol: max_vol = vol return np.array([min_vol, max_vol]).astype(int) - + def find_min_max_volume_entire_dataset(self): ## - # loops over all time points in the stack and returns an + # loops over all time points in the stack and returns an # np.array: [minVolume, maxVolume] of all nodes in the hierarchy over all times ## min_vol = np.inf @@ -313,5 +314,5 @@ def find_min_max_volume_entire_dataset(self): min_vol = minmax[0] if minmax[1] > max_vol: max_vol = minmax[1] - - return np.array([min_vol, max_vol],dtype=int) \ No newline at end of file + + return np.array([min_vol, max_vol], dtype=int) From 4d3b84145fe9fb902fb7e665645ab193233ed5e4 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 16 Jul 2024 10:52:35 -0700 Subject: [PATCH 07/45] added path to database as optional input parameter. If provided, the database is taken from the provided path. If not provided, the database is loaded from the path stored in config --- ultrack/utils/array.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index b4dd744..fd03909 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -203,19 +203,26 @@ class UltrackArray: def __init__( self, config, + Tmax, + database_path: Union[str,None] = None, dtype: np.dtype = np.int32, ): self.config = config - self.database_path = config.data_config.database_path self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype + self.Tmax = Tmax self.ndim = len(self.shape) - self.array = np.zeros(self.shape[1:], dtype=self.dtype) - self.minmax = self.find_min_max_volume_entire_dataset() - self.volume = self.minmax.mean().astype(int) self.export_func = self.array.__setitem__ + if database_path is None: + self.database_path = config.data_config.database_path + else: + self.database_path = database_path + + self.minmax = self.find_min_max_volume_entire_dataset() + self.volume = self.minmax.mean().astype(int) + # proper documentation!! def __getitem__(self, indexing): @@ -308,7 +315,7 @@ def find_min_max_volume_entire_dataset(self): ## min_vol = np.inf max_vol = 0 - for t in range(self.shape[0]): + for t in range(self.Tmax): #range(self.shape[0]): minmax = self.find_minmax_volumes_1_timepoint(t) if minmax[0] < min_vol: min_vol = minmax[0] From dafa92acf11806eb3b1904a6f1bc56c72d19f6c9 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 29 Jul 2024 09:49:55 -0700 Subject: [PATCH 08/45] adding frontier to database --- ultrack/__init__.py | 2 +- ultrack/core/database.py | 1 + ultrack/core/segmentation/processing.py | 1 + .../core/segmentation/vendored/hierarchy.py | 58 ++++++++++++------- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/ultrack/__init__.py b/ultrack/__init__.py index fec873f..0fc5b29 100644 --- a/ultrack/__init__.py +++ b/ultrack/__init__.py @@ -10,7 +10,7 @@ # ignoring small float32/64 zero flush warning warnings.filterwarnings("ignore", message="The value of the smallest subnormal for") -__version__ = "0.6.0" +__version__ = "0.6.1" from ultrack.config.config import MainConfig, load_config from ultrack.core.export.ctc import to_ctc diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 3b1232d..d7d1ff3 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -83,6 +83,7 @@ class NodeDB(Base): y_shift = Column(Float, default=0.0) x_shift = Column(Float, default=0.0) area = Column(Integer) + frontier = Column(Float, default=-1.0) selected = Column(Boolean, default=False) pickle = Column(MaybePickleType) segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index f5b22f6..10d7d76 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -163,6 +163,7 @@ def _process( y=int(y), x=int(x), area=int(hier_node.area), + frontier=hier_node.frontier, pickle=pickle.dumps(hier_node), # pickling to reduce memory usage ) diff --git a/ultrack/core/segmentation/vendored/hierarchy.py b/ultrack/core/segmentation/vendored/hierarchy.py index 9ae6129..b70c62e 100644 --- a/ultrack/core/segmentation/vendored/hierarchy.py +++ b/ultrack/core/segmentation/vendored/hierarchy.py @@ -126,17 +126,13 @@ def cache(self, status: bool) -> None: def _filter_contour_strength( tree: hg.Tree, alt: ArrayLike, - graph: hg.UndirectedGraph, - weights: ArrayLike, + frontier: ArrayLike, threshold: float, max_area: float, ) -> Tuple[hg.Tree, ArrayLike]: LOG.info("Filtering hierarchy by contour strength.") - - hg.set_attribute(graph, "no_border_vertex_out_degree", None) - irrelevant_nodes = hg.attribute_contour_strength(tree, weights) < threshold - hg.set_attribute(graph, "no_border_vertex_out_degree", 6) + irrelevant_nodes = frontier < threshold if max_area is not None: # Avoid filtering nodes where merge leads to a node with maximum area above threshold @@ -144,9 +140,9 @@ def _filter_contour_strength( irrelevant_nodes[parent_area > max_area] = False tree, node_map = hg.simplify_tree(tree, irrelevant_nodes) - return tree, alt[node_map] + return tree, alt[node_map], frontier[node_map] - def watershed_hierarchy(self) -> Tuple[hg.Tree, ArrayLike]: + def watershed_hierarchy(self) -> Tuple[hg.Tree, ArrayLike, ArrayLike]: """ Creates and filters the watershed hierarchy. @@ -175,12 +171,15 @@ def watershed_hierarchy(self) -> Tuple[hg.Tree, ArrayLike]: LOG.info("Filtering small nodes of hierarchy.") tree, alt = hg.filter_small_nodes_from_tree(tree, alt, self._min_area) + hg.set_attribute(graph, "no_border_vertex_out_degree", None) + frontier = hg.attribute_contour_strength(tree, weights) + hg.set_attribute(graph, "no_border_vertex_out_degree", 2 * mask.ndim) + if self._min_frontier > 0.0: - tree, alt = self._filter_contour_strength( + tree, alt, frontier = self._filter_contour_strength( tree, alt, - graph, - weights, + frontier, self._min_frontier, self._max_area, ) @@ -190,8 +189,9 @@ def watershed_hierarchy(self) -> Tuple[hg.Tree, ArrayLike]: tree, hg.attribute_area(tree) > self._max_area ) alt = alt[node_map] + frontier = frontier[node_map] - return tree, alt + return tree, alt, frontier @property @_cached @@ -212,19 +212,33 @@ def dynamics(self) -> ArrayLike: @property @_cached def tree(self) -> hg.Tree: - tree, alt = self.watershed_hierarchy() + tree, alt, frontier = self.watershed_hierarchy() if self.cache: - self._cache["tree"], self._cache["alt"] = tree, alt + self._cache["tree"] = tree + self._cache["alt"] = alt + self._cache["frontier"] = frontier return tree @property @_cached def alt(self) -> ArrayLike: - tree, alt = self.watershed_hierarchy() + tree, alt, frontier = self.watershed_hierarchy() if self.cache: - self._cache["tree"], self._cache["alt"] = tree, alt + self._cache["tree"] = tree + self._cache["alt"] = alt + self._cache["frontier"] = frontier return alt + @property + @_cached + def frontier(self) -> hg.Tree: + tree, alt, frontier = self.watershed_hierarchy() + if self.cache: + self._cache["tree"] = tree + self._cache["alt"] = alt + self._cache["frontier"] = frontier + return frontier + @property @_cached def cut(self) -> hg.HorizontalCutExplorer: @@ -255,17 +269,16 @@ def compute_nodes(self) -> None: """ tree = self.tree area = self.area - num_leaves = tree.num_leaves() + frontier = self.frontier - for i, node_idx in enumerate( - tree.leaves_to_root_iterator(include_leaves=False) - ): - if area[num_leaves + i] > self._max_area: + for node_idx in tree.leaves_to_root_iterator(include_leaves=False): + if area[node_idx] > self._max_area: continue self._nodes[node_idx] = self.create_node( node_idx, self, - area=area[num_leaves + i].item(), + area=area[node_idx].item(), + frontier=frontier[node_idx].item(), ) def _fix_empty_nodes(self) -> None: @@ -281,6 +294,7 @@ def _fix_empty_nodes(self) -> None: root_index, self, area=self.props.area, + frontier=-1.0, ) @property From dbc25784313de3a8eed90531ce2a15de1b48aaf1 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 30 Jul 2024 14:20:00 -0700 Subject: [PATCH 09/45] hierarchy widget, work in progress --- ultrack/napari.yaml | 7 + ultrack/utils/array.py | 22 ++- ultrack/widgets/__init__.py | 1 + ultrack/widgets/hierarchy_viz_widget.py | 182 ++++++++++++++++++++++++ 4 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 ultrack/widgets/hierarchy_viz_widget.py diff --git a/ultrack/napari.yaml b/ultrack/napari.yaml index 8bcfb3c..6fe654a 100644 --- a/ultrack/napari.yaml +++ b/ultrack/napari.yaml @@ -17,6 +17,10 @@ contributions: python_name: ultrack.widgets:UltrackWidget title: Ultrack + - id: ultrack.hierarchy_viz_widget + python_name: ultrack.widgets:HierarchyVizWidget + title: Hierarchy visualization + ###### DEPRECATED & WIP WIDGETS ##### # - id: ultrack.labels_to_edges_widget # python_name: ultrack.widgets:LabelsToContoursWidget @@ -69,3 +73,6 @@ contributions: - command: ultrack.track_inspection display_name: Track inspection + + - command: ultrack.hierarchy_viz_widget + display_name: Hierarchy visualization diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index fd03909..855f99b 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -203,14 +203,13 @@ class UltrackArray: def __init__( self, config, - Tmax, database_path: Union[str,None] = None, dtype: np.dtype = np.int32, ): self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.Tmax = Tmax + self.Tmax = config.data_config.metadata["shape"][0] #first channel must the T!! self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.export_func = self.array.__setitem__ @@ -222,6 +221,7 @@ def __init__( self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) + self.initial_volume = self.volume.copy() # proper documentation!! @@ -323,3 +323,21 @@ def find_min_max_volume_entire_dataset(self): max_vol = minmax[1] return np.array([min_vol, max_vol], dtype=int) + + def get_volume_list( + self, + ) -> np.ndarray: + + ## + # get a list of the volumes of ALL segments in the database (all time frames) + ## + + engine = sqla.create_engine(self.database_path) + vol_list = [] + with Session(engine) as session: + query = list(session.query(NodeDB.pickle)) + for node in query: + vol = node[0].area + vol_list.append(vol) + + return vol_list diff --git a/ultrack/widgets/__init__.py b/ultrack/widgets/__init__.py index fc9dedc..9338644 100644 --- a/ultrack/widgets/__init__.py +++ b/ultrack/widgets/__init__.py @@ -1,5 +1,6 @@ from ultrack.widgets.division_annotation_widget import DivisionAnnotationWidget from ultrack.widgets.hypotheses_viz_widget import HypothesesVizWidget +from ultrack.widgets.hierarchy_viz_widget import HierarchyVizWidget from ultrack.widgets.labels_to_edges_widget import LabelsToContoursWidget from ultrack.widgets.node_annotation_widget import NodeAnnotationWidget from ultrack.widgets.track_inspection_widget import TrackInspectionWidget diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py new file mode 100644 index 0000000..a27cccf --- /dev/null +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -0,0 +1,182 @@ +import logging +from typing import Dict, List, Optional, Sequence +# from warnings import warn + +import napari +import numpy as np +import pandas as pd +# import sqlalchemy as sqla +from magicgui.widgets import CheckBox, FloatSlider, PushButton, Container +from qtpy.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton, QSlider +from qtpy.QtCore import Qt + +from napari.layers import Labels +from sqlalchemy.orm import Session + +# from ultrack.core.database import LinkDB, NodeDB +# from ultrack.core.segmentation.node import Node +# from ultrack.widgets._generic_data_widget import GenericDataWidget +from ultrack.utils.array import UltrackArray + +logging.basicConfig() +logging.getLogger("sqlachemy.engine").setLevel(logging.INFO) + +LOG = logging.getLogger(__name__) + + +class HierarchyVizWidget(Container): + def __init__(self, + viewer: napari.Viewer, + new_config = None, + ) -> None: + super().__init__() + + self._viewer = viewer + + if new_config is None: + print('ULTRACK WIDGET NOT OPEN!!!') + #load the config from Ultrack widget + else: + self.new_config = new_config + + self._area_threshold_w = FloatSlider(label="Area", min=0, max=0) + self._area_threshold_w.changed.connect(self._slider_update) + self.append(self._area_threshold_w) + + print('Check if hierarchy doesnt exist already!') + self.ultrack_layer = UltrackArray(self.new_config) + self._viewer.add_labels(self.ultrack_layer, name='hierarchy') + self._area_threshold_w.max = self.ultrack_layer.minmax[1]+1 + self._area_threshold_w.min = self.ultrack_layer.minmax[0]-1 + self._area_threshold_w.value = self.ultrack_layer.initial_volume + + def _on_config_changed(self) -> None: + self._ndim = len(self._shape) + + @property + def _shape(self) -> Sequence[int]: + return self.config.metadata.get("shape", []) + + def _slider_update(self, value: float) -> None: + # print('updated slider:',value) + self.ultrack_layer.volume = value + self._viewer.layers['hierarchy'].refresh() + + + +# def _on_load_segm(self) -> None: +# time = self._time +# engine = sqla.create_engine(self.config.database_path) +# with Session(engine) as session: +# query = ( +# session.query(NodeDB.pickle, NodeDB.t_hier_id) +# .where(NodeDB.t == time) +# .order_by(NodeDB.area) +# ) +# self._nodes, self._hier_ids = zip(*query) +# # overlaps = ( +# # session.query(OverlapDB) +# # .join(NodeDB, NodeDB.id == OverlapDB.node_id) +# # .where(NodeDB.t == time) +# # ) + +# self._nodes = {node.id: node for node in self._nodes} + +# if len(self._nodes) == 0: +# raise ValueError(f"Could not find segmentations at time {time}") + +# area = np.asarray([node.area for node in self._nodes.values()]) +# self._area_threshold_w.min = area.min() +# self._area_threshold_w.max = area.max() +# self._area_threshold_w.value = np.median(area) + +# def _on_threshold_update(self, value: float) -> None: +# segmentation = self._get_segmentation(threshold=value) +# if self._segm_layer_name in self._viewer.layers: +# self._viewer.layers[self._segm_layer_name].data = segmentation +# else: +# layer = self._viewer.add_labels(segmentation, name=self._segm_layer_name) +# layer.mouse_move_callbacks.append(self._on_mouse_move) + +# def _get_segmentation(self, threshold: float) -> np.ndarray: +# """ +# NOTE: +# when making this interactive it could be interesting to use the overlap data +# to avoid empty regions when visualizing segments +# """ +# if self._ndim == 0: +# raise ValueError( +# "Could not find `shape` metadata. It should be saved during `segmentation` on your `workdir`." +# ) + +# seen_hierarchies = set() + +# buffer = np.zeros(self._shape[1:], dtype=np.uint32) # ignoring time +# for node, hier_id in zip(self._nodes.values(), self._hier_ids): +# if node.area <= threshold or hier_id not in seen_hierarchies: +# # paint segments larger than threshold on empty regions +# node.paint_buffer(buffer, node.id, include_time=False) +# seen_hierarchies.add(hier_id) + +# return buffer + +# @property +# def _time(self) -> None: +# available_ndim = self._viewer.dims.ndim +# if available_ndim < self._ndim: +# warn( +# "Napari `ndims` smaller than dataset `ndims`. " +# f"Expected {self._ndim}, found {available_ndim}. Using time = 0" +# ) +# return 0 + +# return self._viewer.dims.point[-self._ndim] + +# def _on_mouse_move(self, layer: Optional[Labels], event) -> None: +# if not self._link_w.value: +# return +# self._load_neighbors(layer.get_value(event.position, world=True)) + +# def _load_neighbors(self, index: int) -> None: +# if index is None or index <= 0: +# return + +# index = int(index) # might be numpy array + +# LOG.info(f"Loading node index = {index}") + +# engine = sqla.create_engine(self.config.database_path, echo=True) +# with Session(engine) as session: +# query = session.query(NodeDB.z, NodeDB.y, NodeDB.x, LinkDB.weight).where( +# LinkDB.target_id == NodeDB.id, LinkDB.source_id == index +# ) +# df = pd.read_sql(query.statement, session.bind) + +# LOG.info(f"Found {len(df)} neighbors") + +# if len(df) == 0: +# return + +# node = self._nodes[index] +# ndim = len(node.centroid) +# centroids = df[["z", "y", "x"]].values[:, -ndim:] # removing z if 2D + +# vectors = np.tile(self._nodes[index].centroid, (len(df), 2, 1)) +# vectors[:, 1, :] = centroids - vectors[:, 0, :] + +# if self._link_layer_name in self._viewer.layers: +# self._viewer.layers.remove(self._link_layer_name) + +# self._viewer.add_vectors( +# data=vectors, +# name=self._link_layer_name, +# features={"weights": df["weight"]}, +# edge_color="weights", +# opacity=1.0, +# ) + +# LOG.info(f"vectors:\n{vectors}") + +# self._viewer.layers.selection.active = self._viewer.layers[ +# self._segm_layer_name +# ] From 57fd8031b09f32e051498b1bac23712b3811e508 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 30 Jul 2024 16:21:02 -0700 Subject: [PATCH 10/45] added non-linear slider for uniform addition of volumes' --- ultrack/utils/array.py | 3 +- ultrack/widgets/hierarchy_viz_widget.py | 170 +++++------------------- 2 files changed, 35 insertions(+), 138 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 855f99b..2839ed0 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -326,6 +326,7 @@ def find_min_max_volume_entire_dataset(self): def get_volume_list( self, + timeLimit, ) -> np.ndarray: ## @@ -335,7 +336,7 @@ def get_volume_list( engine = sqla.create_engine(self.database_path) vol_list = [] with Session(engine) as session: - query = list(session.query(NodeDB.pickle)) + query = list(session.query(NodeDB.pickle).where(NodeDB.t <= timeLimit)) for node in query: vol = node[0].area vol_list.append(vol) diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index a27cccf..6ce4f70 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -4,18 +4,13 @@ import napari import numpy as np -import pandas as pd +# import pandas as pd +from scipy import interpolate # import sqlalchemy as sqla -from magicgui.widgets import CheckBox, FloatSlider, PushButton, Container -from qtpy.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton, QSlider -from qtpy.QtCore import Qt +from magicgui.widgets import FloatSlider, Container, Label +# from qtpy.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton, QSlider +# from qtpy.QtCore import Qt -from napari.layers import Labels -from sqlalchemy.orm import Session - -# from ultrack.core.database import LinkDB, NodeDB -# from ultrack.core.segmentation.node import Node -# from ultrack.widgets._generic_data_widget import GenericDataWidget from ultrack.utils.array import UltrackArray logging.basicConfig() @@ -29,7 +24,7 @@ def __init__(self, viewer: napari.Viewer, new_config = None, ) -> None: - super().__init__() + super().__init__(layout='horizontal') self._viewer = viewer @@ -39,16 +34,24 @@ def __init__(self, else: self.new_config = new_config - self._area_threshold_w = FloatSlider(label="Area", min=0, max=0) - self._area_threshold_w.changed.connect(self._slider_update) - self.append(self._area_threshold_w) - print('Check if hierarchy doesnt exist already!') self.ultrack_layer = UltrackArray(self.new_config) self._viewer.add_labels(self.ultrack_layer, name='hierarchy') - self._area_threshold_w.max = self.ultrack_layer.minmax[1]+1 - self._area_threshold_w.min = self.ultrack_layer.minmax[0]-1 - self._area_threshold_w.value = self.ultrack_layer.initial_volume + + self.mapping = self._create_mapping() + + self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) + self._area_threshold_w.value = 0.5 + self._area_threshold_w.changed.connect(self._slider_update) + + self.slider_label = Label(label=str(self.mapping(self._area_threshold_w.value))) + self.slider_label.native.setFixedWidth(100) + self.append(self._area_threshold_w) + self.append(self.slider_label) + + # self._area_threshold_w.max = self.ultrack_layer.minmax[1]+1 + # self._area_threshold_w.min = self.ultrack_layer.minmax[0]-1 + # self._area_threshold_w.value = self.ultrack_layer.initial_volume def _on_config_changed(self) -> None: self._ndim = len(self._shape) @@ -58,125 +61,18 @@ def _shape(self) -> Sequence[int]: return self.config.metadata.get("shape", []) def _slider_update(self, value: float) -> None: - # print('updated slider:',value) - self.ultrack_layer.volume = value + self.ultrack_layer.volume = self.mapping(value) + self.slider_label.label = str(int(self.mapping(value))) + # print(len(self._area_threshold_w.label)) self._viewer.layers['hierarchy'].refresh() + def _create_mapping(self): + volume_list = self.ultrack_layer.get_volume_list(timeLimit=5) + volume_list.append(self.ultrack_layer.minmax[0]) + volume_list.append(self.ultrack_layer.minmax[1]) + volume_list.sort() - -# def _on_load_segm(self) -> None: -# time = self._time -# engine = sqla.create_engine(self.config.database_path) -# with Session(engine) as session: -# query = ( -# session.query(NodeDB.pickle, NodeDB.t_hier_id) -# .where(NodeDB.t == time) -# .order_by(NodeDB.area) -# ) -# self._nodes, self._hier_ids = zip(*query) -# # overlaps = ( -# # session.query(OverlapDB) -# # .join(NodeDB, NodeDB.id == OverlapDB.node_id) -# # .where(NodeDB.t == time) -# # ) - -# self._nodes = {node.id: node for node in self._nodes} - -# if len(self._nodes) == 0: -# raise ValueError(f"Could not find segmentations at time {time}") - -# area = np.asarray([node.area for node in self._nodes.values()]) -# self._area_threshold_w.min = area.min() -# self._area_threshold_w.max = area.max() -# self._area_threshold_w.value = np.median(area) - -# def _on_threshold_update(self, value: float) -> None: -# segmentation = self._get_segmentation(threshold=value) -# if self._segm_layer_name in self._viewer.layers: -# self._viewer.layers[self._segm_layer_name].data = segmentation -# else: -# layer = self._viewer.add_labels(segmentation, name=self._segm_layer_name) -# layer.mouse_move_callbacks.append(self._on_mouse_move) - -# def _get_segmentation(self, threshold: float) -> np.ndarray: -# """ -# NOTE: -# when making this interactive it could be interesting to use the overlap data -# to avoid empty regions when visualizing segments -# """ -# if self._ndim == 0: -# raise ValueError( -# "Could not find `shape` metadata. It should be saved during `segmentation` on your `workdir`." -# ) - -# seen_hierarchies = set() - -# buffer = np.zeros(self._shape[1:], dtype=np.uint32) # ignoring time -# for node, hier_id in zip(self._nodes.values(), self._hier_ids): -# if node.area <= threshold or hier_id not in seen_hierarchies: -# # paint segments larger than threshold on empty regions -# node.paint_buffer(buffer, node.id, include_time=False) -# seen_hierarchies.add(hier_id) - -# return buffer - -# @property -# def _time(self) -> None: -# available_ndim = self._viewer.dims.ndim -# if available_ndim < self._ndim: -# warn( -# "Napari `ndims` smaller than dataset `ndims`. " -# f"Expected {self._ndim}, found {available_ndim}. Using time = 0" -# ) -# return 0 - -# return self._viewer.dims.point[-self._ndim] - -# def _on_mouse_move(self, layer: Optional[Labels], event) -> None: -# if not self._link_w.value: -# return -# self._load_neighbors(layer.get_value(event.position, world=True)) - -# def _load_neighbors(self, index: int) -> None: -# if index is None or index <= 0: -# return - -# index = int(index) # might be numpy array - -# LOG.info(f"Loading node index = {index}") - -# engine = sqla.create_engine(self.config.database_path, echo=True) -# with Session(engine) as session: -# query = session.query(NodeDB.z, NodeDB.y, NodeDB.x, LinkDB.weight).where( -# LinkDB.target_id == NodeDB.id, LinkDB.source_id == index -# ) -# df = pd.read_sql(query.statement, session.bind) - -# LOG.info(f"Found {len(df)} neighbors") - -# if len(df) == 0: -# return - -# node = self._nodes[index] -# ndim = len(node.centroid) -# centroids = df[["z", "y", "x"]].values[:, -ndim:] # removing z if 2D - -# vectors = np.tile(self._nodes[index].centroid, (len(df), 2, 1)) -# vectors[:, 1, :] = centroids - vectors[:, 0, :] - -# if self._link_layer_name in self._viewer.layers: -# self._viewer.layers.remove(self._link_layer_name) - -# self._viewer.add_vectors( -# data=vectors, -# name=self._link_layer_name, -# features={"weights": df["weight"]}, -# edge_color="weights", -# opacity=1.0, -# ) - -# LOG.info(f"vectors:\n{vectors}") - -# self._viewer.layers.selection.active = self._viewer.layers[ -# self._segm_layer_name -# ] + x_vec = np.linspace(0,1,len(volume_list)) + y_vec = np.array(volume_list) + mapping = interpolate.interp1d(x_vec,y_vec) + return mapping \ No newline at end of file From 4e1437d1700074d6a9e841d2c760a08891cd60f9 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Tue, 30 Jul 2024 18:11:37 -0700 Subject: [PATCH 11/45] added documentation to ultrack-array and HierarchyVizWidget --- ultrack/utils/array.py | 83 ++++++++++++++----- ultrack/widgets/hierarchy_viz_widget.py | 64 +++++++++----- .../widgets/ultrackwidget/ultrackwidget.py | 11 ++- 3 files changed, 112 insertions(+), 46 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 2839ed0..bde29ec 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -203,9 +203,18 @@ class UltrackArray: def __init__( self, config, - database_path: Union[str,None] = None, dtype: np.dtype = np.int32, ): + """Create an array that directly visualizes the segments in the ultrack database. + + Parameters + ---------- + config : MainConfig + Configuration file of Ultrack. + dtype : np.dtype + Data type of the array. + """ + self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype @@ -214,18 +223,24 @@ def __init__( self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.export_func = self.array.__setitem__ - if database_path is None: - self.database_path = config.data_config.database_path - else: - self.database_path = database_path - + self.database_path = config.data_config.database_path self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) self.initial_volume = self.volume.copy() - # proper documentation!! - def __getitem__(self, indexing): + """Indexing the ultrack-array + + Parameters + ---------- + indexing : Tuple or Array + + Returns + ------- + array : numpy array + array with painted segments + """ + if isinstance(indexing, tuple): time, volume_slicing = indexing[0], indexing[1:] else: @@ -249,6 +264,15 @@ def query_volume( time: int, buffer: np.array, ) -> None: + """Paint all segments of specific time point which volume is bigger than self.volume + Parameters + ---------- + time : int + time point to paint the segments + buffer : np.array + np.zeros to be filled with segments + """ + engine = sqla.create_engine(self.database_path) buffer.fill(0) @@ -290,11 +314,17 @@ def find_minmax_volumes_1_timepoint( self, time: int, ) -> np.ndarray: + """Find minimum and maximum segment volume for single time point - ## - # returns an np.array: [minVolume, maxVolume] of all nodes in the hierarchy for a single time point - ## + Parameters + ---------- + time : int + Returns + ------- + np.array : np.array + array with two elements: [min_volume, max_volume] + """ engine = sqla.create_engine(self.database_path) min_vol = np.inf max_vol = 0 @@ -309,10 +339,13 @@ def find_minmax_volumes_1_timepoint( return np.array([min_vol, max_vol]).astype(int) def find_min_max_volume_entire_dataset(self): - ## - # loops over all time points in the stack and returns an - # np.array: [minVolume, maxVolume] of all nodes in the hierarchy over all times - ## + """Find minimum and maximum segment volume for ALL time point + + Returns + ------- + np.array : np.array + array with two elements: [min_volume, max_volume] + """ min_vol = np.inf max_vol = 0 for t in range(self.Tmax): #range(self.shape[0]): @@ -326,13 +359,19 @@ def find_min_max_volume_entire_dataset(self): def get_volume_list( self, - timeLimit, - ) -> np.ndarray: - - ## - # get a list of the volumes of ALL segments in the database (all time frames) - ## - + timeLimit: int, + ) -> list: + """Creates a list of the volumes of all segments in the database (up untill t=timeLimit) + + Parameters + ---------- + timeLimit : int + + Returns + ------- + vol_list : list + list with all volumes from t=0 to t=timeLimit + """ engine = sqla.create_engine(self.database_path) vol_list = [] with Session(engine) as session: diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index 6ce4f70..a0a429d 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -1,17 +1,14 @@ import logging -from typing import Dict, List, Optional, Sequence -# from warnings import warn +from typing import List, Optional, Sequence import napari import numpy as np -# import pandas as pd from scipy import interpolate -# import sqlalchemy as sqla from magicgui.widgets import FloatSlider, Container, Label -# from qtpy.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton, QSlider -# from qtpy.QtCore import Qt from ultrack.utils.array import UltrackArray +from ultrack.config import MainConfig +from ultrack.widgets.ultrackwidget import UltrackWidget logging.basicConfig() logging.getLogger("sqlachemy.engine").setLevel(logging.INFO) @@ -22,37 +19,43 @@ class HierarchyVizWidget(Container): def __init__(self, viewer: napari.Viewer, - new_config = None, + config = None, ) -> None: - super().__init__(layout='horizontal') + """ + Initialize the HierarchyVizWidget. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer instance. + config : MainConfig of Ultrack + if not provided, config will be taken from UltrackWidget + """ + + super().__init__() #layout='horizontal') self._viewer = viewer - if new_config is None: - print('ULTRACK WIDGET NOT OPEN!!!') - #load the config from Ultrack widget + if config is None: + self.config = self._get_config() else: - self.new_config = new_config + self.config = config - print('Check if hierarchy doesnt exist already!') - self.ultrack_layer = UltrackArray(self.new_config) + self.ultrack_layer = UltrackArray(self.config) self._viewer.add_labels(self.ultrack_layer, name='hierarchy') self.mapping = self._create_mapping() self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) - self._area_threshold_w.value = 0.5 + self._area_threshold_w.value = 0.5 self._area_threshold_w.changed.connect(self._slider_update) - self.slider_label = Label(label=str(self.mapping(self._area_threshold_w.value))) - self.slider_label.native.setFixedWidth(100) + self.slider_label = Label(label=str(int(self.mapping(self._area_threshold_w.value)))) + self.slider_label.native.setFixedWidth(25) + self.append(self._area_threshold_w) self.append(self.slider_label) - # self._area_threshold_w.max = self.ultrack_layer.minmax[1]+1 - # self._area_threshold_w.min = self.ultrack_layer.minmax[0]-1 - # self._area_threshold_w.value = self.ultrack_layer.initial_volume - def _on_config_changed(self) -> None: self._ndim = len(self._shape) @@ -63,10 +66,13 @@ def _shape(self) -> Sequence[int]: def _slider_update(self, value: float) -> None: self.ultrack_layer.volume = self.mapping(value) self.slider_label.label = str(int(self.mapping(value))) - # print(len(self._area_threshold_w.label)) self._viewer.layers['hierarchy'].refresh() def _create_mapping(self): + """ + Creates a pseudo-linear mapping from U[0,1] to full range of segment volumes: + volume = mapping([0,1]) + """ volume_list = self.ultrack_layer.get_volume_list(timeLimit=5) volume_list.append(self.ultrack_layer.minmax[0]) volume_list.append(self.ultrack_layer.minmax[1]) @@ -75,4 +81,16 @@ def _create_mapping(self): x_vec = np.linspace(0,1,len(volume_list)) y_vec = np.array(volume_list) mapping = interpolate.interp1d(x_vec,y_vec) - return mapping \ No newline at end of file + return mapping + + def _get_config(self) -> MainConfig: + """ + Gets config from the Ultrack widget + """ + ultrack_widget = UltrackWidget.find_ultrack_widget(self._viewer) + if ultrack_widget is None: + raise TypeError( + "config not provided and was not found within ultrack widget" + ) + + return ultrack_widget._data_forms.get_config() \ No newline at end of file diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 64f960c..8740a2d 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -1,7 +1,7 @@ import logging import webbrowser from contextlib import redirect_stderr, redirect_stdout -from typing import Any, Generator +from typing import Any, Generator, Union import napari import qtawesome as qta @@ -573,6 +573,15 @@ def _cancel(self): self._current_worker.quit() self._bt_cancel.setEnabled(False) + @staticmethod + def find_ultrack_widget(viewer: napari.Viewer) -> Union["UltrackWidget", None]: + + for _, w in viewer.window._dock_widgets.items(): + if isinstance(w.widget(), UltrackWidget): + return w.widget() + + return None + if __name__ == "__main__": from napari import Viewer From 6ccbed306c45f632cbab8bf7f53011b2d6eadf99 Mon Sep 17 00:00:00 2001 From: Teun Huijben <45037215+TeunHuijben@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:11:31 -0700 Subject: [PATCH 12/45] Update ultrack/widgets/ultrackwidget/ultrackwidget.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jordão Bragantini --- ultrack/widgets/ultrackwidget/ultrackwidget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 8740a2d..fe327e7 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -574,7 +574,7 @@ def _cancel(self): self._bt_cancel.setEnabled(False) @staticmethod - def find_ultrack_widget(viewer: napari.Viewer) -> Union["UltrackWidget", None]: + def find_ultrack_widget(viewer: napari.Viewer) -> Optional["UltrackWidget"]: for _, w in viewer.window._dock_widgets.items(): if isinstance(w.widget(), UltrackWidget): From 336a928368ae0c61ca8e4349edc10a9aa80d7691 Mon Sep 17 00:00:00 2001 From: Teun Huijben <45037215+TeunHuijben@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:11:48 -0700 Subject: [PATCH 13/45] Update ultrack/widgets/ultrackwidget/ultrackwidget.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jordão Bragantini --- ultrack/widgets/ultrackwidget/ultrackwidget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index fe327e7..45fa720 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -1,7 +1,7 @@ import logging import webbrowser from contextlib import redirect_stderr, redirect_stdout -from typing import Any, Generator, Union +from typing import Any, Generator, Optional import napari import qtawesome as qta From 2daaa31878d0492da44c0868fddc5e92a863218a Mon Sep 17 00:00:00 2001 From: Teun Huijben <45037215+TeunHuijben@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:15:38 -0700 Subject: [PATCH 14/45] Update ultrack/utils/array.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jordão Bragantini --- ultrack/utils/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index bde29ec..ba7b346 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -218,7 +218,7 @@ def __init__( self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.Tmax = config.data_config.metadata["shape"][0] #first channel must the T!! + self.Tmax = self.shape[0] self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.export_func = self.array.__setitem__ From 925c5643fa7d36c606c1c31bb31875ff4dba8a5e Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Wed, 31 Jul 2024 19:17:24 -0700 Subject: [PATCH 15/45] implementing Jordaos revisions --- ultrack/utils/array.py | 11 +++++++---- ultrack/widgets/hierarchy_viz_widget.py | 14 +++++++------- ultrack/widgets/ultrackwidget/ultrackwidget.py | 1 + 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index bde29ec..bfbbfe4 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -14,6 +14,7 @@ from zarr.storage import Store from ultrack.core.database import NodeDB +from ultrack import MainConfig LOG = logging.getLogger(__name__) @@ -202,7 +203,7 @@ def create_zarr( class UltrackArray: def __init__( self, - config, + config: MainConfig, dtype: np.dtype = np.int32, ): """Create an array that directly visualizes the segments in the ultrack database. @@ -218,7 +219,7 @@ def __init__( self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.Tmax = config.data_config.metadata["shape"][0] #first channel must the T!! + self.t_max = config.data_config.metadata["shape"][0] #first channel must the T!! self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.export_func = self.array.__setitem__ @@ -228,7 +229,9 @@ def __init__( self.volume = self.minmax.mean().astype(int) self.initial_volume = self.volume.copy() - def __getitem__(self, indexing): + def __getitem__(self, + indexing: tuple, + ) -> np.ndarray: """Indexing the ultrack-array Parameters @@ -348,7 +351,7 @@ def find_min_max_volume_entire_dataset(self): """ min_vol = np.inf max_vol = 0 - for t in range(self.Tmax): #range(self.shape[0]): + for t in range(self.t_max): #range(self.shape[0]): minmax = self.find_minmax_volumes_1_timepoint(t) if minmax[0] < min_vol: min_vol = minmax[0] diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index a0a429d..2f2ddf4 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -32,7 +32,7 @@ def __init__(self, if not provided, config will be taken from UltrackWidget """ - super().__init__() #layout='horizontal') + super().__init__(layout='horizontal') self._viewer = viewer @@ -41,8 +41,8 @@ def __init__(self, else: self.config = config - self.ultrack_layer = UltrackArray(self.config) - self._viewer.add_labels(self.ultrack_layer, name='hierarchy') + self.ultrack_array = UltrackArray(self.config) + self._viewer.add_labels(self.ultrack_array, name='hierarchy') self.mapping = self._create_mapping() @@ -64,7 +64,7 @@ def _shape(self) -> Sequence[int]: return self.config.metadata.get("shape", []) def _slider_update(self, value: float) -> None: - self.ultrack_layer.volume = self.mapping(value) + self.ultrack_array.volume = self.mapping(value) self.slider_label.label = str(int(self.mapping(value))) self._viewer.layers['hierarchy'].refresh() @@ -73,9 +73,9 @@ def _create_mapping(self): Creates a pseudo-linear mapping from U[0,1] to full range of segment volumes: volume = mapping([0,1]) """ - volume_list = self.ultrack_layer.get_volume_list(timeLimit=5) - volume_list.append(self.ultrack_layer.minmax[0]) - volume_list.append(self.ultrack_layer.minmax[1]) + volume_list = self.ultrack_array.get_volume_list(timeLimit=5) + volume_list.append(self.ultrack_array.minmax[0]) + volume_list.append(self.ultrack_array.minmax[1]) volume_list.sort() x_vec = np.linspace(0,1,len(volume_list)) diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 8740a2d..372ff4f 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -575,6 +575,7 @@ def _cancel(self): @staticmethod def find_ultrack_widget(viewer: napari.Viewer) -> Union["UltrackWidget", None]: + """Find the ultrack, if the widget is open, otherwise returns None.""" for _, w in viewer.window._dock_widgets.items(): if isinstance(w.widget(), UltrackWidget): From 840c8521e68dc0ebb3461e957ef3896ed9bd6c4f Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 1 Aug 2024 09:15:34 -0700 Subject: [PATCH 16/45] revising PR --- .gitattributes | 2 + .github/dependabot.yaml | 10 + .gitmodules | 0 docs/README.md | 29 ++ docs/source/_static/css/style.css | 27 ++ docs/source/configuration.rst | 37 ++ docs/source/examples.rst | 15 + docs/source/fiji.rst | 6 + docs/source/getting_started.rst | 98 ++++++ docs/source/install.rst | 82 +++++ docs/source/napari.rst | 55 +++ docs/source/optimizing.rst | 106 ++++++ docs/source/rest_api.rst | 300 +++++++++++++++++ examples/README.rst | 34 ++ ultrack/core/_test/test_interactive.py | 57 ++++ ultrack/core/export/_test/test_exporter.py | 42 +++ ultrack/core/export/exporter.py | 74 ++++ ultrack/core/interactive.py | 316 ++++++++++++++++++ ultrack/imgproc/_test/test_register.py | 38 +++ ultrack/imgproc/register.py | 121 +++++++ ultrack/utils/array.py | 4 +- .../widgets/ultrackwidget/ultrackwidget.py | 2 +- 22 files changed, 1452 insertions(+), 3 deletions(-) create mode 100644 .gitattributes create mode 100644 .github/dependabot.yaml create mode 100644 .gitmodules create mode 100644 docs/README.md create mode 100644 docs/source/_static/css/style.css create mode 100644 docs/source/configuration.rst create mode 100644 docs/source/examples.rst create mode 100644 docs/source/fiji.rst create mode 100644 docs/source/getting_started.rst create mode 100644 docs/source/install.rst create mode 100644 docs/source/napari.rst create mode 100644 docs/source/optimizing.rst create mode 100644 docs/source/rest_api.rst create mode 100644 examples/README.rst create mode 100644 ultrack/core/_test/test_interactive.py create mode 100644 ultrack/core/export/_test/test_exporter.py create mode 100644 ultrack/core/export/exporter.py create mode 100644 ultrack/core/interactive.py create mode 100644 ultrack/imgproc/_test/test_register.py create mode 100644 ultrack/imgproc/register.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..07fe41c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# GitHub syntax highlighting +pixi.lock linguist-language=YAML linguist-generated=true diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..2390d8c --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - "*" diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..e69de29 diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..f9bde74 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,29 @@ +# Building docs instructions + +This assumes you have already cloned the repository and are in the root directory of the repository. + +Go to the docs directory and install the requirements + +```bash +cd docs +pip install '..[docs]' +``` + +Clean and build the docs with + +```bash +make clean +make html +``` + +In Linux, open the generated html file with + +```bash +xdg-open build/html/index.html +``` + +or in macOS + +```bash +open build/html/index.html +``` diff --git a/docs/source/_static/css/style.css b/docs/source/_static/css/style.css new file mode 100644 index 0000000..3273ab9 --- /dev/null +++ b/docs/source/_static/css/style.css @@ -0,0 +1,27 @@ +p { + line-height: 1.5; /* Adjust this value as needed */ + margin-top: 5px; /* Space above each paragraph */ + margin-bottom: 0px; /* Space below each paragraph */ +} + +li { + margin-top: 0px; /* Adjust this value as needed */ + margin-bottom: 0px; /* Adjust this value as needed */ +} + +ul, ol { + margin-top: 0px; /* Margin around the list */ +} + +h1 { + margin-top: 40px; +} + +h2 { + margin-top: 30px; +} + +h3 { + margin-top: 20px; + font-size: 1.0em; +} diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst new file mode 100644 index 0000000..60ade3a --- /dev/null +++ b/docs/source/configuration.rst @@ -0,0 +1,37 @@ +Configuration +------------- + +The configuration is at the heart of ultrack, it is used to define the parameters for each step of the pipeline and where to store the intermediate results. +The `MainConfig` is the main configuration that contains the other configurations of the individual steps plus the data configuration. + +The configurations are documented below, the parameters are ordered by importance, most important parameters are at the top of the list. Parameters that should not be changed in most of the cases are at the bottom of the list and contain a ``SPECIAL`` tag. + +.. autosummary:: + + ultrack.config.MainConfig + ultrack.config.DataConfig + ultrack.config.SegmentationConfig + ultrack.config.LinkingConfig + ultrack.config.TrackingConfig + +--------------- + +.. autopydantic_model:: ultrack.config.MainConfig + +--------------- + +.. autopydantic_model:: ultrack.config.DataConfig + +--------------- + +.. autopydantic_model:: ultrack.config.SegmentationConfig + +--------------- + +.. autopydantic_model:: ultrack.config.LinkingConfig + +--------------- + +.. _tracking_config: + +.. autopydantic_model:: ultrack.config.TrackingConfig diff --git a/docs/source/examples.rst b/docs/source/examples.rst new file mode 100644 index 0000000..be31503 --- /dev/null +++ b/docs/source/examples.rst @@ -0,0 +1,15 @@ +Examples +-------- + +.. include:: examples/README.rst + :start-line: 2 + :end-line: 25 + +Notebooks +--------- + +.. nbgallery:: + :maxdepth: 2 + :glob: + + examples/**/* diff --git a/docs/source/fiji.rst b/docs/source/fiji.rst new file mode 100644 index 0000000..f68e75e --- /dev/null +++ b/docs/source/fiji.rst @@ -0,0 +1,6 @@ +FIJI plugin +----------- + +Ultrack is also available as a `FIJI `_ plugin. + +Its usage and installation instructions are in FIJI's ultrack `documentation `_. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst new file mode 100644 index 0000000..56aabaa --- /dev/null +++ b/docs/source/getting_started.rst @@ -0,0 +1,98 @@ +Getting started +--------------- + +Ultrack tracking pipeline is divided into three main steps: + +- ``segment``: Creating the candidate segmentation hypotheses; +- ``link``: Finding candidate links between segmentation hypotheses of adjacent frames; +- ``solve``: Solving the tracking problem by finding the best segmentation and trajectory for each cell. + +These three steps have their respective function with the same names and configurations but are summarized in the ``track`` function or the ``Tracker`` class, which are the main entry point for the tracking pipeline. + +You'll notice that these functions do not return any results. Instead, they store the results in a database. This enables us to process datasets larger than memory, and distributed or parallel computing. We provide auxiliary functions to export the results to a format of your choice. + +The ``MainConfig.data_config`` provides the interface to interact with the database, so beware of using ``overwrite`` parameter when re-executing these functions, to erase previous results otherwise it will build on top of existing data. + +If you want to go deep into the weeds of our backend. We recommend looking at the ``ultrack.core.database.py`` file. + +Each one of the main steps will be explained in detail below, a detailed description of the parameters can be found in :doc:`configuration`. + +Segmentation +```````````` + +Ultrack's canonical inputs are a ``foreground`` and a ``contours``, there are several ways to obtain these inputs, which will be explained below. For now, let's consider we are working with them directly. + +Both ``foreground`` and ``contours`` maps must have the same shape, with the first dimension being time (``T``) and the remaining being the spatial dimensions (``Z`` optional, and ``Y``, ``X``). + +``foreground`` is used with ``config.segmentation_config.threshold`` to create a binary mask indicating the presence of the object of interest, it's by default 0.5. Values above the threshold are considered as foreground, and values below are considered as background. + +``contours`` indicates the confidence of each pixel (voxel) being a cell boundary (contour). The higher the value, the more likely it is a cell boundary. It couples with ``config.segmentation_config.min_frontier`` which fuses segmentation candidates separated by a boundary with an average value below this threshold, it's by default 0, so no fusion is performed. + +The segmentation is the most important step, as it will define the candidates for the tracking problem. +If your cells of interest are not present in the ``foreground`` after the threshold, you won't be able to track them. +If there isn't any faint boundary in the ``contours``, you won't be able to split into individual cells. That's why it's preferred to have a lot of contours (more hypotheses), even incorrect ones than having none. + +Linking +``````` + +The linking step is responsible for finding candidate links between segmentation hypotheses of adjacent frames. Usually, not a lot of candidates are needed (``config.linking_config.max_neighbors = 5``), unless you have several segmentation hypotheses (``contours`` with several gray levels). + +The parameter ``config.linking_config.max_distance`` must be at least the maximum distance between two cells in consecutive frames. It's used to filter out candidates that are too far from each other. If this value is too small, you won't be able to link cells that are far from each other. + +Solving +``````` + +The solving step is responsible for solving the tracking problem by finding the best segmentation and trajectory for each cell. The parameters for this step are harder to interpret, as they are related to the optimization problem. The most important ones are: + +- ``config.tracking_config.appear_weight``: The penalization for a cell to appear, which means to start a new lineage; +- ``config.tracking_config.division_weight``: The penalization for a cell to divide, breaking a single tracklet into two; +- ``config.tracking_config.disappear_weight``: The penalization for a cell to disappear, which means to end a lineage; + +These weights are negative or zero, as they try to balance the cost of including new lineages in the final solution. The connections (links) between segmentation hypotheses are positive and measure the quality of the tracks, so only lineages with a total linking weight higher than the penalizations are included in the final solution. At the same time, our optimization problem is finding the combination of connections that maximize the sum of weights of all lineages. + +See the :ref:`tracking configuration description ` for more information and :doc:`optimizing` for details on how to select these parameters. + + +Exporting +````````` + +Once the above steps have been applied, the tracking solutions are recorded in the database and they can be exported to a format of your choice, them being, ``to_networkx``, ``to_trackmate``, ``to_tracks_layer``, ``tracks_to_zarr`` and others. + +See the :ref:`export API reference ` for all available options and their parameters. + +Example of exporting solutions to napari tracks layer: + +.. code-block:: python + + # ... tracking computation + + # Exporting to napari format using `Tracker` class + tracks, graph = tracker.to_tracks_layer() + + # Exporting using config file + tracks, graph = to_tracks_layer(config) + + +Post-processing +``````````````` + +We also provide some additional post-processing functions, to remove, join, or analyze your tracks. Most of them are available in ``ultrack.tracks``. Some examples are: + +- ``close_tracks_gaps``: That closes gaps by joining tracklets and interpolating the missing segments; +- ``filter_short_sibling_tracks``: That removes short tracklets generated by false divisions; +- ``get_subgraph``: Which returns the whole lineage(s) of a given tracklet. + +Other functionalities can be found in ``ultrack.utils`` or ``ultrack.imgproc``, one notable example is: + +- ``tracks_properties``: Which returns compute statistics from the tracks, segmentation masks and images. + +For additional information, please refer to the :ref:`tracks post-processing API reference `. + +Image processing +```````````````` + +Despite being presented here last, ultrack's image processing module provides auxiliary functions to process your image before the segmentation step. It's not mandatory to use it, but it might reduce the amount of code you need to write to preprocess your images. + +Most of them are available in ``ultrack.imgproc`` , ``ultrack.utils.array`` and ``ultrack.utils.cuda`` modules. + +Refer to the :ref:`image processing API reference ` for more information. diff --git a/docs/source/install.rst b/docs/source/install.rst new file mode 100644 index 0000000..1d5d6bc --- /dev/null +++ b/docs/source/install.rst @@ -0,0 +1,82 @@ +Installation +============ + +The easiest way to install the package is to use the conda (or mamba) package manager. +If you do not have conda installed, we recommend to install mamba first, which is a faster alternative to conda. +You can find mamba installation instructions `here `_. + +Once you have conda (mamba) installed, you should create an environment for ``ultrack`` as follows: + +.. code-block:: bash + + conda create -n ultrack python=3.11 gurobi pytorch pyqt -c pytorch -c gurobi -c conda-forge + +Then, you can activate the environment and install ``ultrack``: + +.. code-block:: bash + + conda activate ultrack + pip install ultrack + +If you're using OSX you may need to install ``higra`` from source. You can do this by running the following commands: + +.. code-block:: bash + + conda activate ultrack + pip install numpy + pip install -vv git+https://github.com/higra/Higra + pip install ultrack + +You can check if the installation was successful by running: + +.. code-block:: bash + + ultrack --help + + +Gurobi setup +------------ + +Gurobi is a commercial optimization solver that is used in the tracking module of ``ultrack``. +While it is not a requirement, it is recommended to install it for the best performance. + +To use it, you need to obtain a license (free for academics) and activate it. + +Install gurobi using conda +`````````````````````````` + +You can skip this step if you have already installed Gurobi. + +In your existing Conda environment, install Gurobi with the following command: + +.. code-block:: bash + + conda install -c gurobi gurobi + +Obtain and activate an academic license +``````````````````````````````````````` + +**Obtaining a license:** register for an account using your academic email at `Gurobi's website `_. +Navigate to the Gurobi's `named academic license page `_, and follow the instructions to get your academic license key. + +**Activating license:** In your Conda environment, run: + +.. code-block:: bash + + grbgetkey YOUR_LICENSE_KEY + +Replace YOUR_LICENSE_KEY with the key you received. Follow the prompts to complete activation. + +Test the installation +````````````````````` + +Verify Gurobi's installation by running: + +.. code-block:: bash + + ultrack check_gurobi + +Troubleshooting +``````````````` + +Depending on the operating system, the gurobi library might be missing and you need to install it from `here `_. diff --git a/docs/source/napari.rst b/docs/source/napari.rst new file mode 100644 index 0000000..6429c47 --- /dev/null +++ b/docs/source/napari.rst @@ -0,0 +1,55 @@ +Napari plugin +------------- + +We wrapped up most of the functionality in a napari widget. The widget is already installed +by default, but you must have napari installed to use it. + +To use it, open napari and select the widget from the plugins menu selecting ``ultrack`` and then ``Ultrack`` +from the dropdown menu. + +The plugin is built around the concept of a tracking workflow. Any workflow is a sequence +of pre-processing steps, segmentation, (candidate segments) linking, and the tracking problem solver. +We explain the different workflows in the following sections. + +Workflows +````````` + +The difference between the workflows is the way the user provides the information to the plugin, +and the way it processes the information. The remaining steps are the same for all workflows. +In that sense, ``segmentation``, ``linking``, and ``solver`` are the same for all workflows. +For each step, the widget provides direct access to the parameters of the step, and the user can +change the parameters to adapt the workflow to the specific problem. We explain how these +parameters behave in :doc:`Configuration docs `, and, more specifically, in the +:class:`Experiment `, +:class:`Linking `, and +:class:`Tracking ` sections. Every input requested by the plugin +should be loaded beforehand as a layer in ``Napari``. + +There are three workflows available in the plugin: + +- **Automatic tracking from image**: This workflow is designed to track cells in a sequence of images. + It uses classical image processing techniques to detect the cells (foreground) and their possible contours. + In this workflow, you can change the parameters of the image processing steps. + Refer to the documentation of the functions used in the image processing steps: + + - :func:`ultrack.imgproc.detect_foreground` + - :func:`ultrack.imgproc.robust_invert` + +- **Manual tracking**: Since ultrack is designed to work with precomputed cell detection and + contour detection, this workflow is designed for the situation where the user has already + computed the cell detection and the contours of the cells. In this situation, no additional + parameter is needed, you only need to provide the cell detection and the contours of the cells. + +- **Automatic tracking from segmentation labels**: This workflow is designed to track cells + in a sequence of images where the user has already computed the segmentation of the cells. + This workflow wraps the function :meth:`ultrack.utils.labels_to_contours` to compute the foreground and + contours of the cells from the segmentation labels, refer to its documentation for additional details. + + +Flow Field Estimation +````````````````````` + +Every workflow allows the use of a flow field to improve the tracking of dynamic cells. +This method estimates the movement of the cells in the sequence +of images through the function :func:`ultrack.imgproc.flow.timelapse_flow`. +See the documentation of this function for additional details. diff --git a/docs/source/optimizing.rst b/docs/source/optimizing.rst new file mode 100644 index 0000000..17a1dbb --- /dev/null +++ b/docs/source/optimizing.rst @@ -0,0 +1,106 @@ +Tuning tracking performance +------------------------------- + +Once you have a working ultrack pipeline, the next step is optimizing the tracking performance. +Here we describe our guidelines for optimizing the tracking performance and up to what point you can expect to improve the tracking performance. + +It will be divided into a few sections: + +- Pre-processing: How to make tracking easier by pre-processing the data; +- Input verification: Guidelines to check if you have good `labels` or `foreground` and `contours` maps; +- Hard constraints: Parameters must be adjusted so the hypotheses include the correct solution; +- Tracking tuning: Guidelines to adjust the weights to make the correct solution more likely. + +Pre-processing +`````````````` + +Registration +^^^^^^^^^^^^ + +Before tracking, the first question to ask yourself is, are your frames correctly aligned? + +If not, we recommend aligning them. To do that, we provide the ``ultrack.imgproc.register_timelapse`` to align translations, see the :ref:`registration API `. + +If the movement is more complex, with cells moving in different directions, we recommend using the ``flow`` functionalities to align individual segments with distinct transforms, see the :doc:`flow tutorial `. +See the :ref:`flow estimation API ` for more information. + +Deep learning +^^^^^^^^^^^^^ + +Some deep learning models are sensitive to the contrast of your data, we recommend adjusting the contrast and removing background before applying them to improve their predictions. +See the :ref:`image processing utilities API ` for more information. + +Input verification +`````````````````` + +At this point, we assume you already have a ``labels`` image or a ``foreground`` and ``contours`` maps; + +You should check if ``labels`` or ``foreground`` contains every cell you want to track. +Any region that is not included in the ``labels`` or ``foreground`` will not be tracked and can only be fixed with post-processing. + +If you are using ``foreground`` and ``contours`` maps, you should check if the contours induce hierarchies that lead to your desired segmentation. + +This can be done by loading the ``contours`` in napari and viewing them over your original image with ``blending='additive'``. + +You want your ``contours`` image to have higher values in the boundary of cells and lower values inside it. +This indicates that these regions are more likely to be boundaries than the interior of cells. +Notice, that this notion is much more flexible than a real contour map, which is we can use an intensity image as a `contours` map or an inverted distance transform. + +In cells where this is not the case it is less likely ultrack will be able to separate them into individual segments. + +If your cells (nuclei) are convex it is worth trying the ``ultrack.imgproc.inverted_edt`` for the ``contours``. + +If even after going through the next steps you don't have successful results, I suggest looking for specialized solutions once you have a working pipeline. +Some of these solutions are `PlantSeg `_ for membranes or `GoNuclear `_ for nuclei. + + +Hard constraints +```````````````` + +This section is about adjusting the parameters so we have hypotheses that include the correct solution. + +Please refer to the :doc:`Configuration docs ` as we refer to different parameters. + +1. The expected cell size should be between ``segmentation_config.min_area`` and ``segmentation_config.max_area``. +Having a tight range assists in finding a good segmentation and significantly reduces the computation. +Our rule of thumb is to set the ``min_area`` to half the size of the expected cell or the smallest cell, *disregarding outliers*. +And the ``max_area`` to 1.25~1.5 the size of the largest cell, this is less problematic than the ``min_area``. + +2. ``linking_config.max_distance`` should be set to the maximum distance a cell can move between frames. +We recommend setting some tolerance, for example, 1.5 times the expected movement. + +Tracking tuning +``````````````` + +Once you have gone through the previous steps, you should have a working pipeline and now we can focus on the results and what can be done in each scenario. + +1. My cells are oversegmented (excessive splitting of cells): + - Increase the ``segmentation_config.min_area`` to merge smaller cells; + - Increase the ``segmentation_config.max_area`` to avoid splitting larger cells; + - If you have clear boundaries and the oversegmentation are around weak boundaries, you can increase the ``segmentation_config.min_frontier`` to merge them (steps of 0.05 recommended). + - If you're using ``labels`` as input or to create my contours you can also try to increase the ``sigma`` parameter to create a better surface to segmentation by avoiding flat regions (full of zeros or ones). + +2. My cells are undersegmented (cells are fused): + - Decrease the ``segmentation_config.min_area`` to enable segmenting smaller cells; + - Decrease the ``segmentation_config.max_area`` to remove larger segments that are likely to be fused cells; + - Decrease the ``segmentation_config.min_frontier`` to avoid merging cells that have weak boundaries; + - **EXPERIMENTAL**: Set ``segmentation_config.max_noise`` to a value greater than 0, to create more diverse hierarchies, the scale of this value should be proportional to the ``contours`` value, for example, if the ``contours`` is in the range of 0-1, the ``max_noise`` around 0-0.05 should be enough. Play with it. **NOTE**: the solve step will take longer because of the increased number of hypotheses. + +3. I have missing segments that are present on the ``labels`` or ``foreground``: + - Check if these cells are above the ``segmentation_config.threshold`` value, if not, decrease it; + - Check if ``linking_config.max_distance`` is too low and increase it, when cells don't have connections they are unlikely to be included in the solutions; + - Your ``tracking_config.appear_weight``, ``tracking_config.disappear_weight`` & ``tracking_config.division_weight`` penalization weights are too high (too negative), try bringing them closer to 0.0. **TIP**: We recommend adjusting ``disappear_weight`` weight first, because when tuning ``appear_weight`` you should balance out ``division_weight`` so appearing cells don't become fake divisions. A rule of thumb is to keep ``division_weight`` equal or higher (more negative) than ``appear_weight``. + +4. I'm not detecting enough dividing cells: + - Bring ``tracking_config.division_weight`` to a value closer to 0. + - Depending on your time resolution and your cell type, it might be the case where dividing cells move further apart, in this case, you should tune the ``linking_config.max_distance`` accordingly. + +5. I'm detecting too many dividing cells: + - Make ``tracking_config.division_weight`` more negative. + +6. My tracks are short and not continuous enough: + - This is tricky, once you have tried the previous steps, you can try making the ``tracking_config.{appear, division, disappear}_weight`` more negative, but this will remove low-quality tracks. + - Another option is to use ``ultrack.tracks.close_tracks_gaps`` to post process the tracks. + +7. I have many incorrect tracks connecting distant cells: + - Decrease the ``linking_config.max_distance`` to avoid connecting distant cells. If that can't be done because you will lose correct connections, then you should set ``linking_config.distance_weight`` to a value closer higher than 0, usually in very small steps (0.01). diff --git a/docs/source/rest_api.rst b/docs/source/rest_api.rst new file mode 100644 index 0000000..07871a2 --- /dev/null +++ b/docs/source/rest_api.rst @@ -0,0 +1,300 @@ +REST API +======== + +The ultrack REST API is a set of HTTP/Websockets endpoints that allow you to track your +data from an Ultrack server. +This is what enables the :doc:`Ultrack FIJI plugin `. + +The communication between the Ultrack server and the client is mainly done through websockets. +Allowing real-time responses for efficient communication between the server and the client. + +All the messages sent through the websocket are JSON messages. And there is always an +:class:`Experiment ` object encoded and sent within the message. +This object contains all the information about the experiment that is being run, including +the configuration (:class:`MainConfig `) of the experiment, the status of the +experiment (:class:`ExperimentStatus `), the experiment ID, and the experiment name. +When the experiment is concluded, this object will also contain the results of the +experiment, encoded in the fields ``final_segments_url`` (URL to the tracked segments path) +and ``tracks`` (JSON of napari complaint tracking format). + +**IMPORTANT:** The server must have access to the data shared by the client, for example, through the web, or a shared file system. +Because the server does not store the input data being processed, only the results of the experiment. + +Endpoints +--------- + +In the following sections, we will describe the available endpoints and the expected +payloads for each endpoint. + +Meta endpoint +^^^^^^^^^^^^^ + +To avoid keeping track of each endpoint, there is a single endpoint that returns the available +endpoints for the Ultrack server. This also allows for the Ultrack server to be more dynamic, +as the available endpoints can be changed without changing the client. This endpoint is +described below. + +.. describe:: GET /config/available + + This endpoint returns all the available endpoints for the Ultrack server. + The response is a JSON object with the following structure: + + .. code-block:: JSON + + { + "id_endpoint": { + "link": "/url/to/endpoint", + "human_name": "The title of the endpoint", + "config": { + "experiment": { + "name": "Experiment Name", + "config": "MainConfig()" + }, + "set_of_kwargs_1": {}, + "set_of_kwargs_2": {}, + "...", + "set_of_kwargs_n": {} + } + }, + "..." + } + + As you can see, the response is a JSON object with the keys being the endpoint ID + and the values being a JSON object with the keys `link`, `human_name`, and `config`. + The `link` key is the URL to the endpoint, the `human_name` key is the title of the endpoint, + and the `config` key is the expected input payload for the endpoint. + + The `config` key comprises the initial configuration of the experiment + (an instance of :class:`Experiment `), and a + possible set of keyword arguments that are expected by the endpoint. Those keyword + arguments are dependent on the endpoint and are described in the following sections. + + The `experiment` instance is initialized with the default configuration of the + :class:`MainConfig ` class. This configuration can be + changed by the client and sent to update the server. + +Experiment endpoints +^^^^^^^^^^^^^^^^^^^^ + +The experiment endpoints are the main endpoints of the Ultrack server. +They allow the client to run the experiments and get their respective results. + +.. _segment_auto_detect: +.. describe:: WEBSOCKET /segment/auto_detect + + This endpoint is a websocket endpoint that allows you to send an image (referenced + as ``image_channel_or_path``) to the server and get the segmentation of the image. + + This endpoint wraps the :func:`ultrack.imgproc.detect_foreground` function and the + :func:`ultrack.imgproc.robust_invert` function, which estimates + the foreground of the image and its contours by image processing techniques. + For that reason, one can override the default parameters of those functions by sending + the ``detect_foreground_kwargs`` and ``robust_invert_kwargs`` as keyword arguments. + Those keyword arguments will be passed to their respective function. + + This endpoint requires a JSON payload with the following structure: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "image_channel_or_path": "/path/to/image", + }, + "detect_foreground_kwargs": {}, + "robust_invert_kwargs": {}, + } + + and the reserver repeatedly returns the :class:`Experiment ` + JSON payload. For example: + + .. code-block:: JSON + + { + "id": 1, + "name": "Experiment Name", + "status": "segmenting", + "config": { + "..." + } + "start_time": "2021-01-01T00:00:00", + "end_time": "", + "std_log": "Segmenting frame 1...", + "err_log": "", + "data_url": "", + "image_channel_or_path": "/path/to/image", + "edges_channel_or_path": "", + "detection_channel_or_path": "", + "segmentation_channel_or_path": "", + "labels_channel_or_path": "", + "final_segments_url": "", + "tracks": "" + } + + Alternatively, if the image is an OME-ZARR file, the input data could be + a specific channel. In this case, the input data could be referenced as: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "data_url": "/path/to/image.ome.zarr", + "image_channel_or_path": "image_channel", + }, + "detect_foreground_kwargs": {}, + "robust_invert_kwargs": {}, + } + +All the other endpoints are similar to the :ref:`/segment/auto_detect ` endpoint, but they +are more specific to the segmentation labels of the image. The endpoints are described below. + +.. _segment_manual: +.. describe:: WEBSOCKET /segment/manual + + This endpoint is similar to the :ref:`/segment/auto_detect ` endpoint, but it allows the + client to manually provide cells' foreground mask and their multilevel contours. + This endpoint requires the following JSON payload: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "detection_channel_or_path": "/path/to/detection", + "edges_channel_or_path": "/path/to/contours", + }, + } + + Alternatively, if the image is an OME-ZARR file, the input data could be + a specific channel. In this case, the input data could be referenced as: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "data_url": "/path/to/image.ome.zarr", + "detection_channel_or_path": "detection_channel", + "edges_channel_or_path": "contours_channel", + }, + } + + For both cases, the server will send the :class:`Experiment ` + JSON payload. For example: + + .. code-block:: JSON + + { + "id": 1, + "name": "Experiment Name", + "status": "segmenting", + "config": { + "..." + } + "start_time": "2021-01-01T00:00:00", + "end_time": "", + "std_log": "Linking cells...", + "err_log": "", + "data_url": "", + "image_channel_or_path": "", + "edges_channel_or_path": "/path/to/contours", + "detection_channel_or_path": "/path/to/detection", + "segmentation_channel_or_path": "", + "labels_channel_or_path": "", + "final_segments_url": "", + "tracks": "" + } + +Last but not least, the following endpoint could be used in a situation where the client already has +the instance segmentation of the cells, for example, from Cellpose or StarDist. + +.. _segment_labels: +.. describe:: WEBSOCKET /segment/labels + + This endpoint is similar to the :ref:`/segment/auto_detect ` endpoint, but it allows the + client to provide pre-computed instance segmentation of the cells. + This endpoint wraps the :meth:`ultrack.utils.labels_to_contours` function, which computes the foreground and contours + of the cells from the instance segmentation. + + This endpoint requires the following JSON payload: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "labels_channel_or_path": "/path/to/labels", + }, + "labels_to_edges_kwargs": {}, + } + + Alternatively, if the image is an OME-ZARR file, the input data could be + a specific channel. In this case, the input data could be referenced as: + + .. code-block:: JSON + + { + "experiment": { + "name": "Experiment Name", + "config": "..." + "data_url": "/path/to/image.ome.zarr", + "labels_channel_or_path": "labels_channel", + }, + } + + For both cases, the server will send the :class:`Experiment ` + JSON payload. For example: + + .. code-block:: JSON + + { + "id": 1, + "name": "Experiment Name", + "status": "segmenting", + "config": { + "..." + } + "start_time": "2021-01-01T00:00:00", + "end_time": "", + "std_log": "Linking cells...", + "err_log": "", + "data_url": "", + "image_channel_or_path": "", + "edges_channel_or_path": "", + "detection_channel_or_path": "", + "segmentation_channel_or_path": "", + "labels_channel_or_path": "/path/to/labels", + "final_segments_url": "", + "tracks": "" + } + +Data export endpoints +^^^^^^^^^^^^^^^^^^^^^ + +.. describe:: GET /experiment/{experiment_id}/trackmate + + This endpoint allows the client to download the results of the experiment in the TrackMate + XML format. The client must provide the ``experiment_id`` in the URL. This id is obtained from + the :class:`Experiment ` instance that was executed. + The server will return an XML encoded within the response. + +Database Schema +--------------- + +All the data that is being processed by the Ultrack server is stored in a database. This +database is a SQLite database that is created when the server is started. The database +is used to store the results of the experiments and the configuration of the experiments. + +The database schema is the same as the one used in Ultrack, but with an additional table +to store the configuration and the status of the experiments. The schema is described below. + +.. autopydantic_model:: ultrack.api.database.Experiment + :members: + +.. autoclass:: ultrack.api.database.ExperimentStatus diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 0000000..3170303 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,34 @@ +Ultrack's Usage Examples +======================== + +Here we provide some examples of how to use Ultrack for cell tracking. + +Some examples are provided as Jupyter notebooks with additional documentation, but we do not recommend using Jupyter notebooks for your day-to-day analysis. + +Other examples as Python scripts can be found in `here `_. + +Additional packages might be required. Therefore, conda environment files are provided, which can be installed using: + +.. code-block:: bash + + conda env create -f + conda activate + pip install git+https://github.com/royerlab/ultrack + +The existing examples are: + +- `multi_color_ensemble <./multi_color_ensemble>`_ : Multi-colored cytoplasm cell tracking using Cellpose and Watershed segmentation ensemble. Data provided by `The Lammerding Lab `_. +- `flow_field_3d <./flow_field_3d>`_ : Tracking demo on a cartographic projection of Tribolium Castaneum embryo from the `cell-tracking challenge `_, using a flow field estimation to assist tracking of motile cells. +- `stardist_2d <./stardist_2d>`_ : Tracking demo on HeLa GPF nuclei from the `cell-tracking challenge `_ using Stardist 2D fluorescence images pre-trained model. +- `zebrahub <./zebrahub/>`_ : Tracking demo on zebrafish tail data from `zebrahub `_ acquired with `DaXi `_ using Ultrack's image processing helper functions. +- `neuromast_plantseg <./neuromast_plantseg/>`_ : Tracking demo membrane-labeled zebrafish neuromast from `Jacobo Group of CZ Biohub `_ using `PlantSeg's `_ membrane detection model. +- `micro_sam <./micro_sam/>`_ : Tracking demo with `MicroSAM `_ instance segmentation package. + +Development Notes +^^^^^^^^^^^^^^^^^ + +To run all the examples and update the notebooks in headless mode, run: + +.. code-block:: bash + + bash refresh_examples.sh diff --git a/ultrack/core/_test/test_interactive.py b/ultrack/core/_test/test_interactive.py new file mode 100644 index 0000000..bfa86a8 --- /dev/null +++ b/ultrack/core/_test/test_interactive.py @@ -0,0 +1,57 @@ +from typing import Tuple + +import numpy as np +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from ultrack import MainConfig, add_new_node +from ultrack.core.database import LinkDB, NodeDB, OverlapDB + + +def _get_table_sizes(session: Session) -> Tuple[int, int, int]: + return ( + session.query(NodeDB).count(), + session.query(LinkDB).count(), + session.query(OverlapDB).count(), + ) + + +@pytest.mark.parametrize( + "config_content", + [ + { + "data.database": "sqlite", + "segmentation.n_workers": 4, + "linking.n_workers": 4, + "linking.max_distance": 500, # too big and ignored + }, + ], + indirect=True, +) +def test_clear_solution( + linked_database_mock_data: MainConfig, +) -> None: + + mask = np.ones((7, 12, 12), dtype=bool) + bbox = np.array([15, 24, 24, 22, 36, 36], dtype=int) + + engine = create_engine(linked_database_mock_data.data_config.database_path) + with Session(engine) as session: + n_nodes, n_links, n_overlaps = _get_table_sizes(session) + + add_new_node( + linked_database_mock_data, + 0, + mask, + bbox, + ) + + new_n_nodes, new_n_links, new_n_overlaps = _get_table_sizes(session) + + assert new_n_nodes == n_nodes + 1 + assert new_n_overlaps > n_overlaps + # could smaller than max neighbors because of radius + assert ( + new_n_links == n_links + linked_database_mock_data.linking_config.max_neighbors + ) diff --git a/ultrack/core/export/_test/test_exporter.py b/ultrack/core/export/_test/test_exporter.py new file mode 100644 index 0000000..8595b90 --- /dev/null +++ b/ultrack/core/export/_test/test_exporter.py @@ -0,0 +1,42 @@ +from pathlib import Path + +from ultrack import MainConfig, export_tracks_by_extension + + +def test_exporter(tracked_database_mock_data: MainConfig, tmp_path: Path) -> None: + file_ext_list = [".xml", ".csv", ".zarr", ".dot", ".json"] + last_modified_time = {} + for file_ext in file_ext_list: + tmp_file = tmp_path / f"tracks{file_ext}" + export_tracks_by_extension(tracked_database_mock_data, tmp_file) + + # assert file exists + assert (tmp_path / f"tracks{file_ext}").exists() + # assert file size is not zero + assert (tmp_path / f"tracks{file_ext}").stat().st_size > 0 + + # store last modified time + last_modified_time[str(tmp_file)] = tmp_file.stat().st_mtime + + # loop again testing overwrite=False + for file_ext in file_ext_list: + tmp_file = tmp_path / f"tracks{file_ext}" + try: + export_tracks_by_extension( + tracked_database_mock_data, tmp_file, overwrite=False + ) + assert False, "FileExistsError should be raised" + except FileExistsError: + pass + + # loop again testing overwrite=True + for file_ext in file_ext_list: + tmp_file = tmp_path / f"tracks{file_ext}" + export_tracks_by_extension(tracked_database_mock_data, tmp_file, overwrite=True) + + # assert file exists + assert (tmp_path / f"tracks{file_ext}").exists() + # assert file size is not zero + assert (tmp_path / f"tracks{file_ext}").stat().st_size > 0 + + assert last_modified_time[str(tmp_file)] != tmp_file.stat().st_mtime diff --git a/ultrack/core/export/exporter.py b/ultrack/core/export/exporter.py new file mode 100644 index 0000000..9583251 --- /dev/null +++ b/ultrack/core/export/exporter.py @@ -0,0 +1,74 @@ +import json +from pathlib import Path +from typing import Union + +import networkx as nx + +from ultrack.config import MainConfig +from ultrack.core.export import ( + to_networkx, + to_trackmate, + to_tracks_layer, + tracks_to_zarr, +) + + +def export_tracks_by_extension( + config: MainConfig, filename: Union[str, Path], overwrite: bool = False +) -> None: + """ + Export tracks to a file given the file extension. + + Supported file extensions are .xml, .csv, .zarr, .dot, and .json. + - `.xml` exports to a TrackMate compatible XML file. + - `.csv` exports to a CSV file. + - `.zarr` exports the tracks to dense segments in a `zarr` array format. + - `.dot` exports to a Graphviz DOT file. + - `.json` exports to a networkx JSON file. + + Parameters + ---------- + filename : str or Path + The name of the file to save the tracks to. + config : MainConfig + The configuration object. + overwrite : bool, optional + Whether to overwrite the file if it already exists, by default False. + + See Also + -------- + to_trackmate : + Export tracks to a TrackMate compatible XML file. + to_tracks_layer : + Export tracks to a CSV file. + tracks_to_zarr : + Export tracks to a `zarr` array. + to_networkx : + Export tracks to a networkx graph. + """ + if Path(filename).exists() and not overwrite: + raise FileExistsError( + f"File {filename} already exists. Set `overwrite=True` to overwrite the file" + ) + + file_ext = Path(filename).suffix + if file_ext.lower() == ".xml": + to_trackmate(config, filename, overwrite=True) + elif file_ext.lower() == ".csv": + df, _ = to_tracks_layer(config, include_parents=True) + df.to_csv(filename, index=False) + elif file_ext.lower() == ".zarr": + df, _ = to_tracks_layer(config) + tracks_to_zarr(config, df, filename, overwrite=True) + elif file_ext.lower() == ".dot": + G = to_networkx(config) + nx.drawing.nx_pydot.write_dot(G, filename) + elif file_ext.lower() == ".json": + G = to_networkx(config) + json_data = nx.node_link_data(G) + with open(filename, "w") as f: + json.dump(json_data, f) + else: + raise ValueError( + f"Unknown file extension: {file_ext}. Supported extensions are .xml, .csv, .zarr, .dot, and .json." + ) diff --git a/ultrack/core/interactive.py b/ultrack/core/interactive.py new file mode 100644 index 0000000..18a8ad8 --- /dev/null +++ b/ultrack/core/interactive.py @@ -0,0 +1,316 @@ +from typing import List, Optional, Tuple + +import numpy as np +from numpy.typing import ArrayLike +from sqlalchemy import create_engine, func +from sqlalchemy.orm import Session + +from ultrack.config.config import LinkingConfig, MainConfig +from ultrack.core.database import LinkDB, NodeDB, OverlapDB +from ultrack.core.segmentation.node import Node + + +def _nearest_neighbors( + data_arr: ArrayLike, + node: Node, + n_neighbors: int, + max_distance: Optional[float], + scale: Optional[ArrayLike], +) -> np.ndarray: + """ + Returns the indices of the `n_neighbors` nearest neighbors to `node` in `data_arr`. + + Parameters + ---------- + data_arr : ArrayLike + Array of (id, z, y, x) coordinates of the nodes (or centroids). + node : Node + Node to find neighbors for. + n_neighbors : int + Number of neighbors to be considered. + max_distance : float + Maximum distance to be considered. + scale : Optional[ArrayLike], optional + Scaling factor for the distance, by default None. + + Returns + ------- + np.ndarray + Indices of the nearest neighbors. + """ + if scale is None: + scale = np.ones(len(node.centroid)) + + differences = data_arr[:, -len(node.centroid) :] - node.centroid + differences *= scale + sqdist = np.square(differences).sum(axis=1) + + if max_distance is not None: + valid = sqdist <= (max_distance * max_distance) + data_arr = data_arr[valid] + sqdist = sqdist[valid] + + indices = np.argsort(sqdist) + + return data_arr[indices[:n_neighbors], 0] + + +def _find_links( + session: Session, + node: Node, + adj_time: int, + scale: Optional[ArrayLike], + link_config: LinkingConfig, +) -> List[Tuple[NodeDB, float]]: + """ + Finds links for a given node and time. + + Parameters + ---------- + session : Session + SQLAlchemy session. + node : Node + Node to search for neighbors. + adj_time : int + Adjacent time point. + scale : Optional[ArrayLike], optional + Scaling factor for the distance, by default None. + link_config : LinkingConfig + Linking configuration parameters. + + Returns + ------- + List[NodeDB, float] + List of nodes and their weights. + """ + data = np.asarray( + session.query( + NodeDB.id, + NodeDB.z, + NodeDB.y, + NodeDB.x, + ) + .where(NodeDB.t == adj_time) + .all() + ) + if len(data) == 0: + return [] + + ind = _nearest_neighbors( + data, + node, + 2 * link_config.max_neighbors, + link_config.max_distance, + scale=scale, + ) + + neigh_nodes = session.query(NodeDB.pickle).where(NodeDB.id.in_(ind)).all() + + if scale is None: + scale = np.ones(len(node.centroid)) + + neigh_nodes_with_dist = [] + for (n,) in neigh_nodes: + dist = np.linalg.norm((n.centroid - node.centroid) * scale) + w = node.IoU(n) - link_config.distance_weight * dist + neigh_nodes_with_dist.append((n, w)) + + neigh_nodes_with_dist.sort(key=lambda x: x[1], reverse=True) + + return neigh_nodes_with_dist[: link_config.max_neighbors] + + +def _add_overlaps( + session: Session, + node: Node, + n_neighbors: int = 10, + scale: Optional[ArrayLike] = None, +) -> None: + """ + Adds overlaps to the database. + + Parameters + ---------- + session : Session + SQLAlchemy session. + node : Node + Node to find overlaps with. + n_neighbors : int, optional + Number of neighbors to be considered, by default 10. + scale : Optional[ArrayLike], optional + Scaling factor for the distance, by default None. + """ + data = np.asarray( + session.query( + NodeDB.id, + NodeDB.z, + NodeDB.y, + NodeDB.x, + ) + .where(NodeDB.t == node.time, NodeDB.id != node.id) + .all() + ) + ind = _nearest_neighbors(data, node, n_neighbors, max_distance=None, scale=scale) + + overlaps = [] + + for (neigh_node,) in session.query(NodeDB.pickle).where(NodeDB.id.in_(ind)).all(): + if node.IoU(neigh_node) > 0.0: + overlaps.append( + OverlapDB( + node_id=node.id, + ancestor_id=neigh_node.id, + ) + ) + + session.add_all(overlaps) + + +def _add_links( + session: Session, + node: Node, + link_config: LinkingConfig, + scale: Optional[ArrayLike] = None, +) -> None: + """ + Adds T - 1 and T + 1 links to the database. + + NOTE: this is not taking node shifts into account. + + Parameters + ---------- + session : Session + SQLAlchemy session. + node : Node + Node to search for neighbors. + link_config : LinkingConfig + Linking configuration parameters. + scale : Optional[ArrayLike], optional + Scaling factor for the distance, by default None. + """ + links = [] + + before_links = _find_links( + session=session, + node=node, + adj_time=node.time - 1, + scale=scale, + link_config=link_config, + ) + for before_node, w in before_links: + links.append( + LinkDB( + source_id=before_node.id, + target_id=node.id, + weight=w, + ) + ) + + after_links = _find_links( + session=session, + node=node, + adj_time=node.time + 1, + scale=scale, + link_config=link_config, + ) + for after_node, w in after_links: + links.append( + LinkDB( + source_id=node.id, + target_id=after_node.id, + weight=w, + ) + ) + + session.add_all(links) + + +def add_new_node( + config: MainConfig, + time: int, + mask: ArrayLike, + bbox: Optional[ArrayLike] = None, + index: Optional[int] = None, + include_overlaps: bool = True, +) -> int: + """ + Adds a new node to the database. + + NOTE: this is not taking node shifts or image features (color) into account. + + Parameters + ---------- + config : MainConfig + Ultrack configuration parameters. + time : int + Time point of the node. + mask : ArrayLike + Binary mask of the node. + bbox : Optional[ArrayLike], optional + Bounding box of the node, (min_0, min_1, ..., max_0, max_1, ...). + When provided it assumes the mask is a crop of the original image, by default None + index : Optional[int], optional + Node index, otherwise it is automatically generated, and returned. + include_overlaps : bool, optional + Include overlaps in the database, by default True + When False it will allow oclusions between new node and existing nodes. + + Returns + ------- + int + New node index. + """ + + node = Node.from_mask( + time=time, + mask=mask, + bbox=bbox, + ) + if node.area == 0: + raise ValueError("Node area is zero. Something went wrong.") + + scale = config.data_config.metadata.get("scale") + + engine = create_engine(config.data_config.database_path) + with Session(engine) as session: + + # querying required data + if index is None: + node.id = ( + int(session.query(func.max(NodeDB.id)).where(NodeDB.t == time).scalar()) + + 1 + ) + else: + node.id = index + + # adding node + if len(node.centroid) == 2: + y, x = node.centroid + z = 0 + else: + z, y, x = node.centroid + + node_db_obj = NodeDB( + id=node.id, + t=node.time, + z=z, + y=y, + x=x, + area=node.area, + pickle=node, + ) + session.add(node_db_obj) + + if include_overlaps: + _add_overlaps(session=session, node=node, scale=scale) + + _add_links( + session=session, + node=node, + link_config=config.linking_config, + scale=scale, + ) + + session.commit() + + return node.id diff --git a/ultrack/imgproc/_test/test_register.py b/ultrack/imgproc/_test/test_register.py new file mode 100644 index 0000000..ebfc742 --- /dev/null +++ b/ultrack/imgproc/_test/test_register.py @@ -0,0 +1,38 @@ +from typing import Tuple + +import numpy as np +import pytest +import scipy.ndimage as ndi +import zarr + +from ultrack.imgproc import register_timelapse + + +@pytest.mark.parametrize( + "timelapse_mock_data", + [ + {"length": 3, "size": 32, "n_dim": 3}, + ], + indirect=True, +) +def test_register_timelapse( + timelapse_mock_data: Tuple[zarr.Array, zarr.Array, zarr.Array], +) -> None: + _, moved_edges, _ = timelapse_mock_data + + shift = 8 + + # adding a new to emulate a channel + moved_edges = moved_edges[...][:, None] + + for i in range(moved_edges.shape[0]): + moved_edges[i] = ndi.shift(moved_edges[i], (0, i * shift // 2, 0, 0), order=1) + + fixed_edges = register_timelapse(moved_edges, reference_channel=0, padding=shift) + + for i in range(moved_edges.shape[0] - 1): + # removing padding and out of fov regions + volume = fixed_edges[i, :, : -2 * shift] + next_vol = fixed_edges[i + 1, :, : -2 * shift] + + np.testing.assert_allclose(volume, next_vol) diff --git a/ultrack/imgproc/register.py b/ultrack/imgproc/register.py new file mode 100644 index 0000000..5a32bb1 --- /dev/null +++ b/ultrack/imgproc/register.py @@ -0,0 +1,121 @@ +import logging +from typing import Callable, Optional, Union + +import numpy as np +import zarr +from numpy.typing import ArrayLike +from tqdm import tqdm +from zarr.storage import Store + +from ultrack.utils.array import create_zarr +from ultrack.utils.cuda import import_module, to_cpu + +LOG = logging.getLogger(__name__) + + +def register_timelapse( + timelapse: ArrayLike, + *, + store_or_path: Union[Store, str, None] = None, + overwrite: bool = False, + to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, + reference_channel: Optional[int] = None, + overlap_ratio: float = 0.25, + normalization: Optional[str] = None, + padding: Optional[int] = None, + **kwargs, +) -> zarr.Array: + """ + Register a timelapse sequence using phase cross correlation. + + Parameters + ---------- + timelapse : ArrayLike + Input timelapse sequence, T(CZ)YX array C and Z are optional. + NOTE: when provided, C must be the second dimension after T. + store_or_path : Union[Store, str, None], optional + Zarr storage or a file path, to save the output, useful for larger than memory datasets. + By default it loads the data into memory. + overwrite : bool, optional + Overwrite output file if it already exists, when using directory store or a path. + to_device : Callable[[ArrayLike], ArrayLike], optional + Function to move the input data to cuda device, by default lambda x: x (CPU). + reference_channel : Optional[int], optional + Reference channel for registration, by default None. + It must be provided when it contains a channel dimension. + overlap_ratio : float, optional + Overlap ratio for phase cross correlation, by default 0.25. + normalization : Optional[str], optional + Normalization method for phase cross correlation, by default None. + padding : Optional[int], optional + Padding for registration, by default None. + **kwargs + Additional arguments for phase cross correlation. See `skimage.registration phase_cross_correlation + `_. # noqa: E501 + + Returns + ------- + zarr.Array + Registered timelapse sequence. + """ + shape = list(timelapse.shape) + + if padding is not None: + offset = 1 if reference_channel is None else 2 + pads = [(0, 0)] * (offset - 1) + + for i in range(offset, len(shape)): + shape[i] += 2 * padding + pads.append((padding, padding)) + + def maybe_pad(x: np.ndarray) -> np.ndarray: + x = np.asarray(x) + x = to_device(x) + return np.pad(x, pads, mode="constant") + + else: + + def maybe_pad(x: np.ndarray) -> np.ndarray: + x = np.asarray(x) + return to_device(x) + + out_arr = create_zarr( + tuple(shape), + dtype=timelapse.dtype, + store_or_path=store_or_path, + overwrite=overwrite, + ) + + if reference_channel is None: + channel = ... + else: + channel = reference_channel + + prev_frame = maybe_pad(timelapse[0]) + out_arr[0] = to_cpu(prev_frame) + + ndi = import_module("scipy", "ndimage", arr=prev_frame) + skreg = import_module("skimage", "registration", arr=prev_frame) + + for t in tqdm(range(timelapse.shape[0] - 1), "Registration"): + next_frame = maybe_pad(timelapse[t + 1]) + shift, _, _ = skreg.phase_cross_correlation( + prev_frame[channel], + next_frame[channel], + overlap_ratio=overlap_ratio, + normalization=normalization, + **kwargs, + ) + + LOG.info("Shift at {t}: {shift}", t=t, shift=shift) + print(f"Shift at {t}: {shift}") + + if reference_channel is not None: + shift = (0, *shift) + + next_frame = ndi.shift(next_frame, shift, order=1) + out_arr[t + 1] = to_cpu(next_frame) + + prev_frame = next_frame + + return out_arr diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index e8c2358..057728b 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -14,7 +14,7 @@ from zarr.storage import Store from ultrack.core.database import NodeDB -from ultrack import MainConfig +from ultrack.config import MainConfig LOG = logging.getLogger(__name__) @@ -219,7 +219,7 @@ def __init__( self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.Tmax = self.shape[0] + self.t_max = self.shape[0] self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) self.export_func = self.array.__setitem__ diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 45fa720..36bb39c 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -575,7 +575,7 @@ def _cancel(self): @staticmethod def find_ultrack_widget(viewer: napari.Viewer) -> Optional["UltrackWidget"]: - + """Find and returns Ultrack widget. If widget not found, returns None""" for _, w in viewer.window._dock_widgets.items(): if isinstance(w.widget(), UltrackWidget): return w.widget() From 5cb9a389139d15213f1cf2c892f3c0c82382a59b Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 1 Aug 2024 16:00:43 -0700 Subject: [PATCH 17/45] renaming functions in ultrack-array --- ultrack/utils/array.py | 96 +++++++++++-------------- ultrack/widgets/hierarchy_viz_widget.py | 23 +++--- 2 files changed, 55 insertions(+), 64 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 057728b..38c106a 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -222,12 +222,10 @@ def __init__( self.t_max = self.shape[0] self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) - self.export_func = self.array.__setitem__ self.database_path = config.data_config.database_path self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) - self.initial_volume = self.volume.copy() def __getitem__(self, indexing: tuple, @@ -252,32 +250,29 @@ def __getitem__(self, try: time = time.item() # convert from numpy.int to int - except: + except AttributeError: time = time - self.query_volume( + self.fill_array( time=time, - buffer=self.array, ) return self.array[volume_slicing] - def query_volume( + def fill_array( self, time: int, - buffer: np.array, ) -> None: """Paint all segments of specific time point which volume is bigger than self.volume Parameters ---------- time : int time point to paint the segments - buffer : np.array - np.zeros to be filled with segments """ - + + self.array.fill(0) + engine = sqla.create_engine(self.database_path) - buffer.fill(0) with Session(engine) as session: query = list( @@ -308,12 +303,36 @@ def query_volume( for idx in idx_to_plot: query[idx][1].paint_buffer( - buffer, value=label_list[idx], include_time=False + self.array, value=label_list[idx], include_time=False ) - return query + def get_tp_num_pixels( + self, + timeStart:int, + timeStop:int, + ) -> list: + """Gets a list of number of pixels of all segments range of time points (timeStart to timeStop) + + Parameters + ---------- + timeStart : int + timeStop : int + + Returns + ------- + np.array : np.array + array with two elements: [min_volume, max_volume] + """ + engine = sqla.create_engine(self.database_path) + num_pix_list = [] + with Session(engine) as session: + query = list(session.query(NodeDB.area).where(NodeDB.t >= timeStart).where(NodeDB.t <= timeStop)) + for num_pix in query: + num_pix_list.append(int(np.array(num_pix))) + return num_pix_list + - def find_minmax_volumes_1_timepoint( + def get_tp_num_pixels_minmax( self, time: int, ) -> np.ndarray: @@ -325,21 +344,13 @@ def find_minmax_volumes_1_timepoint( Returns ------- - np.array : np.array - array with two elements: [min_volume, max_volume] + num_pix_list : list + list with all num_pixels from t=0 to t=timeLimit """ - engine = sqla.create_engine(self.database_path) - min_vol = np.inf - max_vol = 0 - with Session(engine) as session: - query = list(session.query(NodeDB.pickle).where(NodeDB.t == time)) - for node in query: - vol = node[0].area - if vol < min_vol: - min_vol = vol - if vol > max_vol: - max_vol = vol - return np.array([min_vol, max_vol]).astype(int) + num_pixels_list = self.get_tp_num_pixels(time,time) + return (min(num_pixels_list),max(num_pixels_list)) + + def find_min_max_volume_entire_dataset(self): """Find minimum and maximum segment volume for ALL time point @@ -352,35 +363,10 @@ def find_min_max_volume_entire_dataset(self): min_vol = np.inf max_vol = 0 for t in range(self.t_max): #range(self.shape[0]): - minmax = self.find_minmax_volumes_1_timepoint(t) + minmax = self.get_tp_num_pixels_minmax(t) if minmax[0] < min_vol: min_vol = minmax[0] if minmax[1] > max_vol: max_vol = minmax[1] - return np.array([min_vol, max_vol], dtype=int) - - def get_volume_list( - self, - timeLimit: int, - ) -> list: - """Creates a list of the volumes of all segments in the database (up untill t=timeLimit) - - Parameters - ---------- - timeLimit : int - - Returns - ------- - vol_list : list - list with all volumes from t=0 to t=timeLimit - """ - engine = sqla.create_engine(self.database_path) - vol_list = [] - with Session(engine) as session: - query = list(session.query(NodeDB.pickle).where(NodeDB.t <= timeLimit)) - for node in query: - vol = node[0].area - vol_list.append(vol) - - return vol_list + return np.array([min_vol, max_vol], dtype=int) \ No newline at end of file diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index 2f2ddf4..4b5ee34 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -42,12 +42,12 @@ def __init__(self, self.config = config self.ultrack_array = UltrackArray(self.config) - self._viewer.add_labels(self.ultrack_array, name='hierarchy') self.mapping = self._create_mapping() self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) self._area_threshold_w.value = 0.5 + self.ultrack_array.volume = self.mapping(0.5) self._area_threshold_w.changed.connect(self._slider_update) self.slider_label = Label(label=str(int(self.mapping(self._area_threshold_w.value)))) @@ -56,6 +56,11 @@ def __init__(self, self.append(self._area_threshold_w) self.append(self.slider_label) + self._viewer.add_labels(self.ultrack_array, name='hierarchy') + #THERE SHOULD BE CHECK HERE IF THERE EXISTS A LAYER WITH THE NAME 'HIERARCHY' + self._viewer.layers['hierarchy'].refresh() + + def _on_config_changed(self) -> None: self._ndim = len(self._shape) @@ -70,16 +75,16 @@ def _slider_update(self, value: float) -> None: def _create_mapping(self): """ - Creates a pseudo-linear mapping from U[0,1] to full range of segment volumes: - volume = mapping([0,1]) + Creates a pseudo-linear mapping from U[0,1] to full range of number of pixels: + num_pixels = mapping([0,1]) """ - volume_list = self.ultrack_array.get_volume_list(timeLimit=5) - volume_list.append(self.ultrack_array.minmax[0]) - volume_list.append(self.ultrack_array.minmax[1]) - volume_list.sort() + num_pixels_list = self.ultrack_array.get_tp_num_pixels(timeStart=0,timeStop=5) + num_pixels_list.append(self.ultrack_array.minmax[0]) + num_pixels_list.append(self.ultrack_array.minmax[1]) + num_pixels_list.sort() - x_vec = np.linspace(0,1,len(volume_list)) - y_vec = np.array(volume_list) + x_vec = np.linspace(0,1,len(num_pixels_list)) + y_vec = np.array(num_pixels_list) mapping = interpolate.interp1d(x_vec,y_vec) return mapping From a1f8c4c80df021f7e1d3026586ce94e29c7a1cd2 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 1 Aug 2024 16:12:16 -0700 Subject: [PATCH 18/45] reverting dirty history changes --- .gitattributes | 2 - .github/dependabot.yaml | 10 - .gitmodules | 0 docs/README.md | 29 -- docs/source/_static/css/style.css | 27 -- docs/source/configuration.rst | 37 -- docs/source/examples.rst | 15 - docs/source/fiji.rst | 6 - docs/source/getting_started.rst | 98 ------ docs/source/install.rst | 82 ----- docs/source/napari.rst | 55 --- docs/source/optimizing.rst | 106 ------ docs/source/rest_api.rst | 300 ----------------- examples/README.rst | 34 -- ultrack/core/_test/test_interactive.py | 57 ---- ultrack/core/export/_test/test_exporter.py | 42 --- ultrack/core/export/exporter.py | 74 ---- ultrack/core/interactive.py | 316 ------------------ ultrack/imgproc/_test/test_register.py | 38 --- ultrack/imgproc/register.py | 121 ------- ultrack/utils/array.py | 100 +++--- ultrack/widgets/hierarchy_viz_widget.py | 23 +- .../widgets/ultrackwidget/ultrackwidget.py | 2 +- 23 files changed, 67 insertions(+), 1507 deletions(-) delete mode 100644 .gitattributes delete mode 100644 .github/dependabot.yaml delete mode 100644 .gitmodules delete mode 100644 docs/README.md delete mode 100644 docs/source/_static/css/style.css delete mode 100644 docs/source/configuration.rst delete mode 100644 docs/source/examples.rst delete mode 100644 docs/source/fiji.rst delete mode 100644 docs/source/getting_started.rst delete mode 100644 docs/source/install.rst delete mode 100644 docs/source/napari.rst delete mode 100644 docs/source/optimizing.rst delete mode 100644 docs/source/rest_api.rst delete mode 100644 examples/README.rst delete mode 100644 ultrack/core/_test/test_interactive.py delete mode 100644 ultrack/core/export/_test/test_exporter.py delete mode 100644 ultrack/core/export/exporter.py delete mode 100644 ultrack/core/interactive.py delete mode 100644 ultrack/imgproc/_test/test_register.py delete mode 100644 ultrack/imgproc/register.py diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index 07fe41c..0000000 --- a/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -# GitHub syntax highlighting -pixi.lock linguist-language=YAML linguist-generated=true diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml deleted file mode 100644 index 2390d8c..0000000 --- a/.github/dependabot.yaml +++ /dev/null @@ -1,10 +0,0 @@ -version: 2 -updates: - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: "monthly" - groups: - github-actions: - patterns: - - "*" diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e69de29..0000000 diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index f9bde74..0000000 --- a/docs/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Building docs instructions - -This assumes you have already cloned the repository and are in the root directory of the repository. - -Go to the docs directory and install the requirements - -```bash -cd docs -pip install '..[docs]' -``` - -Clean and build the docs with - -```bash -make clean -make html -``` - -In Linux, open the generated html file with - -```bash -xdg-open build/html/index.html -``` - -or in macOS - -```bash -open build/html/index.html -``` diff --git a/docs/source/_static/css/style.css b/docs/source/_static/css/style.css deleted file mode 100644 index 3273ab9..0000000 --- a/docs/source/_static/css/style.css +++ /dev/null @@ -1,27 +0,0 @@ -p { - line-height: 1.5; /* Adjust this value as needed */ - margin-top: 5px; /* Space above each paragraph */ - margin-bottom: 0px; /* Space below each paragraph */ -} - -li { - margin-top: 0px; /* Adjust this value as needed */ - margin-bottom: 0px; /* Adjust this value as needed */ -} - -ul, ol { - margin-top: 0px; /* Margin around the list */ -} - -h1 { - margin-top: 40px; -} - -h2 { - margin-top: 30px; -} - -h3 { - margin-top: 20px; - font-size: 1.0em; -} diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst deleted file mode 100644 index 60ade3a..0000000 --- a/docs/source/configuration.rst +++ /dev/null @@ -1,37 +0,0 @@ -Configuration -------------- - -The configuration is at the heart of ultrack, it is used to define the parameters for each step of the pipeline and where to store the intermediate results. -The `MainConfig` is the main configuration that contains the other configurations of the individual steps plus the data configuration. - -The configurations are documented below, the parameters are ordered by importance, most important parameters are at the top of the list. Parameters that should not be changed in most of the cases are at the bottom of the list and contain a ``SPECIAL`` tag. - -.. autosummary:: - - ultrack.config.MainConfig - ultrack.config.DataConfig - ultrack.config.SegmentationConfig - ultrack.config.LinkingConfig - ultrack.config.TrackingConfig - ---------------- - -.. autopydantic_model:: ultrack.config.MainConfig - ---------------- - -.. autopydantic_model:: ultrack.config.DataConfig - ---------------- - -.. autopydantic_model:: ultrack.config.SegmentationConfig - ---------------- - -.. autopydantic_model:: ultrack.config.LinkingConfig - ---------------- - -.. _tracking_config: - -.. autopydantic_model:: ultrack.config.TrackingConfig diff --git a/docs/source/examples.rst b/docs/source/examples.rst deleted file mode 100644 index be31503..0000000 --- a/docs/source/examples.rst +++ /dev/null @@ -1,15 +0,0 @@ -Examples --------- - -.. include:: examples/README.rst - :start-line: 2 - :end-line: 25 - -Notebooks ---------- - -.. nbgallery:: - :maxdepth: 2 - :glob: - - examples/**/* diff --git a/docs/source/fiji.rst b/docs/source/fiji.rst deleted file mode 100644 index f68e75e..0000000 --- a/docs/source/fiji.rst +++ /dev/null @@ -1,6 +0,0 @@ -FIJI plugin ------------ - -Ultrack is also available as a `FIJI `_ plugin. - -Its usage and installation instructions are in FIJI's ultrack `documentation `_. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst deleted file mode 100644 index 56aabaa..0000000 --- a/docs/source/getting_started.rst +++ /dev/null @@ -1,98 +0,0 @@ -Getting started ---------------- - -Ultrack tracking pipeline is divided into three main steps: - -- ``segment``: Creating the candidate segmentation hypotheses; -- ``link``: Finding candidate links between segmentation hypotheses of adjacent frames; -- ``solve``: Solving the tracking problem by finding the best segmentation and trajectory for each cell. - -These three steps have their respective function with the same names and configurations but are summarized in the ``track`` function or the ``Tracker`` class, which are the main entry point for the tracking pipeline. - -You'll notice that these functions do not return any results. Instead, they store the results in a database. This enables us to process datasets larger than memory, and distributed or parallel computing. We provide auxiliary functions to export the results to a format of your choice. - -The ``MainConfig.data_config`` provides the interface to interact with the database, so beware of using ``overwrite`` parameter when re-executing these functions, to erase previous results otherwise it will build on top of existing data. - -If you want to go deep into the weeds of our backend. We recommend looking at the ``ultrack.core.database.py`` file. - -Each one of the main steps will be explained in detail below, a detailed description of the parameters can be found in :doc:`configuration`. - -Segmentation -```````````` - -Ultrack's canonical inputs are a ``foreground`` and a ``contours``, there are several ways to obtain these inputs, which will be explained below. For now, let's consider we are working with them directly. - -Both ``foreground`` and ``contours`` maps must have the same shape, with the first dimension being time (``T``) and the remaining being the spatial dimensions (``Z`` optional, and ``Y``, ``X``). - -``foreground`` is used with ``config.segmentation_config.threshold`` to create a binary mask indicating the presence of the object of interest, it's by default 0.5. Values above the threshold are considered as foreground, and values below are considered as background. - -``contours`` indicates the confidence of each pixel (voxel) being a cell boundary (contour). The higher the value, the more likely it is a cell boundary. It couples with ``config.segmentation_config.min_frontier`` which fuses segmentation candidates separated by a boundary with an average value below this threshold, it's by default 0, so no fusion is performed. - -The segmentation is the most important step, as it will define the candidates for the tracking problem. -If your cells of interest are not present in the ``foreground`` after the threshold, you won't be able to track them. -If there isn't any faint boundary in the ``contours``, you won't be able to split into individual cells. That's why it's preferred to have a lot of contours (more hypotheses), even incorrect ones than having none. - -Linking -``````` - -The linking step is responsible for finding candidate links between segmentation hypotheses of adjacent frames. Usually, not a lot of candidates are needed (``config.linking_config.max_neighbors = 5``), unless you have several segmentation hypotheses (``contours`` with several gray levels). - -The parameter ``config.linking_config.max_distance`` must be at least the maximum distance between two cells in consecutive frames. It's used to filter out candidates that are too far from each other. If this value is too small, you won't be able to link cells that are far from each other. - -Solving -``````` - -The solving step is responsible for solving the tracking problem by finding the best segmentation and trajectory for each cell. The parameters for this step are harder to interpret, as they are related to the optimization problem. The most important ones are: - -- ``config.tracking_config.appear_weight``: The penalization for a cell to appear, which means to start a new lineage; -- ``config.tracking_config.division_weight``: The penalization for a cell to divide, breaking a single tracklet into two; -- ``config.tracking_config.disappear_weight``: The penalization for a cell to disappear, which means to end a lineage; - -These weights are negative or zero, as they try to balance the cost of including new lineages in the final solution. The connections (links) between segmentation hypotheses are positive and measure the quality of the tracks, so only lineages with a total linking weight higher than the penalizations are included in the final solution. At the same time, our optimization problem is finding the combination of connections that maximize the sum of weights of all lineages. - -See the :ref:`tracking configuration description ` for more information and :doc:`optimizing` for details on how to select these parameters. - - -Exporting -````````` - -Once the above steps have been applied, the tracking solutions are recorded in the database and they can be exported to a format of your choice, them being, ``to_networkx``, ``to_trackmate``, ``to_tracks_layer``, ``tracks_to_zarr`` and others. - -See the :ref:`export API reference ` for all available options and their parameters. - -Example of exporting solutions to napari tracks layer: - -.. code-block:: python - - # ... tracking computation - - # Exporting to napari format using `Tracker` class - tracks, graph = tracker.to_tracks_layer() - - # Exporting using config file - tracks, graph = to_tracks_layer(config) - - -Post-processing -``````````````` - -We also provide some additional post-processing functions, to remove, join, or analyze your tracks. Most of them are available in ``ultrack.tracks``. Some examples are: - -- ``close_tracks_gaps``: That closes gaps by joining tracklets and interpolating the missing segments; -- ``filter_short_sibling_tracks``: That removes short tracklets generated by false divisions; -- ``get_subgraph``: Which returns the whole lineage(s) of a given tracklet. - -Other functionalities can be found in ``ultrack.utils`` or ``ultrack.imgproc``, one notable example is: - -- ``tracks_properties``: Which returns compute statistics from the tracks, segmentation masks and images. - -For additional information, please refer to the :ref:`tracks post-processing API reference `. - -Image processing -```````````````` - -Despite being presented here last, ultrack's image processing module provides auxiliary functions to process your image before the segmentation step. It's not mandatory to use it, but it might reduce the amount of code you need to write to preprocess your images. - -Most of them are available in ``ultrack.imgproc`` , ``ultrack.utils.array`` and ``ultrack.utils.cuda`` modules. - -Refer to the :ref:`image processing API reference ` for more information. diff --git a/docs/source/install.rst b/docs/source/install.rst deleted file mode 100644 index 1d5d6bc..0000000 --- a/docs/source/install.rst +++ /dev/null @@ -1,82 +0,0 @@ -Installation -============ - -The easiest way to install the package is to use the conda (or mamba) package manager. -If you do not have conda installed, we recommend to install mamba first, which is a faster alternative to conda. -You can find mamba installation instructions `here `_. - -Once you have conda (mamba) installed, you should create an environment for ``ultrack`` as follows: - -.. code-block:: bash - - conda create -n ultrack python=3.11 gurobi pytorch pyqt -c pytorch -c gurobi -c conda-forge - -Then, you can activate the environment and install ``ultrack``: - -.. code-block:: bash - - conda activate ultrack - pip install ultrack - -If you're using OSX you may need to install ``higra`` from source. You can do this by running the following commands: - -.. code-block:: bash - - conda activate ultrack - pip install numpy - pip install -vv git+https://github.com/higra/Higra - pip install ultrack - -You can check if the installation was successful by running: - -.. code-block:: bash - - ultrack --help - - -Gurobi setup ------------- - -Gurobi is a commercial optimization solver that is used in the tracking module of ``ultrack``. -While it is not a requirement, it is recommended to install it for the best performance. - -To use it, you need to obtain a license (free for academics) and activate it. - -Install gurobi using conda -`````````````````````````` - -You can skip this step if you have already installed Gurobi. - -In your existing Conda environment, install Gurobi with the following command: - -.. code-block:: bash - - conda install -c gurobi gurobi - -Obtain and activate an academic license -``````````````````````````````````````` - -**Obtaining a license:** register for an account using your academic email at `Gurobi's website `_. -Navigate to the Gurobi's `named academic license page `_, and follow the instructions to get your academic license key. - -**Activating license:** In your Conda environment, run: - -.. code-block:: bash - - grbgetkey YOUR_LICENSE_KEY - -Replace YOUR_LICENSE_KEY with the key you received. Follow the prompts to complete activation. - -Test the installation -````````````````````` - -Verify Gurobi's installation by running: - -.. code-block:: bash - - ultrack check_gurobi - -Troubleshooting -``````````````` - -Depending on the operating system, the gurobi library might be missing and you need to install it from `here `_. diff --git a/docs/source/napari.rst b/docs/source/napari.rst deleted file mode 100644 index 6429c47..0000000 --- a/docs/source/napari.rst +++ /dev/null @@ -1,55 +0,0 @@ -Napari plugin -------------- - -We wrapped up most of the functionality in a napari widget. The widget is already installed -by default, but you must have napari installed to use it. - -To use it, open napari and select the widget from the plugins menu selecting ``ultrack`` and then ``Ultrack`` -from the dropdown menu. - -The plugin is built around the concept of a tracking workflow. Any workflow is a sequence -of pre-processing steps, segmentation, (candidate segments) linking, and the tracking problem solver. -We explain the different workflows in the following sections. - -Workflows -````````` - -The difference between the workflows is the way the user provides the information to the plugin, -and the way it processes the information. The remaining steps are the same for all workflows. -In that sense, ``segmentation``, ``linking``, and ``solver`` are the same for all workflows. -For each step, the widget provides direct access to the parameters of the step, and the user can -change the parameters to adapt the workflow to the specific problem. We explain how these -parameters behave in :doc:`Configuration docs `, and, more specifically, in the -:class:`Experiment `, -:class:`Linking `, and -:class:`Tracking ` sections. Every input requested by the plugin -should be loaded beforehand as a layer in ``Napari``. - -There are three workflows available in the plugin: - -- **Automatic tracking from image**: This workflow is designed to track cells in a sequence of images. - It uses classical image processing techniques to detect the cells (foreground) and their possible contours. - In this workflow, you can change the parameters of the image processing steps. - Refer to the documentation of the functions used in the image processing steps: - - - :func:`ultrack.imgproc.detect_foreground` - - :func:`ultrack.imgproc.robust_invert` - -- **Manual tracking**: Since ultrack is designed to work with precomputed cell detection and - contour detection, this workflow is designed for the situation where the user has already - computed the cell detection and the contours of the cells. In this situation, no additional - parameter is needed, you only need to provide the cell detection and the contours of the cells. - -- **Automatic tracking from segmentation labels**: This workflow is designed to track cells - in a sequence of images where the user has already computed the segmentation of the cells. - This workflow wraps the function :meth:`ultrack.utils.labels_to_contours` to compute the foreground and - contours of the cells from the segmentation labels, refer to its documentation for additional details. - - -Flow Field Estimation -````````````````````` - -Every workflow allows the use of a flow field to improve the tracking of dynamic cells. -This method estimates the movement of the cells in the sequence -of images through the function :func:`ultrack.imgproc.flow.timelapse_flow`. -See the documentation of this function for additional details. diff --git a/docs/source/optimizing.rst b/docs/source/optimizing.rst deleted file mode 100644 index 17a1dbb..0000000 --- a/docs/source/optimizing.rst +++ /dev/null @@ -1,106 +0,0 @@ -Tuning tracking performance -------------------------------- - -Once you have a working ultrack pipeline, the next step is optimizing the tracking performance. -Here we describe our guidelines for optimizing the tracking performance and up to what point you can expect to improve the tracking performance. - -It will be divided into a few sections: - -- Pre-processing: How to make tracking easier by pre-processing the data; -- Input verification: Guidelines to check if you have good `labels` or `foreground` and `contours` maps; -- Hard constraints: Parameters must be adjusted so the hypotheses include the correct solution; -- Tracking tuning: Guidelines to adjust the weights to make the correct solution more likely. - -Pre-processing -`````````````` - -Registration -^^^^^^^^^^^^ - -Before tracking, the first question to ask yourself is, are your frames correctly aligned? - -If not, we recommend aligning them. To do that, we provide the ``ultrack.imgproc.register_timelapse`` to align translations, see the :ref:`registration API `. - -If the movement is more complex, with cells moving in different directions, we recommend using the ``flow`` functionalities to align individual segments with distinct transforms, see the :doc:`flow tutorial `. -See the :ref:`flow estimation API ` for more information. - -Deep learning -^^^^^^^^^^^^^ - -Some deep learning models are sensitive to the contrast of your data, we recommend adjusting the contrast and removing background before applying them to improve their predictions. -See the :ref:`image processing utilities API ` for more information. - -Input verification -`````````````````` - -At this point, we assume you already have a ``labels`` image or a ``foreground`` and ``contours`` maps; - -You should check if ``labels`` or ``foreground`` contains every cell you want to track. -Any region that is not included in the ``labels`` or ``foreground`` will not be tracked and can only be fixed with post-processing. - -If you are using ``foreground`` and ``contours`` maps, you should check if the contours induce hierarchies that lead to your desired segmentation. - -This can be done by loading the ``contours`` in napari and viewing them over your original image with ``blending='additive'``. - -You want your ``contours`` image to have higher values in the boundary of cells and lower values inside it. -This indicates that these regions are more likely to be boundaries than the interior of cells. -Notice, that this notion is much more flexible than a real contour map, which is we can use an intensity image as a `contours` map or an inverted distance transform. - -In cells where this is not the case it is less likely ultrack will be able to separate them into individual segments. - -If your cells (nuclei) are convex it is worth trying the ``ultrack.imgproc.inverted_edt`` for the ``contours``. - -If even after going through the next steps you don't have successful results, I suggest looking for specialized solutions once you have a working pipeline. -Some of these solutions are `PlantSeg `_ for membranes or `GoNuclear `_ for nuclei. - - -Hard constraints -```````````````` - -This section is about adjusting the parameters so we have hypotheses that include the correct solution. - -Please refer to the :doc:`Configuration docs ` as we refer to different parameters. - -1. The expected cell size should be between ``segmentation_config.min_area`` and ``segmentation_config.max_area``. -Having a tight range assists in finding a good segmentation and significantly reduces the computation. -Our rule of thumb is to set the ``min_area`` to half the size of the expected cell or the smallest cell, *disregarding outliers*. -And the ``max_area`` to 1.25~1.5 the size of the largest cell, this is less problematic than the ``min_area``. - -2. ``linking_config.max_distance`` should be set to the maximum distance a cell can move between frames. -We recommend setting some tolerance, for example, 1.5 times the expected movement. - -Tracking tuning -``````````````` - -Once you have gone through the previous steps, you should have a working pipeline and now we can focus on the results and what can be done in each scenario. - -1. My cells are oversegmented (excessive splitting of cells): - - Increase the ``segmentation_config.min_area`` to merge smaller cells; - - Increase the ``segmentation_config.max_area`` to avoid splitting larger cells; - - If you have clear boundaries and the oversegmentation are around weak boundaries, you can increase the ``segmentation_config.min_frontier`` to merge them (steps of 0.05 recommended). - - If you're using ``labels`` as input or to create my contours you can also try to increase the ``sigma`` parameter to create a better surface to segmentation by avoiding flat regions (full of zeros or ones). - -2. My cells are undersegmented (cells are fused): - - Decrease the ``segmentation_config.min_area`` to enable segmenting smaller cells; - - Decrease the ``segmentation_config.max_area`` to remove larger segments that are likely to be fused cells; - - Decrease the ``segmentation_config.min_frontier`` to avoid merging cells that have weak boundaries; - - **EXPERIMENTAL**: Set ``segmentation_config.max_noise`` to a value greater than 0, to create more diverse hierarchies, the scale of this value should be proportional to the ``contours`` value, for example, if the ``contours`` is in the range of 0-1, the ``max_noise`` around 0-0.05 should be enough. Play with it. **NOTE**: the solve step will take longer because of the increased number of hypotheses. - -3. I have missing segments that are present on the ``labels`` or ``foreground``: - - Check if these cells are above the ``segmentation_config.threshold`` value, if not, decrease it; - - Check if ``linking_config.max_distance`` is too low and increase it, when cells don't have connections they are unlikely to be included in the solutions; - - Your ``tracking_config.appear_weight``, ``tracking_config.disappear_weight`` & ``tracking_config.division_weight`` penalization weights are too high (too negative), try bringing them closer to 0.0. **TIP**: We recommend adjusting ``disappear_weight`` weight first, because when tuning ``appear_weight`` you should balance out ``division_weight`` so appearing cells don't become fake divisions. A rule of thumb is to keep ``division_weight`` equal or higher (more negative) than ``appear_weight``. - -4. I'm not detecting enough dividing cells: - - Bring ``tracking_config.division_weight`` to a value closer to 0. - - Depending on your time resolution and your cell type, it might be the case where dividing cells move further apart, in this case, you should tune the ``linking_config.max_distance`` accordingly. - -5. I'm detecting too many dividing cells: - - Make ``tracking_config.division_weight`` more negative. - -6. My tracks are short and not continuous enough: - - This is tricky, once you have tried the previous steps, you can try making the ``tracking_config.{appear, division, disappear}_weight`` more negative, but this will remove low-quality tracks. - - Another option is to use ``ultrack.tracks.close_tracks_gaps`` to post process the tracks. - -7. I have many incorrect tracks connecting distant cells: - - Decrease the ``linking_config.max_distance`` to avoid connecting distant cells. If that can't be done because you will lose correct connections, then you should set ``linking_config.distance_weight`` to a value closer higher than 0, usually in very small steps (0.01). diff --git a/docs/source/rest_api.rst b/docs/source/rest_api.rst deleted file mode 100644 index 07871a2..0000000 --- a/docs/source/rest_api.rst +++ /dev/null @@ -1,300 +0,0 @@ -REST API -======== - -The ultrack REST API is a set of HTTP/Websockets endpoints that allow you to track your -data from an Ultrack server. -This is what enables the :doc:`Ultrack FIJI plugin `. - -The communication between the Ultrack server and the client is mainly done through websockets. -Allowing real-time responses for efficient communication between the server and the client. - -All the messages sent through the websocket are JSON messages. And there is always an -:class:`Experiment ` object encoded and sent within the message. -This object contains all the information about the experiment that is being run, including -the configuration (:class:`MainConfig `) of the experiment, the status of the -experiment (:class:`ExperimentStatus `), the experiment ID, and the experiment name. -When the experiment is concluded, this object will also contain the results of the -experiment, encoded in the fields ``final_segments_url`` (URL to the tracked segments path) -and ``tracks`` (JSON of napari complaint tracking format). - -**IMPORTANT:** The server must have access to the data shared by the client, for example, through the web, or a shared file system. -Because the server does not store the input data being processed, only the results of the experiment. - -Endpoints ---------- - -In the following sections, we will describe the available endpoints and the expected -payloads for each endpoint. - -Meta endpoint -^^^^^^^^^^^^^ - -To avoid keeping track of each endpoint, there is a single endpoint that returns the available -endpoints for the Ultrack server. This also allows for the Ultrack server to be more dynamic, -as the available endpoints can be changed without changing the client. This endpoint is -described below. - -.. describe:: GET /config/available - - This endpoint returns all the available endpoints for the Ultrack server. - The response is a JSON object with the following structure: - - .. code-block:: JSON - - { - "id_endpoint": { - "link": "/url/to/endpoint", - "human_name": "The title of the endpoint", - "config": { - "experiment": { - "name": "Experiment Name", - "config": "MainConfig()" - }, - "set_of_kwargs_1": {}, - "set_of_kwargs_2": {}, - "...", - "set_of_kwargs_n": {} - } - }, - "..." - } - - As you can see, the response is a JSON object with the keys being the endpoint ID - and the values being a JSON object with the keys `link`, `human_name`, and `config`. - The `link` key is the URL to the endpoint, the `human_name` key is the title of the endpoint, - and the `config` key is the expected input payload for the endpoint. - - The `config` key comprises the initial configuration of the experiment - (an instance of :class:`Experiment `), and a - possible set of keyword arguments that are expected by the endpoint. Those keyword - arguments are dependent on the endpoint and are described in the following sections. - - The `experiment` instance is initialized with the default configuration of the - :class:`MainConfig ` class. This configuration can be - changed by the client and sent to update the server. - -Experiment endpoints -^^^^^^^^^^^^^^^^^^^^ - -The experiment endpoints are the main endpoints of the Ultrack server. -They allow the client to run the experiments and get their respective results. - -.. _segment_auto_detect: -.. describe:: WEBSOCKET /segment/auto_detect - - This endpoint is a websocket endpoint that allows you to send an image (referenced - as ``image_channel_or_path``) to the server and get the segmentation of the image. - - This endpoint wraps the :func:`ultrack.imgproc.detect_foreground` function and the - :func:`ultrack.imgproc.robust_invert` function, which estimates - the foreground of the image and its contours by image processing techniques. - For that reason, one can override the default parameters of those functions by sending - the ``detect_foreground_kwargs`` and ``robust_invert_kwargs`` as keyword arguments. - Those keyword arguments will be passed to their respective function. - - This endpoint requires a JSON payload with the following structure: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "image_channel_or_path": "/path/to/image", - }, - "detect_foreground_kwargs": {}, - "robust_invert_kwargs": {}, - } - - and the reserver repeatedly returns the :class:`Experiment ` - JSON payload. For example: - - .. code-block:: JSON - - { - "id": 1, - "name": "Experiment Name", - "status": "segmenting", - "config": { - "..." - } - "start_time": "2021-01-01T00:00:00", - "end_time": "", - "std_log": "Segmenting frame 1...", - "err_log": "", - "data_url": "", - "image_channel_or_path": "/path/to/image", - "edges_channel_or_path": "", - "detection_channel_or_path": "", - "segmentation_channel_or_path": "", - "labels_channel_or_path": "", - "final_segments_url": "", - "tracks": "" - } - - Alternatively, if the image is an OME-ZARR file, the input data could be - a specific channel. In this case, the input data could be referenced as: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "data_url": "/path/to/image.ome.zarr", - "image_channel_or_path": "image_channel", - }, - "detect_foreground_kwargs": {}, - "robust_invert_kwargs": {}, - } - -All the other endpoints are similar to the :ref:`/segment/auto_detect ` endpoint, but they -are more specific to the segmentation labels of the image. The endpoints are described below. - -.. _segment_manual: -.. describe:: WEBSOCKET /segment/manual - - This endpoint is similar to the :ref:`/segment/auto_detect ` endpoint, but it allows the - client to manually provide cells' foreground mask and their multilevel contours. - This endpoint requires the following JSON payload: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "detection_channel_or_path": "/path/to/detection", - "edges_channel_or_path": "/path/to/contours", - }, - } - - Alternatively, if the image is an OME-ZARR file, the input data could be - a specific channel. In this case, the input data could be referenced as: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "data_url": "/path/to/image.ome.zarr", - "detection_channel_or_path": "detection_channel", - "edges_channel_or_path": "contours_channel", - }, - } - - For both cases, the server will send the :class:`Experiment ` - JSON payload. For example: - - .. code-block:: JSON - - { - "id": 1, - "name": "Experiment Name", - "status": "segmenting", - "config": { - "..." - } - "start_time": "2021-01-01T00:00:00", - "end_time": "", - "std_log": "Linking cells...", - "err_log": "", - "data_url": "", - "image_channel_or_path": "", - "edges_channel_or_path": "/path/to/contours", - "detection_channel_or_path": "/path/to/detection", - "segmentation_channel_or_path": "", - "labels_channel_or_path": "", - "final_segments_url": "", - "tracks": "" - } - -Last but not least, the following endpoint could be used in a situation where the client already has -the instance segmentation of the cells, for example, from Cellpose or StarDist. - -.. _segment_labels: -.. describe:: WEBSOCKET /segment/labels - - This endpoint is similar to the :ref:`/segment/auto_detect ` endpoint, but it allows the - client to provide pre-computed instance segmentation of the cells. - This endpoint wraps the :meth:`ultrack.utils.labels_to_contours` function, which computes the foreground and contours - of the cells from the instance segmentation. - - This endpoint requires the following JSON payload: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "labels_channel_or_path": "/path/to/labels", - }, - "labels_to_edges_kwargs": {}, - } - - Alternatively, if the image is an OME-ZARR file, the input data could be - a specific channel. In this case, the input data could be referenced as: - - .. code-block:: JSON - - { - "experiment": { - "name": "Experiment Name", - "config": "..." - "data_url": "/path/to/image.ome.zarr", - "labels_channel_or_path": "labels_channel", - }, - } - - For both cases, the server will send the :class:`Experiment ` - JSON payload. For example: - - .. code-block:: JSON - - { - "id": 1, - "name": "Experiment Name", - "status": "segmenting", - "config": { - "..." - } - "start_time": "2021-01-01T00:00:00", - "end_time": "", - "std_log": "Linking cells...", - "err_log": "", - "data_url": "", - "image_channel_or_path": "", - "edges_channel_or_path": "", - "detection_channel_or_path": "", - "segmentation_channel_or_path": "", - "labels_channel_or_path": "/path/to/labels", - "final_segments_url": "", - "tracks": "" - } - -Data export endpoints -^^^^^^^^^^^^^^^^^^^^^ - -.. describe:: GET /experiment/{experiment_id}/trackmate - - This endpoint allows the client to download the results of the experiment in the TrackMate - XML format. The client must provide the ``experiment_id`` in the URL. This id is obtained from - the :class:`Experiment ` instance that was executed. - The server will return an XML encoded within the response. - -Database Schema ---------------- - -All the data that is being processed by the Ultrack server is stored in a database. This -database is a SQLite database that is created when the server is started. The database -is used to store the results of the experiments and the configuration of the experiments. - -The database schema is the same as the one used in Ultrack, but with an additional table -to store the configuration and the status of the experiments. The schema is described below. - -.. autopydantic_model:: ultrack.api.database.Experiment - :members: - -.. autoclass:: ultrack.api.database.ExperimentStatus diff --git a/examples/README.rst b/examples/README.rst deleted file mode 100644 index 3170303..0000000 --- a/examples/README.rst +++ /dev/null @@ -1,34 +0,0 @@ -Ultrack's Usage Examples -======================== - -Here we provide some examples of how to use Ultrack for cell tracking. - -Some examples are provided as Jupyter notebooks with additional documentation, but we do not recommend using Jupyter notebooks for your day-to-day analysis. - -Other examples as Python scripts can be found in `here `_. - -Additional packages might be required. Therefore, conda environment files are provided, which can be installed using: - -.. code-block:: bash - - conda env create -f - conda activate - pip install git+https://github.com/royerlab/ultrack - -The existing examples are: - -- `multi_color_ensemble <./multi_color_ensemble>`_ : Multi-colored cytoplasm cell tracking using Cellpose and Watershed segmentation ensemble. Data provided by `The Lammerding Lab `_. -- `flow_field_3d <./flow_field_3d>`_ : Tracking demo on a cartographic projection of Tribolium Castaneum embryo from the `cell-tracking challenge `_, using a flow field estimation to assist tracking of motile cells. -- `stardist_2d <./stardist_2d>`_ : Tracking demo on HeLa GPF nuclei from the `cell-tracking challenge `_ using Stardist 2D fluorescence images pre-trained model. -- `zebrahub <./zebrahub/>`_ : Tracking demo on zebrafish tail data from `zebrahub `_ acquired with `DaXi `_ using Ultrack's image processing helper functions. -- `neuromast_plantseg <./neuromast_plantseg/>`_ : Tracking demo membrane-labeled zebrafish neuromast from `Jacobo Group of CZ Biohub `_ using `PlantSeg's `_ membrane detection model. -- `micro_sam <./micro_sam/>`_ : Tracking demo with `MicroSAM `_ instance segmentation package. - -Development Notes -^^^^^^^^^^^^^^^^^ - -To run all the examples and update the notebooks in headless mode, run: - -.. code-block:: bash - - bash refresh_examples.sh diff --git a/ultrack/core/_test/test_interactive.py b/ultrack/core/_test/test_interactive.py deleted file mode 100644 index bfa86a8..0000000 --- a/ultrack/core/_test/test_interactive.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Tuple - -import numpy as np -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import Session - -from ultrack import MainConfig, add_new_node -from ultrack.core.database import LinkDB, NodeDB, OverlapDB - - -def _get_table_sizes(session: Session) -> Tuple[int, int, int]: - return ( - session.query(NodeDB).count(), - session.query(LinkDB).count(), - session.query(OverlapDB).count(), - ) - - -@pytest.mark.parametrize( - "config_content", - [ - { - "data.database": "sqlite", - "segmentation.n_workers": 4, - "linking.n_workers": 4, - "linking.max_distance": 500, # too big and ignored - }, - ], - indirect=True, -) -def test_clear_solution( - linked_database_mock_data: MainConfig, -) -> None: - - mask = np.ones((7, 12, 12), dtype=bool) - bbox = np.array([15, 24, 24, 22, 36, 36], dtype=int) - - engine = create_engine(linked_database_mock_data.data_config.database_path) - with Session(engine) as session: - n_nodes, n_links, n_overlaps = _get_table_sizes(session) - - add_new_node( - linked_database_mock_data, - 0, - mask, - bbox, - ) - - new_n_nodes, new_n_links, new_n_overlaps = _get_table_sizes(session) - - assert new_n_nodes == n_nodes + 1 - assert new_n_overlaps > n_overlaps - # could smaller than max neighbors because of radius - assert ( - new_n_links == n_links + linked_database_mock_data.linking_config.max_neighbors - ) diff --git a/ultrack/core/export/_test/test_exporter.py b/ultrack/core/export/_test/test_exporter.py deleted file mode 100644 index 8595b90..0000000 --- a/ultrack/core/export/_test/test_exporter.py +++ /dev/null @@ -1,42 +0,0 @@ -from pathlib import Path - -from ultrack import MainConfig, export_tracks_by_extension - - -def test_exporter(tracked_database_mock_data: MainConfig, tmp_path: Path) -> None: - file_ext_list = [".xml", ".csv", ".zarr", ".dot", ".json"] - last_modified_time = {} - for file_ext in file_ext_list: - tmp_file = tmp_path / f"tracks{file_ext}" - export_tracks_by_extension(tracked_database_mock_data, tmp_file) - - # assert file exists - assert (tmp_path / f"tracks{file_ext}").exists() - # assert file size is not zero - assert (tmp_path / f"tracks{file_ext}").stat().st_size > 0 - - # store last modified time - last_modified_time[str(tmp_file)] = tmp_file.stat().st_mtime - - # loop again testing overwrite=False - for file_ext in file_ext_list: - tmp_file = tmp_path / f"tracks{file_ext}" - try: - export_tracks_by_extension( - tracked_database_mock_data, tmp_file, overwrite=False - ) - assert False, "FileExistsError should be raised" - except FileExistsError: - pass - - # loop again testing overwrite=True - for file_ext in file_ext_list: - tmp_file = tmp_path / f"tracks{file_ext}" - export_tracks_by_extension(tracked_database_mock_data, tmp_file, overwrite=True) - - # assert file exists - assert (tmp_path / f"tracks{file_ext}").exists() - # assert file size is not zero - assert (tmp_path / f"tracks{file_ext}").stat().st_size > 0 - - assert last_modified_time[str(tmp_file)] != tmp_file.stat().st_mtime diff --git a/ultrack/core/export/exporter.py b/ultrack/core/export/exporter.py deleted file mode 100644 index 9583251..0000000 --- a/ultrack/core/export/exporter.py +++ /dev/null @@ -1,74 +0,0 @@ -import json -from pathlib import Path -from typing import Union - -import networkx as nx - -from ultrack.config import MainConfig -from ultrack.core.export import ( - to_networkx, - to_trackmate, - to_tracks_layer, - tracks_to_zarr, -) - - -def export_tracks_by_extension( - config: MainConfig, filename: Union[str, Path], overwrite: bool = False -) -> None: - """ - Export tracks to a file given the file extension. - - Supported file extensions are .xml, .csv, .zarr, .dot, and .json. - - `.xml` exports to a TrackMate compatible XML file. - - `.csv` exports to a CSV file. - - `.zarr` exports the tracks to dense segments in a `zarr` array format. - - `.dot` exports to a Graphviz DOT file. - - `.json` exports to a networkx JSON file. - - Parameters - ---------- - filename : str or Path - The name of the file to save the tracks to. - config : MainConfig - The configuration object. - overwrite : bool, optional - Whether to overwrite the file if it already exists, by default False. - - See Also - -------- - to_trackmate : - Export tracks to a TrackMate compatible XML file. - to_tracks_layer : - Export tracks to a CSV file. - tracks_to_zarr : - Export tracks to a `zarr` array. - to_networkx : - Export tracks to a networkx graph. - """ - if Path(filename).exists() and not overwrite: - raise FileExistsError( - f"File {filename} already exists. Set `overwrite=True` to overwrite the file" - ) - - file_ext = Path(filename).suffix - if file_ext.lower() == ".xml": - to_trackmate(config, filename, overwrite=True) - elif file_ext.lower() == ".csv": - df, _ = to_tracks_layer(config, include_parents=True) - df.to_csv(filename, index=False) - elif file_ext.lower() == ".zarr": - df, _ = to_tracks_layer(config) - tracks_to_zarr(config, df, filename, overwrite=True) - elif file_ext.lower() == ".dot": - G = to_networkx(config) - nx.drawing.nx_pydot.write_dot(G, filename) - elif file_ext.lower() == ".json": - G = to_networkx(config) - json_data = nx.node_link_data(G) - with open(filename, "w") as f: - json.dump(json_data, f) - else: - raise ValueError( - f"Unknown file extension: {file_ext}. Supported extensions are .xml, .csv, .zarr, .dot, and .json." - ) diff --git a/ultrack/core/interactive.py b/ultrack/core/interactive.py deleted file mode 100644 index 18a8ad8..0000000 --- a/ultrack/core/interactive.py +++ /dev/null @@ -1,316 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy as np -from numpy.typing import ArrayLike -from sqlalchemy import create_engine, func -from sqlalchemy.orm import Session - -from ultrack.config.config import LinkingConfig, MainConfig -from ultrack.core.database import LinkDB, NodeDB, OverlapDB -from ultrack.core.segmentation.node import Node - - -def _nearest_neighbors( - data_arr: ArrayLike, - node: Node, - n_neighbors: int, - max_distance: Optional[float], - scale: Optional[ArrayLike], -) -> np.ndarray: - """ - Returns the indices of the `n_neighbors` nearest neighbors to `node` in `data_arr`. - - Parameters - ---------- - data_arr : ArrayLike - Array of (id, z, y, x) coordinates of the nodes (or centroids). - node : Node - Node to find neighbors for. - n_neighbors : int - Number of neighbors to be considered. - max_distance : float - Maximum distance to be considered. - scale : Optional[ArrayLike], optional - Scaling factor for the distance, by default None. - - Returns - ------- - np.ndarray - Indices of the nearest neighbors. - """ - if scale is None: - scale = np.ones(len(node.centroid)) - - differences = data_arr[:, -len(node.centroid) :] - node.centroid - differences *= scale - sqdist = np.square(differences).sum(axis=1) - - if max_distance is not None: - valid = sqdist <= (max_distance * max_distance) - data_arr = data_arr[valid] - sqdist = sqdist[valid] - - indices = np.argsort(sqdist) - - return data_arr[indices[:n_neighbors], 0] - - -def _find_links( - session: Session, - node: Node, - adj_time: int, - scale: Optional[ArrayLike], - link_config: LinkingConfig, -) -> List[Tuple[NodeDB, float]]: - """ - Finds links for a given node and time. - - Parameters - ---------- - session : Session - SQLAlchemy session. - node : Node - Node to search for neighbors. - adj_time : int - Adjacent time point. - scale : Optional[ArrayLike], optional - Scaling factor for the distance, by default None. - link_config : LinkingConfig - Linking configuration parameters. - - Returns - ------- - List[NodeDB, float] - List of nodes and their weights. - """ - data = np.asarray( - session.query( - NodeDB.id, - NodeDB.z, - NodeDB.y, - NodeDB.x, - ) - .where(NodeDB.t == adj_time) - .all() - ) - if len(data) == 0: - return [] - - ind = _nearest_neighbors( - data, - node, - 2 * link_config.max_neighbors, - link_config.max_distance, - scale=scale, - ) - - neigh_nodes = session.query(NodeDB.pickle).where(NodeDB.id.in_(ind)).all() - - if scale is None: - scale = np.ones(len(node.centroid)) - - neigh_nodes_with_dist = [] - for (n,) in neigh_nodes: - dist = np.linalg.norm((n.centroid - node.centroid) * scale) - w = node.IoU(n) - link_config.distance_weight * dist - neigh_nodes_with_dist.append((n, w)) - - neigh_nodes_with_dist.sort(key=lambda x: x[1], reverse=True) - - return neigh_nodes_with_dist[: link_config.max_neighbors] - - -def _add_overlaps( - session: Session, - node: Node, - n_neighbors: int = 10, - scale: Optional[ArrayLike] = None, -) -> None: - """ - Adds overlaps to the database. - - Parameters - ---------- - session : Session - SQLAlchemy session. - node : Node - Node to find overlaps with. - n_neighbors : int, optional - Number of neighbors to be considered, by default 10. - scale : Optional[ArrayLike], optional - Scaling factor for the distance, by default None. - """ - data = np.asarray( - session.query( - NodeDB.id, - NodeDB.z, - NodeDB.y, - NodeDB.x, - ) - .where(NodeDB.t == node.time, NodeDB.id != node.id) - .all() - ) - ind = _nearest_neighbors(data, node, n_neighbors, max_distance=None, scale=scale) - - overlaps = [] - - for (neigh_node,) in session.query(NodeDB.pickle).where(NodeDB.id.in_(ind)).all(): - if node.IoU(neigh_node) > 0.0: - overlaps.append( - OverlapDB( - node_id=node.id, - ancestor_id=neigh_node.id, - ) - ) - - session.add_all(overlaps) - - -def _add_links( - session: Session, - node: Node, - link_config: LinkingConfig, - scale: Optional[ArrayLike] = None, -) -> None: - """ - Adds T - 1 and T + 1 links to the database. - - NOTE: this is not taking node shifts into account. - - Parameters - ---------- - session : Session - SQLAlchemy session. - node : Node - Node to search for neighbors. - link_config : LinkingConfig - Linking configuration parameters. - scale : Optional[ArrayLike], optional - Scaling factor for the distance, by default None. - """ - links = [] - - before_links = _find_links( - session=session, - node=node, - adj_time=node.time - 1, - scale=scale, - link_config=link_config, - ) - for before_node, w in before_links: - links.append( - LinkDB( - source_id=before_node.id, - target_id=node.id, - weight=w, - ) - ) - - after_links = _find_links( - session=session, - node=node, - adj_time=node.time + 1, - scale=scale, - link_config=link_config, - ) - for after_node, w in after_links: - links.append( - LinkDB( - source_id=node.id, - target_id=after_node.id, - weight=w, - ) - ) - - session.add_all(links) - - -def add_new_node( - config: MainConfig, - time: int, - mask: ArrayLike, - bbox: Optional[ArrayLike] = None, - index: Optional[int] = None, - include_overlaps: bool = True, -) -> int: - """ - Adds a new node to the database. - - NOTE: this is not taking node shifts or image features (color) into account. - - Parameters - ---------- - config : MainConfig - Ultrack configuration parameters. - time : int - Time point of the node. - mask : ArrayLike - Binary mask of the node. - bbox : Optional[ArrayLike], optional - Bounding box of the node, (min_0, min_1, ..., max_0, max_1, ...). - When provided it assumes the mask is a crop of the original image, by default None - index : Optional[int], optional - Node index, otherwise it is automatically generated, and returned. - include_overlaps : bool, optional - Include overlaps in the database, by default True - When False it will allow oclusions between new node and existing nodes. - - Returns - ------- - int - New node index. - """ - - node = Node.from_mask( - time=time, - mask=mask, - bbox=bbox, - ) - if node.area == 0: - raise ValueError("Node area is zero. Something went wrong.") - - scale = config.data_config.metadata.get("scale") - - engine = create_engine(config.data_config.database_path) - with Session(engine) as session: - - # querying required data - if index is None: - node.id = ( - int(session.query(func.max(NodeDB.id)).where(NodeDB.t == time).scalar()) - + 1 - ) - else: - node.id = index - - # adding node - if len(node.centroid) == 2: - y, x = node.centroid - z = 0 - else: - z, y, x = node.centroid - - node_db_obj = NodeDB( - id=node.id, - t=node.time, - z=z, - y=y, - x=x, - area=node.area, - pickle=node, - ) - session.add(node_db_obj) - - if include_overlaps: - _add_overlaps(session=session, node=node, scale=scale) - - _add_links( - session=session, - node=node, - link_config=config.linking_config, - scale=scale, - ) - - session.commit() - - return node.id diff --git a/ultrack/imgproc/_test/test_register.py b/ultrack/imgproc/_test/test_register.py deleted file mode 100644 index ebfc742..0000000 --- a/ultrack/imgproc/_test/test_register.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Tuple - -import numpy as np -import pytest -import scipy.ndimage as ndi -import zarr - -from ultrack.imgproc import register_timelapse - - -@pytest.mark.parametrize( - "timelapse_mock_data", - [ - {"length": 3, "size": 32, "n_dim": 3}, - ], - indirect=True, -) -def test_register_timelapse( - timelapse_mock_data: Tuple[zarr.Array, zarr.Array, zarr.Array], -) -> None: - _, moved_edges, _ = timelapse_mock_data - - shift = 8 - - # adding a new to emulate a channel - moved_edges = moved_edges[...][:, None] - - for i in range(moved_edges.shape[0]): - moved_edges[i] = ndi.shift(moved_edges[i], (0, i * shift // 2, 0, 0), order=1) - - fixed_edges = register_timelapse(moved_edges, reference_channel=0, padding=shift) - - for i in range(moved_edges.shape[0] - 1): - # removing padding and out of fov regions - volume = fixed_edges[i, :, : -2 * shift] - next_vol = fixed_edges[i + 1, :, : -2 * shift] - - np.testing.assert_allclose(volume, next_vol) diff --git a/ultrack/imgproc/register.py b/ultrack/imgproc/register.py deleted file mode 100644 index 5a32bb1..0000000 --- a/ultrack/imgproc/register.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from typing import Callable, Optional, Union - -import numpy as np -import zarr -from numpy.typing import ArrayLike -from tqdm import tqdm -from zarr.storage import Store - -from ultrack.utils.array import create_zarr -from ultrack.utils.cuda import import_module, to_cpu - -LOG = logging.getLogger(__name__) - - -def register_timelapse( - timelapse: ArrayLike, - *, - store_or_path: Union[Store, str, None] = None, - overwrite: bool = False, - to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, - reference_channel: Optional[int] = None, - overlap_ratio: float = 0.25, - normalization: Optional[str] = None, - padding: Optional[int] = None, - **kwargs, -) -> zarr.Array: - """ - Register a timelapse sequence using phase cross correlation. - - Parameters - ---------- - timelapse : ArrayLike - Input timelapse sequence, T(CZ)YX array C and Z are optional. - NOTE: when provided, C must be the second dimension after T. - store_or_path : Union[Store, str, None], optional - Zarr storage or a file path, to save the output, useful for larger than memory datasets. - By default it loads the data into memory. - overwrite : bool, optional - Overwrite output file if it already exists, when using directory store or a path. - to_device : Callable[[ArrayLike], ArrayLike], optional - Function to move the input data to cuda device, by default lambda x: x (CPU). - reference_channel : Optional[int], optional - Reference channel for registration, by default None. - It must be provided when it contains a channel dimension. - overlap_ratio : float, optional - Overlap ratio for phase cross correlation, by default 0.25. - normalization : Optional[str], optional - Normalization method for phase cross correlation, by default None. - padding : Optional[int], optional - Padding for registration, by default None. - **kwargs - Additional arguments for phase cross correlation. See `skimage.registration phase_cross_correlation - `_. # noqa: E501 - - Returns - ------- - zarr.Array - Registered timelapse sequence. - """ - shape = list(timelapse.shape) - - if padding is not None: - offset = 1 if reference_channel is None else 2 - pads = [(0, 0)] * (offset - 1) - - for i in range(offset, len(shape)): - shape[i] += 2 * padding - pads.append((padding, padding)) - - def maybe_pad(x: np.ndarray) -> np.ndarray: - x = np.asarray(x) - x = to_device(x) - return np.pad(x, pads, mode="constant") - - else: - - def maybe_pad(x: np.ndarray) -> np.ndarray: - x = np.asarray(x) - return to_device(x) - - out_arr = create_zarr( - tuple(shape), - dtype=timelapse.dtype, - store_or_path=store_or_path, - overwrite=overwrite, - ) - - if reference_channel is None: - channel = ... - else: - channel = reference_channel - - prev_frame = maybe_pad(timelapse[0]) - out_arr[0] = to_cpu(prev_frame) - - ndi = import_module("scipy", "ndimage", arr=prev_frame) - skreg = import_module("skimage", "registration", arr=prev_frame) - - for t in tqdm(range(timelapse.shape[0] - 1), "Registration"): - next_frame = maybe_pad(timelapse[t + 1]) - shift, _, _ = skreg.phase_cross_correlation( - prev_frame[channel], - next_frame[channel], - overlap_ratio=overlap_ratio, - normalization=normalization, - **kwargs, - ) - - LOG.info("Shift at {t}: {shift}", t=t, shift=shift) - print(f"Shift at {t}: {shift}") - - if reference_channel is not None: - shift = (0, *shift) - - next_frame = ndi.shift(next_frame, shift, order=1) - out_arr[t + 1] = to_cpu(next_frame) - - prev_frame = next_frame - - return out_arr diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 38c106a..e8c2358 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -14,7 +14,7 @@ from zarr.storage import Store from ultrack.core.database import NodeDB -from ultrack.config import MainConfig +from ultrack import MainConfig LOG = logging.getLogger(__name__) @@ -219,13 +219,15 @@ def __init__( self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.t_max = self.shape[0] + self.Tmax = self.shape[0] self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) + self.export_func = self.array.__setitem__ self.database_path = config.data_config.database_path self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) + self.initial_volume = self.volume.copy() def __getitem__(self, indexing: tuple, @@ -250,29 +252,32 @@ def __getitem__(self, try: time = time.item() # convert from numpy.int to int - except AttributeError: + except: time = time - self.fill_array( + self.query_volume( time=time, + buffer=self.array, ) return self.array[volume_slicing] - def fill_array( + def query_volume( self, time: int, + buffer: np.array, ) -> None: """Paint all segments of specific time point which volume is bigger than self.volume Parameters ---------- time : int time point to paint the segments + buffer : np.array + np.zeros to be filled with segments """ - - self.array.fill(0) - + engine = sqla.create_engine(self.database_path) + buffer.fill(0) with Session(engine) as session: query = list( @@ -303,36 +308,12 @@ def fill_array( for idx in idx_to_plot: query[idx][1].paint_buffer( - self.array, value=label_list[idx], include_time=False + buffer, value=label_list[idx], include_time=False ) - def get_tp_num_pixels( - self, - timeStart:int, - timeStop:int, - ) -> list: - """Gets a list of number of pixels of all segments range of time points (timeStart to timeStop) - - Parameters - ---------- - timeStart : int - timeStop : int - - Returns - ------- - np.array : np.array - array with two elements: [min_volume, max_volume] - """ - engine = sqla.create_engine(self.database_path) - num_pix_list = [] - with Session(engine) as session: - query = list(session.query(NodeDB.area).where(NodeDB.t >= timeStart).where(NodeDB.t <= timeStop)) - for num_pix in query: - num_pix_list.append(int(np.array(num_pix))) - return num_pix_list - + return query - def get_tp_num_pixels_minmax( + def find_minmax_volumes_1_timepoint( self, time: int, ) -> np.ndarray: @@ -344,13 +325,21 @@ def get_tp_num_pixels_minmax( Returns ------- - num_pix_list : list - list with all num_pixels from t=0 to t=timeLimit + np.array : np.array + array with two elements: [min_volume, max_volume] """ - num_pixels_list = self.get_tp_num_pixels(time,time) - return (min(num_pixels_list),max(num_pixels_list)) - - + engine = sqla.create_engine(self.database_path) + min_vol = np.inf + max_vol = 0 + with Session(engine) as session: + query = list(session.query(NodeDB.pickle).where(NodeDB.t == time)) + for node in query: + vol = node[0].area + if vol < min_vol: + min_vol = vol + if vol > max_vol: + max_vol = vol + return np.array([min_vol, max_vol]).astype(int) def find_min_max_volume_entire_dataset(self): """Find minimum and maximum segment volume for ALL time point @@ -363,10 +352,35 @@ def find_min_max_volume_entire_dataset(self): min_vol = np.inf max_vol = 0 for t in range(self.t_max): #range(self.shape[0]): - minmax = self.get_tp_num_pixels_minmax(t) + minmax = self.find_minmax_volumes_1_timepoint(t) if minmax[0] < min_vol: min_vol = minmax[0] if minmax[1] > max_vol: max_vol = minmax[1] - return np.array([min_vol, max_vol], dtype=int) \ No newline at end of file + return np.array([min_vol, max_vol], dtype=int) + + def get_volume_list( + self, + timeLimit: int, + ) -> list: + """Creates a list of the volumes of all segments in the database (up untill t=timeLimit) + + Parameters + ---------- + timeLimit : int + + Returns + ------- + vol_list : list + list with all volumes from t=0 to t=timeLimit + """ + engine = sqla.create_engine(self.database_path) + vol_list = [] + with Session(engine) as session: + query = list(session.query(NodeDB.pickle).where(NodeDB.t <= timeLimit)) + for node in query: + vol = node[0].area + vol_list.append(vol) + + return vol_list diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index 4b5ee34..2f2ddf4 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -42,12 +42,12 @@ def __init__(self, self.config = config self.ultrack_array = UltrackArray(self.config) + self._viewer.add_labels(self.ultrack_array, name='hierarchy') self.mapping = self._create_mapping() self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) self._area_threshold_w.value = 0.5 - self.ultrack_array.volume = self.mapping(0.5) self._area_threshold_w.changed.connect(self._slider_update) self.slider_label = Label(label=str(int(self.mapping(self._area_threshold_w.value)))) @@ -56,11 +56,6 @@ def __init__(self, self.append(self._area_threshold_w) self.append(self.slider_label) - self._viewer.add_labels(self.ultrack_array, name='hierarchy') - #THERE SHOULD BE CHECK HERE IF THERE EXISTS A LAYER WITH THE NAME 'HIERARCHY' - self._viewer.layers['hierarchy'].refresh() - - def _on_config_changed(self) -> None: self._ndim = len(self._shape) @@ -75,16 +70,16 @@ def _slider_update(self, value: float) -> None: def _create_mapping(self): """ - Creates a pseudo-linear mapping from U[0,1] to full range of number of pixels: - num_pixels = mapping([0,1]) + Creates a pseudo-linear mapping from U[0,1] to full range of segment volumes: + volume = mapping([0,1]) """ - num_pixels_list = self.ultrack_array.get_tp_num_pixels(timeStart=0,timeStop=5) - num_pixels_list.append(self.ultrack_array.minmax[0]) - num_pixels_list.append(self.ultrack_array.minmax[1]) - num_pixels_list.sort() + volume_list = self.ultrack_array.get_volume_list(timeLimit=5) + volume_list.append(self.ultrack_array.minmax[0]) + volume_list.append(self.ultrack_array.minmax[1]) + volume_list.sort() - x_vec = np.linspace(0,1,len(num_pixels_list)) - y_vec = np.array(num_pixels_list) + x_vec = np.linspace(0,1,len(volume_list)) + y_vec = np.array(volume_list) mapping = interpolate.interp1d(x_vec,y_vec) return mapping diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 36bb39c..45fa720 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -575,7 +575,7 @@ def _cancel(self): @staticmethod def find_ultrack_widget(viewer: napari.Viewer) -> Optional["UltrackWidget"]: - """Find and returns Ultrack widget. If widget not found, returns None""" + for _, w in viewer.window._dock_widgets.items(): if isinstance(w.widget(), UltrackWidget): return w.widget() From 1208da8a1a3125822b62760e631cdb99c17b0103 Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Thu, 1 Aug 2024 16:33:02 -0700 Subject: [PATCH 19/45] redid the changes lost in merge --- ultrack/utils/array.py | 89 ++++++++----------- ultrack/widgets/hierarchy_viz_widget.py | 22 +++-- .../widgets/ultrackwidget/ultrackwidget.py | 2 +- 3 files changed, 49 insertions(+), 64 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index e8c2358..c03f38f 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -14,7 +14,7 @@ from zarr.storage import Store from ultrack.core.database import NodeDB -from ultrack import MainConfig +from ultrack.config import MainConfig LOG = logging.getLogger(__name__) @@ -219,15 +219,13 @@ def __init__( self.config = config self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) self.dtype = dtype - self.Tmax = self.shape[0] + self.t_max = self.shape[0] self.ndim = len(self.shape) self.array = np.zeros(self.shape[1:], dtype=self.dtype) - self.export_func = self.array.__setitem__ self.database_path = config.data_config.database_path self.minmax = self.find_min_max_volume_entire_dataset() self.volume = self.minmax.mean().astype(int) - self.initial_volume = self.volume.copy() def __getitem__(self, indexing: tuple, @@ -252,32 +250,28 @@ def __getitem__(self, try: time = time.item() # convert from numpy.int to int - except: + except AttributeError: time = time - self.query_volume( + self.fill_array( time=time, - buffer=self.array, ) return self.array[volume_slicing] - def query_volume( + def fill_array( self, time: int, - buffer: np.array, ) -> None: """Paint all segments of specific time point which volume is bigger than self.volume Parameters ---------- time : int time point to paint the segments - buffer : np.array - np.zeros to be filled with segments """ engine = sqla.create_engine(self.database_path) - buffer.fill(0) + self.array.fill(0) with Session(engine) as session: query = list( @@ -308,12 +302,33 @@ def query_volume( for idx in idx_to_plot: query[idx][1].paint_buffer( - buffer, value=label_list[idx], include_time=False + self.array, value=label_list[idx], include_time=False ) - return query + def get_tp_num_pixels( + self, + timeStart:int, + timeStop:int, + ) -> list: + """Gets a list of number of pixels of all segments range of time points (timeStart to timeStop) + Parameters + ---------- + timeStart : int + timeStop : int + Returns + ------- + num_pix_list : list + list with all num_pixels for timeStart to timeStop + """ + engine = sqla.create_engine(self.database_path) + num_pix_list = [] + with Session(engine) as session: + query = list(session.query(NodeDB.area).where(NodeDB.t >= timeStart).where(NodeDB.t <= timeStop)) + for num_pix in query: + num_pix_list.append(int(np.array(num_pix))) + return num_pix_list - def find_minmax_volumes_1_timepoint( + def get_tp_num_pixels_minmax( self, time: int, ) -> np.ndarray: @@ -325,21 +340,12 @@ def find_minmax_volumes_1_timepoint( Returns ------- - np.array : np.array + num_pix_list : list array with two elements: [min_volume, max_volume] """ - engine = sqla.create_engine(self.database_path) - min_vol = np.inf - max_vol = 0 - with Session(engine) as session: - query = list(session.query(NodeDB.pickle).where(NodeDB.t == time)) - for node in query: - vol = node[0].area - if vol < min_vol: - min_vol = vol - if vol > max_vol: - max_vol = vol - return np.array([min_vol, max_vol]).astype(int) + num_pix_list = self.get_tp_num_pixels(time,time) + return (min(num_pix_list), max(num_pix_list)) + def find_min_max_volume_entire_dataset(self): """Find minimum and maximum segment volume for ALL time point @@ -352,35 +358,10 @@ def find_min_max_volume_entire_dataset(self): min_vol = np.inf max_vol = 0 for t in range(self.t_max): #range(self.shape[0]): - minmax = self.find_minmax_volumes_1_timepoint(t) + minmax = self.get_tp_num_pixels_minmax(t) if minmax[0] < min_vol: min_vol = minmax[0] if minmax[1] > max_vol: max_vol = minmax[1] return np.array([min_vol, max_vol], dtype=int) - - def get_volume_list( - self, - timeLimit: int, - ) -> list: - """Creates a list of the volumes of all segments in the database (up untill t=timeLimit) - - Parameters - ---------- - timeLimit : int - - Returns - ------- - vol_list : list - list with all volumes from t=0 to t=timeLimit - """ - engine = sqla.create_engine(self.database_path) - vol_list = [] - with Session(engine) as session: - query = list(session.query(NodeDB.pickle).where(NodeDB.t <= timeLimit)) - for node in query: - vol = node[0].area - vol_list.append(vol) - - return vol_list diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py index 2f2ddf4..036f962 100644 --- a/ultrack/widgets/hierarchy_viz_widget.py +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -42,12 +42,12 @@ def __init__(self, self.config = config self.ultrack_array = UltrackArray(self.config) - self._viewer.add_labels(self.ultrack_array, name='hierarchy') self.mapping = self._create_mapping() self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) self._area_threshold_w.value = 0.5 + self.ultrack_array.volume = self.mapping(0.5) self._area_threshold_w.changed.connect(self._slider_update) self.slider_label = Label(label=str(int(self.mapping(self._area_threshold_w.value)))) @@ -56,6 +56,10 @@ def __init__(self, self.append(self._area_threshold_w) self.append(self.slider_label) + #THERE SHOULD BE CHECK HERE IF THERE EXISTS A LAYER WITH THE NAME 'HIERARCHY' + self._viewer.add_labels(self.ultrack_array, name='hierarchy') + self._viewer.layers['hierarchy'].refresh() + def _on_config_changed(self) -> None: self._ndim = len(self._shape) @@ -70,16 +74,16 @@ def _slider_update(self, value: float) -> None: def _create_mapping(self): """ - Creates a pseudo-linear mapping from U[0,1] to full range of segment volumes: - volume = mapping([0,1]) + Creates a pseudo-linear mapping from U[0,1] to full range of number of pixels + num_pixels = mapping([0,1]) """ - volume_list = self.ultrack_array.get_volume_list(timeLimit=5) - volume_list.append(self.ultrack_array.minmax[0]) - volume_list.append(self.ultrack_array.minmax[1]) - volume_list.sort() + num_pixels_list = self.ultrack_array.get_tp_num_pixels(timeStart=5,timeStop=5) + num_pixels_list.append(self.ultrack_array.minmax[0]) + num_pixels_list.append(self.ultrack_array.minmax[1]) + num_pixels_list.sort() - x_vec = np.linspace(0,1,len(volume_list)) - y_vec = np.array(volume_list) + x_vec = np.linspace(0,1,len(num_pixels_list)) + y_vec = np.array(num_pixels_list) mapping = interpolate.interp1d(x_vec,y_vec) return mapping diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 45fa720..36bb39c 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -575,7 +575,7 @@ def _cancel(self): @staticmethod def find_ultrack_widget(viewer: napari.Viewer) -> Optional["UltrackWidget"]: - + """Find and returns Ultrack widget. If widget not found, returns None""" for _, w in viewer.window._dock_widgets.items(): if isinstance(w.widget(), UltrackWidget): return w.widget() From 765556a714ed5518e0f0c0b2f3269af43d9cb95a Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 2 Aug 2024 09:07:17 -0700 Subject: [PATCH 20/45] added indexing test for ultrack-array --- ultrack/utils/_test/test_utils_array.py | 33 ++++++++++++++++++++++++- ultrack/utils/array.py | 21 ++++++++++------ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/ultrack/utils/_test/test_utils_array.py b/ultrack/utils/_test/test_utils_array.py index c70e7e1..600ae4d 100644 --- a/ultrack/utils/_test/test_utils_array.py +++ b/ultrack/utils/_test/test_utils_array.py @@ -1,8 +1,10 @@ import numpy as np import pytest +from typing import Tuple from ultrack.utils.array import array_apply - +from ultrack.config import MainConfig +from ultrack.utils.array import UltrackArray @pytest.mark.parametrize("axis", [0, 1]) def test_array_apply_parametrized(axis): @@ -17,3 +19,32 @@ def sample_func(arr_1, arr_2): array_apply(in_data, in_data, out_array=out_data, func=sample_func, axis=axis) other_axes_length = in_data.shape[1 - axis] assert np.array_equal(out_data, 2 * in_data + other_axes_length) + +@pytest.mark.parametrize( + "key,timelapse_mock_data", + [ + (1,{'n_dim':3}), + (1,{'n_dim':2}), + ((slice(None), 1),{'n_dim':3}), + ((slice(None), 1),{'n_dim':2}), + ((0, [1, 2]),{'n_dim':3}), + ((0, [1, 2]),{'n_dim':2}), + # ((-1, np.asarray([0, 3])),{'n_dim':3}), #does testing negative time make sense? + # ((-1, np.asarray([0, 3])),{'n_dim':2}), + ((slice(1), -2),{'n_dim':3}), + ((slice(1), -2),{'n_dim':2}), + ((np.asarray(0),),{'n_dim':3}), + ((np.asarray(0),),{'n_dim':2}), + ((0, 0, slice(32)),{'n_dim':3}), + ((0, 0, slice(32)),{'n_dim':2}), + ], + indirect=["timelapse_mock_data",], +) +def test_ultrack_array( + segmentation_database_mock_data: MainConfig, + tmp_path: str, + key: Tuple, + ): + ua = UltrackArray(segmentation_database_mock_data) + ua_numpy = ua[slice(None)] + np.testing.assert_equal(ua_numpy[key], ua[key]) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index c03f38f..17fc1fb 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -228,7 +228,7 @@ def __init__( self.volume = self.minmax.mean().astype(int) def __getitem__(self, - indexing: tuple, + indexing: Union[Tuple[Union[int, slice]], int, slice], ) -> np.ndarray: """Indexing the ultrack-array @@ -241,17 +241,24 @@ def __getitem__(self, array : numpy array array with painted segments """ + # print('indexing in getitem:',indexing) if isinstance(indexing, tuple): time, volume_slicing = indexing[0], indexing[1:] - else: + else: #if only 1 (time) is provided time = indexing - volume_slicing = ... + volume_slicing = tuple() - try: - time = time.item() # convert from numpy.int to int - except AttributeError: - time = time + if isinstance(time, slice): #if all time points are requested + return np.stack([ + self.__getitem__((t,) + volume_slicing) + for t in range(*time.indices(self.shape[0])) + ]) + else: + try: + time = time.item() # convert from numpy.int to int + except AttributeError: + time = time self.fill_array( time=time, From 590fad155df61c4ec9f1271ac349a113cb4b5b6d Mon Sep 17 00:00:00 2001 From: TeunHuijben Date: Fri, 2 Aug 2024 13:04:54 -0700 Subject: [PATCH 21/45] added test for hierarchy widget --- ultrack/utils/_test/test_utils_array.py | 1 - .../_test/test_hierarchy_viz_widget.py | 95 +++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 ultrack/widgets/_test/test_hierarchy_viz_widget.py diff --git a/ultrack/utils/_test/test_utils_array.py b/ultrack/utils/_test/test_utils_array.py index 600ae4d..76d77e3 100644 --- a/ultrack/utils/_test/test_utils_array.py +++ b/ultrack/utils/_test/test_utils_array.py @@ -42,7 +42,6 @@ def sample_func(arr_1, arr_2): ) def test_ultrack_array( segmentation_database_mock_data: MainConfig, - tmp_path: str, key: Tuple, ): ua = UltrackArray(segmentation_database_mock_data) diff --git a/ultrack/widgets/_test/test_hierarchy_viz_widget.py b/ultrack/widgets/_test/test_hierarchy_viz_widget.py new file mode 100644 index 0000000..889ff5c --- /dev/null +++ b/ultrack/widgets/_test/test_hierarchy_viz_widget.py @@ -0,0 +1,95 @@ +from typing import Callable, Tuple + +import napari +import numpy as np +import pytest +import zarr + +from ultrack.config import MainConfig +from ultrack.widgets.ultrackwidget import UltrackWidget +from ultrack.widgets.hierarchy_viz_widget import HierarchyVizWidget +from ultrack.widgets.ultrackwidget.workflows import WorkflowChoice +from ultrack.widgets.ultrackwidget.utils import UltrackInput + + + + + +def test_hierarchy_viz_widget( + make_napari_viewer: Callable[[],napari.Viewer], + segmentation_database_mock_data: MainConfig, + timelapse_mock_data: Tuple[zarr.Array, zarr.Array, zarr.Array], + request, + ): + + #################################################################################### + #OPTION 1: run widget using config + #################################################################################### + config = segmentation_database_mock_data + viewer = make_napari_viewer() + widget = HierarchyVizWidget(viewer,config) + viewer.window.add_dock_widget(widget) + + assert "hierarchy" in viewer.layers + + #test moving sliders: + widget._slider_update(0.75) + widget._slider_update(0.25) + + #test is shape of layer.data has same shape as the data shape reported in config: + assert tuple(config.data_config.metadata["shape"]) == viewer.layers['hierarchy'].data.shape #metadata["shape"] is a list, data.shape is a tuple + + + #################################################################################### + #OPTION 2: run widget by taking config from Ultrack-widget + #################################################################################### + #make napari viewer + viewer2 = make_napari_viewer() + + #get mock segmentation data + add to viewer + segments = timelapse_mock_data[2] + print('segments shape',segments.shape) + viewer2.add_labels(segments,name='segments') + + #open ultrack widget + widget_ultrack = UltrackWidget(viewer2) + viewer2.window.add_dock_widget(widget_ultrack) + + #setup ultrack widget for 'Labels' input + layers = viewer2.layers + workflow = WorkflowChoice.AUTO_FROM_LABELS + workflow_idx = widget_ultrack._cb_workflow.findData(workflow) + widget_ultrack._cb_workflow.setCurrentIndex(workflow_idx) + widget_ultrack._cb_workflow.currentIndexChanged.emit(workflow_idx) + # setting combobox choices manually, because they were not working automatically + widget_ultrack._cb_images[UltrackInput.LABELS].choices = layers + # # selecting layers + widget_ultrack._cb_images[UltrackInput.LABELS].value = layers['segments'] + + #load config + widget_ultrack._data_forms.load_config(config) + + + widget_hier = HierarchyVizWidget(viewer2) + viewer2.window.add_dock_widget(widget_hier) + + assert "hierarchy" in viewer.layers + + #test moving sliders: + widget._slider_update(0.75) + widget._slider_update(0.25) + + # test is shape of layer.data has same shape as the data shape reported in config: + assert tuple(config.data_config.metadata["shape"]) == viewer2.layers['hierarchy'].data.shape #metadata["shape"] is a list, data.shape in layer is a tuple + + #################################################################################### + #TO DO: + # - test other datatypes than labels (contours, detection, image, etc.), but shouldn't make a difference + # - how to test that the 'hierarchy' layer actually shows data? (apart from the layer having a .data property) + #################################################################################### + + + if request.config.getoption("--show-napari-viewer"): + napari.run() + + \ No newline at end of file From 8f3daa9d4140347a1e9b1a036b8b6b40a01ca0ab Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 16 Aug 2024 14:52:54 -0700 Subject: [PATCH 22/45] WIP basic autotuning implementation --- ultrack/config/dataconfig.py | 4 + ultrack/core/autotune.py | 286 +++++++++++++++++++++++++++++++++++ ultrack/core/interactive.py | 2 +- 3 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 ultrack/core/autotune.py diff --git a/ultrack/config/dataconfig.py b/ultrack/config/dataconfig.py index 62f07a9..6d838c4 100644 --- a/ultrack/config/dataconfig.py +++ b/ultrack/config/dataconfig.py @@ -13,6 +13,7 @@ class DatabaseChoices(Enum): sqlite = "sqlite" postgresql = "postgresql" + memory = "memory" class DataConfig(BaseModel): @@ -73,6 +74,9 @@ def database_path(self) -> str: if self.database == DatabaseChoices.sqlite.value: return f"sqlite:///{self.working_dir.absolute()}/data.db" + elif self.database == DatabaseChoices.memory.value: + return "sqlite+pysqlite:///:memory:" + elif self.database == DatabaseChoices.postgresql.value: return f"postgresql://{self.address}" diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py new file mode 100644 index 0000000..efe23ef --- /dev/null +++ b/ultrack/core/autotune.py @@ -0,0 +1,286 @@ +import logging +from typing import Callable, List, Literal, Optional, Tuple + +import mip +import mip.exceptions +import numpy as np +import pandas as pd +import sqlalchemy as sqla +from numpy.typing import ArrayLike +from skimage.measure import regionprops +from sqlalchemy.orm import Session +from toolz import curry + +from ultrack.config.config import MainConfig +from ultrack.core.database import LinkDB, NodeDB, OverlapDB, clear_all_data +from ultrack.core.interactive import add_new_node +from ultrack.core.linking.processing import link +from ultrack.core.segmentation.processing import _generate_id, segment +from ultrack.tracks.stats import estimate_drift +from ultrack.utils.multiprocessing import multiprocessing_apply + +LOG = logging.getLogger(__name__) + + +class SQLGTMatching: + def __init__( + self, + config: MainConfig, + solver: Literal["CBC", "GUROBI", ""] = "", + ) -> None: + # TODO + + self._data_config = config.data_config + + try: + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) + except mip.exceptions.InterfacingError as e: + LOG.warning(e) + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name="CBC") + + def _add_nodes(self) -> None: + # TODO + + engine = sqla.create_engine(self._data_config.database_path) + + # t = 0 is hierarchies + # t = 1 is ground-truth nodes + with Session(engine) as session: + query = session.query(NodeDB.id, NodeDB.t).where(NodeDB.t == 0) + self._nodes_df = pd.read_sql(query.statement, session.bind, index_col="id") + + size = len(self._nodes_df) + self._nodes = self._model.add_var_tensor( + (size,), name="nodes", var_type=mip.BINARY + ) + + # hierarchy overlap constraints + with Session(engine) as session: + query = session.query(OverlapDB).join( + NodeDB, NodeDB.id == OverlapDB.node_id + ) + overlap_df = pd.read_sql(query.statement, session.bind) + + overlap_df["node_id"] = self._nodes_df.index.get_indexer(overlap_df["node_id"]) + overlap_df["ancestor_id"] = self._nodes_df.index.get_indexer( + overlap_df["ancestor_id"] + ) + + for node_id, anc_id in zip(overlap_df["node_id"], overlap_df["ancestor_id"]): + self._model.add_constr(self._nodes[node_id] + self._nodes[anc_id] <= 1) + + def _add_edges(self) -> None: + # TODO + + if not hasattr(self, "_nodes"): + raise ValueError("Nodes must be added before adding edges.") + + engine = sqla.create_engine(self._data_config.database_path) + with Session(engine) as session: + query = session.query(LinkDB).join(NodeDB, NodeDB.id == LinkDB.source_id) + df = pd.read_sql(query.statement, session.bind) + + df["source_id"] = self._nodes_df.index.get_indexer(df["source_id"]) + df.reset_index(drop=True, inplace=True) + + self._edges = self._model.add_var_tensor( + (len(df),), + name="edges", + var_type=mip.BINARY, + ) + # setting objective function + self._model.objective = mip.xsum(df["weight"].to_numpy() * self._edges) + + # source_id is time point T (hierarchies id) + # target_id is time point T+1 (ground-truth) + for source_id, group in df.groupby("source_id", as_index=False): + self._model.add_constr( + self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) + ) + + for _, group in df.groupby("target_id", as_index=False): + self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) + + def __call__(self) -> Tuple[pd.Series, float]: + # TODO + self._add_nodes() + self._add_edges() + self._model.optimize() + + score = self._model.objective_value + solution = pd.Series( + data=[var.x for var in self._nodes], + index=self._nodes_df.index, + name="solution", + ) + + return score, solution + + +@curry +def _tune_time_point( + t: int, + foreground: ArrayLike, + contours: ArrayLike, + gt_labels: ArrayLike, + config: MainConfig, + scale: Optional[ArrayLike], +) -> Optional[Tuple[MainConfig, pd.DataFrame]]: + # TODO + + config = config.copy(deep=True) + + clear_all_data(config.data_config.database_path) + + gt_labels = np.asarray(gt_labels[t]) + gt_rows = [] + + props = regionprops(gt_labels) + + if len(props) == 0: + return None + + foreground = np.asarray(foreground[t]) + contours = np.asarray(contours[t]) + + # adding hierarchy nodes + segment( + foreground=foreground[None, ...], + contours=contours[None, ...], + config=config, + overwrite=False, + ) + + # adding ground-truth nodes + for obj in props: + add_new_node( + config=config, + time=1, + mask=obj.image, + bbox=obj.bbox, + index=_generate_id(obj.label, 1, 10_000_000), + include_overlaps=False, + ) + row = {c: v for c, v in zip("xyz", obj.centroid[::-1])} + row["track_id"] = obj.label + gt_rows.append(row) + + gt_df = pd.DataFrame.from_records(gt_rows) + gt_df["t"] = t + + # computing GT matching + link(config, scale=scale, overwrite=False) + + matching = SQLGTMatching(config) + total_score, solution = matching() + mean_score = total_score / len(gt_df) + + print(f"Total score: {total_score:0.4f}") + print(f"Mean score: {mean_score:0.4f}") + + engine = sqla.create_engine(config.data_config.database_path) + + with Session(engine) as session: + query = session.query( + NodeDB.id, + NodeDB.hier_parent_id, + NodeDB.area, + NodeDB.frontier, + ).where(NodeDB.t == 0) + + df = pd.read_sql(query.statement, session.bind, index_col="id") + + df = df.merge(solution, left_index=True, right_index=True) + + frontiers = df["frontier"] + + df["parent_frontier"] = df["hier_parent_id"].map(lambda x: frontiers.get(x, -1.0)) + df.loc[df["parent_frontier"] < 0, "parent_frontier"] = df["frontier"].max() + + # selecting only nodes in solution + # must be after parent_frontier computation + df = df[df["solution"] > 0.5] + + config.segmentation_config.min_frontier = df["parent_frontier"].min() + config.segmentation_config.min_area = df["area"].min() + config.segmentation_config.max_area = df["area"].max() + + return config, gt_df + + +def _min_attr( + configs: List[MainConfig], + getter: Callable[[MainConfig], float], +) -> float: + return min(getter(cfg) for cfg in configs) + # return np.quantile([getter(cfg) for cfg in configs], 0.01) + + +def _max_attr( + configs: List[MainConfig], + getter: Callable[[MainConfig], float], +) -> float: + return max(getter(cfg) for cfg in configs) + # return np.quantile([getter(cfg) for cfg in configs], 0.99) + + +def auto_tune_config( + foreground: ArrayLike, + contours: ArrayLike, + ground_truth_labels: ArrayLike, + config: Optional[MainConfig] = None, + scale: Optional[ArrayLike] = None, +) -> MainConfig: + + prev_database = config.data_config.database + # config.data_config.database = "memory" + + if config is None: + config = MainConfig() + else: + config = config.copy(deep=True) + + tuning_tup = multiprocessing_apply( + _tune_time_point( + foreground=foreground, + contours=contours, + gt_labels=ground_truth_labels, + config=config, + scale=scale, + ), + range(foreground.shape[0]), + n_workers=config.segmentation_config.n_workers, + desc="Auto-tuning individual time points", + ) + tuning_tup = tuple(zip(*tuning_tup)) + new_configs: List[MainConfig] = tuning_tup[0] + gt_df = pd.concat(tuning_tup[1], ignore_index=True) + + if scale is not None: + cols = ["z", "y", "x"][-len(scale) :] + gt_df[cols] *= scale + + if "z" not in gt_df.columns: + gt_df["z"] = 0.0 + + max_distance = estimate_drift(gt_df) + if not np.isnan(max_distance) or max_distance > 0: + config.linking_config.max_distance = max_distance + 1.0 + + config.segmentation_config.min_area = ( + _min_attr(new_configs, lambda cfg: cfg.segmentation_config.min_area) * 0.95 + ) + + config.segmentation_config.max_area = ( + _max_attr(new_configs, lambda cfg: cfg.segmentation_config.max_area) * 1.025 + ) + + config.segmentation_config.min_frontier = max( + _min_attr(new_configs, lambda cfg: cfg.segmentation_config.min_frontier) + - 0.025, + 0.0, + ) + + config.data_config.database = prev_database + + return config diff --git a/ultrack/core/interactive.py b/ultrack/core/interactive.py index 18a8ad8..cabd657 100644 --- a/ultrack/core/interactive.py +++ b/ultrack/core/interactive.py @@ -264,7 +264,7 @@ def add_new_node( node = Node.from_mask( time=time, mask=mask, - bbox=bbox, + bbox=np.asarray(bbox), ) if node.area == 0: raise ValueError("Node area is zero. Something went wrong.") From 4f563376c5aada1bfcb6fd518c5f2088c8029a18 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 19 Aug 2024 08:31:12 -0700 Subject: [PATCH 23/45] removing in memory database --- ultrack/config/dataconfig.py | 4 ---- ultrack/core/autotune.py | 5 ----- 2 files changed, 9 deletions(-) diff --git a/ultrack/config/dataconfig.py b/ultrack/config/dataconfig.py index 6d838c4..62f07a9 100644 --- a/ultrack/config/dataconfig.py +++ b/ultrack/config/dataconfig.py @@ -13,7 +13,6 @@ class DatabaseChoices(Enum): sqlite = "sqlite" postgresql = "postgresql" - memory = "memory" class DataConfig(BaseModel): @@ -74,9 +73,6 @@ def database_path(self) -> str: if self.database == DatabaseChoices.sqlite.value: return f"sqlite:///{self.working_dir.absolute()}/data.db" - elif self.database == DatabaseChoices.memory.value: - return "sqlite+pysqlite:///:memory:" - elif self.database == DatabaseChoices.postgresql.value: return f"postgresql://{self.address}" diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index efe23ef..c266b01 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -232,9 +232,6 @@ def auto_tune_config( scale: Optional[ArrayLike] = None, ) -> MainConfig: - prev_database = config.data_config.database - # config.data_config.database = "memory" - if config is None: config = MainConfig() else: @@ -281,6 +278,4 @@ def auto_tune_config( 0.0, ) - config.data_config.database = prev_database - return config From f1f05c539e6c9cdb816021c7df08550127aa404a Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 19 Aug 2024 09:03:39 -0700 Subject: [PATCH 24/45] in memory DB --- ultrack/config/dataconfig.py | 10 ++++++++++ ultrack/core/autotune.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/ultrack/config/dataconfig.py b/ultrack/config/dataconfig.py index 62f07a9..5aa4dcd 100644 --- a/ultrack/config/dataconfig.py +++ b/ultrack/config/dataconfig.py @@ -13,6 +13,7 @@ class DatabaseChoices(Enum): sqlite = "sqlite" postgresql = "postgresql" + memory = "memory" class DataConfig(BaseModel): @@ -32,6 +33,12 @@ class DataConfig(BaseModel): address: Optional[str] = None """``SPECIAL``: Postgresql database path, for example, ``postgres@localhost:12345/example``""" + in_memory_db_id: int = 0 + """ + ``SPECIAL``: Memory database id used to identify the database in memory, + must be altered manually if multiple instances are used + """ + class Config: validate_assignment = True use_enum_values = True @@ -73,6 +80,9 @@ def database_path(self) -> str: if self.database == DatabaseChoices.sqlite.value: return f"sqlite:///{self.working_dir.absolute()}/data.db" + elif self.database == DatabaseChoices.memory.value: + return f"sqlite:///file:{self.in_memory_db_id}?mode=memory&cache=shared&uri=true" + elif self.database == DatabaseChoices.postgresql.value: return f"postgresql://{self.address}" diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index c266b01..a1a4628 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -130,6 +130,9 @@ def _tune_time_point( config = config.copy(deep=True) + prev_in_memory_db_id = config.data_config.in_memory_db_id + config.data_config.in_memory_db_id = t + clear_all_data(config.data_config.database_path) gt_labels = np.asarray(gt_labels[t]) @@ -205,6 +208,8 @@ def _tune_time_point( config.segmentation_config.min_area = df["area"].min() config.segmentation_config.max_area = df["area"].max() + config.data_config.in_memory_db_id = prev_in_memory_db_id + return config, gt_df @@ -237,6 +242,9 @@ def auto_tune_config( else: config = config.copy(deep=True) + prev_db = config.data_config.database + config.data_config.database = "memory" + tuning_tup = multiprocessing_apply( _tune_time_point( foreground=foreground, @@ -278,4 +286,6 @@ def auto_tune_config( 0.0, ) + config.data_config.database = prev_db + return config From 1a05547777eaed791d6706ede5c920a4d762293f Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sat, 31 Aug 2024 09:47:54 -0400 Subject: [PATCH 25/45] WIP database centric autotuning --- ultrack/core/autotune.py | 127 +++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 64 deletions(-) diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index a1a4628..f6237d9 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, List, Literal, Optional, Tuple +from typing import Literal, Optional, Tuple import mip import mip.exceptions @@ -15,7 +15,7 @@ from ultrack.core.database import LinkDB, NodeDB, OverlapDB, clear_all_data from ultrack.core.interactive import add_new_node from ultrack.core.linking.processing import link -from ultrack.core.segmentation.processing import _generate_id, segment +from ultrack.core.segmentation.processing import segment from ultrack.tracks.stats import estimate_drift from ultrack.utils.multiprocessing import multiprocessing_apply @@ -78,43 +78,56 @@ def _add_edges(self) -> None: engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: query = session.query(LinkDB).join(NodeDB, NodeDB.id == LinkDB.source_id) - df = pd.read_sql(query.statement, session.bind) + self._edges_df = pd.read_sql(query.statement, session.bind) - df["source_id"] = self._nodes_df.index.get_indexer(df["source_id"]) - df.reset_index(drop=True, inplace=True) + self._edges_df["source_id"] = self._nodes_df.index.get_indexer( + self._edges_df["source_id"] + ) + self._edges_df.reset_index(drop=True, inplace=True) self._edges = self._model.add_var_tensor( - (len(df),), + (len(self._edges_df),), name="edges", var_type=mip.BINARY, ) # setting objective function - self._model.objective = mip.xsum(df["weight"].to_numpy() * self._edges) + self._model.objective = mip.xsum( + self._edges_df["weight"].to_numpy() * self._edges + ) # source_id is time point T (hierarchies id) # target_id is time point T+1 (ground-truth) - for source_id, group in df.groupby("source_id", as_index=False): + for source_id, group in self._edges_df.groupby("source_id", as_index=False): self._model.add_constr( self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) ) - for _, group in df.groupby("target_id", as_index=False): + for _, group in self._edges_df.groupby("target_id", as_index=False): self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) - def __call__(self) -> Tuple[pd.Series, float]: + def __call__(self) -> Tuple[float, pd.DataFrame]: # TODO self._add_nodes() self._add_edges() self._model.optimize() + data = [] + + for i, e_var in enumerate(self._edges): + if e_var.x > 0.5: + data.append( + { + "id": self._nodes_df.index.get_indexer( + self._edges_df.iloc[i]["source_id"] + ), + "gt_id": self._edges_df.iloc[i]["target_id"], + } + ) + score = self._model.objective_value - solution = pd.Series( - data=[var.x for var in self._nodes], - index=self._nodes_df.index, - name="solution", - ) + matching_df = pd.DataFrame(data) - return score, solution + return score, matching_df @curry @@ -125,12 +138,10 @@ def _tune_time_point( gt_labels: ArrayLike, config: MainConfig, scale: Optional[ArrayLike], -) -> Optional[Tuple[MainConfig, pd.DataFrame]]: +) -> Tuple[pd.DataFrame, pd.DataFrame]: # TODO config = config.copy(deep=True) - - prev_in_memory_db_id = config.data_config.in_memory_db_id config.data_config.in_memory_db_id = t clear_all_data(config.data_config.database_path) @@ -141,7 +152,7 @@ def _tune_time_point( props = regionprops(gt_labels) if len(props) == 0: - return None + LOG.warning(f"No objects found in time point {t}") foreground = np.asarray(foreground[t]) contours = np.asarray(contours[t]) @@ -161,7 +172,7 @@ def _tune_time_point( time=1, mask=obj.image, bbox=obj.bbox, - index=_generate_id(obj.label, 1, 10_000_000), + index=obj.label, # _generate_id(obj.label, 1, 10_000_000), include_overlaps=False, ) row = {c: v for c, v in zip("xyz", obj.centroid[::-1])} @@ -175,8 +186,12 @@ def _tune_time_point( link(config, scale=scale, overwrite=False) matching = SQLGTMatching(config) - total_score, solution = matching() - mean_score = total_score / len(gt_df) + total_score, solution_df = matching() + + if len(gt_df) > 0: + mean_score = total_score / len(gt_df) + else: + mean_score = 0.0 print(f"Total score: {total_score:0.4f}") print(f"Mean score: {mean_score:0.4f}") @@ -187,13 +202,14 @@ def _tune_time_point( query = session.query( NodeDB.id, NodeDB.hier_parent_id, + NodeDB.t_hier_id, NodeDB.area, NodeDB.frontier, ).where(NodeDB.t == 0) df = pd.read_sql(query.statement, session.bind, index_col="id") - df = df.merge(solution, left_index=True, right_index=True) + df = df.join(solution_df) frontiers = df["frontier"] @@ -202,31 +218,15 @@ def _tune_time_point( # selecting only nodes in solution # must be after parent_frontier computation - df = df[df["solution"] > 0.5] - - config.segmentation_config.min_frontier = df["parent_frontier"].min() - config.segmentation_config.min_area = df["area"].min() - config.segmentation_config.max_area = df["area"].max() + # matched_df = df[df["solution"] > 0.5] - config.data_config.in_memory_db_id = prev_in_memory_db_id + # config.segmentation_config.min_frontier = matched_df["parent_frontier"].min() + # config.segmentation_config.min_area = matched_df["area"].min() + # config.segmentation_config.max_area = matched_df["area"].max() - return config, gt_df + # config.data_config.in_memory_db_id = prev_in_memory_db_id - -def _min_attr( - configs: List[MainConfig], - getter: Callable[[MainConfig], float], -) -> float: - return min(getter(cfg) for cfg in configs) - # return np.quantile([getter(cfg) for cfg in configs], 0.01) - - -def _max_attr( - configs: List[MainConfig], - getter: Callable[[MainConfig], float], -) -> float: - return max(getter(cfg) for cfg in configs) - # return np.quantile([getter(cfg) for cfg in configs], 0.99) + return df, gt_df def auto_tune_config( @@ -235,7 +235,7 @@ def auto_tune_config( ground_truth_labels: ArrayLike, config: Optional[MainConfig] = None, scale: Optional[ArrayLike] = None, -) -> MainConfig: +) -> Tuple[MainConfig, pd.DataFrame]: if config is None: config = MainConfig() @@ -258,7 +258,7 @@ def auto_tune_config( desc="Auto-tuning individual time points", ) tuning_tup = tuple(zip(*tuning_tup)) - new_configs: List[MainConfig] = tuning_tup[0] + df = pd.concat(tuning_tup[0], ignore_index=True) gt_df = pd.concat(tuning_tup[1], ignore_index=True) if scale is not None: @@ -268,24 +268,23 @@ def auto_tune_config( if "z" not in gt_df.columns: gt_df["z"] = 0.0 - max_distance = estimate_drift(gt_df) - if not np.isnan(max_distance) or max_distance > 0: - config.linking_config.max_distance = max_distance + 1.0 + if len(gt_df) > 0: + max_distance = estimate_drift(gt_df) + if not np.isnan(max_distance) or max_distance > 0: + config.linking_config.max_distance = max_distance + 1.0 - config.segmentation_config.min_area = ( - _min_attr(new_configs, lambda cfg: cfg.segmentation_config.min_area) * 0.95 - ) + if "solution" in df.columns: + matched_df = df[df["solution"] > 0.5] + config.segmentation_config.min_area = matched_df["area"].min() * 0.95 - config.segmentation_config.max_area = ( - _max_attr(new_configs, lambda cfg: cfg.segmentation_config.max_area) * 1.025 - ) + config.segmentation_config.max_area = matched_df["area"].max() * 1.025 - config.segmentation_config.min_frontier = max( - _min_attr(new_configs, lambda cfg: cfg.segmentation_config.min_frontier) - - 0.025, - 0.0, - ) + config.segmentation_config.min_frontier = max( + matched_df["parent_frontier"].min() - 0.025, 0.0 + ) - config.data_config.database = prev_db + config.data_config.database = prev_db + else: + LOG.warning("No nodes were matched. Keeping previous configuration.") - return config + return config, df From cc75143448d4bc563be9bf9d44bdf95411f09029 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sat, 31 Aug 2024 10:12:45 -0400 Subject: [PATCH 26/45] fix db bug --- ultrack/core/autotune.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index f6237d9..ba57cd8 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -282,9 +282,9 @@ def auto_tune_config( config.segmentation_config.min_frontier = max( matched_df["parent_frontier"].min() - 0.025, 0.0 ) - - config.data_config.database = prev_db else: LOG.warning("No nodes were matched. Keeping previous configuration.") + config.data_config.database = prev_db + return config, df From 90aad46ff4919b21a64b1cb5a0aba7c916166747 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 09:35:47 -0400 Subject: [PATCH 27/45] added feature extraction during segmentation --- .../_test/test_segment_processing.py | 9 +- ultrack/core/segmentation/processing.py | 93 ++++++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/ultrack/core/segmentation/_test/test_segment_processing.py b/ultrack/core/segmentation/_test/test_segment_processing.py index 093a5c9..41e82a6 100644 --- a/ultrack/core/segmentation/_test/test_segment_processing.py +++ b/ultrack/core/segmentation/_test/test_segment_processing.py @@ -8,7 +8,7 @@ from ultrack import segment from ultrack.config.config import MainConfig -from ultrack.core.database import NodeDB, OverlapDB +from ultrack.core.database import NodeDB, OverlapDB, get_node_values @pytest.mark.parametrize( @@ -43,6 +43,7 @@ def test_multiprocessing_segment( foreground, contours, config_instance, + properties=["centroid"], ) assert config_instance.data_config.metadata["shape"] == list(contours.shape) @@ -83,3 +84,9 @@ def test_multiprocessing_segment( continue node_j = nodes[j] assert node_i.IoU(node_j) == 0.0 + + feats = get_node_values(config_instance.data_config, i, NodeDB.features) + feats_name = config_instance.data_config.metadata["properties"] + + assert len(feats) == len(feats_name) + assert isinstance(feats, np.ndarray) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index f5b22f6..2440977 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -1,13 +1,14 @@ import logging import pickle from contextlib import nullcontext -from typing import List, Optional +from typing import Callable, List, Optional import fasteners import numpy as np import sqlalchemy as sqla import zarr from numpy.typing import ArrayLike +from skimage.measure._regionprops import RegionProperties, regionprops_table from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from toolz import curry @@ -15,6 +16,7 @@ from ultrack.config.config import MainConfig, SegmentationConfig from ultrack.core.database import Base, NodeDB, OverlapDB, clear_all_data from ultrack.core.segmentation.hierarchy import create_hierarchies +from ultrack.core.segmentation.node import Node from ultrack.utils.array import check_array_chunk from ultrack.utils.deprecation import rename_argument from ultrack.utils.multiprocessing import ( @@ -77,6 +79,40 @@ def _insert_db( overlaps.clear() +def create_feats_callback( + shape: ArrayLike, image: Optional[ArrayLike], properties: List[str] +) -> Callable[[Node], np.ndarray]: + + mask = np.zeros(shape, dtype=bool) + + def _feats_callback(node: Node) -> np.ndarray: + + node.paint_buffer(mask, True, include_time=False) + + if image is None: + frame = None + else: + frame = np.asarray(image[node.time]) + + obj = RegionProperties( + node.slice, + label=True, + label_image=mask, + intensity_image=frame, + cache_active=True, + ) + + feats = np.concatenate( + [np.ravel(getattr(obj, p)) for p in properties], dtype=np.float32 + ) + + node.paint_buffer(mask, False, include_time=False) + + return feats + + return _feats_callback + + @curry def _process( time: int, @@ -88,6 +124,8 @@ def _process( write_lock: Optional[fasteners.InterProcessLock] = None, catch_duplicates_expection: bool = False, insertion_throttle_rate: int = 50, + image: Optional[ArrayLike] = None, + properties: Optional[List[str]] = None, ) -> None: """Process `foreground` and `edge` of current time and add data to database. @@ -111,6 +149,10 @@ def _process( If True, catches duplicates exception, by default False. insertion_throttle_rate : int Throttling rate for insertion, by default 50. + image : Optional[ArrayLike], optional + Image array for segments properties, by default None. + properties : Optional[List[str]], optional + List of properties to compute for each segment, by default None. """ np.random.seed(time) @@ -138,6 +180,12 @@ def _process( nodes = [] overlaps = [] + node_feats = None + feats_callback = None + + if properties is not None: + feats_callback = create_feats_callback(foreground.shape[1:], image, properties) + for h, hierarchy in enumerate(hiers): hierarchy.cache = True @@ -154,6 +202,9 @@ def _process( else: z, y, x = centroid + if feats_callback is not None: + node_feats = feats_callback(hier_node) + node = NodeDB( id=hier_node.id, t_node_id=index, @@ -164,6 +215,7 @@ def _process( x=int(x), area=int(hier_node.area), pickle=pickle.dumps(hier_node), # pickling to reduce memory usage + features=node_feats, ) hier_index_map[hier_node._h_node_index] = node @@ -246,6 +298,28 @@ def _check_zarr_memory_store(arr: ArrayLike) -> None: ) +def _get_properties_names( + shape: ArrayLike, + image: Optional[ArrayLike], + properties: Optional[List[str]], +) -> Optional[List[str]]: + + if properties is None: + return None + + if image is None: + dummy_image = None + else: + dummy_image = np.ones((4,) * (image.ndim - 1), dtype=np.float32) + + dummy_labels = np.zeros(shape, dtype=np.uint32) + dummy_labels[:2, :2] = 1 + + data_dict = regionprops_table(dummy_labels, dummy_image, properties=properties) + + return list(data_dict.keys()) + + @rename_argument("detection", "foreground") @rename_argument("edge", "contours") def segment( @@ -256,6 +330,8 @@ def segment( batch_index: Optional[int] = None, overwrite: bool = False, insertion_throttle_rate: int = 50, + image: Optional[ArrayLike] = None, + properties: Optional[List[str]] = None, ) -> None: """Add candidate segmentation (nodes) from `foreground` and `edge` to database. @@ -276,6 +352,10 @@ def segment( insertion_throttle_rate : int Throttling rate for insertion, by default 50. Only used with non-sqlite databases. + image : Optional[ArrayLike], optional + Image array of shape (T, (Z), Y, X) for segments properties, by default None. + properties : Optional[List[str]], optional + List of properties to compute for each segment, by default None. """ LOG.info(f"Adding nodes with SegmentationConfig:\n{config.segmentation_config}") @@ -306,7 +386,14 @@ def segment( clear_all_data(config.data_config.database_path) Base.metadata.create_all(engine) - config.data_config.metadata_add({"shape": foreground.shape}) + config.data_config.metadata_add( + { + "shape": foreground.shape, + "properties": _get_properties_names( + foreground.shape[1:], image, properties=properties + ), + } + ) with multiprocessing_sqlite_lock(config.data_config) as lock: process = _process( @@ -318,6 +405,8 @@ def segment( max_segments_per_time=max_segments_per_time, catch_duplicates_expection=batch_index is not None, insertion_throttle_rate=insertion_throttle_rate, + image=image, + properties=properties, ) multiprocessing_apply( From 3a6acc4d9771206276da0431ffe9deb2a28b05b8 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 14:04:53 -0400 Subject: [PATCH 28/45] added get_nodes_features functionality --- ultrack/core/_test/test_database.py | 20 ++++++++ ultrack/core/database.py | 51 ++++++++++++++++--- ultrack/core/segmentation/__init__.py | 1 + .../_test/test_segment_processing.py | 10 ++++ ultrack/core/segmentation/processing.py | 38 +++++++++++++- ultrack/core/tracker.py | 8 ++- 6 files changed, 119 insertions(+), 9 deletions(-) diff --git a/ultrack/core/_test/test_database.py b/ultrack/core/_test/test_database.py index 6e6f005..ea14763 100644 --- a/ultrack/core/_test/test_database.py +++ b/ultrack/core/_test/test_database.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from ultrack.config.config import MainConfig @@ -39,6 +40,7 @@ def test_set_get_node_values( segmentation_database_mock_data: MainConfig, ) -> None: + # test single node index = _generate_id(1, 1, 1_000_000) set_node_values( @@ -54,3 +56,21 @@ def test_set_get_node_values( ) assert value == 0 + + # test multiple nodes + indices = np.asarray([_generate_id(i, 1, 1_000_000) for i in range(1, 3)]) + + for i, index in enumerate(indices): + set_node_values( + segmentation_database_mock_data.data_config, + index.item(), + area=i, + ) + + value = get_node_values( + segmentation_database_mock_data.data_config, + indices, + NodeDB.area, + ) + + np.testing.assert_array_equal(value, np.arange(len(indices))) diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 3b1232d..5d2c86b 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -1,9 +1,12 @@ import enum import logging from pathlib import Path -from typing import Any, List, Union +from typing import Any, List, Optional, Union +import numpy as np +import pandas as pd import sqlalchemy as sqla +from numpy.typing import ArrayLike from sqlalchemy import ( BigInteger, Boolean, @@ -85,6 +88,7 @@ class NodeDB(Base): area = Column(Integer) selected = Column(Boolean, default=False) pickle = Column(MaybePickleType) + features = Column(MaybePickleType, default=None) segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN) node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) appear_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) @@ -156,6 +160,9 @@ def set_node_values( annot : NodeAnnotation Node annotation. """ + if hasattr(node_id, "item"): + node_id = node_id.item() + engine = sqla.create_engine(data_config.database_path) with Session(engine) as session: stmt = sqla.update(NodeDB).where(NodeDB.id == node_id).values(**kwargs) @@ -164,27 +171,57 @@ def set_node_values( def get_node_values( - data_config: DataConfig, node_id: int, values: Union[Column, List[Column]] -) -> Any: + data_config: DataConfig, + indices: Optional[Union[int, ArrayLike]], + values: Union[Column, List[Column]], +) -> List[Any]: """Get the annotation of `node_id`. Parameters ---------- data_config : DataConfig Data configuration parameters. - node_id : int - Node database index. + indices : int + Node database indices. values : List[Column] List of columns to be queried. """ if not isinstance(values, List): values = [values] + values.insert(0, NodeDB.id) + + is_scalar = False + if isinstance(indices, int): + indices = [indices] + is_scalar = True + + elif isinstance(indices, np.ndarray): + indices = indices.astype(int).tolist() + engine = sqla.create_engine(data_config.database_path) with Session(engine) as session: - annotation = session.query(*values).where(NodeDB.id == node_id).first()[0] + query = session.query(*values) + + if indices is not None: + query = query.where(NodeDB.id.in_(indices)) + + df = pd.read_sql_query(query.statement, session.bind, index_col="id") + + if indices is not None and len(df) != len(indices): + raise ValueError( + f"Query returned {len(df)} rows, expected {len(indices)}." + "\nCheck if node_id exists in database." + ) + + df = df.squeeze() + if is_scalar: + try: + df = df.item() + except ValueError: + pass - return annotation + return df def clear_all_data(database_path: str) -> None: diff --git a/ultrack/core/segmentation/__init__.py b/ultrack/core/segmentation/__init__.py index e69de29..90f6f51 100644 --- a/ultrack/core/segmentation/__init__.py +++ b/ultrack/core/segmentation/__init__.py @@ -0,0 +1 @@ +from ultrack.core.segmentation.processing import get_nodes_features diff --git a/ultrack/core/segmentation/_test/test_segment_processing.py b/ultrack/core/segmentation/_test/test_segment_processing.py index 41e82a6..9a8a6d6 100644 --- a/ultrack/core/segmentation/_test/test_segment_processing.py +++ b/ultrack/core/segmentation/_test/test_segment_processing.py @@ -9,6 +9,7 @@ from ultrack import segment from ultrack.config.config import MainConfig from ultrack.core.database import NodeDB, OverlapDB, get_node_values +from ultrack.core.segmentation import get_nodes_features @pytest.mark.parametrize( @@ -90,3 +91,12 @@ def test_multiprocessing_segment( assert len(feats) == len(feats_name) assert isinstance(feats, np.ndarray) + + df = get_nodes_features(config_instance) + + centroids_cols = [f"centroid-{i}" for i in range(contours.ndim - 1)] + + assert df.shape[0] == len(nodes) + np.testing.assert_array_equal( + df.columns.to_numpy(dtype=str), ["t", "z", "y", "x", "area"] + centroids_cols + ) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index 2440977..09770da 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -5,6 +5,7 @@ import fasteners import numpy as np +import pandas as pd import sqlalchemy as sqla import zarr from numpy.typing import ArrayLike @@ -14,7 +15,13 @@ from toolz import curry from ultrack.config.config import MainConfig, SegmentationConfig -from ultrack.core.database import Base, NodeDB, OverlapDB, clear_all_data +from ultrack.core.database import ( + Base, + NodeDB, + OverlapDB, + clear_all_data, + get_node_values, +) from ultrack.core.segmentation.hierarchy import create_hierarchies from ultrack.core.segmentation.node import Node from ultrack.utils.array import check_array_chunk @@ -415,3 +422,32 @@ def segment( config.segmentation_config.n_workers, desc="Adding nodes to database", ) + + +def get_nodes_features( + config: MainConfig, + indices: Optional[ArrayLike] = None, +) -> pd.DataFrame: + """ + Creates a pandas dataframe from nodes features defined during segmentation + plus area and coordinates. + + Parameters + ---------- + config : MainConfig + Configuration parameters. + indices : Optional[ArrayLike], optional + List of node indices, by default + + Returns + ------- + pd.DataFrame + Dataframe with nodes features + """ + feats_cols = [NodeDB.t, NodeDB.z, NodeDB.y, NodeDB.x, NodeDB.area, NodeDB.features] + df = get_node_values(config.data_config, indices=indices, values=feats_cols) + feat_columns = config.data_config.metadata["properties"] + df.loc[:, feat_columns] = np.vstack(df["features"].to_numpy()) + df.drop(columns=["features"], inplace=True) + + return df diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index f3d21a3..2e23bac 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -18,7 +18,7 @@ ) from ultrack.core.linking.processing import link from ultrack.core.main import track -from ultrack.core.segmentation.processing import segment +from ultrack.core.segmentation.processing import get_nodes_features, segment from ultrack.core.solve.processing import solve from ultrack.imgproc.flow import add_flow from ultrack.utils.deprecation import rename_argument @@ -153,3 +153,9 @@ def to_tracks_layer(self, *args, **kwargs) -> Tuple[pd.DataFrame, Dict]: def export_by_extension(self, filename: str, overwrite: bool = False) -> None: self._assert_solved() export_tracks_by_extension(self.config, filename, overwrite=overwrite) + + @functools.wraps(get_nodes_features) + def get_nodes_features(self, **kwargs) -> pd.DataFrame: + self._assert_solved() + nodes_features_df = get_nodes_features(self.config, **kwargs) + return nodes_features_df From bb97ad88c686561d9e5386f78baedda501599dbc Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 15:30:13 -0400 Subject: [PATCH 29/45] added manual linking value to ultrack --- ultrack/core/linking/processing.py | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/ultrack/core/linking/processing.py b/ultrack/core/linking/processing.py index f02a944..dcc5fb8 100644 --- a/ultrack/core/linking/processing.py +++ b/ultrack/core/linking/processing.py @@ -230,3 +230,40 @@ def link( multiprocessing_apply( process, time_points, config.linking_config.n_workers, desc="Linking nodes." ) + + +def add_links( + config: MainConfig, + source: ArrayLike, + target: ArrayLike, + weight: ArrayLike, +) -> None: + """ + Adds user-defined links to the database. + + Parameters + ---------- + config : MainConfig + Configuration parameters. + source : ArrayLike + Source (t) node id. + target : ArrayLike + Target (t + 1) node id. + weight : ArrayLike + Link weight. + """ + df = pd.DataFrame( + { + "source_id": np.asarray(source, dtype=int), + "target_id": np.asarray(target, dtype=int), + "weight": weight, + } + ) + + engine = sqla.create_engine( + config.data_config.database_path, + hide_parameters=True, + ) + + with engine.begin() as conn: + df.to_sql(name=LinkDB.__tablename__, con=conn, if_exists="append", index=False) From 5c6ae642ad3eac2d8865f25b99197ba4ee26e637 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 15:38:31 -0400 Subject: [PATCH 30/45] improving documentation --- ultrack/core/segmentation/processing.py | 35 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index 09770da..bdf9bac 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -89,7 +89,23 @@ def _insert_db( def create_feats_callback( shape: ArrayLike, image: Optional[ArrayLike], properties: List[str] ) -> Callable[[Node], np.ndarray]: + """ + Create a callback function to compute features for each node. + + Parameters + ---------- + shape : ArrayLike + Volume (plane) shape. + image : Optional[ArrayLike] + Image array for segments properties, could have channel dimension on last axis. + properties : List[str] + List of properties to compute for each segment, see skimage.measure.regionprops documentation. + Returns + ------- + Callable[[Node], np.ndarray] + Callback function to compute features for each node returning a numpy array. + """ mask = np.zeros(shape, dtype=bool) def _feats_callback(node: Node) -> np.ndarray: @@ -157,7 +173,7 @@ def _process( insertion_throttle_rate : int Throttling rate for insertion, by default 50. image : Optional[ArrayLike], optional - Image array for segments properties, by default None. + Image array for segments properties, channel dimension is optional on last axis, by default None. properties : Optional[List[str]], optional List of properties to compute for each segment, by default None. """ @@ -310,6 +326,18 @@ def _get_properties_names( image: Optional[ArrayLike], properties: Optional[List[str]], ) -> Optional[List[str]]: + """ + Get properties names from provided properties list. + + Parameters + ---------- + shape : ArrayLike + Volume (plane) shape. + image : Optional[ArrayLike] + Image array for segments properties, could have channel dimension on last axis. + properties : Optional[List[str]] + List of properties to compute for each segment, see skimage.measure.regionprops documentation. + """ if properties is None: return None @@ -360,9 +388,10 @@ def segment( Throttling rate for insertion, by default 50. Only used with non-sqlite databases. image : Optional[ArrayLike], optional - Image array of shape (T, (Z), Y, X) for segments properties, by default None. + Image array of shape (T, (Z), Y, X, (C)) for segments properties, by default None. + Channel and Z dimensions are optional. properties : Optional[List[str]], optional - List of properties to compute for each segment, by default None. + List of properties to compute for each segment, see skimage.measure.regionprops documentation. """ LOG.info(f"Adding nodes with SegmentationConfig:\n{config.segmentation_config}") From ad5673a9a8415ebf5fa10253a581be788f06db61 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 18:09:27 -0400 Subject: [PATCH 31/45] examples of using node features --- examples/node_features.py | 89 +++++++++++++++++++++++++ ultrack/core/linking/__init__.py | 1 + ultrack/core/linking/processing.py | 24 +++---- ultrack/core/segmentation/processing.py | 6 +- ultrack/core/tracker.py | 30 ++++++--- 5 files changed, 126 insertions(+), 24 deletions(-) create mode 100644 examples/node_features.py diff --git a/examples/node_features.py b/examples/node_features.py new file mode 100644 index 0000000..7af26a5 --- /dev/null +++ b/examples/node_features.py @@ -0,0 +1,89 @@ +import napari +import numpy as np +from scipy.spatial.distance import cdist +from skimage import morphology as morph +from skimage.data import cells3d + +from ultrack import MainConfig, Tracker + + +def main() -> None: + + config = MainConfig() + config.data_config.working_dir = "/tmp/ultrack/." + + # removing small segments + config.segmentation_config.min_area = 1_000 + # disable division + config.tracking_config.division_weight = -1_000_000 + + tracker = Tracker(config) + + # mocking a 3D image as 2D video + image = cells3d()[:, 1] # nuclei + + foreground = image > image.mean() + foreground = morph.opening(foreground, morph.disk(3)[None, :]) + foreground = morph.closing(foreground, morph.disk(3)[None, :]) + + contour = 1 - image / image.max() + + tracker.segment( + foreground=foreground, + contours=contour, + image=image, + properties=["equivalent_diameter_area", "intensity_mean", "inertia_tensor"], + overwrite=True, + ) + + df = tracker.get_nodes_features() + + # extending properties, some include -0-0, -0-1, -1-0, -1-1 + cols = [ + "y", + "x", + "area", + "intensity_mean", + "inertia_tensor-0-0", + "inertia_tensor-0-1", + "inertia_tensor-1-0", + "inertia_tensor-1-1", + "equivalent_diameter_area", + ] + + df[cols] -= df[cols].mean() + df[cols] /= df[cols].std() + + df_by_t = df.groupby("t") + t_max = df["t"].max() + + for t in range(t_max + 1): + try: + source_df = df_by_t.get_group(t) + target_df = df_by_t.get_group(t + 1) + except KeyError: + continue + + # the higher the weights the more likely the link + weights = 1 - cdist(source_df[cols], target_df[cols], "cosine").ravel() + + source_ids = np.repeat(source_df.index.to_numpy(), len(target_df)) + target_ids = np.tile(target_df.index.to_numpy(), len(source_df)) + + # very dense graph, not recommended, select k-nearest neighbors + tracker.add_links(sources=source_ids, targets=target_ids, weights=weights) + + tracker.solve() + tracks, graph = tracker.to_tracks_layer() + segments = tracker.to_zarr() + + viewer = napari.Viewer() + viewer.add_image(image, name="cells") + viewer.add_tracks(tracks[["track_id", "t", "y", "x"]], name="tracks", graph=graph) + viewer.add_labels(segments, name="segments") + + napari.run() + + +if __name__ == "__main__": + main() diff --git a/ultrack/core/linking/__init__.py b/ultrack/core/linking/__init__.py index e69de29..6623968 100644 --- a/ultrack/core/linking/__init__.py +++ b/ultrack/core/linking/__init__.py @@ -0,0 +1 @@ +from ultrack.core.linking.processing import add_links diff --git a/ultrack/core/linking/processing.py b/ultrack/core/linking/processing.py index dcc5fb8..f63857f 100644 --- a/ultrack/core/linking/processing.py +++ b/ultrack/core/linking/processing.py @@ -234,9 +234,9 @@ def link( def add_links( config: MainConfig, - source: ArrayLike, - target: ArrayLike, - weight: ArrayLike, + sources: ArrayLike, + targets: ArrayLike, + weights: ArrayLike, ) -> None: """ Adds user-defined links to the database. @@ -245,18 +245,18 @@ def add_links( ---------- config : MainConfig Configuration parameters. - source : ArrayLike - Source (t) node id. - target : ArrayLike - Target (t + 1) node id. - weight : ArrayLike - Link weight. + sources : ArrayLike + Sources (t) node id. + targets : ArrayLike + Targets (t + 1) node id. + weights : ArrayLike + Link weights, the higher the weight the more likely the link. """ df = pd.DataFrame( { - "source_id": np.asarray(source, dtype=int), - "target_id": np.asarray(target, dtype=int), - "weight": weight, + "source_id": np.asarray(sources, dtype=int), + "target_id": np.asarray(targets, dtype=int), + "weight": weights, } ) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index bdf9bac..30828f0 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -304,7 +304,9 @@ def _process( ) return else: - raise e + raise ValueError( + "Duplicated nodes found. Set `overwrite=True` to overwrite existing data." + ) from e # pushes any remaning data with write_lock if write_lock is not None else nullcontext(): @@ -347,7 +349,7 @@ def _get_properties_names( else: dummy_image = np.ones((4,) * (image.ndim - 1), dtype=np.float32) - dummy_labels = np.zeros(shape, dtype=np.uint32) + dummy_labels = np.zeros((4,) * len(shape), dtype=np.uint32) dummy_labels[:2, :2] = 1 data_dict = regionprops_table(dummy_labels, dummy_image, properties=properties) diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index 2e23bac..33cf9f8 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -16,7 +16,7 @@ tracks_layer_to_trackmate, tracks_to_zarr, ) -from ultrack.core.linking.processing import link +from ultrack.core.linking.processing import add_links, link from ultrack.core.main import track from ultrack.core.segmentation.processing import get_nodes_features, segment from ultrack.core.solve.processing import solve @@ -70,32 +70,31 @@ def __init__(self, config: MainConfig) -> None: @rename_argument("edges", "contours") def segment(self, foreground: ArrayLike, contours: ArrayLike, **kwargs) -> None: segment(foreground=foreground, contours=contours, config=self.config, **kwargs) - self.status = TrackerStatus.SEGMENTED + self.status &= ~TrackerStatus.NOT_COMPUTED + self.status |= TrackerStatus.SEGMENTED @functools.wraps(add_flow) def add_flow(self, vector_field: ArrayLike) -> None: - if TrackerStatus.SEGMENTED not in self.status: - raise ValueError("You must call `segment` before calling `add_flow`.") + self._assert_segmented("add_flow") add_flow(config=self.config, vector_field=vector_field) @functools.wraps(link) def link(self, *args, **kwargs) -> None: - if TrackerStatus.SEGMENTED not in self.status: - raise ValueError("You must call `segment` before calling `link`.") + self._assert_segmented("link") link(config=self.config, *args, **kwargs) - self.status = TrackerStatus.LINKED + self.status |= TrackerStatus.LINKED @functools.wraps(solve) def solve(self, *args, **kwargs) -> None: if TrackerStatus.LINKED not in self.status: raise ValueError("You must call `segment` & `link` before calling `solve`.") solve(config=self.config, *args, **kwargs) - self.status = TrackerStatus.SOLVED + self.status |= TrackerStatus.SOLVED @functools.wraps(track) def track(self, *args, **kwargs) -> None: track(config=self.config, *args, **kwargs) - self.status = TrackerStatus.SOLVED + self.status |= TrackerStatus.SOLVED def _assert_solved(self) -> None: """Raise an error if the tracking is not solved.""" @@ -105,6 +104,11 @@ def _assert_solved(self) -> None: "called `segment` &a `link` & `solve` or `track`." ) + def _assert_segmented(self, method_name: str) -> None: + """Raise an error if segmentation is not done.""" + if TrackerStatus.SEGMENTED not in self.status: + raise ValueError(f"You must call `segment` before calling `{method_name}`.") + @functools.wraps(tracks_layer_to_networkx) def to_networkx( self, *, tracks_df: Optional[pd.DataFrame] = None, **kwargs @@ -156,6 +160,12 @@ def export_by_extension(self, filename: str, overwrite: bool = False) -> None: @functools.wraps(get_nodes_features) def get_nodes_features(self, **kwargs) -> pd.DataFrame: - self._assert_solved() + self._assert_segmented("get_nodes_features") nodes_features_df = get_nodes_features(self.config, **kwargs) return nodes_features_df + + @functools.wraps(add_links) + def add_links(self, **kwargs) -> None: + self._assert_segmented("add_links") + add_links(config=self.config, **kwargs) + self.status |= TrackerStatus.LINKED From e36082c5f515ea77e38aab26ecd54d161150db29 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 1 Sep 2024 18:10:08 -0400 Subject: [PATCH 32/45] added CTC usage --- examples/node_features.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/node_features.py b/examples/node_features.py index 7af26a5..442fd5d 100644 --- a/examples/node_features.py +++ b/examples/node_features.py @@ -74,6 +74,10 @@ def main() -> None: tracker.add_links(sources=source_ids, targets=target_ids, weights=weights) tracker.solve() + + # for CTC use this + # tracker.to_ctc() + tracks, graph = tracker.to_tracks_layer() segments = tracker.to_zarr() From 967ce940b122d522c69708b1d42d70ca3018c696 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 3 Sep 2024 15:35:01 -0400 Subject: [PATCH 33/45] adding ground-truth database --- ultrack/core/database.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 5d2c86b..b1110ec 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -112,6 +112,16 @@ class LinkDB(Base): annotation = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) +class GroundTruthDB(Base): + __tablename__ = "ground_truth" + t = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) + node_id = Column(BigInteger, ForeignKey(f"{NodeDB.__tablename__}.id")) + weight = Column(Float) + pickle = Column(MaybePickleType) + label = Column(Integer) + + def maximum_time_from_database(data_config: DataConfig) -> int: """Returns the maximum `t` found in the `NodesDB`.""" engine = sqla.create_engine(data_config.database_path) From c0a1001779c000b253e585f01f99bbf4bef0fa54 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 3 Sep 2024 16:21:13 -0400 Subject: [PATCH 34/45] improving linking codebase --- ultrack/core/linking/processing.py | 81 +++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/ultrack/core/linking/processing.py b/ultrack/core/linking/processing.py index f63857f..f461133 100644 --- a/ultrack/core/linking/processing.py +++ b/ultrack/core/linking/processing.py @@ -48,6 +48,61 @@ def _compute_features( ] +def color_filtering_mask( + time: int, + current_nodes: List[Node], + next_nodes: List[Node], + images: Sequence[ArrayLike], + neighbors: ArrayLike, + z_score_threshold: float, +) -> ArrayLike: + """ + Filtering by color z-score. + + Parameters + ---------- + time : int + Current time. + current_nodes : List[Node] + List of source nodes. + next_nodes : List[Node] + List of target nodes. + images : Sequence[ArrayLike] + Sequence of images to extract color features for filtering. + neighbors : ArrayLike + Neighbors indices (current/source) for each target (next) node. + z_score_threshold : float + Z-score threshold for color filtering. + + Returns + ------- + ArrayLike + Boolean mask of neighboring nodes within color z-score threshold. + + """ + LOG.info(f"computing filtering by color z-score from t={time}") + (current_features,) = _compute_features( + time, current_nodes, images, [Node.intensity_mean] + ) + # inserting dummy value for missing neighbors + current_features = np.append( + current_features, + np.zeros((1, current_features.shape[1])), + axis=0, + ) + next_features, next_features_std = _compute_features( + time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std] + ) + LOG.info( + f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}" + ) + next_features_std[next_features_std <= 1e-6] = 1.0 + difference = next_features[:, None, ...] - current_features[neighbors] + difference /= next_features_std[:, None, ...] + filtered_by_color = np.abs(difference).max(axis=-1) <= z_score_threshold + return filtered_by_color + + @curry def _process( time: int, @@ -116,26 +171,14 @@ def _process( ) if len(images) > 0: - LOG.info(f"computing filtering by color z-score from t={time}") - (current_features,) = _compute_features( - time, current_nodes, images, [Node.intensity_mean] - ) - # inserting dummy value for missing neighbors - current_features = np.append( - current_features, - np.zeros((1, current_features.shape[1])), - axis=0, - ) - next_features, next_features_std = _compute_features( - time + 1, next_nodes, images, [Node.intensity_mean, Node.intensity_std] - ) - LOG.info( - f"Features Std. Dev. range {next_features_std.min()} {next_features_std.max()}" + filtered_by_color = color_filtering_mask( + time, + current_nodes, + next_nodes, + images, + neighbors, + config.z_score_threshold, ) - next_features_std[next_features_std <= 1e-6] = 1.0 - difference = next_features[:, None, ...] - current_features[neighbors] - difference /= next_features_std[:, None, ...] - filtered_by_color = np.abs(difference).max(axis=-1) <= config.z_score_threshold else: filtered_by_color = np.ones_like(neighbors, dtype=bool) From b572f50fe4f7b1a801b6dcc78496c319a8d076d2 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 4 Sep 2024 09:35:45 -0400 Subject: [PATCH 35/45] WIP auto memory GT linking --- ultrack/core/autotune.py | 364 +++++++++++++++--------------- ultrack/core/database.py | 20 +- ultrack/core/segmentation/node.py | 5 +- ultrack/core/tracker.py | 6 + 4 files changed, 207 insertions(+), 188 deletions(-) diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index ba57cd8..56ed230 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -1,6 +1,9 @@ import logging +import pickle +from contextlib import nullcontext from typing import Literal, Optional, Tuple +import fasteners import mip import mip.exceptions import numpy as np @@ -12,25 +15,168 @@ from toolz import curry from ultrack.config.config import MainConfig -from ultrack.core.database import LinkDB, NodeDB, OverlapDB, clear_all_data -from ultrack.core.interactive import add_new_node -from ultrack.core.linking.processing import link -from ultrack.core.segmentation.processing import segment -from ultrack.tracks.stats import estimate_drift +from ultrack.core.database import GTLinkDB, GTNodeDB, NodeDB, OverlapDB +from ultrack.core.segmentation.node import Node from ultrack.utils.multiprocessing import multiprocessing_apply LOG = logging.getLogger(__name__) -class SQLGTMatching: +def _link_gt( + time: int, + config: MainConfig, + db_path: str, + scale: Optional[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock], +) -> None: + pass + + +@curry +def _match_ground_truth_frame( + time: int, + gt_labels: ArrayLike, + config: MainConfig, + scale: Optional[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock], +) -> Tuple[pd.DataFrame, pd.DataFrame]: + # TODO + + gt_labels = np.asarray(gt_labels[time]) + gt_props = regionprops(gt_labels) + + if len(gt_props) == 0: + LOG.warning(f"No objects found in time point {time}") + + gt_nodes = [] + # adding ground-truth nodes + for obj in gt_props: + node = Node.from_mask( + node_id=obj.label, time=time, mask=obj.image, bbox=obj.bbox + ) + + gt_nodes.append( + GTNodeDB( + t=time, + label=obj.label, + pickle=pickle.dumps(node), + ) + ) + + with write_lock if write_lock is not None else nullcontext(): + engine = sqla.create_engine(config.data_config.database_path) + + with Session(engine) as session: + session.add_all(gt_nodes) + session.commit() + + engine.dispose() + + _link_gt(config, scale=scale, overwrite=False) + + # computing GT matching + gt_matcher = SQLGTMatcher(config, write_lock=write_lock) + total_score = gt_matcher() + + if len(gt_nodes) > 0: + mean_score = total_score / len(gt_nodes) + else: + mean_score = 0.0 + + LOG.info(f"time {time} total score: {total_score:0.4f}") + LOG.info(f"time {time} mean score: {mean_score:0.4f}") + + +def _get_matched_nodes_df(database_path: str) -> pd.DataFrame: + # TODO + engine = sqla.create_engine(database_path) + + with Session(engine) as session: + node_query = session.query( + NodeDB.id, + NodeDB.hier_parent_id, + NodeDB.t_hier_id, + NodeDB.area, + NodeDB.frontier, + ).where(NodeDB.t == 0) + node_df = pd.read_sql(node_query.statement, session.bind, index_col="id") + + gt_query = session.query(GTLinkDB.source_id, GTLinkDB.target_id).where( + GTLinkDB.selected + ) + gt_df = pd.read_sql(gt_query.statement, session.bind, index="source_id") + + node_df = node_df.join(gt_df) + + frontiers = node_df["frontier"] + node_df["parent_frontier"] = node_df["hier_parent_id"].map( + lambda x: frontiers.get(x, -1.0) + ) + node_df.loc[node_df["parent_frontier"] < 0, "parent_frontier"] = node_df[ + "frontier" + ].max() + + return node_df + + +def match_to_ground_truth( + config: MainConfig, + gt_labels: ArrayLike, + scale: Optional[ArrayLike] = None, +) -> pd.DataFrame: + # TODO + + multiprocessing_apply( + _match_ground_truth_frame( + gt_labels=gt_labels, + config=config, + scale=scale, + ), + range(gt_labels.shape[0]), + n_workers=config.segmentation_config.n_workers, + desc="Matching hierarchy nodes with ground-truth", + ) + + # if scale is not None: + # cols = ["z", "y", "x"][-len(scale) :] + # gt_df[cols] *= scale + + # if "z" not in gt_df.columns: + # gt_df["z"] = 0.0 + + # if len(gt_df) > 0: + # max_distance = estimate_drift(gt_df) + # if not np.isnan(max_distance) or max_distance > 0: + # config.linking_config.max_distance = max_distance + 1.0 + + # if "solution" in df.columns: + # matched_df = df[df["solution"] > 0.5] + # config.segmentation_config.min_area = matched_df["area"].min() * 0.95 + + # config.segmentation_config.max_area = matched_df["area"].max() * 1.025 + + # config.segmentation_config.min_frontier = max( + # matched_df["parent_frontier"].min() - 0.025, 0.0 + # ) + # else: + # LOG.warning("No nodes were matched. Keeping previous configuration.") + + # config.data_config.database = prev_db + + # return config, df + + +class SQLGTMatcher: def __init__( self, config: MainConfig, solver: Literal["CBC", "GUROBI", ""] = "", + write_lock: Optional[fasteners.InterProcessLock] = None, ) -> None: # TODO self._data_config = config.data_config + self._write_lock = write_lock try: self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) @@ -40,7 +186,6 @@ def __init__( def _add_nodes(self) -> None: # TODO - engine = sqla.create_engine(self._data_config.database_path) # t = 0 is hierarchies @@ -77,7 +222,9 @@ def _add_edges(self) -> None: engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: - query = session.query(LinkDB).join(NodeDB, NodeDB.id == LinkDB.source_id) + query = session.query(GTLinkDB).join( + NodeDB, NodeDB.id == GTLinkDB.source_id + ) self._edges_df = pd.read_sql(query.statement, session.bind) self._edges_df["source_id"] = self._nodes_df.index.get_indexer( @@ -105,186 +252,39 @@ def _add_edges(self) -> None: for _, group in self._edges_df.groupby("target_id", as_index=False): self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) - def __call__(self) -> Tuple[float, pd.DataFrame]: + def add_solution(self) -> None: # TODO - self._add_nodes() - self._add_edges() - self._model.optimize() - - data = [] + engine = sqla.create_engine(self._data_config.database_path) - for i, e_var in enumerate(self._edges): + edges_records = [] + for idx, e_var in zip(self._edges_df.index, self._edges): if e_var.x > 0.5: - data.append( + edges_records.append( { - "id": self._nodes_df.index.get_indexer( - self._edges_df.iloc[i]["source_id"] - ), - "gt_id": self._edges_df.iloc[i]["target_id"], + "id": idx, + "selected": e_var.x > 0.5, } ) - score = self._model.objective_value - matching_df = pd.DataFrame(data) - - return score, matching_df - - -@curry -def _tune_time_point( - t: int, - foreground: ArrayLike, - contours: ArrayLike, - gt_labels: ArrayLike, - config: MainConfig, - scale: Optional[ArrayLike], -) -> Tuple[pd.DataFrame, pd.DataFrame]: - # TODO - - config = config.copy(deep=True) - config.data_config.in_memory_db_id = t - - clear_all_data(config.data_config.database_path) - - gt_labels = np.asarray(gt_labels[t]) - gt_rows = [] - - props = regionprops(gt_labels) - - if len(props) == 0: - LOG.warning(f"No objects found in time point {t}") - - foreground = np.asarray(foreground[t]) - contours = np.asarray(contours[t]) - - # adding hierarchy nodes - segment( - foreground=foreground[None, ...], - contours=contours[None, ...], - config=config, - overwrite=False, - ) - - # adding ground-truth nodes - for obj in props: - add_new_node( - config=config, - time=1, - mask=obj.image, - bbox=obj.bbox, - index=obj.label, # _generate_id(obj.label, 1, 10_000_000), - include_overlaps=False, - ) - row = {c: v for c, v in zip("xyz", obj.centroid[::-1])} - row["track_id"] = obj.label - gt_rows.append(row) - - gt_df = pd.DataFrame.from_records(gt_rows) - gt_df["t"] = t - - # computing GT matching - link(config, scale=scale, overwrite=False) - - matching = SQLGTMatching(config) - total_score, solution_df = matching() - - if len(gt_df) > 0: - mean_score = total_score / len(gt_df) - else: - mean_score = 0.0 - - print(f"Total score: {total_score:0.4f}") - print(f"Mean score: {mean_score:0.4f}") - - engine = sqla.create_engine(config.data_config.database_path) - - with Session(engine) as session: - query = session.query( - NodeDB.id, - NodeDB.hier_parent_id, - NodeDB.t_hier_id, - NodeDB.area, - NodeDB.frontier, - ).where(NodeDB.t == 0) - - df = pd.read_sql(query.statement, session.bind, index_col="id") - - df = df.join(solution_df) - - frontiers = df["frontier"] - - df["parent_frontier"] = df["hier_parent_id"].map(lambda x: frontiers.get(x, -1.0)) - df.loc[df["parent_frontier"] < 0, "parent_frontier"] = df["frontier"].max() - - # selecting only nodes in solution - # must be after parent_frontier computation - # matched_df = df[df["solution"] > 0.5] - - # config.segmentation_config.min_frontier = matched_df["parent_frontier"].min() - # config.segmentation_config.min_area = matched_df["area"].min() - # config.segmentation_config.max_area = matched_df["area"].max() - - # config.data_config.in_memory_db_id = prev_in_memory_db_id - - return df, gt_df - - -def auto_tune_config( - foreground: ArrayLike, - contours: ArrayLike, - ground_truth_labels: ArrayLike, - config: Optional[MainConfig] = None, - scale: Optional[ArrayLike] = None, -) -> Tuple[MainConfig, pd.DataFrame]: - - if config is None: - config = MainConfig() - else: - config = config.copy(deep=True) - - prev_db = config.data_config.database - config.data_config.database = "memory" - - tuning_tup = multiprocessing_apply( - _tune_time_point( - foreground=foreground, - contours=contours, - gt_labels=ground_truth_labels, - config=config, - scale=scale, - ), - range(foreground.shape[0]), - n_workers=config.segmentation_config.n_workers, - desc="Auto-tuning individual time points", - ) - tuning_tup = tuple(zip(*tuning_tup)) - df = pd.concat(tuning_tup[0], ignore_index=True) - gt_df = pd.concat(tuning_tup[1], ignore_index=True) - - if scale is not None: - cols = ["z", "y", "x"][-len(scale) :] - gt_df[cols] *= scale - - if "z" not in gt_df.columns: - gt_df["z"] = 0.0 - - if len(gt_df) > 0: - max_distance = estimate_drift(gt_df) - if not np.isnan(max_distance) or max_distance > 0: - config.linking_config.max_distance = max_distance + 1.0 - - if "solution" in df.columns: - matched_df = df[df["solution"] > 0.5] - config.segmentation_config.min_area = matched_df["area"].min() * 0.95 - - config.segmentation_config.max_area = matched_df["area"].max() * 1.025 - - config.segmentation_config.min_frontier = max( - matched_df["parent_frontier"].min() - 0.025, 0.0 - ) - else: - LOG.warning("No nodes were matched. Keeping previous configuration.") + with self._write_lock if self._write_lock is not None else nullcontext(): + with Session(engine) as session: + stmt = ( + sqla.update(GTLinkDB) + .where(GTLinkDB.id == sqla.bindparam("id")) + .values(selected=sqla.bindparam("selected")) + ) + session.connection().execute( + stmt, + edges_records, + execution_options={"synchronize_session": False}, + ) + session.commit() - config.data_config.database = prev_db + def __call__(self) -> float: + # TODO + self._add_nodes() + self._add_edges() + self._model.optimize() + self.add_solution() - return config, df + return self._model.objective_value diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 3787a40..3af2b3c 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -118,14 +118,24 @@ class LinkDB(Base): annotation = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) -class GroundTruthDB(Base): - __tablename__ = "ground_truth" +class GTNodeDB(Base): + __tablename__ = "gt_nodes" t = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True, autoincrement=True) - node_id = Column(BigInteger, ForeignKey(f"{NodeDB.__tablename__}.id")) - weight = Column(Float) - pickle = Column(MaybePickleType) label = Column(Integer) + pickle = Column(MaybePickleType) + z = Column(Float) + y = Column(Float) + x = Column(Float) + + +class GTLinkDB(Base): + __tablename__ = "gt_links" + id = Column(Integer, primary_key=True, autoincrement=True) + source_id = Column(BigInteger, ForeignKey(f"{NodeDB.__tablename__}.id")) + target_id = Column(BigInteger, ForeignKey(f"{GTNodeDB.__tablename__}.id")) + weight = Column(Float) + selected = Column(Boolean, default=False) def maximum_time_from_database(data_config: DataConfig) -> int: diff --git a/ultrack/core/segmentation/node.py b/ultrack/core/segmentation/node.py index b58880f..0aaac3d 100644 --- a/ultrack/core/segmentation/node.py +++ b/ultrack/core/segmentation/node.py @@ -280,6 +280,7 @@ def from_mask( time: int, mask: ArrayLike, bbox: Optional[ArrayLike] = None, + node_id: int = -1, ) -> "Node": """ Create a new node from a mask. @@ -293,6 +294,8 @@ def from_mask( bbox : Optional[ArrayLike], optional Bounding box of the node, (min_0, min_1, ..., max_0, max_1, ...). When provided it assumes the mask is a crop of the original image, by default None + node_id : int, optional + Node ID, by default -1 Returns ------- @@ -303,7 +306,7 @@ def from_mask( if mask.dtype != bool: raise ValueError(f"Mask should be a boolean array. Found {mask.dtype}") - node = Node(h_node_index=-1, time=time, parent=None) + node = Node(h_node_index=-1, id=node_id, time=time, parent=None) if bbox is None: bbox = ndi.find_objects(mask)[0] diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index 33cf9f8..d1c6a77 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -9,6 +9,7 @@ from ultrack import export_tracks_by_extension from ultrack.config import MainConfig +from ultrack.core.autotune import match_to_ground_truth from ultrack.core.export import ( to_ctc, to_tracks_layer, @@ -169,3 +170,8 @@ def add_links(self, **kwargs) -> None: self._assert_segmented("add_links") add_links(config=self.config, **kwargs) self.status |= TrackerStatus.LINKED + + @functools.wraps(match_to_ground_truth) + def match_to_ground_truth(self, **kwargs) -> None: + self._assert_segmented("match_to_ground_truth") + match_to_ground_truth(config=self.config, **kwargs) From 9c78385f3f62850f678fb477c8cc16450c0a5cb8 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 4 Sep 2024 11:35:52 -0400 Subject: [PATCH 36/45] bad duplicated link code --- ultrack/core/autotune.py | 103 ++++++++++++++++++++++++++++++++++----- 1 file changed, 91 insertions(+), 12 deletions(-) diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py index 56ed230..b8d33fa 100644 --- a/ultrack/core/autotune.py +++ b/ultrack/core/autotune.py @@ -1,7 +1,6 @@ import logging -import pickle from contextlib import nullcontext -from typing import Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple import fasteners import mip @@ -10,6 +9,7 @@ import pandas as pd import sqlalchemy as sqla from numpy.typing import ArrayLike +from scipy.spatial import KDTree from skimage.measure import regionprops from sqlalchemy.orm import Session from toolz import curry @@ -24,12 +24,84 @@ def _link_gt( time: int, + gt_nodes: List[Node], config: MainConfig, - db_path: str, scale: Optional[ArrayLike], write_lock: Optional[fasteners.InterProcessLock], -) -> None: - pass +) -> pd.DataFrame: + + if len(gt_nodes) == 0: + LOG.warn(f"No ground-truth nodes found at {time}") + return + + db_path = config.data_config.database_path + + engine = sqla.create_engine(db_path) + with Session(engine) as session: + h_nodes = [n for n, in session.query(NodeDB.id).where(NodeDB.t == time)] + + h_nodes_pos = np.asarray([n.centroids for n in h_nodes]) + gt_pos = np.asarray([n.centroids for n in gt_nodes]) + + n_dim = h_nodes_pos.shape[-1] + + if scale is not None: + min_n_dim = min(n_dim, len(scale)) + scale = scale[-min_n_dim:] + h_nodes_pos = h_nodes_pos[..., -min_n_dim:] * scale + gt_pos = gt_pos[..., -min_n_dim:] * scale + + # finds neighbors nodes within the radius + # and connect the pairs with highest edge weight + current_kdtree = KDTree(h_nodes_pos) + + distances, neighbors = current_kdtree.query( + gt_pos, + # twice as expected because we select the nearest with highest edge weight + k=2 * config.linking_config.max_neighbors, + distance_upper_bound=config.linking_config.max_distance, + ) + + gt_links = [] + + for i, node in enumerate(gt_nodes): + valid = ~np.isinf(distances[i]) + valid_neighbors = neighbors[i, valid] + neigh_distances = distances[i, valid] + + neighborhood = [] + for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances): + neigh = h_nodes[neigh_idx] + edge_weight = node.IoU(neigh) + # using dist as a tie-breaker + neighborhood.append( + (edge_weight, -neigh_dist, neigh.id, node.id) + ) # current, next + + neighborhood = sorted(neighborhood, reverse=True)[: config.max_neighbors] + LOG.info("Node %s links %s", node.id, neighborhood) + gt_links += neighborhood + + if len(gt_links) == 0: + raise ValueError( + f"No links found for time {time}. Increase `linking_config.max_distance` parameter." + ) + + gt_links = np.asarray(gt_links)[:, [0, 2, 3]] + df = pd.DataFrame(gt_links, columns=["weight", "source_id", "target_id"]) + + with write_lock if write_lock is not None else nullcontext(): + LOG.info(f"Pushing gt links from time {time} to {db_path}") + engine = sqla.create_engine( + db_path, + hide_parameters=True, + ) + with engine.begin() as conn: + df.to_sql( + name=GTLinkDB.__tablename__, con=conn, if_exists="append", index=False + ) + + return df @curry @@ -48,6 +120,7 @@ def _match_ground_truth_frame( if len(gt_props) == 0: LOG.warning(f"No objects found in time point {time}") + gt_db_rows = [] gt_nodes = [] # adding ground-truth nodes for obj in gt_props: @@ -55,19 +128,20 @@ def _match_ground_truth_frame( node_id=obj.label, time=time, mask=obj.image, bbox=obj.bbox ) - gt_nodes.append( + gt_db_rows.append( GTNodeDB( t=time, label=obj.label, - pickle=pickle.dumps(node), + pickle=node, ) ) + gt_nodes.append(node) with write_lock if write_lock is not None else nullcontext(): engine = sqla.create_engine(config.data_config.database_path) with Session(engine) as session: - session.add_all(gt_nodes) + session.add_all(gt_db_rows) session.commit() engine.dispose() @@ -78,8 +152,8 @@ def _match_ground_truth_frame( gt_matcher = SQLGTMatcher(config, write_lock=write_lock) total_score = gt_matcher() - if len(gt_nodes) > 0: - mean_score = total_score / len(gt_nodes) + if len(gt_db_rows) > 0: + mean_score = total_score / len(gt_db_rows) else: mean_score = 0.0 @@ -172,11 +246,13 @@ def __init__( config: MainConfig, solver: Literal["CBC", "GUROBI", ""] = "", write_lock: Optional[fasteners.InterProcessLock] = None, + eps=1e-3, ) -> None: # TODO self._data_config = config.data_config self._write_lock = write_lock + self._eps = eps try: self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) @@ -237,9 +313,10 @@ def _add_edges(self) -> None: name="edges", var_type=mip.BINARY, ) + # small value to prefer not selecting edges than bad ones # setting objective function self._model.objective = mip.xsum( - self._edges_df["weight"].to_numpy() * self._edges + (self._edges_df["weight"].to_numpy() - self._eps) * self._edges ) # source_id is time point T (hierarchies id) @@ -287,4 +364,6 @@ def __call__(self) -> float: self._model.optimize() self.add_solution() - return self._model.objective_value + n_selected_vars = sum(e_var.x > 0.5 for e_var in self._edges) + + return self._model.objective_value + n_selected_vars * self._eps From f8109ea52fa7c873ded897c9f7d1cb32af5f1664 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 4 Sep 2024 15:26:39 -0400 Subject: [PATCH 37/45] WIP working version of ground-truth matching, missing documentation --- ultrack/core/autotune.py | 369 ----------------------------- ultrack/core/database.py | 2 +- ultrack/core/gt_matching.py | 211 +++++++++++++++++ ultrack/core/linking/processing.py | 64 +++-- ultrack/core/segmentation/node.py | 2 + ultrack/core/solve/sqlgtmatcher.py | 156 ++++++++++++ ultrack/core/tracker.py | 6 +- 7 files changed, 419 insertions(+), 391 deletions(-) delete mode 100644 ultrack/core/autotune.py create mode 100644 ultrack/core/gt_matching.py create mode 100644 ultrack/core/solve/sqlgtmatcher.py diff --git a/ultrack/core/autotune.py b/ultrack/core/autotune.py deleted file mode 100644 index b8d33fa..0000000 --- a/ultrack/core/autotune.py +++ /dev/null @@ -1,369 +0,0 @@ -import logging -from contextlib import nullcontext -from typing import List, Literal, Optional, Tuple - -import fasteners -import mip -import mip.exceptions -import numpy as np -import pandas as pd -import sqlalchemy as sqla -from numpy.typing import ArrayLike -from scipy.spatial import KDTree -from skimage.measure import regionprops -from sqlalchemy.orm import Session -from toolz import curry - -from ultrack.config.config import MainConfig -from ultrack.core.database import GTLinkDB, GTNodeDB, NodeDB, OverlapDB -from ultrack.core.segmentation.node import Node -from ultrack.utils.multiprocessing import multiprocessing_apply - -LOG = logging.getLogger(__name__) - - -def _link_gt( - time: int, - gt_nodes: List[Node], - config: MainConfig, - scale: Optional[ArrayLike], - write_lock: Optional[fasteners.InterProcessLock], -) -> pd.DataFrame: - - if len(gt_nodes) == 0: - LOG.warn(f"No ground-truth nodes found at {time}") - return - - db_path = config.data_config.database_path - - engine = sqla.create_engine(db_path) - with Session(engine) as session: - h_nodes = [n for n, in session.query(NodeDB.id).where(NodeDB.t == time)] - - h_nodes_pos = np.asarray([n.centroids for n in h_nodes]) - gt_pos = np.asarray([n.centroids for n in gt_nodes]) - - n_dim = h_nodes_pos.shape[-1] - - if scale is not None: - min_n_dim = min(n_dim, len(scale)) - scale = scale[-min_n_dim:] - h_nodes_pos = h_nodes_pos[..., -min_n_dim:] * scale - gt_pos = gt_pos[..., -min_n_dim:] * scale - - # finds neighbors nodes within the radius - # and connect the pairs with highest edge weight - current_kdtree = KDTree(h_nodes_pos) - - distances, neighbors = current_kdtree.query( - gt_pos, - # twice as expected because we select the nearest with highest edge weight - k=2 * config.linking_config.max_neighbors, - distance_upper_bound=config.linking_config.max_distance, - ) - - gt_links = [] - - for i, node in enumerate(gt_nodes): - valid = ~np.isinf(distances[i]) - valid_neighbors = neighbors[i, valid] - neigh_distances = distances[i, valid] - - neighborhood = [] - for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances): - neigh = h_nodes[neigh_idx] - edge_weight = node.IoU(neigh) - # using dist as a tie-breaker - neighborhood.append( - (edge_weight, -neigh_dist, neigh.id, node.id) - ) # current, next - - neighborhood = sorted(neighborhood, reverse=True)[: config.max_neighbors] - LOG.info("Node %s links %s", node.id, neighborhood) - gt_links += neighborhood - - if len(gt_links) == 0: - raise ValueError( - f"No links found for time {time}. Increase `linking_config.max_distance` parameter." - ) - - gt_links = np.asarray(gt_links)[:, [0, 2, 3]] - df = pd.DataFrame(gt_links, columns=["weight", "source_id", "target_id"]) - - with write_lock if write_lock is not None else nullcontext(): - LOG.info(f"Pushing gt links from time {time} to {db_path}") - engine = sqla.create_engine( - db_path, - hide_parameters=True, - ) - with engine.begin() as conn: - df.to_sql( - name=GTLinkDB.__tablename__, con=conn, if_exists="append", index=False - ) - - return df - - -@curry -def _match_ground_truth_frame( - time: int, - gt_labels: ArrayLike, - config: MainConfig, - scale: Optional[ArrayLike], - write_lock: Optional[fasteners.InterProcessLock], -) -> Tuple[pd.DataFrame, pd.DataFrame]: - # TODO - - gt_labels = np.asarray(gt_labels[time]) - gt_props = regionprops(gt_labels) - - if len(gt_props) == 0: - LOG.warning(f"No objects found in time point {time}") - - gt_db_rows = [] - gt_nodes = [] - # adding ground-truth nodes - for obj in gt_props: - node = Node.from_mask( - node_id=obj.label, time=time, mask=obj.image, bbox=obj.bbox - ) - - gt_db_rows.append( - GTNodeDB( - t=time, - label=obj.label, - pickle=node, - ) - ) - gt_nodes.append(node) - - with write_lock if write_lock is not None else nullcontext(): - engine = sqla.create_engine(config.data_config.database_path) - - with Session(engine) as session: - session.add_all(gt_db_rows) - session.commit() - - engine.dispose() - - _link_gt(config, scale=scale, overwrite=False) - - # computing GT matching - gt_matcher = SQLGTMatcher(config, write_lock=write_lock) - total_score = gt_matcher() - - if len(gt_db_rows) > 0: - mean_score = total_score / len(gt_db_rows) - else: - mean_score = 0.0 - - LOG.info(f"time {time} total score: {total_score:0.4f}") - LOG.info(f"time {time} mean score: {mean_score:0.4f}") - - -def _get_matched_nodes_df(database_path: str) -> pd.DataFrame: - # TODO - engine = sqla.create_engine(database_path) - - with Session(engine) as session: - node_query = session.query( - NodeDB.id, - NodeDB.hier_parent_id, - NodeDB.t_hier_id, - NodeDB.area, - NodeDB.frontier, - ).where(NodeDB.t == 0) - node_df = pd.read_sql(node_query.statement, session.bind, index_col="id") - - gt_query = session.query(GTLinkDB.source_id, GTLinkDB.target_id).where( - GTLinkDB.selected - ) - gt_df = pd.read_sql(gt_query.statement, session.bind, index="source_id") - - node_df = node_df.join(gt_df) - - frontiers = node_df["frontier"] - node_df["parent_frontier"] = node_df["hier_parent_id"].map( - lambda x: frontiers.get(x, -1.0) - ) - node_df.loc[node_df["parent_frontier"] < 0, "parent_frontier"] = node_df[ - "frontier" - ].max() - - return node_df - - -def match_to_ground_truth( - config: MainConfig, - gt_labels: ArrayLike, - scale: Optional[ArrayLike] = None, -) -> pd.DataFrame: - # TODO - - multiprocessing_apply( - _match_ground_truth_frame( - gt_labels=gt_labels, - config=config, - scale=scale, - ), - range(gt_labels.shape[0]), - n_workers=config.segmentation_config.n_workers, - desc="Matching hierarchy nodes with ground-truth", - ) - - # if scale is not None: - # cols = ["z", "y", "x"][-len(scale) :] - # gt_df[cols] *= scale - - # if "z" not in gt_df.columns: - # gt_df["z"] = 0.0 - - # if len(gt_df) > 0: - # max_distance = estimate_drift(gt_df) - # if not np.isnan(max_distance) or max_distance > 0: - # config.linking_config.max_distance = max_distance + 1.0 - - # if "solution" in df.columns: - # matched_df = df[df["solution"] > 0.5] - # config.segmentation_config.min_area = matched_df["area"].min() * 0.95 - - # config.segmentation_config.max_area = matched_df["area"].max() * 1.025 - - # config.segmentation_config.min_frontier = max( - # matched_df["parent_frontier"].min() - 0.025, 0.0 - # ) - # else: - # LOG.warning("No nodes were matched. Keeping previous configuration.") - - # config.data_config.database = prev_db - - # return config, df - - -class SQLGTMatcher: - def __init__( - self, - config: MainConfig, - solver: Literal["CBC", "GUROBI", ""] = "", - write_lock: Optional[fasteners.InterProcessLock] = None, - eps=1e-3, - ) -> None: - # TODO - - self._data_config = config.data_config - self._write_lock = write_lock - self._eps = eps - - try: - self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) - except mip.exceptions.InterfacingError as e: - LOG.warning(e) - self._model = mip.Model(sense=mip.MAXIMIZE, solver_name="CBC") - - def _add_nodes(self) -> None: - # TODO - engine = sqla.create_engine(self._data_config.database_path) - - # t = 0 is hierarchies - # t = 1 is ground-truth nodes - with Session(engine) as session: - query = session.query(NodeDB.id, NodeDB.t).where(NodeDB.t == 0) - self._nodes_df = pd.read_sql(query.statement, session.bind, index_col="id") - - size = len(self._nodes_df) - self._nodes = self._model.add_var_tensor( - (size,), name="nodes", var_type=mip.BINARY - ) - - # hierarchy overlap constraints - with Session(engine) as session: - query = session.query(OverlapDB).join( - NodeDB, NodeDB.id == OverlapDB.node_id - ) - overlap_df = pd.read_sql(query.statement, session.bind) - - overlap_df["node_id"] = self._nodes_df.index.get_indexer(overlap_df["node_id"]) - overlap_df["ancestor_id"] = self._nodes_df.index.get_indexer( - overlap_df["ancestor_id"] - ) - - for node_id, anc_id in zip(overlap_df["node_id"], overlap_df["ancestor_id"]): - self._model.add_constr(self._nodes[node_id] + self._nodes[anc_id] <= 1) - - def _add_edges(self) -> None: - # TODO - - if not hasattr(self, "_nodes"): - raise ValueError("Nodes must be added before adding edges.") - - engine = sqla.create_engine(self._data_config.database_path) - with Session(engine) as session: - query = session.query(GTLinkDB).join( - NodeDB, NodeDB.id == GTLinkDB.source_id - ) - self._edges_df = pd.read_sql(query.statement, session.bind) - - self._edges_df["source_id"] = self._nodes_df.index.get_indexer( - self._edges_df["source_id"] - ) - self._edges_df.reset_index(drop=True, inplace=True) - - self._edges = self._model.add_var_tensor( - (len(self._edges_df),), - name="edges", - var_type=mip.BINARY, - ) - # small value to prefer not selecting edges than bad ones - # setting objective function - self._model.objective = mip.xsum( - (self._edges_df["weight"].to_numpy() - self._eps) * self._edges - ) - - # source_id is time point T (hierarchies id) - # target_id is time point T+1 (ground-truth) - for source_id, group in self._edges_df.groupby("source_id", as_index=False): - self._model.add_constr( - self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) - ) - - for _, group in self._edges_df.groupby("target_id", as_index=False): - self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) - - def add_solution(self) -> None: - # TODO - engine = sqla.create_engine(self._data_config.database_path) - - edges_records = [] - for idx, e_var in zip(self._edges_df.index, self._edges): - if e_var.x > 0.5: - edges_records.append( - { - "id": idx, - "selected": e_var.x > 0.5, - } - ) - - with self._write_lock if self._write_lock is not None else nullcontext(): - with Session(engine) as session: - stmt = ( - sqla.update(GTLinkDB) - .where(GTLinkDB.id == sqla.bindparam("id")) - .values(selected=sqla.bindparam("selected")) - ) - session.connection().execute( - stmt, - edges_records, - execution_options={"synchronize_session": False}, - ) - session.commit() - - def __call__(self) -> float: - # TODO - self._add_nodes() - self._add_edges() - self._model.optimize() - self.add_solution() - - n_selected_vars = sum(e_var.x > 0.5 for e_var in self._edges) - - return self._model.objective_value + n_selected_vars * self._eps diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 3af2b3c..d550827 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -120,8 +120,8 @@ class LinkDB(Base): class GTNodeDB(Base): __tablename__ = "gt_nodes" - t = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True, autoincrement=True) + t = Column(Integer) label = Column(Integer) pickle = Column(MaybePickleType) z = Column(Float) diff --git a/ultrack/core/gt_matching.py b/ultrack/core/gt_matching.py new file mode 100644 index 0000000..7bcecc9 --- /dev/null +++ b/ultrack/core/gt_matching.py @@ -0,0 +1,211 @@ +import logging +from contextlib import nullcontext +from typing import Optional, Tuple + +import fasteners +import numpy as np +import pandas as pd +import sqlalchemy as sqla +from numpy.typing import ArrayLike +from skimage.measure import regionprops +from sqlalchemy.orm import Session +from toolz import curry + +from ultrack.config.config import MainConfig +from ultrack.core.database import NO_PARENT, GTLinkDB, GTNodeDB, NodeDB +from ultrack.core.linking.processing import compute_spatial_neighbors +from ultrack.core.segmentation.node import Node +from ultrack.core.solve.sqlgtmatcher import SQLGTMatcher +from ultrack.utils.multiprocessing import ( + multiprocessing_apply, + multiprocessing_sqlite_lock, +) + +LOG = logging.getLogger(__name__) + + +@curry +def _match_ground_truth_frame( + time: int, + gt_labels: ArrayLike, + config: MainConfig, + scale: Optional[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock], +) -> Tuple[pd.DataFrame, pd.DataFrame]: + # TODO + + gt_labels = np.asarray(gt_labels[time]) + gt_props = regionprops(gt_labels) + + if len(gt_props) == 0: + LOG.warning(f"No objects found in time point {time}") + return + + LOG.info(f"Found {len(gt_props)} objects in time point {time}") + + gt_db_rows = [] + gt_nodes = [] + # adding ground-truth nodes + for obj in gt_props: + node = Node.from_mask( + node_id=obj.label, time=time, mask=obj.image, bbox=obj.bbox + ) + + if len(node.centroid) == 2: + y, x = node.centroid + z = 0 + else: + z, y, x = node.centroid + + gt_db_rows.append( + GTNodeDB( + t=time, + label=obj.label, + pickle=node, + z=z, + y=y, + x=x, + ) + ) + gt_nodes.append(node) + + with write_lock if write_lock is not None else nullcontext(): + engine = sqla.create_engine(config.data_config.database_path) + + with Session(engine) as session: + session.add_all(gt_db_rows) + session.commit() + + source_nodes = [ + n for n, in session.query(NodeDB.pickle).where(NodeDB.t == time) + ] + + engine.dispose() + + compute_spatial_neighbors( + time, + config=config.linking_config, + source_nodes=source_nodes, + target_nodes=gt_nodes, + target_shift=np.zeros((len(gt_nodes), 3), dtype=np.float32), + table_name=GTLinkDB.__tablename__, + db_path=config.data_config.database_path, + scale=scale, + images=[], + write_lock=write_lock, + ) + + # computing GT matching + gt_matcher = SQLGTMatcher(config, write_lock=write_lock) + total_score = gt_matcher(time=time) + + if len(gt_db_rows) > 0: + mean_score = total_score / len(gt_db_rows) + else: + mean_score = 0.0 + + LOG.info(f"time {time} total score: {total_score:0.4f}") + LOG.info(f"time {time} mean score: {mean_score:0.4f}") + + +def _get_matched_nodes_df(database_path: str) -> pd.DataFrame: + # TODO + engine = sqla.create_engine(database_path) + + with Session(engine) as session: + node_query = session.query( + NodeDB.id, + NodeDB.hier_parent_id, + NodeDB.t_hier_id, + # NodeDB.area, + # NodeDB.frontier, + ) + node_df = pd.read_sql(node_query.statement, session.bind, index_col="id") + + gt_edge_query = ( + session.query( + GTLinkDB.source_id, + GTLinkDB.target_id, + GTNodeDB.z, + GTNodeDB.y, + GTNodeDB.x, + ) + .join(GTNodeDB, GTNodeDB.id == GTLinkDB.target_id) + .where(GTLinkDB.selected) + ) + gt_df = pd.read_sql( + gt_edge_query.statement, session.bind, index_col="source_id" + ) + gt_df.rename( + # columns={"target_id": "gt_id"}, # , "z": "gt_z", "y": "gt_y", "x": "gt_x"}, + columns={"target_id": "gt_id", "z": "gt_z", "y": "gt_y", "x": "gt_x"}, + inplace=True, + ) + + LOG.info(f"Found {len(node_df)} nodes and {len(gt_df)} ground-truth links") + + node_df = node_df.join(gt_df) + node_df["gt_id"] = node_df["gt_id"].fillna(NO_PARENT).astype(int) + + # frontiers = node_df["frontier"] + # node_df["parent_frontier"] = node_df["hier_parent_id"].map( + # lambda x: frontiers.get(x, -1.0) + # ) + # node_df.loc[node_df["parent_frontier"] < 0, "parent_frontier"] = node_df[ + # "frontier" + # ].max() + + return node_df + + +def match_to_ground_truth( + config: MainConfig, + gt_labels: ArrayLike, + scale: Optional[ArrayLike] = None, +) -> pd.DataFrame: + # TODO + + with multiprocessing_sqlite_lock(config.data_config) as lock: + multiprocessing_apply( + _match_ground_truth_frame( + gt_labels=gt_labels, + config=config, + scale=scale, + write_lock=lock, + ), + range(gt_labels.shape[0]), + n_workers=config.segmentation_config.n_workers, + desc="Matching hierarchy nodes with ground-truth", + ) + + df_nodes = _get_matched_nodes_df(config.data_config.database_path) + + return df_nodes + + # if scale is not None: + # cols = ["z", "y", "x"][-len(scale) :] + # gt_df[cols] *= scale + + # if "z" not in gt_df.columns: + # gt_df["z"] = 0.0 + + # if len(gt_df) > 0: + # max_distance = estimate_drift(gt_df) + # if not np.isnan(max_distance) or max_distance > 0: + # config.linking_config.max_distance = max_distance + 1.0 + + # if "solution" in df.columns: + # matched_df = df[df["solution"] > 0.5] + # config.segmentation_config.min_area = matched_df["area"].min() * 0.95 + + # config.segmentation_config.max_area = matched_df["area"].max() * 1.025 + + # config.segmentation_config.min_frontier = max( + # matched_df["parent_frontier"].min() - 0.025, 0.0 + # ) + # else: + # LOG.warning("No nodes were matched. Keeping previous configuration.") + + # config.data_config.database = prev_db + + # return config, df diff --git a/ultrack/core/linking/processing.py b/ultrack/core/linking/processing.py index f461133..1941d8e 100644 --- a/ultrack/core/linking/processing.py +++ b/ultrack/core/linking/processing.py @@ -146,25 +146,52 @@ def _process( next_nodes = [row[0] for row in query] next_shift = np.asarray([row[1:] for row in query]) - current_pos = np.asarray([n.centroid for n in current_nodes]) - next_pos = np.asarray([n.centroid for n in next_nodes], dtype=np.float32) + compute_spatial_neighbors( + time, + config, + current_nodes, + next_nodes, + next_shift, + scale=scale, + table_name=LinkDB.__tablename__, + db_path=db_path, + images=images, + write_lock=write_lock, + ) - n_dim = next_pos.shape[1] - next_shift = next_shift[:, -n_dim:] # matching positions dimensions - next_pos += next_shift + +def compute_spatial_neighbors( + time: int, + config: LinkingConfig, + source_nodes: List[Node], + target_nodes: List[Node], + target_shift: ArrayLike, + scale: Optional[Sequence[float]], + table_name: str, + db_path: str, + images: Sequence[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock] = None, +) -> pd.DataFrame: + + source_pos = np.asarray([n.centroid for n in source_nodes]) + target_pos = np.asarray([n.centroid for n in target_nodes], dtype=np.float32) + + n_dim = target_pos.shape[1] + target_shift = target_shift[:, -n_dim:] # matching positions dimensions + target_pos += target_shift if scale is not None: min_n_dim = min(n_dim, len(scale)) scale = scale[-min_n_dim:] - current_pos = current_pos[..., -min_n_dim:] * scale - next_pos = next_pos[..., -min_n_dim:] * scale + source_pos = source_pos[..., -min_n_dim:] * scale + target_pos = target_pos[..., -min_n_dim:] * scale # finds neighbors nodes within the radius # and connect the pairs with highest edge weight - current_kdtree = KDTree(current_pos) + current_kdtree = KDTree(source_pos) distances, neighbors = current_kdtree.query( - next_pos, + target_pos, # twice as expected because we select the nearest with highest edge weight k=2 * config.max_neighbors, distance_upper_bound=config.max_distance, @@ -173,8 +200,8 @@ def _process( if len(images) > 0: filtered_by_color = color_filtering_mask( time, - current_nodes, - next_nodes, + source_nodes, + target_nodes, images, neighbors, config.z_score_threshold, @@ -182,23 +209,23 @@ def _process( else: filtered_by_color = np.ones_like(neighbors, dtype=bool) - int_next_shift = np.round(next_shift).astype(int) + int_next_shift = np.round(target_shift).astype(int) # NOTE: moving bbox with shift, MUST be after `feature computation` - for node, shift in zip(next_nodes, int_next_shift): + for node, shift in zip(target_nodes, int_next_shift): node.bbox[:n_dim] += shift node.bbox[-n_dim:] += shift distance_w = config.distance_weight links = [] - for i, node in enumerate(next_nodes): + for i, node in enumerate(target_nodes): valid = (~np.isinf(distances[i])) & filtered_by_color[i] valid_neighbors = neighbors[i, valid] neigh_distances = distances[i, valid] neighborhood = [] for neigh_idx, neigh_dist in zip(valid_neighbors, neigh_distances): - neigh = current_nodes[neigh_idx] + neigh = source_nodes[neigh_idx] edge_weight = node.IoU(neigh) - distance_w * neigh_dist # using dist as a tie-breaker neighborhood.append( @@ -219,13 +246,14 @@ def _process( with write_lock if write_lock is not None else nullcontext(): LOG.info(f"Pushing links from time {time} to {db_path}") + connect_args = {"timeout": 45} if write_lock is not None else {} engine = sqla.create_engine( db_path, hide_parameters=True, connect_args=connect_args ) with engine.begin() as conn: - df.to_sql( - name=LinkDB.__tablename__, con=conn, if_exists="append", index=False - ) + df.to_sql(name=table_name, con=conn, if_exists="append", index=False) + + return df def link( diff --git a/ultrack/core/segmentation/node.py b/ultrack/core/segmentation/node.py index 0aaac3d..dec7e74 100644 --- a/ultrack/core/segmentation/node.py +++ b/ultrack/core/segmentation/node.py @@ -312,6 +312,8 @@ def from_mask( bbox = ndi.find_objects(mask)[0] mask = mask[bbox] + bbox = np.asarray(bbox) + if mask.ndim * 2 != len(bbox): raise ValueError( f"Bounding box {bbox} does not match 2x mask ndim {mask.ndim}" diff --git a/ultrack/core/solve/sqlgtmatcher.py b/ultrack/core/solve/sqlgtmatcher.py new file mode 100644 index 0000000..e6522d0 --- /dev/null +++ b/ultrack/core/solve/sqlgtmatcher.py @@ -0,0 +1,156 @@ +import logging +from contextlib import nullcontext +from typing import Literal, Optional + +import fasteners +import mip +import mip.exceptions +import pandas as pd +import sqlalchemy as sqla +from sqlalchemy.orm import Session + +from ultrack.config.config import MainConfig +from ultrack.core.database import GTLinkDB, NodeDB, OverlapDB + +LOG = logging.getLogger(__name__) + + +class SQLGTMatcher: + def __init__( + self, + config: MainConfig, + solver: Literal["CBC", "GUROBI", ""] = "", + write_lock: Optional[fasteners.InterProcessLock] = None, + eps: float = 1e-3, + ) -> None: + # TODO + + self._data_config = config.data_config + self._write_lock = write_lock + self._eps = eps + + try: + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) + except mip.exceptions.InterfacingError as e: + LOG.warning(e) + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name="CBC") + + def _add_nodes(self, time: int) -> None: + # TODO + engine = sqla.create_engine(self._data_config.database_path) + + # t = 0 is hierarchies + # t = 1 is ground-truth nodes + with Session(engine) as session: + query = session.query(NodeDB.id, NodeDB.t).where(NodeDB.t == time) + self._nodes_df = pd.read_sql(query.statement, session.bind, index_col="id") + + size = len(self._nodes_df) + self._nodes = self._model.add_var_tensor( + (size,), name="nodes", var_type=mip.BINARY + ) + + # hierarchy overlap constraints + with Session(engine) as session: + query = session.query(OverlapDB).join( + NodeDB, NodeDB.id == OverlapDB.node_id + ) + overlap_df = pd.read_sql(query.statement, session.bind) + + overlap_df["node_id"] = self._nodes_df.index.get_indexer(overlap_df["node_id"]) + overlap_df["ancestor_id"] = self._nodes_df.index.get_indexer( + overlap_df["ancestor_id"] + ) + + for node_id, anc_id in zip(overlap_df["node_id"], overlap_df["ancestor_id"]): + self._model.add_constr(self._nodes[node_id] + self._nodes[anc_id] <= 1) + + def _add_edges(self, time: int) -> None: + # TODO + + if not hasattr(self, "_nodes"): + raise ValueError("Nodes must be added before adding edges.") + + engine = sqla.create_engine(self._data_config.database_path) + + with Session(engine) as session: + + query = ( + session.query(GTLinkDB) + .join(NodeDB, NodeDB.id == GTLinkDB.source_id) + .where(NodeDB.t == time) + ) + + self._edges_df = pd.read_sql(query.statement, session.bind) + + self._edges_df["source_id"] = self._nodes_df.index.get_indexer( + self._edges_df["source_id"] + ) + self._edges_df.reset_index(drop=True, inplace=True) + + self._edges = self._model.add_var_tensor( + (len(self._edges_df),), + name="edges", + var_type=mip.BINARY, + ) + # small value to prefer not selecting edges than bad ones + # setting objective function + self._model.objective = mip.xsum( + (self._edges_df["weight"].to_numpy() - self._eps) * self._edges + ) + + # source_id is time point T (hierarchies id) + # target_id is time point T+1 (ground-truth) + for source_id, group in self._edges_df.groupby("source_id", as_index=False): + print("SOURCE", group) + self._model.add_constr( + self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) + ) + + for _, group in self._edges_df.groupby("target_id", as_index=False): + print("TARGET", group) + self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) + + def add_solution(self) -> None: + # TODO + engine = sqla.create_engine(self._data_config.database_path) + + edges_records = [] + for idx, e_var in zip(self._edges_df["id"], self._edges): + if e_var.x > 0.5: + edges_records.append( + { + "link_id": idx, + "selected": e_var.x > 0.5, + } + ) + + LOG.info(f"Selected {len(edges_records)} edges to ground-truth") + + with self._write_lock if self._write_lock is not None else nullcontext(): + with Session(engine) as session: + stmt = ( + sqla.update(GTLinkDB) + .where(GTLinkDB.id == sqla.bindparam("link_id")) + .values(selected=sqla.bindparam("selected")) + ) + session.connection().execute( + stmt, + edges_records, + execution_options={"synchronize_session": False}, + ) + session.commit() + + def __call__(self, time: int) -> float: + # TODO + + LOG.info(f"Computing GT matching for time {time}") + + self._add_nodes(time) + self._add_edges(time) + self._model.optimize() + self.add_solution() + + n_selected_vars = sum(e_var.x > 0.5 for e_var in self._edges) + + return self._model.objective_value + n_selected_vars * self._eps diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index d1c6a77..b83eee1 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -9,7 +9,6 @@ from ultrack import export_tracks_by_extension from ultrack.config import MainConfig -from ultrack.core.autotune import match_to_ground_truth from ultrack.core.export import ( to_ctc, to_tracks_layer, @@ -17,6 +16,7 @@ tracks_layer_to_trackmate, tracks_to_zarr, ) +from ultrack.core.gt_matching import match_to_ground_truth from ultrack.core.linking.processing import add_links, link from ultrack.core.main import track from ultrack.core.segmentation.processing import get_nodes_features, segment @@ -172,6 +172,6 @@ def add_links(self, **kwargs) -> None: self.status |= TrackerStatus.LINKED @functools.wraps(match_to_ground_truth) - def match_to_ground_truth(self, **kwargs) -> None: + def match_to_ground_truth(self, **kwargs) -> pd.DataFrame: self._assert_segmented("match_to_ground_truth") - match_to_ground_truth(config=self.config, **kwargs) + return match_to_ground_truth(config=self.config, **kwargs) From 94f3dc5bbba80bc1fa7167d424f4c3bec889945f Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 4 Sep 2024 15:41:26 -0400 Subject: [PATCH 38/45] bug fixing --- ultrack/core/gt_matching.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ultrack/core/gt_matching.py b/ultrack/core/gt_matching.py index 7bcecc9..b3e2ccf 100644 --- a/ultrack/core/gt_matching.py +++ b/ultrack/core/gt_matching.py @@ -126,26 +126,26 @@ def _get_matched_nodes_df(database_path: str) -> pd.DataFrame: session.query( GTLinkDB.source_id, GTLinkDB.target_id, - GTNodeDB.z, - GTNodeDB.y, - GTNodeDB.x, - ) - .join(GTNodeDB, GTNodeDB.id == GTLinkDB.target_id) - .where(GTLinkDB.selected) + # GTNodeDB.z, + # GTNodeDB.y, + # GTNodeDB.x, + ).where(GTLinkDB.selected) + # .join(GTNodeDB, GTNodeDB.id == GTLinkDB.target_id) ) gt_df = pd.read_sql( gt_edge_query.statement, session.bind, index_col="source_id" ) gt_df.rename( - # columns={"target_id": "gt_id"}, # , "z": "gt_z", "y": "gt_y", "x": "gt_x"}, - columns={"target_id": "gt_id", "z": "gt_z", "y": "gt_y", "x": "gt_x"}, + columns={ + "target_id": "gt_track_id" + }, # , "z": "gt_z", "y": "gt_y", "x": "gt_x"}, inplace=True, ) LOG.info(f"Found {len(node_df)} nodes and {len(gt_df)} ground-truth links") node_df = node_df.join(gt_df) - node_df["gt_id"] = node_df["gt_id"].fillna(NO_PARENT).astype(int) + node_df["gt_track_id"] = node_df["gt_track_id"].fillna(NO_PARENT).astype(int) # frontiers = node_df["frontier"] # node_df["parent_frontier"] = node_df["hier_parent_id"].map( From 76bdcebe69c9dc8b91f7fc5186857eddfde0ca81 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 2 Oct 2024 10:17:26 +0200 Subject: [PATCH 39/45] addings docs gt matching --- ultrack/core/gt_matching.py | 59 ++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/ultrack/core/gt_matching.py b/ultrack/core/gt_matching.py index b3e2ccf..ab3bfcd 100644 --- a/ultrack/core/gt_matching.py +++ b/ultrack/core/gt_matching.py @@ -1,6 +1,6 @@ import logging from contextlib import nullcontext -from typing import Optional, Tuple +from typing import Optional import fasteners import numpy as np @@ -31,9 +31,24 @@ def _match_ground_truth_frame( config: MainConfig, scale: Optional[ArrayLike], write_lock: Optional[fasteners.InterProcessLock], -) -> Tuple[pd.DataFrame, pd.DataFrame]: - # TODO - +) -> None: + """ + Matches candidate hypotheses to ground-truth labels for a given time point. + Segmentation hypotheses must be pre-computed. + + Parameters + ---------- + time : int + Time point to match. + gt_labels : ArrayLike + Ground-truth labels. + config : MainConfig + Configuration object. + scale : Optional[ArrayLike] + Scale of the data for distance computation. + write_lock : Optional[fasteners.InterProcessLock] + Lock for writing to the database. + """ gt_labels = np.asarray(gt_labels[time]) gt_props = regionprops(gt_labels) @@ -108,8 +123,20 @@ def _match_ground_truth_frame( LOG.info(f"time {time} mean score: {mean_score:0.4f}") -def _get_matched_nodes_df(database_path: str) -> pd.DataFrame: - # TODO +def _get_nodes_df_with_matches(database_path: str) -> pd.DataFrame: + """ + Gets nodes data frame with matched ground-truth labels. + + Parameters + ---------- + database_path : str + Path to the database file. + + Returns + ------- + pd.DataFrame + DataFrame with matched nodes. + """ engine = sqla.create_engine(database_path) with Session(engine) as session: @@ -163,7 +190,23 @@ def match_to_ground_truth( gt_labels: ArrayLike, scale: Optional[ArrayLike] = None, ) -> pd.DataFrame: - # TODO + """ + Matches nodes to ground-truth labels returning additional features for automatic parameter tuning. + + Parameters + ---------- + config : MainConfig + Configuration object. + gt_labels : ArrayLike + Ground-truth labels. + scale : Optional[ArrayLike], optional + Scale of the data for distance computation, by default None. + + Returns + ------- + pd.DataFrame + Data frame containing matched ground-truth labels to their respective nodes. + """ with multiprocessing_sqlite_lock(config.data_config) as lock: multiprocessing_apply( @@ -178,7 +221,7 @@ def match_to_ground_truth( desc="Matching hierarchy nodes with ground-truth", ) - df_nodes = _get_matched_nodes_df(config.data_config.database_path) + df_nodes = _get_nodes_df_with_matches(config.data_config.database_path) return df_nodes From d213a7c2881dacafd26eb5458691a3c80747b8eb Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 3 Oct 2024 14:57:14 +0200 Subject: [PATCH 40/45] removed unnecessary prints --- ultrack/core/solve/sqlgtmatcher.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ultrack/core/solve/sqlgtmatcher.py b/ultrack/core/solve/sqlgtmatcher.py index e6522d0..5506636 100644 --- a/ultrack/core/solve/sqlgtmatcher.py +++ b/ultrack/core/solve/sqlgtmatcher.py @@ -39,8 +39,6 @@ def _add_nodes(self, time: int) -> None: # TODO engine = sqla.create_engine(self._data_config.database_path) - # t = 0 is hierarchies - # t = 1 is ground-truth nodes with Session(engine) as session: query = session.query(NodeDB.id, NodeDB.t).where(NodeDB.t == time) self._nodes_df = pd.read_sql(query.statement, session.bind, index_col="id") @@ -102,13 +100,11 @@ def _add_edges(self, time: int) -> None: # source_id is time point T (hierarchies id) # target_id is time point T+1 (ground-truth) for source_id, group in self._edges_df.groupby("source_id", as_index=False): - print("SOURCE", group) self._model.add_constr( self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) ) for _, group in self._edges_df.groupby("target_id", as_index=False): - print("TARGET", group) self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) def add_solution(self) -> None: From f36836a3f42cdbf4cdb35e7eb7b46235b963a074 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 6 Oct 2024 16:54:45 -0700 Subject: [PATCH 41/45] added image wrapper for feature computation --- ultrack/core/segmentation/processing.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/ultrack/core/segmentation/processing.py b/ultrack/core/segmentation/processing.py index 8e0d38c..36c62fc 100644 --- a/ultrack/core/segmentation/processing.py +++ b/ultrack/core/segmentation/processing.py @@ -86,6 +86,23 @@ def _insert_db( overlaps.clear() +class _ImageCachedLazyLoader: + """ + Wrapper class to cache dask/zarr data loading for feature computation. + """ + + def __init__(self, image: ArrayLike): + self._image = image + self._current_t = -1 + self._frame = None + + def __getitem__(self, index: int) -> np.ndarray: + if index != self._current_t: + self._frame = np.asarray(self._image[index]) + self._current_t = index + return self._frame + + def create_feats_callback( shape: ArrayLike, image: Optional[ArrayLike], properties: List[str] ) -> Callable[[Node], np.ndarray]: @@ -108,6 +125,9 @@ def create_feats_callback( """ mask = np.zeros(shape, dtype=bool) + if image is not None: + image = _ImageCachedLazyLoader(image) + def _feats_callback(node: Node) -> np.ndarray: node.paint_buffer(mask, True, include_time=False) @@ -115,7 +135,7 @@ def _feats_callback(node: Node) -> np.ndarray: if image is None: frame = None else: - frame = np.asarray(image[node.time]) + frame = image[node.time] obj = RegionProperties( node.slice, From b53de0dcb80776fb349166a3a6e5529e5a17f4ce Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 6 Oct 2024 17:43:32 -0700 Subject: [PATCH 42/45] fix inverted EDT processing for dask arrays / non-binary masks --- ultrack/imgproc/segmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ultrack/imgproc/segmentation.py b/ultrack/imgproc/segmentation.py index c0a8cc9..24a8f52 100644 --- a/ultrack/imgproc/segmentation.py +++ b/ultrack/imgproc/segmentation.py @@ -204,6 +204,7 @@ def inverted_edt( ArrayLike Inverted and normalized EDT. """ + mask = np.asarray(mask) if axis is None: dist = edt.edt(mask, anisotropy=voxel_size) else: @@ -216,7 +217,7 @@ def inverted_edt( ) dist = dist / dist.max() dist = 1.0 - dist - dist[~mask] = 1 + dist[mask == 0] = 1 return dist From 9b485ba115d267d44259e7e1ee95bd319637101e Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 6 Oct 2024 18:03:03 -0700 Subject: [PATCH 43/45] improved array apply, adding automatic output array creation --- ultrack/utils/array.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 8ad457f..e7432ac 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -3,7 +3,7 @@ import shutil import warnings from pathlib import Path -from typing import Callable, Literal, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union import numpy as np import zarr @@ -95,30 +95,46 @@ def check_array_chunk(array: ArrayLike) -> None: def array_apply( *in_arrays: ArrayLike, - out_array: ArrayLike, func: Callable, + out_array: Optional[ArrayLike] = None, axis: Union[Tuple[int], int] = 0, + out_zarr_kwargs: Optional[Dict[str, Any]] = {}, **kwargs, -) -> None: +) -> zarr.Array: """Apply a function over a given dimension of an array. Parameters ---------- in_arrays : ArrayLike Arrays to apply function to. - out_array : ArrayLike - Array to store result of function. func : function Function to apply over time. + out_array : ArrayLike, optional + Array to store result of function if not provided a new array is created, by default None. + See `create_zarr` for more information. axis : Union[Tuple[int], int], optional Axis of data to apply func, by default 0. args : tuple Positional arguments to pass to func. + out_zarr_kwargs : Dict[str, Any], optional + Keyword arguments to pass to `create_zarr`. + If `dtype` and `shape` are not provided, they are inferred from the first input array. **kwargs : Keyword arguments to pass to func. + + Returns + ------- + zarr.Array + `out_array` or new array with result of function. """ name = func.__name__ if hasattr(func, "__name__") else type(func).__name__ + if out_array is None: + for param in ("shape", "dtype"): + if param not in out_zarr_kwargs: + out_zarr_kwargs[param] = getattr(in_arrays[0], param) + out_array = create_zarr(**out_zarr_kwargs) + try: in_shape = [arr.shape for arr in in_arrays] np.broadcast_shapes(out_array.shape, *in_shape) @@ -142,6 +158,8 @@ def array_apply( output_shape = out_array[indexing].shape out_array[indexing] = np.broadcast_to(func_result, output_shape) + return out_array + def create_zarr( shape: Tuple[int, ...], @@ -193,15 +211,3 @@ def create_zarr( chunks = large_chunk_size(shape, dtype=dtype) return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) - - -def assert_same_length(**kwargs) -> None: - """Validates if key-word arguments have the same length.""" - for k1, v1 in kwargs.items(): - if v1 is None: - continue - for k2, v2 in kwargs.items(): - if v2 is not None and len(v2) != len(v1): - raise ValueError( - f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}." - ) From 14696e52316a25ff9f02a4079ff1916b8e5fa319 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Sun, 6 Oct 2024 18:07:07 -0700 Subject: [PATCH 44/45] undoing wrong removal --- ultrack/utils/array.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index e7432ac..3a49fe1 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -211,3 +211,15 @@ def create_zarr( chunks = large_chunk_size(shape, dtype=dtype) return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) + + +def assert_same_length(**kwargs) -> None: + """Validates if key-word arguments have the same length.""" + for k1, v1 in kwargs.items(): + if v1 is None: + continue + for k2, v2 in kwargs.items(): + if v2 is not None and len(v2) != len(v1): + raise ValueError( + f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}." + ) From ce6c575231c15ad2067bb9d08e11c7a5ff660468 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 10 Oct 2024 15:00:35 -0700 Subject: [PATCH 45/45] enabling cell division into matching --- ultrack/core/gt_matching.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ultrack/core/gt_matching.py b/ultrack/core/gt_matching.py index ab3bfcd..f905255 100644 --- a/ultrack/core/gt_matching.py +++ b/ultrack/core/gt_matching.py @@ -1,6 +1,6 @@ import logging from contextlib import nullcontext -from typing import Optional +from typing import Dict, Optional import fasteners import numpy as np @@ -189,6 +189,7 @@ def match_to_ground_truth( config: MainConfig, gt_labels: ArrayLike, scale: Optional[ArrayLike] = None, + track_id_graph: Optional[Dict[int, int]] = None, ) -> pd.DataFrame: """ Matches nodes to ground-truth labels returning additional features for automatic parameter tuning. @@ -201,6 +202,8 @@ def match_to_ground_truth( Ground-truth labels. scale : Optional[ArrayLike], optional Scale of the data for distance computation, by default None. + track_id_graph : Optional[Dict[int, int]], optional + Ground-truth graph of track IDs, by default None. Returns ------- @@ -223,6 +226,13 @@ def match_to_ground_truth( df_nodes = _get_nodes_df_with_matches(config.data_config.database_path) + if track_id_graph is not None: + df_nodes["gt_parent_track_id"] = df_nodes["gt_track_id"].apply( + lambda x: track_id_graph.get(x, NO_PARENT) + ) + else: + df_nodes["gt_parent_track_id"] = NO_PARENT + return df_nodes # if scale is not None: