Skip to content

Commit

Permalink
Semantic Segmentation Bugfix (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom authored Nov 5, 2024
1 parent ce8d261 commit 09f804b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
62 changes: 62 additions & 0 deletions lite/tests/semantic_segmentation/test_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from valor_lite.semantic_segmentation import (
Bitmask,
DataLoader,
MetricType,
Segmentation,
Expand Down Expand Up @@ -89,3 +91,63 @@ def test_confusion_matrix_segmentations_from_boxes(
assert m in expected_metrics
for m in expected_metrics:
assert m in actual_metrics


def test_confusion_matrix_intermediate_counting():

segmentation = Segmentation(
uid="uid1",
groundtruths=[
Bitmask(
mask=np.array([[False, False], [True, False]]),
label="a",
),
Bitmask(
mask=np.array([[False, False], [False, True]]),
label="b",
),
Bitmask(
mask=np.array([[True, False], [False, False]]),
label="c",
),
Bitmask(
mask=np.array([[False, True], [False, False]]),
label="d",
),
],
predictions=[
Bitmask(
mask=np.array([[False, False], [False, False]]),
label="a",
),
Bitmask(
mask=np.array([[False, False], [False, False]]),
label="b",
),
Bitmask(
mask=np.array([[True, True], [True, True]]),
label="c",
),
Bitmask(
mask=np.array([[False, False], [False, False]]),
label="d",
),
],
)

loader = DataLoader()
loader.add_data([segmentation])

assert len(loader.matrices) == 1
assert (
loader.matrices[0]
== np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0],
]
)
).all()
4 changes: 2 additions & 2 deletions lite/valor_lite/semantic_segmentation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def compute_intermediate_confusion_matrices(
predictions.reshape(1, n_pd_labels, -1),
).sum(axis=2)

intersected_groundtruth_counts = intersection_counts.sum(axis=0)
intersected_prediction_counts = intersection_counts.sum(axis=1)
intersected_groundtruth_counts = intersection_counts.sum(axis=1)
intersected_prediction_counts = intersection_counts.sum(axis=0)

confusion_matrix = np.zeros((n_labels + 1, n_labels + 1), dtype=np.int32)
confusion_matrix[0, 0] = background_counts
Expand Down

0 comments on commit 09f804b

Please sign in to comment.