diff --git a/ultrack/config/config.py b/ultrack/config/config.py index 7216d08..bf8441d 100644 --- a/ultrack/config/config.py +++ b/ultrack/config/config.py @@ -72,3 +72,9 @@ def load_config(path: Union[str, Path]) -> MainConfig: data = toml.load(f) LOG.info(data) return MainConfig.parse_obj(data) + + +def save_config(config, path: Union[str, Path]): + """Saved MainConfig to TOML file.""" + with open(path, mode="w") as f: + toml.dump(config.dict(by_alias=True), f) diff --git a/ultrack/config/dataconfig.py b/ultrack/config/dataconfig.py index 62f07a9..5aa4dcd 100644 --- a/ultrack/config/dataconfig.py +++ b/ultrack/config/dataconfig.py @@ -13,6 +13,7 @@ class DatabaseChoices(Enum): sqlite = "sqlite" postgresql = "postgresql" + memory = "memory" class DataConfig(BaseModel): @@ -32,6 +33,12 @@ class DataConfig(BaseModel): address: Optional[str] = None """``SPECIAL``: Postgresql database path, for example, ``postgres@localhost:12345/example``""" + in_memory_db_id: int = 0 + """ + ``SPECIAL``: Memory database id used to identify the database in memory, + must be altered manually if multiple instances are used + """ + class Config: validate_assignment = True use_enum_values = True @@ -73,6 +80,9 @@ def database_path(self) -> str: if self.database == DatabaseChoices.sqlite.value: return f"sqlite:///{self.working_dir.absolute()}/data.db" + elif self.database == DatabaseChoices.memory.value: + return f"sqlite:///file:{self.in_memory_db_id}?mode=memory&cache=shared&uri=true" + elif self.database == DatabaseChoices.postgresql.value: return f"postgresql://{self.address}" diff --git a/ultrack/core/database.py b/ultrack/core/database.py index 5988e11..5c47f84 100644 --- a/ultrack/core/database.py +++ b/ultrack/core/database.py @@ -118,6 +118,26 @@ class LinkDB(Base): annotation = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN) +class GTNodeDB(Base): + __tablename__ = "gt_nodes" + id = Column(Integer, primary_key=True, autoincrement=True) + t = Column(Integer) + label = Column(Integer) + pickle = Column(MaybePickleType) + z = Column(Float) + y = Column(Float) + x = Column(Float) + + +class GTLinkDB(Base): + __tablename__ = "gt_links" + id = Column(Integer, primary_key=True, autoincrement=True) + source_id = Column(BigInteger, ForeignKey(f"{NodeDB.__tablename__}.id")) + target_id = Column(BigInteger, ForeignKey(f"{GTNodeDB.__tablename__}.id")) + weight = Column(Float) + selected = Column(Boolean, default=False) + + def maximum_time_from_database(data_config: DataConfig) -> int: """Returns the maximum `t` found in the `NodesDB`.""" engine = sqla.create_engine(data_config.database_path) diff --git a/ultrack/core/gt_matching.py b/ultrack/core/gt_matching.py new file mode 100644 index 0000000..f905255 --- /dev/null +++ b/ultrack/core/gt_matching.py @@ -0,0 +1,264 @@ +import logging +from contextlib import nullcontext +from typing import Dict, Optional + +import fasteners +import numpy as np +import pandas as pd +import sqlalchemy as sqla +from numpy.typing import ArrayLike +from skimage.measure import regionprops +from sqlalchemy.orm import Session +from toolz import curry + +from ultrack.config.config import MainConfig +from ultrack.core.database import NO_PARENT, GTLinkDB, GTNodeDB, NodeDB +from ultrack.core.linking.processing import compute_spatial_neighbors +from ultrack.core.segmentation.node import Node +from ultrack.core.solve.sqlgtmatcher import SQLGTMatcher +from ultrack.utils.multiprocessing import ( + multiprocessing_apply, + multiprocessing_sqlite_lock, +) + +LOG = logging.getLogger(__name__) + + +@curry +def _match_ground_truth_frame( + time: int, + gt_labels: ArrayLike, + config: MainConfig, + scale: Optional[ArrayLike], + write_lock: Optional[fasteners.InterProcessLock], +) -> None: + """ + Matches candidate hypotheses to ground-truth labels for a given time point. + Segmentation hypotheses must be pre-computed. + + Parameters + ---------- + time : int + Time point to match. + gt_labels : ArrayLike + Ground-truth labels. + config : MainConfig + Configuration object. + scale : Optional[ArrayLike] + Scale of the data for distance computation. + write_lock : Optional[fasteners.InterProcessLock] + Lock for writing to the database. + """ + gt_labels = np.asarray(gt_labels[time]) + gt_props = regionprops(gt_labels) + + if len(gt_props) == 0: + LOG.warning(f"No objects found in time point {time}") + return + + LOG.info(f"Found {len(gt_props)} objects in time point {time}") + + gt_db_rows = [] + gt_nodes = [] + # adding ground-truth nodes + for obj in gt_props: + node = Node.from_mask( + node_id=obj.label, time=time, mask=obj.image, bbox=obj.bbox + ) + + if len(node.centroid) == 2: + y, x = node.centroid + z = 0 + else: + z, y, x = node.centroid + + gt_db_rows.append( + GTNodeDB( + t=time, + label=obj.label, + pickle=node, + z=z, + y=y, + x=x, + ) + ) + gt_nodes.append(node) + + with write_lock if write_lock is not None else nullcontext(): + engine = sqla.create_engine(config.data_config.database_path) + + with Session(engine) as session: + session.add_all(gt_db_rows) + session.commit() + + source_nodes = [ + n for n, in session.query(NodeDB.pickle).where(NodeDB.t == time) + ] + + engine.dispose() + + compute_spatial_neighbors( + time, + config=config.linking_config, + source_nodes=source_nodes, + target_nodes=gt_nodes, + target_shift=np.zeros((len(gt_nodes), 3), dtype=np.float32), + table_name=GTLinkDB.__tablename__, + db_path=config.data_config.database_path, + scale=scale, + images=[], + write_lock=write_lock, + ) + + # computing GT matching + gt_matcher = SQLGTMatcher(config, write_lock=write_lock) + total_score = gt_matcher(time=time) + + if len(gt_db_rows) > 0: + mean_score = total_score / len(gt_db_rows) + else: + mean_score = 0.0 + + LOG.info(f"time {time} total score: {total_score:0.4f}") + LOG.info(f"time {time} mean score: {mean_score:0.4f}") + + +def _get_nodes_df_with_matches(database_path: str) -> pd.DataFrame: + """ + Gets nodes data frame with matched ground-truth labels. + + Parameters + ---------- + database_path : str + Path to the database file. + + Returns + ------- + pd.DataFrame + DataFrame with matched nodes. + """ + engine = sqla.create_engine(database_path) + + with Session(engine) as session: + node_query = session.query( + NodeDB.id, + NodeDB.hier_parent_id, + NodeDB.t_hier_id, + # NodeDB.area, + # NodeDB.frontier, + ) + node_df = pd.read_sql(node_query.statement, session.bind, index_col="id") + + gt_edge_query = ( + session.query( + GTLinkDB.source_id, + GTLinkDB.target_id, + # GTNodeDB.z, + # GTNodeDB.y, + # GTNodeDB.x, + ).where(GTLinkDB.selected) + # .join(GTNodeDB, GTNodeDB.id == GTLinkDB.target_id) + ) + gt_df = pd.read_sql( + gt_edge_query.statement, session.bind, index_col="source_id" + ) + gt_df.rename( + columns={ + "target_id": "gt_track_id" + }, # , "z": "gt_z", "y": "gt_y", "x": "gt_x"}, + inplace=True, + ) + + LOG.info(f"Found {len(node_df)} nodes and {len(gt_df)} ground-truth links") + + node_df = node_df.join(gt_df) + node_df["gt_track_id"] = node_df["gt_track_id"].fillna(NO_PARENT).astype(int) + + # frontiers = node_df["frontier"] + # node_df["parent_frontier"] = node_df["hier_parent_id"].map( + # lambda x: frontiers.get(x, -1.0) + # ) + # node_df.loc[node_df["parent_frontier"] < 0, "parent_frontier"] = node_df[ + # "frontier" + # ].max() + + return node_df + + +def match_to_ground_truth( + config: MainConfig, + gt_labels: ArrayLike, + scale: Optional[ArrayLike] = None, + track_id_graph: Optional[Dict[int, int]] = None, +) -> pd.DataFrame: + """ + Matches nodes to ground-truth labels returning additional features for automatic parameter tuning. + + Parameters + ---------- + config : MainConfig + Configuration object. + gt_labels : ArrayLike + Ground-truth labels. + scale : Optional[ArrayLike], optional + Scale of the data for distance computation, by default None. + track_id_graph : Optional[Dict[int, int]], optional + Ground-truth graph of track IDs, by default None. + + Returns + ------- + pd.DataFrame + Data frame containing matched ground-truth labels to their respective nodes. + """ + + with multiprocessing_sqlite_lock(config.data_config) as lock: + multiprocessing_apply( + _match_ground_truth_frame( + gt_labels=gt_labels, + config=config, + scale=scale, + write_lock=lock, + ), + range(gt_labels.shape[0]), + n_workers=config.segmentation_config.n_workers, + desc="Matching hierarchy nodes with ground-truth", + ) + + df_nodes = _get_nodes_df_with_matches(config.data_config.database_path) + + if track_id_graph is not None: + df_nodes["gt_parent_track_id"] = df_nodes["gt_track_id"].apply( + lambda x: track_id_graph.get(x, NO_PARENT) + ) + else: + df_nodes["gt_parent_track_id"] = NO_PARENT + + return df_nodes + + # if scale is not None: + # cols = ["z", "y", "x"][-len(scale) :] + # gt_df[cols] *= scale + + # if "z" not in gt_df.columns: + # gt_df["z"] = 0.0 + + # if len(gt_df) > 0: + # max_distance = estimate_drift(gt_df) + # if not np.isnan(max_distance) or max_distance > 0: + # config.linking_config.max_distance = max_distance + 1.0 + + # if "solution" in df.columns: + # matched_df = df[df["solution"] > 0.5] + # config.segmentation_config.min_area = matched_df["area"].min() * 0.95 + + # config.segmentation_config.max_area = matched_df["area"].max() * 1.025 + + # config.segmentation_config.min_frontier = max( + # matched_df["parent_frontier"].min() - 0.025, 0.0 + # ) + # else: + # LOG.warning("No nodes were matched. Keeping previous configuration.") + + # config.data_config.database = prev_db + + # return config, df diff --git a/ultrack/core/solve/sqlgtmatcher.py b/ultrack/core/solve/sqlgtmatcher.py new file mode 100644 index 0000000..5506636 --- /dev/null +++ b/ultrack/core/solve/sqlgtmatcher.py @@ -0,0 +1,152 @@ +import logging +from contextlib import nullcontext +from typing import Literal, Optional + +import fasteners +import mip +import mip.exceptions +import pandas as pd +import sqlalchemy as sqla +from sqlalchemy.orm import Session + +from ultrack.config.config import MainConfig +from ultrack.core.database import GTLinkDB, NodeDB, OverlapDB + +LOG = logging.getLogger(__name__) + + +class SQLGTMatcher: + def __init__( + self, + config: MainConfig, + solver: Literal["CBC", "GUROBI", ""] = "", + write_lock: Optional[fasteners.InterProcessLock] = None, + eps: float = 1e-3, + ) -> None: + # TODO + + self._data_config = config.data_config + self._write_lock = write_lock + self._eps = eps + + try: + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name=solver) + except mip.exceptions.InterfacingError as e: + LOG.warning(e) + self._model = mip.Model(sense=mip.MAXIMIZE, solver_name="CBC") + + def _add_nodes(self, time: int) -> None: + # TODO + engine = sqla.create_engine(self._data_config.database_path) + + with Session(engine) as session: + query = session.query(NodeDB.id, NodeDB.t).where(NodeDB.t == time) + self._nodes_df = pd.read_sql(query.statement, session.bind, index_col="id") + + size = len(self._nodes_df) + self._nodes = self._model.add_var_tensor( + (size,), name="nodes", var_type=mip.BINARY + ) + + # hierarchy overlap constraints + with Session(engine) as session: + query = session.query(OverlapDB).join( + NodeDB, NodeDB.id == OverlapDB.node_id + ) + overlap_df = pd.read_sql(query.statement, session.bind) + + overlap_df["node_id"] = self._nodes_df.index.get_indexer(overlap_df["node_id"]) + overlap_df["ancestor_id"] = self._nodes_df.index.get_indexer( + overlap_df["ancestor_id"] + ) + + for node_id, anc_id in zip(overlap_df["node_id"], overlap_df["ancestor_id"]): + self._model.add_constr(self._nodes[node_id] + self._nodes[anc_id] <= 1) + + def _add_edges(self, time: int) -> None: + # TODO + + if not hasattr(self, "_nodes"): + raise ValueError("Nodes must be added before adding edges.") + + engine = sqla.create_engine(self._data_config.database_path) + + with Session(engine) as session: + + query = ( + session.query(GTLinkDB) + .join(NodeDB, NodeDB.id == GTLinkDB.source_id) + .where(NodeDB.t == time) + ) + + self._edges_df = pd.read_sql(query.statement, session.bind) + + self._edges_df["source_id"] = self._nodes_df.index.get_indexer( + self._edges_df["source_id"] + ) + self._edges_df.reset_index(drop=True, inplace=True) + + self._edges = self._model.add_var_tensor( + (len(self._edges_df),), + name="edges", + var_type=mip.BINARY, + ) + # small value to prefer not selecting edges than bad ones + # setting objective function + self._model.objective = mip.xsum( + (self._edges_df["weight"].to_numpy() - self._eps) * self._edges + ) + + # source_id is time point T (hierarchies id) + # target_id is time point T+1 (ground-truth) + for source_id, group in self._edges_df.groupby("source_id", as_index=False): + self._model.add_constr( + self._nodes[source_id] == mip.xsum(self._edges[group.index.to_numpy()]) + ) + + for _, group in self._edges_df.groupby("target_id", as_index=False): + self._model.add_constr(mip.xsum(self._edges[group.index.to_numpy()]) <= 1) + + def add_solution(self) -> None: + # TODO + engine = sqla.create_engine(self._data_config.database_path) + + edges_records = [] + for idx, e_var in zip(self._edges_df["id"], self._edges): + if e_var.x > 0.5: + edges_records.append( + { + "link_id": idx, + "selected": e_var.x > 0.5, + } + ) + + LOG.info(f"Selected {len(edges_records)} edges to ground-truth") + + with self._write_lock if self._write_lock is not None else nullcontext(): + with Session(engine) as session: + stmt = ( + sqla.update(GTLinkDB) + .where(GTLinkDB.id == sqla.bindparam("link_id")) + .values(selected=sqla.bindparam("selected")) + ) + session.connection().execute( + stmt, + edges_records, + execution_options={"synchronize_session": False}, + ) + session.commit() + + def __call__(self, time: int) -> float: + # TODO + + LOG.info(f"Computing GT matching for time {time}") + + self._add_nodes(time) + self._add_edges(time) + self._model.optimize() + self.add_solution() + + n_selected_vars = sum(e_var.x > 0.5 for e_var in self._edges) + + return self._model.objective_value + n_selected_vars * self._eps diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index 5a3a3df..8c8cf1b 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -16,6 +16,7 @@ tracks_layer_to_trackmate, tracks_to_zarr, ) +from ultrack.core.gt_matching import match_to_ground_truth from ultrack.core.linking.processing import add_links, link from ultrack.core.main import track from ultrack.core.segmentation.processing import get_nodes_features, segment @@ -171,6 +172,11 @@ def add_links(self, **kwargs) -> None: add_links(config=self.config, **kwargs) self.status |= TrackerStatus.LINKED + @functools.wraps(match_to_ground_truth) + def match_to_ground_truth(self, **kwargs) -> pd.DataFrame: + self._assert_segmented("match_to_ground_truth") + return match_to_ground_truth(config=self.config, **kwargs) + @functools.wraps(add_nodes_prob) def add_nodes_prob( self, diff --git a/ultrack/napari.yaml b/ultrack/napari.yaml index 26e709e..4ba389f 100644 --- a/ultrack/napari.yaml +++ b/ultrack/napari.yaml @@ -17,6 +17,10 @@ contributions: python_name: ultrack.widgets:UltrackWidget title: Ultrack + - id: ultrack.hierarchy_viz_widget + python_name: ultrack.widgets:HierarchyVizWidget + title: Hierarchy visualization + ###### DEPRECATED & WIP WIDGETS ##### # - id: ultrack.labels_to_edges_widget # python_name: ultrack.widgets:LabelsToContoursWidget @@ -67,5 +71,8 @@ contributions: # - command: ultrack.division_annotation_widget # display_name: Division annotation - # - command: ultrack.track_inspection - # display_name: Track inspection + - command: ultrack.track_inspection + display_name: Track inspection + + - command: ultrack.hierarchy_viz_widget + display_name: Hierarchy visualization diff --git a/ultrack/utils/_test/test_utils_array.py b/ultrack/utils/_test/test_utils_array.py index c70e7e1..76d77e3 100644 --- a/ultrack/utils/_test/test_utils_array.py +++ b/ultrack/utils/_test/test_utils_array.py @@ -1,8 +1,10 @@ import numpy as np import pytest +from typing import Tuple from ultrack.utils.array import array_apply - +from ultrack.config import MainConfig +from ultrack.utils.array import UltrackArray @pytest.mark.parametrize("axis", [0, 1]) def test_array_apply_parametrized(axis): @@ -17,3 +19,31 @@ def sample_func(arr_1, arr_2): array_apply(in_data, in_data, out_array=out_data, func=sample_func, axis=axis) other_axes_length = in_data.shape[1 - axis] assert np.array_equal(out_data, 2 * in_data + other_axes_length) + +@pytest.mark.parametrize( + "key,timelapse_mock_data", + [ + (1,{'n_dim':3}), + (1,{'n_dim':2}), + ((slice(None), 1),{'n_dim':3}), + ((slice(None), 1),{'n_dim':2}), + ((0, [1, 2]),{'n_dim':3}), + ((0, [1, 2]),{'n_dim':2}), + # ((-1, np.asarray([0, 3])),{'n_dim':3}), #does testing negative time make sense? + # ((-1, np.asarray([0, 3])),{'n_dim':2}), + ((slice(1), -2),{'n_dim':3}), + ((slice(1), -2),{'n_dim':2}), + ((np.asarray(0),),{'n_dim':3}), + ((np.asarray(0),),{'n_dim':2}), + ((0, 0, slice(32)),{'n_dim':3}), + ((0, 0, slice(32)),{'n_dim':2}), + ], + indirect=["timelapse_mock_data",], +) +def test_ultrack_array( + segmentation_database_mock_data: MainConfig, + key: Tuple, + ): + ua = UltrackArray(segmentation_database_mock_data) + ua_numpy = ua[slice(None)] + np.testing.assert_equal(ua_numpy[key], ua[key]) diff --git a/ultrack/utils/array.py b/ultrack/utils/array.py index 3a49fe1..6df876c 100644 --- a/ultrack/utils/array.py +++ b/ultrack/utils/array.py @@ -6,11 +6,16 @@ from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union import numpy as np +import sqlalchemy as sqla import zarr from numpy.typing import ArrayLike +from sqlalchemy.orm import Session from tqdm import tqdm from zarr.storage import Store +from ultrack.core.database import NodeDB +from ultrack.config import MainConfig + LOG = logging.getLogger(__name__) @@ -213,6 +218,180 @@ def create_zarr( return zarr.zeros(shape, dtype=dtype, store=store, chunks=chunks, **kwargs) +class UltrackArray: + def __init__( + self, + config: MainConfig, + dtype: np.dtype = np.int32, + ): + """Create an array that directly visualizes the segments in the ultrack database. + + Parameters + ---------- + config : MainConfig + Configuration file of Ultrack. + dtype : np.dtype + Data type of the array. + """ + + self.config = config + self.shape = tuple(config.data_config.metadata["shape"]) # (t,(z),y,x) + self.dtype = dtype + self.t_max = self.shape[0] + self.ndim = len(self.shape) + self.array = np.zeros(self.shape[1:], dtype=self.dtype) + + self.database_path = config.data_config.database_path + self.minmax = self.find_min_max_volume_entire_dataset() + self.volume = self.minmax.mean().astype(int) + + def __getitem__(self, + indexing: Union[Tuple[Union[int, slice]], int, slice], + ) -> np.ndarray: + """Indexing the ultrack-array + + Parameters + ---------- + indexing : Tuple or Array + + Returns + ------- + array : numpy array + array with painted segments + """ + # print('indexing in getitem:',indexing) + + if isinstance(indexing, tuple): + time, volume_slicing = indexing[0], indexing[1:] + else: #if only 1 (time) is provided + time = indexing + volume_slicing = tuple() + + if isinstance(time, slice): #if all time points are requested + return np.stack([ + self.__getitem__((t,) + volume_slicing) + for t in range(*time.indices(self.shape[0])) + ]) + else: + try: + time = time.item() # convert from numpy.int to int + except AttributeError: + time = time + + self.fill_array( + time=time, + ) + + return self.array[volume_slicing] + + def fill_array( + self, + time: int, + ) -> None: + """Paint all segments of specific time point which volume is bigger than self.volume + Parameters + ---------- + time : int + time point to paint the segments + """ + + engine = sqla.create_engine(self.database_path) + self.array.fill(0) + + with Session(engine) as session: + query = list( + session.query(NodeDB.id, NodeDB.pickle, NodeDB.hier_parent_id).where( + NodeDB.t == time + ) + ) + + idx_to_plot = [] + + for idx, q in enumerate(query): + if q[1].area <= self.volume: + idx_to_plot.append(idx) + + id_to_plot = [q[0] for idx, q in enumerate(query) if idx in idx_to_plot] + label_list = np.arange(1, len(query) + 1, dtype=int) + + to_remove = [] + for idx in idx_to_plot: + if query[idx][2] in id_to_plot: # if parent is also printed + to_remove.append(idx) + + for idx in to_remove: + idx_to_plot.remove(idx) + + if len(query) == 0: + print("query is empty!") + + for idx in idx_to_plot: + query[idx][1].paint_buffer( + self.array, value=label_list[idx], include_time=False + ) + + def get_tp_num_pixels( + self, + timeStart:int, + timeStop:int, + ) -> list: + """Gets a list of number of pixels of all segments range of time points (timeStart to timeStop) + Parameters + ---------- + timeStart : int + timeStop : int + Returns + ------- + num_pix_list : list + list with all num_pixels for timeStart to timeStop + """ + engine = sqla.create_engine(self.database_path) + num_pix_list = [] + with Session(engine) as session: + query = list(session.query(NodeDB.area).where(NodeDB.t >= timeStart).where(NodeDB.t <= timeStop)) + for num_pix in query: + num_pix_list.append(int(np.array(num_pix))) + return num_pix_list + + def get_tp_num_pixels_minmax( + self, + time: int, + ) -> np.ndarray: + """Find minimum and maximum segment volume for single time point + + Parameters + ---------- + time : int + + Returns + ------- + num_pix_list : list + array with two elements: [min_volume, max_volume] + """ + num_pix_list = self.get_tp_num_pixels(time,time) + return (min(num_pix_list), max(num_pix_list)) + + + def find_min_max_volume_entire_dataset(self): + """Find minimum and maximum segment volume for ALL time point + + Returns + ------- + np.array : np.array + array with two elements: [min_volume, max_volume] + """ + min_vol = np.inf + max_vol = 0 + for t in range(self.t_max): #range(self.shape[0]): + minmax = self.get_tp_num_pixels_minmax(t) + if minmax[0] < min_vol: + min_vol = minmax[0] + if minmax[1] > max_vol: + max_vol = minmax[1] + + return np.array([min_vol, max_vol], dtype=int) + + def assert_same_length(**kwargs) -> None: """Validates if key-word arguments have the same length.""" for k1, v1 in kwargs.items(): diff --git a/ultrack/utils/segmentation.py b/ultrack/utils/segmentation.py index b29cf63..6d330cc 100644 --- a/ultrack/utils/segmentation.py +++ b/ultrack/utils/segmentation.py @@ -266,6 +266,6 @@ def copy_segments( ) # not sure why this is necessary in large datasets else: for t in tqdm(range(segments.shape[0]), "Copying segments"): - out_segments[t] = segments[t] + out_segments[t] = np.asarray(segments[t]) return out_segments diff --git a/ultrack/widgets/__init__.py b/ultrack/widgets/__init__.py index fc9dedc..9338644 100644 --- a/ultrack/widgets/__init__.py +++ b/ultrack/widgets/__init__.py @@ -1,5 +1,6 @@ from ultrack.widgets.division_annotation_widget import DivisionAnnotationWidget from ultrack.widgets.hypotheses_viz_widget import HypothesesVizWidget +from ultrack.widgets.hierarchy_viz_widget import HierarchyVizWidget from ultrack.widgets.labels_to_edges_widget import LabelsToContoursWidget from ultrack.widgets.node_annotation_widget import NodeAnnotationWidget from ultrack.widgets.track_inspection_widget import TrackInspectionWidget diff --git a/ultrack/widgets/_test/test_hierarchy_viz_widget.py b/ultrack/widgets/_test/test_hierarchy_viz_widget.py new file mode 100644 index 0000000..889ff5c --- /dev/null +++ b/ultrack/widgets/_test/test_hierarchy_viz_widget.py @@ -0,0 +1,95 @@ +from typing import Callable, Tuple + +import napari +import numpy as np +import pytest +import zarr + +from ultrack.config import MainConfig +from ultrack.widgets.ultrackwidget import UltrackWidget +from ultrack.widgets.hierarchy_viz_widget import HierarchyVizWidget +from ultrack.widgets.ultrackwidget.workflows import WorkflowChoice +from ultrack.widgets.ultrackwidget.utils import UltrackInput + + + + + +def test_hierarchy_viz_widget( + make_napari_viewer: Callable[[],napari.Viewer], + segmentation_database_mock_data: MainConfig, + timelapse_mock_data: Tuple[zarr.Array, zarr.Array, zarr.Array], + request, + ): + + #################################################################################### + #OPTION 1: run widget using config + #################################################################################### + config = segmentation_database_mock_data + viewer = make_napari_viewer() + widget = HierarchyVizWidget(viewer,config) + viewer.window.add_dock_widget(widget) + + assert "hierarchy" in viewer.layers + + #test moving sliders: + widget._slider_update(0.75) + widget._slider_update(0.25) + + #test is shape of layer.data has same shape as the data shape reported in config: + assert tuple(config.data_config.metadata["shape"]) == viewer.layers['hierarchy'].data.shape #metadata["shape"] is a list, data.shape is a tuple + + + #################################################################################### + #OPTION 2: run widget by taking config from Ultrack-widget + #################################################################################### + #make napari viewer + viewer2 = make_napari_viewer() + + #get mock segmentation data + add to viewer + segments = timelapse_mock_data[2] + print('segments shape',segments.shape) + viewer2.add_labels(segments,name='segments') + + #open ultrack widget + widget_ultrack = UltrackWidget(viewer2) + viewer2.window.add_dock_widget(widget_ultrack) + + #setup ultrack widget for 'Labels' input + layers = viewer2.layers + workflow = WorkflowChoice.AUTO_FROM_LABELS + workflow_idx = widget_ultrack._cb_workflow.findData(workflow) + widget_ultrack._cb_workflow.setCurrentIndex(workflow_idx) + widget_ultrack._cb_workflow.currentIndexChanged.emit(workflow_idx) + # setting combobox choices manually, because they were not working automatically + widget_ultrack._cb_images[UltrackInput.LABELS].choices = layers + # # selecting layers + widget_ultrack._cb_images[UltrackInput.LABELS].value = layers['segments'] + + #load config + widget_ultrack._data_forms.load_config(config) + + + widget_hier = HierarchyVizWidget(viewer2) + viewer2.window.add_dock_widget(widget_hier) + + assert "hierarchy" in viewer.layers + + #test moving sliders: + widget._slider_update(0.75) + widget._slider_update(0.25) + + # test is shape of layer.data has same shape as the data shape reported in config: + assert tuple(config.data_config.metadata["shape"]) == viewer2.layers['hierarchy'].data.shape #metadata["shape"] is a list, data.shape in layer is a tuple + + #################################################################################### + #TO DO: + # - test other datatypes than labels (contours, detection, image, etc.), but shouldn't make a difference + # - how to test that the 'hierarchy' layer actually shows data? (apart from the layer having a .data property) + #################################################################################### + + + if request.config.getoption("--show-napari-viewer"): + napari.run() + + \ No newline at end of file diff --git a/ultrack/widgets/hierarchy_viz_widget.py b/ultrack/widgets/hierarchy_viz_widget.py new file mode 100644 index 0000000..036f962 --- /dev/null +++ b/ultrack/widgets/hierarchy_viz_widget.py @@ -0,0 +1,100 @@ +import logging +from typing import List, Optional, Sequence + +import napari +import numpy as np +from scipy import interpolate +from magicgui.widgets import FloatSlider, Container, Label + +from ultrack.utils.array import UltrackArray +from ultrack.config import MainConfig +from ultrack.widgets.ultrackwidget import UltrackWidget + +logging.basicConfig() +logging.getLogger("sqlachemy.engine").setLevel(logging.INFO) + +LOG = logging.getLogger(__name__) + + +class HierarchyVizWidget(Container): + def __init__(self, + viewer: napari.Viewer, + config = None, + ) -> None: + """ + Initialize the HierarchyVizWidget. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer instance. + config : MainConfig of Ultrack + if not provided, config will be taken from UltrackWidget + """ + + super().__init__(layout='horizontal') + + self._viewer = viewer + + if config is None: + self.config = self._get_config() + else: + self.config = config + + self.ultrack_array = UltrackArray(self.config) + + self.mapping = self._create_mapping() + + self._area_threshold_w = FloatSlider(label="Area", min=0, max=1, readout=False) + self._area_threshold_w.value = 0.5 + self.ultrack_array.volume = self.mapping(0.5) + self._area_threshold_w.changed.connect(self._slider_update) + + self.slider_label = Label(label=str(int(self.mapping(self._area_threshold_w.value)))) + self.slider_label.native.setFixedWidth(25) + + self.append(self._area_threshold_w) + self.append(self.slider_label) + + #THERE SHOULD BE CHECK HERE IF THERE EXISTS A LAYER WITH THE NAME 'HIERARCHY' + self._viewer.add_labels(self.ultrack_array, name='hierarchy') + self._viewer.layers['hierarchy'].refresh() + + def _on_config_changed(self) -> None: + self._ndim = len(self._shape) + + @property + def _shape(self) -> Sequence[int]: + return self.config.metadata.get("shape", []) + + def _slider_update(self, value: float) -> None: + self.ultrack_array.volume = self.mapping(value) + self.slider_label.label = str(int(self.mapping(value))) + self._viewer.layers['hierarchy'].refresh() + + def _create_mapping(self): + """ + Creates a pseudo-linear mapping from U[0,1] to full range of number of pixels + num_pixels = mapping([0,1]) + """ + num_pixels_list = self.ultrack_array.get_tp_num_pixels(timeStart=5,timeStop=5) + num_pixels_list.append(self.ultrack_array.minmax[0]) + num_pixels_list.append(self.ultrack_array.minmax[1]) + num_pixels_list.sort() + + x_vec = np.linspace(0,1,len(num_pixels_list)) + y_vec = np.array(num_pixels_list) + mapping = interpolate.interp1d(x_vec,y_vec) + return mapping + + def _get_config(self) -> MainConfig: + """ + Gets config from the Ultrack widget + """ + ultrack_widget = UltrackWidget.find_ultrack_widget(self._viewer) + if ultrack_widget is None: + raise TypeError( + "config not provided and was not found within ultrack widget" + ) + + return ultrack_widget._data_forms.get_config() \ No newline at end of file diff --git a/ultrack/widgets/ultrackwidget/ultrackwidget.py b/ultrack/widgets/ultrackwidget/ultrackwidget.py index 4f95fac..e2ebd3d 100644 --- a/ultrack/widgets/ultrackwidget/ultrackwidget.py +++ b/ultrack/widgets/ultrackwidget/ultrackwidget.py @@ -1,7 +1,7 @@ import logging import webbrowser from contextlib import redirect_stderr, redirect_stdout -from typing import Any, Generator +from typing import Any, Generator, Optional import napari import qtawesome as qta @@ -666,6 +666,15 @@ def _cancel(self): self._current_worker.quit() self._bt_cancel.setEnabled(False) + @staticmethod + def find_ultrack_widget(viewer: napari.Viewer) -> Optional["UltrackWidget"]: + """Find and returns Ultrack widget. If widget not found, returns None""" + for _, w in viewer.window._dock_widgets.items(): + if isinstance(w.widget(), UltrackWidget): + return w.widget() + + return None + if __name__ == "__main__": from napari import Viewer