diff --git a/ultrack/core/tracker.py b/ultrack/core/tracker.py index f3d21a3..a285341 100644 --- a/ultrack/core/tracker.py +++ b/ultrack/core/tracker.py @@ -21,6 +21,7 @@ from ultrack.core.segmentation.processing import segment from ultrack.core.solve.processing import solve from ultrack.imgproc.flow import add_flow +from ultrack.ml.classification import add_nodes_prob from ultrack.utils.deprecation import rename_argument @@ -153,3 +154,11 @@ def to_tracks_layer(self, *args, **kwargs) -> Tuple[pd.DataFrame, Dict]: def export_by_extension(self, filename: str, overwrite: bool = False) -> None: self._assert_solved() export_tracks_by_extension(self.config, filename, overwrite=overwrite) + + @functools.wraps(add_nodes_prob) + def add_nodes_prob( + self, + indices: ArrayLike, + probs: ArrayLike, + ) -> None: + add_nodes_prob(self.config, indices, probs) diff --git a/ultrack/ml/__init__.py b/ultrack/ml/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ultrack/ml/classification.py b/ultrack/ml/classification.py new file mode 100644 index 0000000..de098fa --- /dev/null +++ b/ultrack/ml/classification.py @@ -0,0 +1,28 @@ +from numpy.typing import ArrayLike + +from ultrack.config.config import MainConfig +from ultrack.core.database import set_node_values + + +def add_nodes_prob( + config: MainConfig, + indices: ArrayLike, + probs: ArrayLike, +) -> None: + """ + Add nodes' probabilities to the segmentation/tracking database. + + Parameters + ---------- + config : MainConfig + Main configuration parameters. + indices : ArrayLike + Nodes' indices database index. + probs : ArrayLike + Nodes' probabilities. + """ + set_node_values( + config.data_config, + indices, + node_prob=probs, + )