Skip to content

Commit

Permalink
Update ctc track selection to optimal matching (#16)
Browse files Browse the repository at this point in the history
* update ctc track selection to optimal matching

* fix passenv DISPLAY and XAUTHORIY variables

* add ctc label dilation option
  • Loading branch information
JoOkuma authored Jan 5, 2023
1 parent 4221d5a commit 26c5dfd
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 18 deletions.
4 changes: 3 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ BACKEND =

[testenv]
description = Run unit-testing
passenv = DISPLAY XAUTHORITY
passenv =
DISPLAY
XAUTHORITY
deps =
pytest
coverage
Expand Down
2 changes: 2 additions & 0 deletions ultrack/cli/_test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_ctc_export(self, instance_config_path: str, tmp_path: Path) -> None:
"0.5,1,1",
"-ma",
"5",
"-di",
"1",
"-o",
str(tmp_path / "01_RES"),
]
Expand Down
20 changes: 15 additions & 5 deletions ultrack/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
show_default=True,
help="Optional first frame path used to select a subset of lineages connected to this reference annotations.",
)
@click.option(
"--dilation-iters",
"-di",
default=0,
type=int,
show_default=True,
help="Iterations of radius 1 morphological dilations on labels, applied after scaling.",
)
@click.option(
"--stitch-tracks",
default=False,
Expand All @@ -60,6 +68,7 @@ def ctc_cli(
margin: int,
scale: Optional[Tuple[float]],
first_frame_path: Optional[Path],
dilation_iters: int,
stitch_tracks: bool,
overwrite: bool,
) -> None:
Expand All @@ -73,11 +82,12 @@ def ctc_cli(
to_ctc(
output_directory,
config.data_config,
margin,
scale,
first_frame,
stitch_tracks,
overwrite,
margin=margin,
scale=scale,
first_frame=first_frame,
dilation_iters=dilation_iters,
stitch_tracks=stitch_tracks,
overwrite=overwrite,
)


Expand Down
40 changes: 28 additions & 12 deletions ultrack/core/export/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import pandas as pd
import sqlalchemy as sqla
from numpy.typing import ArrayLike
from scipy.ndimage import zoom
from scipy.ndimage import generate_binary_structure, grey_dilation, zoom
from scipy.optimize import linear_sum_assignment
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -211,19 +213,20 @@ def select_tracks_from_first_frame(
query = session.query(NodeDB.pickle).where(NodeDB.t == 0, NodeDB.selected)
starting_nodes = [n for n, in query]

selected_track_ids = set()
centroids = np.asarray([n.centroid for n in starting_nodes])
root_centroids = np.asarray([n.centroid for n in starting_nodes])
marker_centroids = np.asarray(
[n.centroid for n in regionprops(first_frame, cache=False)]
)
D = cdist(marker_centroids, root_centroids)

_, root_ids = linear_sum_assignment(D)

selected_track_ids = set()
graph = tracks_forest(df)

for det in tqdm(regionprops(first_frame), "Selecting tracks from first trame"):
# select nearest node that contains the reference detection.
dist = np.square(centroids - det.centroid).sum(axis=1)
node = starting_nodes[np.argmin(dist)]
if node.contains(det.centroid):
# add the whole tree to the selection
track_id = df.loc[node.id, "track_id"]
selected_track_ids.update(connected_component(graph, track_id))
for root in tqdm(root_ids, "Selecting tracks from first trame"):
track_id = df.loc[starting_nodes[root].id, "track_id"]
selected_track_ids.update(connected_component(graph, track_id))

if stitch_tracks:
selected_df = stitch_tracks_df(graph, df, selected_track_ids)
Expand Down Expand Up @@ -252,6 +255,7 @@ def _write_tiff_buffer(
buffer: np.ndarray,
output_dir: Path,
scale: Optional[ArrayLike] = None,
dilation_iters: int = 0,
) -> None:
"""Writes a single tiff stack into `output_dir` / "mask%03d.tif"
Expand All @@ -265,12 +269,19 @@ def _write_tiff_buffer(
Output directory.
scale : Optional[ArrayLike], optional
Mask rescaling factor, by default None
dilation_iters: int
Iterations of radius 1 morphological dilations on labels, applied after scaling, by default 0.
"""
if scale is not None:
buffer = zoom(
buffer, scale[-buffer.ndim :], order=0, grid_mode=True, mode="grid-constant"
)

footprint = generate_binary_structure(buffer.ndim, 1)
for _ in range(dilation_iters):
dilated = grey_dilation(buffer, footprint=footprint)
np.putmask(buffer, buffer == 0, dilated)

imwrite(output_dir / f"mask{t:03}.tif", buffer)


Expand All @@ -280,6 +291,7 @@ def to_ctc(
margin: int = 0,
scale: Optional[Tuple[float]] = None,
first_frame: Optional[ArrayLike] = None,
dilation_iters: int = 0,
stitch_tracks: bool = False,
overwrite: bool = False,
) -> None:
Expand All @@ -298,6 +310,8 @@ def to_ctc(
Margin used to filter out nodes and splitting their tracklets
first_frame : Optional[ArrayLike], optional
Optional first frame detection mask to select a subset of tracks (e.g. Fluo-N3DL-DRO), by default None
dilation_iters: int
Iterations of radius 1 morphological dilations on labels, applied after scaling, by default 0.
stitch_tracks: bool, optional
Stitches (connects) incomplete tracks nearby tracks on subsequent time point, by default False
overwrite : bool, optional
Expand Down Expand Up @@ -372,7 +386,9 @@ def condition(node: Node) -> bool:
export_segmentation_generic(
data_config,
df,
_write_tiff_buffer(scale=scale, output_dir=output_dir),
_write_tiff_buffer(
output_dir=output_dir, scale=scale, dilation_iters=dilation_iters
),
)


Expand Down

0 comments on commit 26c5dfd

Please sign in to comment.