diff --git a/ultrack/cli/data_summary.py b/ultrack/cli/data_summary.py index 7d577ee..2661113 100644 --- a/ultrack/cli/data_summary.py +++ b/ultrack/cli/data_summary.py @@ -10,9 +10,10 @@ from ultrack.cli.utils import config_option from ultrack.config import MainConfig -from ultrack.core.database import NO_PARENT, LinkDB, NodeDB +from ultrack.core.database import LinkDB, NodeDB from ultrack.core.export.utils import solution_dataframe_from_sql from ultrack.tracks.graph import add_track_ids_to_tracks_df +from ultrack.utils.constants import NO_PARENT from ultrack.utils.printing import pretty_print_df diff --git a/ultrack/config/trackingconfig.py b/ultrack/config/trackingconfig.py index b4c7850..ef04269 100644 --- a/ultrack/config/trackingconfig.py +++ b/ultrack/config/trackingconfig.py @@ -49,7 +49,7 @@ class TrackingConfig(BaseModel): """``SPECIAL``: Solver method, `reference `_""" link_function: LinkFunctionChoices = "power" - """``SPECIAL``: Function used to transform the edge weights, `identity` or `power`""" + """``SPECIAL``: Function used to transform the edge and node weights, `identity` or `power`""" power: float = 4 r"""``SPECIAL``: Expoent :math:`\eta` of power transform, :math:`w_{pq}^\eta` """ diff --git a/ultrack/core/database.py b/ultrack/core/database.py index eb0a74f..d7a2422 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -23,9 +23,8 @@ from sqlalchemy.orm import Session, declarative_base from ultrack.config.dataconfig import DatabaseChoices, DataConfig - -# constant value to indicate it has no parent -NO_PARENT = -1 +from ultrack.utils.array import assert_same_length +from ultrack.utils.constants import NO_PARENT Base = declarative_base() @@ -93,6 +92,7 @@ class NodeDB(Base): area = Column(Integer) selected = Column(Boolean, default=False) pickle = Column(MaybePickleType) + node_prob = Column(Float, default=-1.0) segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN) node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) appear_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) @@ -182,6 +182,8 @@ def set_node_values( if hasattr(v, "tolist"): kwargs[k] = v.tolist() + assert_same_length(**kwargs) + records = [ {k: v[i] for k, v in kwargs.items()} for i in range(len(kwargs["node_id"])) ] diff --git a/ultrack/core/export/_test/test_ctc.py b/ultrack/core/export/_test/test_ctc.py index c527a82..f451874 100644 --- a/ultrack/core/export/_test/test_ctc.py +++ b/ultrack/core/export/_test/test_ctc.py @@ -2,9 +2,9 @@ import pandas as pd import pytest -from ultrack.core.database import NO_PARENT from ultrack.core.export.ctc import ctc_compress_forest, stitch_tracks_df from ultrack.tracks.graph import add_track_ids_to_tracks_df +from ultrack.utils.constants import NO_PARENT @pytest.fixture diff --git a/ultrack/core/export/_test/test_networkx.py b/ultrack/core/export/_test/test_networkx.py index 70eb9a7..5222d5f 100644 --- a/ultrack/core/export/_test/test_networkx.py +++ b/ultrack/core/export/_test/test_networkx.py @@ -1,8 +1,8 @@ import pandas as pd import pytest -from ultrack.core.database import NO_PARENT from ultrack.core.export import tracks_layer_to_networkx +from ultrack.utils.constants import NO_PARENT @pytest.mark.parametrize("children_to_parent", [True, False]) diff --git a/ultrack/core/export/_test/test_trackmate.py b/ultrack/core/export/_test/test_trackmate.py index 4cb4f29..7cfcee3 100644 --- a/ultrack/core/export/_test/test_trackmate.py +++ b/ultrack/core/export/_test/test_trackmate.py @@ -4,8 +4,8 @@ import pandas as pd import pytest -from ultrack.core.database import NO_PARENT from ultrack.core.export.trackmate import tracks_layer_to_trackmate +from ultrack.utils.constants import NO_PARENT pytrackmate = pytest.importorskip("pytrackmate") diff --git a/ultrack/core/export/ctc.py b/ultrack/core/export/ctc.py index f0d0f13..6475b0e 100644 --- a/ultrack/core/export/ctc.py +++ b/ultrack/core/export/ctc.py @@ -19,7 +19,7 @@ from ultrack.config.config import MainConfig from ultrack.config.dataconfig import DataConfig -from ultrack.core.database import NO_PARENT, NodeDB +from ultrack.core.database import NodeDB from ultrack.core.export.utils import ( export_segmentation_generic, filter_nodes_generic, @@ -32,6 +32,7 @@ tracks_df_forest, ) from ultrack.tracks.stats import estimate_drift +from ultrack.utils.constants import NO_PARENT from ultrack.utils.data import validate_and_overwrite_path logging.basicConfig() diff --git a/ultrack/core/export/networkx.py b/ultrack/core/export/networkx.py index 66681cc..f429ec8 100644 --- a/ultrack/core/export/networkx.py +++ b/ultrack/core/export/networkx.py @@ -4,9 +4,9 @@ import pandas as pd from ultrack.config.config import MainConfig -from ultrack.core.database import NO_PARENT from ultrack.core.export.tracks_layer import to_tracks_layer from ultrack.tracks.graph import _create_tracks_forest +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__) diff --git a/ultrack/core/export/trackmate.py b/ultrack/core/export/trackmate.py index ba17acc..4a03a50 100644 --- a/ultrack/core/export/trackmate.py +++ b/ultrack/core/export/trackmate.py @@ -6,8 +6,8 @@ import pandas as pd from ultrack.config.config import MainConfig -from ultrack.core.database import NO_PARENT from ultrack.core.export.tracks_layer import to_tracks_layer +from ultrack.utils.constants import NO_PARENT def _set_filter_elem(elem: ET.Element) -> None: diff --git a/ultrack/core/export/utils.py b/ultrack/core/export/utils.py index 27b7976..cc08e09 100644 --- a/ultrack/core/export/utils.py +++ b/ultrack/core/export/utils.py @@ -9,8 +9,9 @@ from toolz import curry from ultrack.config.dataconfig import DataConfig -from ultrack.core.database import NO_PARENT, NodeDB +from ultrack.core.database import NodeDB from ultrack.core.segmentation.node import Node +from ultrack.utils.constants import NO_PARENT from ultrack.utils.multiprocessing import multiprocessing_apply LOG = logging.getLogger(__name__) diff --git a/ultrack/core/linking/_test/test_link_utils.py b/ultrack/core/linking/_test/test_link_utils.py index e7d57de..3cd1303 100644 --- a/ultrack/core/linking/_test/test_link_utils.py +++ b/ultrack/core/linking/_test/test_link_utils.py @@ -3,8 +3,9 @@ from sqlalchemy.orm import Session from ultrack.config.config import MainConfig -from ultrack.core.database import NO_PARENT, LinkDB, NodeDB +from ultrack.core.database import LinkDB, NodeDB from ultrack.core.linking.utils import clear_linking_data +from ultrack.utils.constants import NO_PARENT @pytest.mark.parametrize( diff --git a/ultrack/core/solve/_test/test_sql_tracking.py b/ultrack/core/solve/_test/test_sql_tracking.py index e7b1496..b7c05ef 100644 --- a/ultrack/core/solve/_test/test_sql_tracking.py +++ b/ultrack/core/solve/_test/test_sql_tracking.py @@ -6,8 +6,9 @@ from ultrack import solve, to_tracks_layer from ultrack.config.config import MainConfig -from ultrack.core.database import NO_PARENT, LinkDB, NodeDB, VarAnnotation +from ultrack.core.database import LinkDB, NodeDB, VarAnnotation from ultrack.core.solve.sqltracking import SQLTracking +from ultrack.utils.constants import NO_PARENT _CONFIG_PARAMS = { "segmentation.n_workers": 4, @@ -125,7 +126,6 @@ def test_annotations_sql_tracking( solve(config, overwrite=True, use_annotations=True) tracks_df, _ = to_tracks_layer(config) - print(tracks_df) engine = sqla.create_engine(config.data_config.database_path) with Session(engine) as session: @@ -136,6 +136,5 @@ def test_annotations_sql_tracking( solve(config, overwrite=True, use_annotations=True) tracks_df_annot, _ = to_tracks_layer(config) - print(tracks_df_annot) assert len(tracks_df) > len(tracks_df_annot) diff --git a/ultrack/core/solve/solver/_test/test_solvers.py b/ultrack/core/solve/solver/_test/test_solvers.py index 8999ec5..8829342 100644 --- a/ultrack/core/solve/solver/_test/test_solvers.py +++ b/ultrack/core/solve/solver/_test/test_solvers.py @@ -1,34 +1,25 @@ -from itertools import product - import numpy as np import pandas as pd import pytest from ultrack.config.config import MainConfig -from ultrack.core.solve.solver.base_solver import BaseSolver -from ultrack.core.solve.solver.heuristic.heuristic_solver import HeuristicSolver from ultrack.core.solve.solver.mip_solver import MIPSolver @pytest.mark.parametrize( - "solver,config_content", - list( - product( - [MIPSolver, HeuristicSolver], - [ - { - "tracking.appear_weight": -0.25, - "tracking.disappear_weight": -1.0, - "tracking.division_weight": -0.5, - "tracking.link_function": "identity", - "tracking.bias": 0, - } - ], - ) - ), - indirect=["config_content"], + "config_content", + [ + { + "tracking.appear_weight": -0.25, + "tracking.disappear_weight": -1.0, + "tracking.division_weight": -0.5, + "tracking.link_function": "identity", + "tracking.bias": 0, + } + ], + indirect=True, ) -def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> None: +def test_solvers_optimize(config_instance: MainConfig) -> None: """ This demo builds a very simple graph with 7 nodes and a single overlap constraint (2,5) and a two possible divisions on 2 and 6. @@ -53,7 +44,7 @@ def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> No Result: 0.5 + 0.5 + 1.0 + 0.7 - division_weight """ - solver = solver(config_instance.tracking_config) + solver = MIPSolver(config_instance.tracking_config) nodes = np.array([1, 2, 3, 4, 5, 6, 7]) is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool) @@ -160,6 +151,75 @@ def test_fixed_nodes_constraint_solver(config_instance: MainConfig) -> None: ) +def test_solver_with_node_probabilities(config_instance: MainConfig) -> None: + """ + Edge -C- denotes contraint. + + Graph: + + 0.3 0.7 1.0 1.0 + 1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4 + | \\ / \\ due linting software + C 1.0 0.9 + | \\ / + 5 - 0.5 - 6 - 0.7 - 7 + node w. 0.5 1.0 1.0 + + Solution: + + 0.3 0.7 1.0 1.0 + 1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4 + \\ + 1.0 + \\ + 6 - 0.7 - 7 + node w. 0.5 1.0 1.0 + + Result: 0.3 + 0.7 + 1.0 + 1.0 + 1.0 + 1.0 + + 0.5 + 0.5 + 0.5 + 1.0 + 0.7 - division_weight + """ + solver = MIPSolver(config_instance.tracking_config) + + nodes = np.array([1, 2, 3, 4, 5, 6, 7]) + nodes_probs = np.array([0.3, 0.7, 1.0, 1.0, 0.5, 1.0, 1.0]) + is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool) + is_last = np.array([0, 0, 0, 1, 0, 0, 1], dtype=bool) + + solver.add_nodes(nodes, is_first, is_last, nodes_prob=nodes_probs) + + edges = np.array([[1, 2], [2, 3], [2, 6], [3, 4], [5, 6], [6, 4], [6, 7]]) + + weights = np.array([0.5, 0.5, 1.0, 0.5, 0.5, 0.9, 0.7]) + solver.add_edges(edges[:, 0], edges[:, 1], weights) + + solver.set_standard_constraints() + + solver.add_overlap_constraints([2], [5]) + + objective = solver.optimize() + solution = solver.solution() + + expected_solution = pd.DataFrame( + data=[pd.NA, 1, 2, 3, 2, 6], + index=[1, 2, 3, 4, 6, 7], + columns=["parent_id"], + dtype=pd.Int64Dtype(), + ) + expected_edges = np.array([1, 1, 1, 1, 0, 0, 1], dtype=bool) + + assert solution.shape == expected_solution.shape + assert np.all(solution.index.isin(expected_solution.index)) + assert np.all( + expected_solution.loc[solution.index, "parent_id"] == solution["parent_id"] + ) + assert np.allclose( + objective, + nodes_probs[expected_solution.index.to_numpy() - 1].sum() + + weights[expected_edges].sum() + + config_instance.tracking_config.division_weight, + ) + + def test_fixed_edges_constraint_solver(config_instance: MainConfig) -> None: """ Same graph as before but with a fixed division on node 6. diff --git a/ultrack/core/solve/solver/base_solver.py b/ultrack/core/solve/solver/base_solver.py index 277b419..5a56351 100644 --- a/ultrack/core/solve/solver/base_solver.py +++ b/ultrack/core/solve/solver/base_solver.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Literal +from typing import Literal, Optional import pandas as pd from numpy.typing import ArrayLike @@ -22,19 +22,13 @@ def __init__( """ self._config = config - @staticmethod - def _assert_same_length(**kwargs) -> None: - """Validates if key-word arguments have the same length.""" - for k1, v1 in kwargs.items(): - for k2, v2 in kwargs.items(): - if len(v2) != len(v1): - raise ValueError( - f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}." - ) - @abstractmethod def add_nodes( - self, indices: ArrayLike, is_first_t: ArrayLike, is_last_t: ArrayLike + self, + indices: ArrayLike, + is_first_t: ArrayLike, + is_last_t: ArrayLike, + node_prob: Optional[ArrayLike] = None, ) -> None: """Add nodes variables solver. @@ -46,6 +40,8 @@ def add_nodes( Boolean array indicating if it belongs to first time point and it won't receive appearance penalization. is_last_t : ArrayLike Boolean array indicating if it belongs to last time point and it won't receive disappearance penalization. + node_prob: Optional[ArrayLike] + If provided assigns a node probability score to the objective function. """ @abstractmethod diff --git a/ultrack/core/solve/solver/heuristic/heuristic_solver.py b/ultrack/core/solve/solver/heuristic/heuristic_solver.py index 39a809c..5f45b2a 100644 --- a/ultrack/core/solve/solver/heuristic/heuristic_solver.py +++ b/ultrack/core/solve/solver/heuristic/heuristic_solver.py @@ -7,11 +7,12 @@ from skimage.util._map_array import ArrayMap from ultrack.config.config import TrackingConfig -from ultrack.core.database import NO_PARENT from ultrack.core.solve.solver.base_solver import BaseSolver from ultrack.core.solve.solver.heuristic._numba_heuristic_solver import ( NumbaHeuristicSolver, ) +from ultrack.utils.array import assert_same_length +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__) @@ -68,9 +69,7 @@ def add_nodes( if hasattr(self, "_forbidden"): raise ValueError("Nodes have already been added.") - self._assert_same_length( - indices=indices, is_first_t=is_first_t, is_last_t=is_last_t - ) + assert_same_length(indices=indices, is_first_t=is_first_t, is_last_t=is_last_t) indices = np.asarray(indices) size = len(indices) @@ -111,7 +110,7 @@ def add_edges( if hasattr(self, "_weights"): raise ValueError("Edges have already been added.") - self._assert_same_length(weights=weights, sources=sources, targets=targets) + assert_same_length(weights=weights, sources=sources, targets=targets) self._weights = np.asarray( self._config.apply_link_function(weights), np.float32 diff --git a/ultrack/core/solve/solver/mip_solver.py b/ultrack/core/solve/solver/mip_solver.py index 1b624bf..b80fb6e 100644 --- a/ultrack/core/solve/solver/mip_solver.py +++ b/ultrack/core/solve/solver/mip_solver.py @@ -1,7 +1,7 @@ import logging import uuid from pathlib import Path -from typing import Literal +from typing import Literal, Optional import mip import numpy as np @@ -10,8 +10,9 @@ from skimage.util._map_array import ArrayMap from ultrack.config.config import TrackingConfig -from ultrack.core.database import NO_PARENT from ultrack.core.solve.solver.base_solver import BaseSolver +from ultrack.utils.array import assert_same_length +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__) @@ -75,7 +76,11 @@ def _setup_model_parameters(self) -> None: self._model.max_mip_gap = self._config.solution_gap def add_nodes( - self, indices: ArrayLike, is_first_t: ArrayLike, is_last_t: ArrayLike + self, + indices: ArrayLike, + is_first_t: ArrayLike, + is_last_t: ArrayLike, + nodes_prob: Optional[ArrayLike] = None, ) -> None: """Add nodes slack variables to gurobi model. @@ -87,12 +92,17 @@ def add_nodes( Boolean array indicating if it belongs to first time point and it won't receive appearance penalization. is_last_t : ArrayLike Boolean array indicating if it belongs to last time point and it won't receive disappearance penalization. + nodes_prob: Optional[ArrayLike] + If provided assigns a node probability score to the objective function. """ if self._nodes is not None: raise ValueError("Nodes have already been added.") - self._assert_same_length( - indices=indices, is_first_t=is_first_t, is_last_t=is_last_t + assert_same_length( + indices=indices, + is_first_t=is_first_t, + is_last_t=is_last_t, + nodes_prob=nodes_prob, ) LOG.info("# %s nodes at starting `t`.", np.sum(is_first_t)) @@ -119,10 +129,17 @@ def add_nodes( size, name="division", var_type=mip.BINARY ) + if nodes_prob is None: + node_weights = 0 + else: + nodes_prob = self._config.apply_link_function(np.asarray(nodes_prob)) + node_weights = mip.xsum(nodes_prob * self._nodes) + self._model.objective = ( mip.xsum(self._divisions * self._config.division_weight) + mip.xsum(self._appearances * appear_weight) + mip.xsum(self._disappearances * disappear_weight) + + node_weights ) def add_edges( @@ -142,7 +159,7 @@ def add_edges( if self._edges is not None: raise ValueError("Edges have already been added.") - self._assert_same_length(sources=sources, targets=targets, weights=weights) + assert_same_length(sources=sources, targets=targets, weights=weights) weights = self._config.apply_link_function(weights.astype(float)) diff --git a/ultrack/core/solve/sqltracking.py b/ultrack/core/solve/sqltracking.py index 2d047f5..fe31a5e 100644 --- a/ultrack/core/solve/sqltracking.py +++ b/ultrack/core/solve/sqltracking.py @@ -241,7 +241,7 @@ def _add_nodes(self, solver: BaseSolver, index: int) -> None: engine = sqla.create_engine(self._data_config.database_path) with Session(engine) as session: - query = session.query(NodeDB.id, NodeDB.t).where( + query = session.query(NodeDB.id, NodeDB.t, NodeDB.node_prob).where( NodeDB.t.between(start_time, end_time) ) df = pd.read_sql(query.statement, session.bind) @@ -251,10 +251,22 @@ def _add_nodes(self, solver: BaseSolver, index: int) -> None: LOG.info(f"Batch {index}, nodes with t between {start_time} and {end_time}") + n_invalid_prob = (df["node_prob"] < 0).sum() + if n_invalid_prob == df.shape[0]: + nodes_prob = None + elif n_invalid_prob == 0: + nodes_prob = df["node_prob"] + else: + raise ValueError( + "None or all nodes' probabilities must be provided found " + f"Found {df.shape[0] - n_invalid_prob} / {df.shape[0]} valid probs." + ) + solver.add_nodes( df["id"], df["t"] == start_time, df["t"] == end_time, + nodes_prob=nodes_prob, ) def _add_edges(self, solver: BaseSolver, index: int) -> None: diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index f3d21a3..a285341 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -21,6 +21,7 @@ from ultrack.core.segmentation.processing import segment from ultrack.core.solve.processing import solve from ultrack.imgproc.flow import add_flow +from ultrack.ml.classification import add_nodes_prob from ultrack.utils.deprecation import rename_argument @@ -153,3 +154,11 @@ def to_tracks_layer(self, *args, **kwargs) -> Tuple[pd.DataFrame, Dict]: def export_by_extension(self, filename: str, overwrite: bool = False) -> None: self._assert_solved() export_tracks_by_extension(self.config, filename, overwrite=overwrite) + + @functools.wraps(add_nodes_prob) + def add_nodes_prob( + self, + indices: ArrayLike, + probs: ArrayLike, + ) -> None: + add_nodes_prob(self.config, indices, probs) diff --git a/ultrack/ml/__init__.py b/ultrack/ml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ultrack/ml/classification.py b/ultrack/ml/classification.py new file mode 100644 index 0000000..de098fa --- /dev/null +++ b/ultrack/ml/classification.py @@ -0,0 +1,28 @@ +from numpy.typing import ArrayLike + +from ultrack.config.config import MainConfig +from ultrack.core.database import set_node_values + + +def add_nodes_prob( + config: MainConfig, + indices: ArrayLike, + probs: ArrayLike, +) -> None: + """ + Add nodes' probabilities to the segmentation/tracking database. + + Parameters + ---------- + config : MainConfig + Main configuration parameters. + indices : ArrayLike + Nodes' indices database index. + probs : ArrayLike + Nodes' probabilities. + """ + set_node_values( + config.data_config, + indices, + node_prob=probs, + ) diff --git a/ultrack/reader/_test/test_napari_reader.py b/ultrack/reader/_test/test_napari_reader.py index 43f508e..19c3636 100644 --- a/ultrack/reader/_test/test_napari_reader.py +++ b/ultrack/reader/_test/test_napari_reader.py @@ -7,8 +7,8 @@ from napari.plugins import _initialize_plugins from napari.viewer import ViewerModel -from ultrack.core.database import NO_PARENT from ultrack.reader.napari_reader import napari_get_reader +from ultrack.utils.constants import NO_PARENT @pytest.fixture diff --git a/ultrack/tracks/_test/test_tracks_gap_closing.py b/ultrack/tracks/_test/test_tracks_gap_closing.py index c7c659a..cba2e47 100644 --- a/ultrack/tracks/_test/test_tracks_gap_closing.py +++ b/ultrack/tracks/_test/test_tracks_gap_closing.py @@ -3,8 +3,8 @@ import zarr import zarr.storage -from ultrack.core.database import NO_PARENT from ultrack.tracks import close_tracks_gaps +from ultrack.utils.constants import NO_PARENT def test_gap_closing() -> None: diff --git a/ultrack/tracks/_test/test_tracks_graph.py b/ultrack/tracks/_test/test_tracks_graph.py index fe7ea94..82f2155 100644 --- a/ultrack/tracks/_test/test_tracks_graph.py +++ b/ultrack/tracks/_test/test_tracks_graph.py @@ -2,13 +2,13 @@ import pandas as pd import pytest -from ultrack.core.database import NO_PARENT from ultrack.tracks import ( filter_short_sibling_tracks, get_paths_to_roots, get_subgraph, split_trees, ) +from ultrack.utils.constants import NO_PARENT @pytest.fixture diff --git a/ultrack/tracks/_test/test_tracks_sorting.py b/ultrack/tracks/_test/test_tracks_sorting.py index bc9bac8..ba2caad 100644 --- a/ultrack/tracks/_test/test_tracks_sorting.py +++ b/ultrack/tracks/_test/test_tracks_sorting.py @@ -4,13 +4,13 @@ import pandas as pd from numba import typed, types -from ultrack.core.database import NO_PARENT from ultrack.tracks import ( left_first_search, sort_track_ids, sort_trees_by_length, sort_trees_by_max_radius, ) +from ultrack.utils.constants import NO_PARENT def test_sortrees_by_length() -> None: diff --git a/ultrack/tracks/_test/test_tracks_stats.py b/ultrack/tracks/_test/test_tracks_stats.py index 9a4903a..6c97bfc 100644 --- a/ultrack/tracks/_test/test_tracks_stats.py +++ b/ultrack/tracks/_test/test_tracks_stats.py @@ -4,13 +4,13 @@ import pandas as pd import pytest -from ultrack.core.database import NO_PARENT from ultrack.tracks.stats import ( estimate_drift, tracks_df_movement, tracks_length, tracks_profile_matrix, ) +from ultrack.utils.constants import NO_PARENT def spatial_df(group_drift: Sequence[int], length_per_group: int = 10) -> pd.DataFrame: diff --git a/ultrack/tracks/_test/test_tracks_video.py b/ultrack/tracks/_test/test_tracks_video.py index 00c0641..f002541 100644 --- a/ultrack/tracks/_test/test_tracks_video.py +++ b/ultrack/tracks/_test/test_tracks_video.py @@ -6,12 +6,12 @@ import pytest from napari.viewer import ViewerModel -from ultrack.core.database import NO_PARENT from ultrack.tracks.video import ( tracks_df_to_3D_video, tracks_df_to_moving_2D_plane_video, tracks_df_to_videos, ) +from ultrack.utils.constants import NO_PARENT pytest.importorskip("napari_animation") diff --git a/ultrack/tracks/gap_closing.py b/ultrack/tracks/gap_closing.py index f94047f..368c9c1 100644 --- a/ultrack/tracks/gap_closing.py +++ b/ultrack/tracks/gap_closing.py @@ -7,7 +7,7 @@ from scipy.spatial.distance import cdist from zarr.storage import Store -from ultrack.core.database import NO_PARENT +from ultrack.utils.constants import NO_PARENT from ultrack.utils.segmentation import SegmentationPainter, copy_segments diff --git a/ultrack/tracks/graph.py b/ultrack/tracks/graph.py index a48f4bb..377e204 100644 --- a/ultrack/tracks/graph.py +++ b/ultrack/tracks/graph.py @@ -8,7 +8,7 @@ from numpy.typing import ArrayLike from zarr.storage import Store -from ultrack.core.database import NO_PARENT +from ultrack.utils.constants import NO_PARENT from ultrack.utils.segmentation import SegmentationPainter, copy_segments LOG = logging.getLogger(__name__) diff --git a/ultrack/tracks/sorting.py b/ultrack/tracks/sorting.py index 68fad7e..833b248 100644 --- a/ultrack/tracks/sorting.py +++ b/ultrack/tracks/sorting.py @@ -7,13 +7,13 @@ from scipy.spatial.distance import pdist from tqdm import tqdm -from ultrack.core.database import NO_PARENT from ultrack.tracks.graph import ( inv_tracks_df_forest, left_first_search, split_trees, tracks_df_forest, ) +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__) diff --git a/ultrack/tracks/stats.py b/ultrack/tracks/stats.py index 379edfe..72847d0 100644 --- a/ultrack/tracks/stats.py +++ b/ultrack/tracks/stats.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd -from ultrack.core.database import NO_PARENT from ultrack.tracks.sorting import sort_track_ids +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 9a46b9d..8ad457f 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -193,3 +193,15 @@ def create_zarr( chunks = large_chunk_size(shape, dtype=dtype) return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) + + +def assert_same_length(**kwargs) -> None: + """Validates if key-word arguments have the same length.""" + for k1, v1 in kwargs.items(): + if v1 is None: + continue + for k2, v2 in kwargs.items(): + if v2 is not None and len(v2) != len(v1): + raise ValueError( + f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}." + ) diff --git a/ultrack/utils/constants.py b/ultrack/utils/constants.py index 0c6257c..8019bf7 100644 --- a/ultrack/utils/constants.py +++ b/ultrack/utils/constants.py @@ -1,3 +1,5 @@ import os +NO_PARENT = -1 + ULTRACK_DEBUG = bool(int(os.environ.get("ULTRACK_DEBUG", False))) diff --git a/ultrack/validation/_test/test_link_validation.py b/ultrack/validation/_test/test_link_validation.py index 7ea2bbb..4dd4eb6 100644 --- a/ultrack/validation/_test/test_link_validation.py +++ b/ultrack/validation/_test/test_link_validation.py @@ -4,7 +4,7 @@ import pandas as pd from napari.viewer import ViewerModel -from ultrack.core.database import NO_PARENT +from ultrack.utils.constants import NO_PARENT from ultrack.validation.link_validation import Annotation, LinkValidation diff --git a/ultrack/validation/link_validation.py b/ultrack/validation/link_validation.py index 8f7a0cb..ccc826a 100644 --- a/ultrack/validation/link_validation.py +++ b/ultrack/validation/link_validation.py @@ -10,7 +10,7 @@ from qtpy.QtGui import QKeySequence from qtpy.QtWidgets import QLabel, QPushButton, QVBoxLayout, QWidget -from ultrack.core.database import NO_PARENT +from ultrack.utils.constants import NO_PARENT LOG = logging.getLogger(__name__)