diff --git a/src/umetrics/core.py b/src/umetrics/core.py index 64f47c7..26ed39c 100644 --- a/src/umetrics/core.py +++ b/src/umetrics/core.py @@ -63,6 +63,11 @@ def find_matches( The reference (ground truth) segmentation. pred : The predicted segmentation. + strict : bool + Whether to use strict matching, i.e. only allowing matches above a + threshold IoU value. + iou_threshold : + A threshold value to use when strict matching. Return ------ @@ -71,17 +76,9 @@ def find_matches( """ - # return a default dictionary of no matches - matches = { - "true_matches": [], - "true_matches_IoU": [], - "in_ref_only": set(ref.labels), - "in_pred_only": set(pred.labels), - } - # make an infinite cost matrix, so that we only consider matches where # there is some overlap in the masks - cost_matrix = np.full((len(ref.labels), len(pred.labels)), np.inf) + cost_matrix = np.full((len(ref.labels), len(pred.labels)), 1e8) for r_id, ref_label in enumerate(ref.labels): mask = ref.labeled == ref_label @@ -95,13 +92,27 @@ def find_matches( # if it's strict, make sure every element is above the threshold if strict: - assert np.all(cost_matrix >= iou_threshold) - - try: - sol_row, sol_col = linear_sum_assignment(cost_matrix) - except ValueError: + cost_threshold = 1.0 - iou_threshold + assert np.all(cost_matrix >= cost_threshold), cost_matrix + + # solve + sol_row, sol_col = linear_sum_assignment(cost_matrix) + + # remove infeasible solutions + edges = [(r, c) for r, c in zip(sol_row, sol_col) if cost_matrix[r, c] <= 1] + + # return a default dictionary if there are no matches + if not edges: + matches = { + "true_matches": [], + "true_matches_IoU": [], + "in_ref_only": set(ref.labels), + "in_pred_only": set(pred.labels), + } return matches + sol_row, sol_col = zip(*edges) + # now that we've solved the LAP, find the matches that have been made used_ref = [ref.labels[row] for row in sol_row] used_pred = [pred.labels[col] for col in sol_col] @@ -353,13 +364,14 @@ def n_false_positives(self): @property def per_object_IoU(self): """Intersection over Union (IoU) metric""" - iou = [] - for m in self.true_positives: - mask_ref = self._reference.labeled == m[0] - mask_pred = self._predicted.labeled == m[1] + # iou = [] + # for m in self.true_positives: + # mask_ref = self._reference.labeled == m[0] + # mask_pred = self._predicted.labeled == m[1] - iou.append(_IoU(mask_ref, mask_pred)) - return iou + # iou.append(_IoU(mask_ref, mask_pred)) + # return iou + return self._matches["true_matches_IoU"] @property def per_image_pixel_identity(self): @@ -438,7 +450,6 @@ def batch(files, **kwargs): """batch process a list of files""" metrix = [] for f_ref, f_pred in files: - print(f_pred) true = imread(f_ref) pred = imread(f_pred) result = calculate(true, pred, **kwargs).results diff --git a/tests/conftest.py b/tests/conftest.py index 9bba98c..fc49f0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import pytest import numpy as np import numpy.typing as npt + +from skimage.util import montage from typing import Tuple -SEED = 12345 +SEED = 12347 RNG = np.random.default_rng(seed=SEED) @@ -22,7 +24,51 @@ def _IoU(y_true: npt.NDArray, y_pred: npt.NDArray) -> float: @pytest.fixture -def image_pair() -> Tuple[npt.NDArray, npt.NDArray, float]: +def image_grid(N: int = 3, sz: int = 32) -> Tuple[npt.NDArray, npt.NDArray, dict]: + image_types = RNG.choice( + ["pair", "missing_true", "missing_pred"], size=(N * N,) + ).tolist() + true_stack = np.zeros((N * N, sz, sz), dtype=np.uint8) + pred_stack = np.zeros((N * N, sz, sz), dtype=np.uint8) + + ious = [] + + for idx, img_type in enumerate(image_types): + if img_type == "pair": + true_stack[idx, ...] = _synthetic_image() + pred_stack[idx, ...] = _synthetic_image() + ious.append(_IoU(true_stack[idx, ...], pred_stack[idx, ...])) + elif img_type == "missing_true": + pred_stack[idx, ...] = _synthetic_image() + ious.append(0.0) + else: + true_stack[idx, ...] = _synthetic_image() + ious.append(0.0) + + n_pairs = image_types.count("pair") + n_missing_pred = image_types.count("missing_pred") + n_missing_true = image_types.count("missing_true") + + stats = { + "n_pairs": n_pairs, + "n_true": n_pairs + n_missing_pred, + "n_pred": n_pairs + n_missing_true, + "n_missing_pred": n_missing_pred, + "n_missing_true": n_missing_true, + "n_total": len(image_types), + "IoU": ious, + } + + return ( + montage(true_stack, rescale_intensity=False, grid_shape=(sz, sz)), + montage(pred_stack, rescale_intensity=False, grid_shape=(sz, sz)), + stats, + ) + + +@pytest.fixture +def image_pair() -> Tuple[npt.NDArray, npt.NDArray, dict]: y_true = _synthetic_image() y_pred = _synthetic_image() - return y_true, y_pred, _IoU(y_true, y_pred) + stats = {"IoU": _IoU(y_true, y_pred)} + return y_true, y_pred, stats diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 97b4d52..b3ee3ce 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -7,7 +7,8 @@ @pytest.mark.parametrize("strict", (False, True)) def test_calculate(image_pair, strict): """Run the metrics on a pair of images.""" - y_true, y_pred, IoU = image_pair + y_true, y_pred, stats = image_pair + IoU = stats["IoU"] result = umetrics.calculate(y_true, y_pred, strict=strict) @@ -21,8 +22,8 @@ def test_calculate(image_pair, strict): def test_calculate_no_true(image_pair): - """Run the metrics on a pair of images where there is no object in the GT.""" - y_true, y_pred, IoU = image_pair + """Test a pair of images where there is no object in the GT.""" + y_true, y_pred, _ = image_pair y_true = np.zeros_like(y_pred) result = umetrics.calculate(y_true, y_pred) @@ -34,9 +35,8 @@ def test_calculate_no_true(image_pair): def test_calculate_no_pred(image_pair): - """Run the metrics on a pair of images where there is no object in the - prediction.""" - y_true, y_pred, IoU = image_pair + """Test a pair of images where there is no object in the prediction.""" + y_true, y_pred, _ = image_pair y_pred = np.zeros_like(y_true) result = umetrics.calculate(y_true, y_pred) @@ -45,3 +45,15 @@ def test_calculate_no_pred(image_pair): assert result.n_true_positives == 0 assert result.n_false_negatives == 1 assert result.n_false_positives == 0 + + +def test_calculate_grid(image_grid): + """Test a multi-instance segmentation.""" + y_true, y_pred, stats = image_grid + result = umetrics.calculate(y_true, y_pred) + + assert result.n_true_labels == stats["n_true"] + assert result.n_pred_labels == stats["n_pred"] + assert result.n_true_positives == stats["n_pairs"] + assert result.n_false_positives == stats["n_missing_true"] + assert result.n_false_negatives == stats["n_missing_pred"]