diff --git a/tox.ini b/tox.ini index 3b6d49c..be330a9 100644 --- a/tox.ini +++ b/tox.ini @@ -23,7 +23,9 @@ BACKEND = [testenv] description = Run unit-testing -passenv = DISPLAY XAUTHORITY +passenv = + DISPLAY + XAUTHORITY deps = pytest coverage diff --git a/ultrack/cli/_test/test_cli.py b/ultrack/cli/_test/test_cli.py index 7220371..995c885 100644 --- a/ultrack/cli/_test/test_cli.py +++ b/ultrack/cli/_test/test_cli.py @@ -90,6 +90,8 @@ def test_ctc_export(self, instance_config_path: str, tmp_path: Path) -> None: "0.5,1,1", "-ma", "5", + "-di", + "1", "-o", str(tmp_path / "01_RES"), ] diff --git a/ultrack/cli/export.py b/ultrack/cli/export.py index 87408d2..22264cd 100644 --- a/ultrack/cli/export.py +++ b/ultrack/cli/export.py @@ -47,6 +47,14 @@ show_default=True, help="Optional first frame path used to select a subset of lineages connected to this reference annotations.", ) +@click.option( + "--dilation-iters", + "-di", + default=0, + type=int, + show_default=True, + help="Iterations of radius 1 morphological dilations on labels, applied after scaling.", +) @click.option( "--stitch-tracks", default=False, @@ -60,6 +68,7 @@ def ctc_cli( margin: int, scale: Optional[Tuple[float]], first_frame_path: Optional[Path], + dilation_iters: int, stitch_tracks: bool, overwrite: bool, ) -> None: @@ -73,11 +82,12 @@ def ctc_cli( to_ctc( output_directory, config.data_config, - margin, - scale, - first_frame, - stitch_tracks, - overwrite, + margin=margin, + scale=scale, + first_frame=first_frame, + dilation_iters=dilation_iters, + stitch_tracks=stitch_tracks, + overwrite=overwrite, ) diff --git a/ultrack/core/export/ctc.py b/ultrack/core/export/ctc.py index 14106a2..70281fc 100644 --- a/ultrack/core/export/ctc.py +++ b/ultrack/core/export/ctc.py @@ -6,8 +6,10 @@ import pandas as pd import sqlalchemy as sqla from numpy.typing import ArrayLike -from scipy.ndimage import zoom +from scipy.ndimage import generate_binary_structure, grey_dilation, zoom +from scipy.optimize import linear_sum_assignment from scipy.spatial import KDTree +from scipy.spatial.distance import cdist from skimage.measure import regionprops from skimage.segmentation import relabel_sequential from sqlalchemy.orm import Session @@ -211,19 +213,20 @@ def select_tracks_from_first_frame( query = session.query(NodeDB.pickle).where(NodeDB.t == 0, NodeDB.selected) starting_nodes = [n for n, in query] - selected_track_ids = set() - centroids = np.asarray([n.centroid for n in starting_nodes]) + root_centroids = np.asarray([n.centroid for n in starting_nodes]) + marker_centroids = np.asarray( + [n.centroid for n in regionprops(first_frame, cache=False)] + ) + D = cdist(marker_centroids, root_centroids) + + _, root_ids = linear_sum_assignment(D) + selected_track_ids = set() graph = tracks_forest(df) - for det in tqdm(regionprops(first_frame), "Selecting tracks from first trame"): - # select nearest node that contains the reference detection. - dist = np.square(centroids - det.centroid).sum(axis=1) - node = starting_nodes[np.argmin(dist)] - if node.contains(det.centroid): - # add the whole tree to the selection - track_id = df.loc[node.id, "track_id"] - selected_track_ids.update(connected_component(graph, track_id)) + for root in tqdm(root_ids, "Selecting tracks from first trame"): + track_id = df.loc[starting_nodes[root].id, "track_id"] + selected_track_ids.update(connected_component(graph, track_id)) if stitch_tracks: selected_df = stitch_tracks_df(graph, df, selected_track_ids) @@ -252,6 +255,7 @@ def _write_tiff_buffer( buffer: np.ndarray, output_dir: Path, scale: Optional[ArrayLike] = None, + dilation_iters: int = 0, ) -> None: """Writes a single tiff stack into `output_dir` / "mask%03d.tif" @@ -265,12 +269,19 @@ def _write_tiff_buffer( Output directory. scale : Optional[ArrayLike], optional Mask rescaling factor, by default None + dilation_iters: int + Iterations of radius 1 morphological dilations on labels, applied after scaling, by default 0. """ if scale is not None: buffer = zoom( buffer, scale[-buffer.ndim :], order=0, grid_mode=True, mode="grid-constant" ) + footprint = generate_binary_structure(buffer.ndim, 1) + for _ in range(dilation_iters): + dilated = grey_dilation(buffer, footprint=footprint) + np.putmask(buffer, buffer == 0, dilated) + imwrite(output_dir / f"mask{t:03}.tif", buffer) @@ -280,6 +291,7 @@ def to_ctc( margin: int = 0, scale: Optional[Tuple[float]] = None, first_frame: Optional[ArrayLike] = None, + dilation_iters: int = 0, stitch_tracks: bool = False, overwrite: bool = False, ) -> None: @@ -298,6 +310,8 @@ def to_ctc( Margin used to filter out nodes and splitting their tracklets first_frame : Optional[ArrayLike], optional Optional first frame detection mask to select a subset of tracks (e.g. Fluo-N3DL-DRO), by default None + dilation_iters: int + Iterations of radius 1 morphological dilations on labels, applied after scaling, by default 0. stitch_tracks: bool, optional Stitches (connects) incomplete tracks nearby tracks on subsequent time point, by default False overwrite : bool, optional @@ -372,7 +386,9 @@ def condition(node: Node) -> bool: export_segmentation_generic( data_config, df, - _write_tiff_buffer(scale=scale, output_dir=output_dir), + _write_tiff_buffer( + output_dir=output_dir, scale=scale, dilation_iters=dilation_iters + ), )