Skip to content

Commit

Permalink
update matching method and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumjot committed Nov 17, 2023
1 parent 62f7eab commit 1c22d78
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 30 deletions.
53 changes: 32 additions & 21 deletions src/umetrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def find_matches(
The reference (ground truth) segmentation.
pred :
The predicted segmentation.
strict : bool
Whether to use strict matching, i.e. only allowing matches above a
threshold IoU value.
iou_threshold :
A threshold value to use when strict matching.
Return
------
Expand All @@ -71,17 +76,9 @@ def find_matches(
"""

# return a default dictionary of no matches
matches = {
"true_matches": [],
"true_matches_IoU": [],
"in_ref_only": set(ref.labels),
"in_pred_only": set(pred.labels),
}

# make an infinite cost matrix, so that we only consider matches where
# there is some overlap in the masks
cost_matrix = np.full((len(ref.labels), len(pred.labels)), np.inf)
cost_matrix = np.full((len(ref.labels), len(pred.labels)), 1e8)

for r_id, ref_label in enumerate(ref.labels):
mask = ref.labeled == ref_label
Expand All @@ -95,13 +92,27 @@ def find_matches(

# if it's strict, make sure every element is above the threshold
if strict:
assert np.all(cost_matrix >= iou_threshold)

try:
sol_row, sol_col = linear_sum_assignment(cost_matrix)
except ValueError:
cost_threshold = 1.0 - iou_threshold
assert np.all(cost_matrix >= cost_threshold), cost_matrix

# solve
sol_row, sol_col = linear_sum_assignment(cost_matrix)

# remove infeasible solutions
edges = [(r, c) for r, c in zip(sol_row, sol_col) if cost_matrix[r, c] <= 1]

# return a default dictionary if there are no matches
if not edges:
matches = {
"true_matches": [],
"true_matches_IoU": [],
"in_ref_only": set(ref.labels),
"in_pred_only": set(pred.labels),
}
return matches

sol_row, sol_col = zip(*edges)

# now that we've solved the LAP, find the matches that have been made
used_ref = [ref.labels[row] for row in sol_row]
used_pred = [pred.labels[col] for col in sol_col]
Expand Down Expand Up @@ -353,13 +364,14 @@ def n_false_positives(self):
@property
def per_object_IoU(self):
"""Intersection over Union (IoU) metric"""
iou = []
for m in self.true_positives:
mask_ref = self._reference.labeled == m[0]
mask_pred = self._predicted.labeled == m[1]
# iou = []
# for m in self.true_positives:
# mask_ref = self._reference.labeled == m[0]
# mask_pred = self._predicted.labeled == m[1]

iou.append(_IoU(mask_ref, mask_pred))
return iou
# iou.append(_IoU(mask_ref, mask_pred))
# return iou
return self._matches["true_matches_IoU"]

@property
def per_image_pixel_identity(self):
Expand Down Expand Up @@ -438,7 +450,6 @@ def batch(files, **kwargs):
"""batch process a list of files"""
metrix = []
for f_ref, f_pred in files:
print(f_pred)
true = imread(f_ref)
pred = imread(f_pred)
result = calculate(true, pred, **kwargs).results
Expand Down
52 changes: 49 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
import numpy as np
import numpy.typing as npt

from skimage.util import montage
from typing import Tuple

SEED = 12345
SEED = 12347
RNG = np.random.default_rng(seed=SEED)


Expand All @@ -22,7 +24,51 @@ def _IoU(y_true: npt.NDArray, y_pred: npt.NDArray) -> float:


@pytest.fixture
def image_pair() -> Tuple[npt.NDArray, npt.NDArray, float]:
def image_grid(N: int = 3, sz: int = 32) -> Tuple[npt.NDArray, npt.NDArray, dict]:
image_types = RNG.choice(
["pair", "missing_true", "missing_pred"], size=(N * N,)
).tolist()
true_stack = np.zeros((N * N, sz, sz), dtype=np.uint8)
pred_stack = np.zeros((N * N, sz, sz), dtype=np.uint8)

ious = []

for idx, img_type in enumerate(image_types):
if img_type == "pair":
true_stack[idx, ...] = _synthetic_image()
pred_stack[idx, ...] = _synthetic_image()
ious.append(_IoU(true_stack[idx, ...], pred_stack[idx, ...]))
elif img_type == "missing_true":
pred_stack[idx, ...] = _synthetic_image()
ious.append(0.0)
else:
true_stack[idx, ...] = _synthetic_image()
ious.append(0.0)

n_pairs = image_types.count("pair")
n_missing_pred = image_types.count("missing_pred")
n_missing_true = image_types.count("missing_true")

stats = {
"n_pairs": n_pairs,
"n_true": n_pairs + n_missing_pred,
"n_pred": n_pairs + n_missing_true,
"n_missing_pred": n_missing_pred,
"n_missing_true": n_missing_true,
"n_total": len(image_types),
"IoU": ious,
}

return (
montage(true_stack, rescale_intensity=False, grid_shape=(sz, sz)),
montage(pred_stack, rescale_intensity=False, grid_shape=(sz, sz)),
stats,
)


@pytest.fixture
def image_pair() -> Tuple[npt.NDArray, npt.NDArray, dict]:
y_true = _synthetic_image()
y_pred = _synthetic_image()
return y_true, y_pred, _IoU(y_true, y_pred)
stats = {"IoU": _IoU(y_true, y_pred)}
return y_true, y_pred, stats
24 changes: 18 additions & 6 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@pytest.mark.parametrize("strict", (False, True))
def test_calculate(image_pair, strict):
"""Run the metrics on a pair of images."""
y_true, y_pred, IoU = image_pair
y_true, y_pred, stats = image_pair
IoU = stats["IoU"]

result = umetrics.calculate(y_true, y_pred, strict=strict)

Expand All @@ -21,8 +22,8 @@ def test_calculate(image_pair, strict):


def test_calculate_no_true(image_pair):
"""Run the metrics on a pair of images where there is no object in the GT."""
y_true, y_pred, IoU = image_pair
"""Test a pair of images where there is no object in the GT."""
y_true, y_pred, _ = image_pair
y_true = np.zeros_like(y_pred)

result = umetrics.calculate(y_true, y_pred)
Expand All @@ -34,9 +35,8 @@ def test_calculate_no_true(image_pair):


def test_calculate_no_pred(image_pair):
"""Run the metrics on a pair of images where there is no object in the
prediction."""
y_true, y_pred, IoU = image_pair
"""Test a pair of images where there is no object in the prediction."""
y_true, y_pred, _ = image_pair
y_pred = np.zeros_like(y_true)

result = umetrics.calculate(y_true, y_pred)
Expand All @@ -45,3 +45,15 @@ def test_calculate_no_pred(image_pair):
assert result.n_true_positives == 0
assert result.n_false_negatives == 1
assert result.n_false_positives == 0


def test_calculate_grid(image_grid):
"""Test a multi-instance segmentation."""
y_true, y_pred, stats = image_grid
result = umetrics.calculate(y_true, y_pred)

assert result.n_true_labels == stats["n_true"]
assert result.n_pred_labels == stats["n_pred"]
assert result.n_true_positives == stats["n_pairs"]
assert result.n_false_positives == stats["n_missing_true"]
assert result.n_false_negatives == stats["n_missing_pred"]

0 comments on commit 1c22d78

Please sign in to comment.