Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nodes' probabilities support to cost function #143

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ultrack/cli/data_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from ultrack.cli.utils import config_option
from ultrack.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB
from ultrack.core.database import LinkDB, NodeDB
from ultrack.core.export.utils import solution_dataframe_from_sql
from ultrack.tracks.graph import add_track_ids_to_tracks_df
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.printing import pretty_print_df


Expand Down
2 changes: 1 addition & 1 deletion ultrack/config/trackingconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TrackingConfig(BaseModel):
"""``SPECIAL``: Solver method, `reference <https://docs.python-mip.com/en/latest/classes.html#lp-method>`_"""

link_function: LinkFunctionChoices = "power"
"""``SPECIAL``: Function used to transform the edge weights, `identity` or `power`"""
"""``SPECIAL``: Function used to transform the edge and node weights, `identity` or `power`"""

power: float = 4
r"""``SPECIAL``: Expoent :math:`\eta` of power transform, :math:`w_{pq}^\eta` """
Expand Down
8 changes: 5 additions & 3 deletions ultrack/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from sqlalchemy.orm import Session, declarative_base

from ultrack.config.dataconfig import DatabaseChoices, DataConfig

# constant value to indicate it has no parent
NO_PARENT = -1
from ultrack.utils.array import assert_same_length
from ultrack.utils.constants import NO_PARENT

Base = declarative_base()

Expand Down Expand Up @@ -93,6 +92,7 @@ class NodeDB(Base):
area = Column(Integer)
selected = Column(Boolean, default=False)
pickle = Column(MaybePickleType)
node_prob = Column(Float, default=-1.0)
segm_annot = Column(Enum(NodeSegmAnnotation), default=NodeSegmAnnotation.UNKNOWN)
node_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
appear_annot = Column(Enum(VarAnnotation), default=VarAnnotation.UNKNOWN)
Expand Down Expand Up @@ -182,6 +182,8 @@ def set_node_values(
if hasattr(v, "tolist"):
kwargs[k] = v.tolist()

assert_same_length(**kwargs)

records = [
{k: v[i] for k, v in kwargs.items()} for i in range(len(kwargs["node_id"]))
]
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export.ctc import ctc_compress_forest, stitch_tracks_df
from ultrack.tracks.graph import add_track_ids_to_tracks_df
from ultrack.utils.constants import NO_PARENT


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_networkx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export import tracks_layer_to_networkx
from ultrack.utils.constants import NO_PARENT


@pytest.mark.parametrize("children_to_parent", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/_test/test_trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import pytest

from ultrack.core.database import NO_PARENT
from ultrack.core.export.trackmate import tracks_layer_to_trackmate
from ultrack.utils.constants import NO_PARENT

pytrackmate = pytest.importorskip("pytrackmate")

Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/export/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ultrack.config.config import MainConfig
from ultrack.config.dataconfig import DataConfig
from ultrack.core.database import NO_PARENT, NodeDB
from ultrack.core.database import NodeDB
from ultrack.core.export.utils import (
export_segmentation_generic,
filter_nodes_generic,
Expand All @@ -32,6 +32,7 @@
tracks_df_forest,
)
from ultrack.tracks.stats import estimate_drift
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.data import validate_and_overwrite_path

logging.basicConfig()
Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pandas as pd

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.export.tracks_layer import to_tracks_layer
from ultrack.tracks.graph import _create_tracks_forest
from ultrack.utils.constants import NO_PARENT

LOG = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion ultrack/core/export/trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.export.tracks_layer import to_tracks_layer
from ultrack.utils.constants import NO_PARENT


def _set_filter_elem(elem: ET.Element) -> None:
Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from toolz import curry

from ultrack.config.dataconfig import DataConfig
from ultrack.core.database import NO_PARENT, NodeDB
from ultrack.core.database import NodeDB
from ultrack.core.segmentation.node import Node
from ultrack.utils.constants import NO_PARENT
from ultrack.utils.multiprocessing import multiprocessing_apply

LOG = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion ultrack/core/linking/_test/test_link_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from sqlalchemy.orm import Session

from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB
from ultrack.core.database import LinkDB, NodeDB
from ultrack.core.linking.utils import clear_linking_data
from ultrack.utils.constants import NO_PARENT


@pytest.mark.parametrize(
Expand Down
5 changes: 2 additions & 3 deletions ultrack/core/solve/_test/test_sql_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from ultrack import solve, to_tracks_layer
from ultrack.config.config import MainConfig
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB, VarAnnotation
from ultrack.core.database import LinkDB, NodeDB, VarAnnotation
from ultrack.core.solve.sqltracking import SQLTracking
from ultrack.utils.constants import NO_PARENT

_CONFIG_PARAMS = {
"segmentation.n_workers": 4,
Expand Down Expand Up @@ -125,7 +126,6 @@ def test_annotations_sql_tracking(

solve(config, overwrite=True, use_annotations=True)
tracks_df, _ = to_tracks_layer(config)
print(tracks_df)

engine = sqla.create_engine(config.data_config.database_path)
with Session(engine) as session:
Expand All @@ -136,6 +136,5 @@ def test_annotations_sql_tracking(

solve(config, overwrite=True, use_annotations=True)
tracks_df_annot, _ = to_tracks_layer(config)
print(tracks_df_annot)

assert len(tracks_df) > len(tracks_df_annot)
104 changes: 82 additions & 22 deletions ultrack/core/solve/solver/_test/test_solvers.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from itertools import product

import numpy as np
import pandas as pd
import pytest

from ultrack.config.config import MainConfig
from ultrack.core.solve.solver.base_solver import BaseSolver
from ultrack.core.solve.solver.heuristic.heuristic_solver import HeuristicSolver
from ultrack.core.solve.solver.mip_solver import MIPSolver


@pytest.mark.parametrize(
"solver,config_content",
list(
product(
[MIPSolver, HeuristicSolver],
[
{
"tracking.appear_weight": -0.25,
"tracking.disappear_weight": -1.0,
"tracking.division_weight": -0.5,
"tracking.link_function": "identity",
"tracking.bias": 0,
}
],
)
),
indirect=["config_content"],
"config_content",
[
{
"tracking.appear_weight": -0.25,
"tracking.disappear_weight": -1.0,
"tracking.division_weight": -0.5,
"tracking.link_function": "identity",
"tracking.bias": 0,
}
],
indirect=True,
)
def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> None:
def test_solvers_optimize(config_instance: MainConfig) -> None:
"""
This demo builds a very simple graph with 7 nodes and a single overlap constraint (2,5)
and a two possible divisions on 2 and 6.
Expand All @@ -53,7 +44,7 @@ def test_solvers_optimize(solver: BaseSolver, config_instance: MainConfig) -> No

Result: 0.5 + 0.5 + 1.0 + 0.7 - division_weight
"""
solver = solver(config_instance.tracking_config)
solver = MIPSolver(config_instance.tracking_config)

nodes = np.array([1, 2, 3, 4, 5, 6, 7])
is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool)
Expand Down Expand Up @@ -160,6 +151,75 @@ def test_fixed_nodes_constraint_solver(config_instance: MainConfig) -> None:
)


def test_solver_with_node_probabilities(config_instance: MainConfig) -> None:
"""
Edge -C- denotes contraint.

Graph:

0.3 0.7 1.0 1.0
1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4
| \\ / \\ due linting software
C 1.0 0.9
| \\ /
5 - 0.5 - 6 - 0.7 - 7
node w. 0.5 1.0 1.0

Solution:

0.3 0.7 1.0 1.0
1 - 0.5 - 2 - 0.5 - 3 - 0.5 - 4
\\
1.0
\\
6 - 0.7 - 7
node w. 0.5 1.0 1.0

Result: 0.3 + 0.7 + 1.0 + 1.0 + 1.0 + 1.0 +
0.5 + 0.5 + 0.5 + 1.0 + 0.7 - division_weight
"""
solver = MIPSolver(config_instance.tracking_config)

nodes = np.array([1, 2, 3, 4, 5, 6, 7])
nodes_probs = np.array([0.3, 0.7, 1.0, 1.0, 0.5, 1.0, 1.0])
is_first = np.array([1, 0, 0, 0, 0, 0, 0], dtype=bool)
is_last = np.array([0, 0, 0, 1, 0, 0, 1], dtype=bool)

solver.add_nodes(nodes, is_first, is_last, nodes_prob=nodes_probs)

edges = np.array([[1, 2], [2, 3], [2, 6], [3, 4], [5, 6], [6, 4], [6, 7]])

weights = np.array([0.5, 0.5, 1.0, 0.5, 0.5, 0.9, 0.7])
solver.add_edges(edges[:, 0], edges[:, 1], weights)

solver.set_standard_constraints()

solver.add_overlap_constraints([2], [5])

objective = solver.optimize()
solution = solver.solution()

expected_solution = pd.DataFrame(
data=[pd.NA, 1, 2, 3, 2, 6],
index=[1, 2, 3, 4, 6, 7],
columns=["parent_id"],
dtype=pd.Int64Dtype(),
)
expected_edges = np.array([1, 1, 1, 1, 0, 0, 1], dtype=bool)

assert solution.shape == expected_solution.shape
assert np.all(solution.index.isin(expected_solution.index))
assert np.all(
expected_solution.loc[solution.index, "parent_id"] == solution["parent_id"]
)
assert np.allclose(
objective,
nodes_probs[expected_solution.index.to_numpy() - 1].sum()
+ weights[expected_edges].sum()
+ config_instance.tracking_config.division_weight,
)


def test_fixed_edges_constraint_solver(config_instance: MainConfig) -> None:
"""
Same graph as before but with a fixed division on node 6.
Expand Down
20 changes: 8 additions & 12 deletions ultrack/core/solve/solver/base_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Literal
from typing import Literal, Optional

import pandas as pd
from numpy.typing import ArrayLike
Expand All @@ -22,19 +22,13 @@ def __init__(
"""
self._config = config

@staticmethod
def _assert_same_length(**kwargs) -> None:
"""Validates if key-word arguments have the same length."""
for k1, v1 in kwargs.items():
for k2, v2 in kwargs.items():
if len(v2) != len(v1):
raise ValueError(
f"`{k1}` and `{k2}` must have the same length. Found {len(v1)} and {len(v2)}."
)

@abstractmethod
def add_nodes(
self, indices: ArrayLike, is_first_t: ArrayLike, is_last_t: ArrayLike
self,
indices: ArrayLike,
is_first_t: ArrayLike,
is_last_t: ArrayLike,
node_prob: Optional[ArrayLike] = None,
) -> None:
"""Add nodes variables solver.

Expand All @@ -46,6 +40,8 @@ def add_nodes(
Boolean array indicating if it belongs to first time point and it won't receive appearance penalization.
is_last_t : ArrayLike
Boolean array indicating if it belongs to last time point and it won't receive disappearance penalization.
node_prob: Optional[ArrayLike]
If provided assigns a node probability score to the objective function.
"""

@abstractmethod
Expand Down
9 changes: 4 additions & 5 deletions ultrack/core/solve/solver/heuristic/heuristic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from skimage.util._map_array import ArrayMap

from ultrack.config.config import TrackingConfig
from ultrack.core.database import NO_PARENT
from ultrack.core.solve.solver.base_solver import BaseSolver
from ultrack.core.solve.solver.heuristic._numba_heuristic_solver import (
NumbaHeuristicSolver,
)
from ultrack.utils.array import assert_same_length
from ultrack.utils.constants import NO_PARENT

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,9 +69,7 @@ def add_nodes(
if hasattr(self, "_forbidden"):
raise ValueError("Nodes have already been added.")

self._assert_same_length(
indices=indices, is_first_t=is_first_t, is_last_t=is_last_t
)
assert_same_length(indices=indices, is_first_t=is_first_t, is_last_t=is_last_t)

indices = np.asarray(indices)
size = len(indices)
Expand Down Expand Up @@ -111,7 +110,7 @@ def add_edges(
if hasattr(self, "_weights"):
raise ValueError("Edges have already been added.")

self._assert_same_length(weights=weights, sources=sources, targets=targets)
assert_same_length(weights=weights, sources=sources, targets=targets)

self._weights = np.asarray(
self._config.apply_link_function(weights), np.float32
Expand Down
Loading
Loading