Skip to content

Commit

Permalink
Add nodes' probabilities support to cost function (#143)
Browse files Browse the repository at this point in the history
* added nodes' probabilities support to cost function

* improved assert same length usage

* add option to update nodes' probabilities

* fixing circular import by moving NO_PARENTS to constants files
  • Loading branch information
JoOkuma authored Oct 2, 2024
1 parent 6ddaf9f commit 5ca946d
Show file tree
Hide file tree
Showing 34 changed files with 214 additions and 74 deletions.
3 changes: 2 additions & 1 deletion ultrack/cli/data_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from ultrack.cli.utils import config_option
from ultrack.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB
from ultrack.core.database import LinkDB, NodeDB
from ultrack.core.export.utils import solution_dataframe_from_sql
from ultrack.tracks.graph import add_track_ids_to_tracks_df
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.printing import pretty_print_df


Expand Down
2 changes: 1 addition & 1 deletion ultrack/config/trackingconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TrackingConfig(BaseModel):
"""``SPECIAL``: Solver method, `reference <https://docs.python-mip.com/en/latest/classes.html#lp-method>`_"""

link_function: LinkFunctionChoices = "power"
"""``SPECIAL``: Function used to transform the edge weights, `identity` or `power`"""
"""``SPECIAL``: Function used to transform the edge and node weights, `identity` or `power`"""

power: float = 4
r"""``SPECIAL``: Expoent :math:`\eta` of power transform, :math:`w_{pq}^\eta` """
Expand Down
8 changes: 5 additions & 3 deletions ultrack/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from sqlalchemy.orm import Session, declarative_base

from ultrack.config.dataconfig import DatabaseChoices, DataConfig

# constant value to indicate it has no parent
NO_PARENT = -1
from ultrack.utils.array import assert_same_length
from ultrack.utils.constants import NO_PARENT

Base = declarative_base()

Expand Down Expand Up @@ -93,6 +92,7 @@ class NodeDB(Base):
area = Column(Integer)
selected = Column(Boolean, default=False)
pickle = Column(MaybePickleType)
node_prob = Column(Float, default=-1.0)
segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN)
node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
appear_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
Expand Down Expand Up @@ -182,6 +182,8 @@ def set_node_values(
if hasattr(v, "tolist"):
kwargs[k] = v.tolist()

assert_same_length(**kwargs)

records = [
{k: v[i] for k, v in kwargs.items()} for i in range(len(kwargs["node_id"]))
]
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export.ctc import ctc_compress_forest, stitch_tracks_df
from ultrack.tracks.graph import add_track_ids_to_tracks_df
from ultrack.utils.constants import NO_PARENT


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_networkx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export import tracks_layer_to_networkx
from ultrack.utils.constants import NO_PARENT


@pytest.mark.parametrize("children_to_parent", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export.trackmate import tracks_layer_to_trackmate
from ultrack.utils.constants import NO_PARENT

pytrackmate = pytest.importorskip("pytrackmate")

Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/export/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ultrack.config.config import MainConfig
from ultrack.config.dataconfig import DataConfig
from ultrack.core.database import NO_PARENT, NodeDB
from ultrack.core.database import NodeDB
from ultrack.core.export.utils import (
export_segmentation_generic,
filter_nodes_generic,
Expand All @@ -32,6 +32,7 @@
tracks_df_forest,
)
from ultrack.tracks.stats import estimate_drift
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.data import validate_and_overwrite_path

logging.basicConfig()
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pandas as pd

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.export.tracks_layer import to_tracks_layer
from ultrack.tracks.graph import _create_tracks_forest
from ultrack.utils.constants import NO_PARENT

LOG = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.export.tracks_layer import to_tracks_layer
from ultrack.utils.constants import NO_PARENT


def _set_filter_elem(elem: ET.Element) -> None:
Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from toolz import curry

from ultrack.config.dataconfig import DataConfig
from ultrack.core.database import NO_PARENT, NodeDB
from ultrack.core.database import NodeDB
from ultrack.core.segmentation.node import Node
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.multiprocessing import multiprocessing_apply

LOG = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/linking/_test/test_link_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from sqlalchemy.orm import Session

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB
from ultrack.core.database import LinkDB, NodeDB
from ultrack.core.linking.utils import clear_linking_data
from ultrack.utils.constants import NO_PARENT


@pytest.mark.parametrize(
Expand Down
5 changes: 2 additions & 3 deletions ultrack/core/solve/_test/test_sql_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from ultrack import solve, to_tracks_layer
from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB, VarAnnotation
from ultrack.core.database import LinkDB, NodeDB, VarAnnotation
from ultrack.core.solve.sqltracking import SQLTracking
from ultrack.utils.constants import NO_PARENT

_CONFIG_PARAMS = {
"segmentation.n_workers": 4,
Expand Down Expand Up @@ -125,7 +126,6 @@ def test_annotations_sql_tracking(

solve(config, overwrite=True, use_annotations=True)
tracks_df, _ = to_tracks_layer(config)
print(tracks_df)

engine = sqla.create_engine(config.data_config.database_path)
with Session(engine) as session:
Expand All @@ -136,6 +136,5 @@ def test_annotations_sql_tracking(

solve(config, overwrite=True, use_annotations=True)
tracks_df_annot, _ = to_tracks_layer(config)
print(tracks_df_annot)

assert len(tracks_df) > len(tracks_df_annot)
104 changes: 82 additions & 22 deletions ultrack/core/solve/solver/_test/test_solvers.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from itertools import product

import numpy as np
import pandas as pd
import pytest

from ultrack.config.config import MainConfig
from ultrack.core.solve.solver.base_solver import BaseSolver
from ultrack.core.solve.solver.heuristic.heuristic_solver import HeuristicSolver
from ultrack.core.solve.solver.mip_solver import MIPSolver


@pytest.mark.parametrize(
"solver,config_content",
list(
product(
[MIPSolver, HeuristicSolver],
[
{
"tracking.appear_weight": -0.25,
"tracking.disappear_weight": -1.0,
"tracking.division_weight": -0.5,
"tracking.link_function": "identity",
"tracking.bias": 0,
}
],
)
),
indirect=["config_content"],
"config_content",
[
{
"tracking.appear_weight": -0.25,
"tracking.disappear_weight": -1.0,
"tracking.division_weight": -0.5,
"tracking.link_function": "identity",
"tracking.bias": 0,
}
],
indirect=True,
)
def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> None:
def test_solvers_optimize(config_instance: MainConfig) -> None:
"""
This demo builds a very simple graph with 7 nodes and a single overlap constraint (2,5)
and a two possible divisions on 2 and 6.
Expand All @@ -53,7 +44,7 @@ def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> No
Result: 0.5 + 0.5 + 1.0 + 0.7 - division_weight
"""
solver = solver(config_instance.tracking_config)
solver = MIPSolver(config_instance.tracking_config)

nodes = np.array([1, 2, 3, 4, 5, 6, 7])
is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool)
Expand Down Expand Up @@ -160,6 +151,75 @@ def test_fixed_nodes_constraint_solver(config_instance: MainConfig) -> None:
)


def test_solver_with_node_probabilities(config_instance: MainConfig) -> None:
"""
Edge -C- denotes contraint.
Graph:
0.3 0.7 1.0 1.0
1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4
| \\ / \\ due linting software
C 1.0 0.9
| \\ /
5 - 0.5 - 6 - 0.7 - 7
node w. 0.5 1.0 1.0
Solution:
0.3 0.7 1.0 1.0
1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4
\\
1.0
\\
6 - 0.7 - 7
node w. 0.5 1.0 1.0
Result: 0.3 + 0.7 + 1.0 + 1.0 + 1.0 + 1.0 +
0.5 + 0.5 + 0.5 + 1.0 + 0.7 - division_weight
"""
solver = MIPSolver(config_instance.tracking_config)

nodes = np.array([1, 2, 3, 4, 5, 6, 7])
nodes_probs = np.array([0.3, 0.7, 1.0, 1.0, 0.5, 1.0, 1.0])
is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool)
is_last = np.array([0, 0, 0, 1, 0, 0, 1], dtype=bool)

solver.add_nodes(nodes, is_first, is_last, nodes_prob=nodes_probs)

edges = np.array([[1, 2], [2, 3], [2, 6], [3, 4], [5, 6], [6, 4], [6, 7]])

weights = np.array([0.5, 0.5, 1.0, 0.5, 0.5, 0.9, 0.7])
solver.add_edges(edges[:, 0], edges[:, 1], weights)

solver.set_standard_constraints()

solver.add_overlap_constraints([2], [5])

objective = solver.optimize()
solution = solver.solution()

expected_solution = pd.DataFrame(
data=[pd.NA, 1, 2, 3, 2, 6],
index=[1, 2, 3, 4, 6, 7],
columns=["parent_id"],
dtype=pd.Int64Dtype(),
)
expected_edges = np.array([1, 1, 1, 1, 0, 0, 1], dtype=bool)

assert solution.shape == expected_solution.shape
assert np.all(solution.index.isin(expected_solution.index))
assert np.all(
expected_solution.loc[solution.index, "parent_id"] == solution["parent_id"]
)
assert np.allclose(
objective,
nodes_probs[expected_solution.index.to_numpy() - 1].sum()
+ weights[expected_edges].sum()
+ config_instance.tracking_config.division_weight,
)


def test_fixed_edges_constraint_solver(config_instance: MainConfig) -> None:
"""
Same graph as before but with a fixed division on node 6.
Expand Down
20 changes: 8 additions & 12 deletions ultrack/core/solve/solver/base_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Literal
from typing import Literal, Optional

import pandas as pd
from numpy.typing import ArrayLike
Expand All @@ -22,19 +22,13 @@ def __init__(
"""
self._config = config

@staticmethod
def _assert_same_length(**kwargs) -> None:
"""Validates if key-word arguments have the same length."""
for k1, v1 in kwargs.items():
for k2, v2 in kwargs.items():
if len(v2) != len(v1):
raise ValueError(
f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}."
)

@abstractmethod
def add_nodes(
self, indices: ArrayLike, is_first_t: ArrayLike, is_last_t: ArrayLike
self,
indices: ArrayLike,
is_first_t: ArrayLike,
is_last_t: ArrayLike,
node_prob: Optional[ArrayLike] = None,
) -> None:
"""Add nodes variables solver.
Expand All @@ -46,6 +40,8 @@ def add_nodes(
Boolean array indicating if it belongs to first time point and it won't receive appearance penalization.
is_last_t : ArrayLike
Boolean array indicating if it belongs to last time point and it won't receive disappearance penalization.
node_prob: Optional[ArrayLike]
If provided assigns a node probability score to the objective function.
"""

@abstractmethod
Expand Down
9 changes: 4 additions & 5 deletions ultrack/core/solve/solver/heuristic/heuristic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from skimage.util._map_array import ArrayMap

from ultrack.config.config import TrackingConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.solve.solver.base_solver import BaseSolver
from ultrack.core.solve.solver.heuristic._numba_heuristic_solver import (
NumbaHeuristicSolver,
)
from ultrack.utils.array import assert_same_length
from ultrack.utils.constants import NO_PARENT

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,9 +69,7 @@ def add_nodes(
if hasattr(self, "_forbidden"):
raise ValueError("Nodes have already been added.")

self._assert_same_length(
indices=indices, is_first_t=is_first_t, is_last_t=is_last_t
)
assert_same_length(indices=indices, is_first_t=is_first_t, is_last_t=is_last_t)

indices = np.asarray(indices)
size = len(indices)
Expand Down Expand Up @@ -111,7 +110,7 @@ def add_edges(
if hasattr(self, "_weights"):
raise ValueError("Edges have already been added.")

self._assert_same_length(weights=weights, sources=sources, targets=targets)
assert_same_length(weights=weights, sources=sources, targets=targets)

self._weights = np.asarray(
self._config.apply_link_function(weights), np.float32
Expand Down
Loading

0 comments on commit 5ca946d

Please sign in to comment.