From 8fef294a645105f192ee6348514d520218cb2cb1 Mon Sep 17 00:00:00 2001 From: martinmeinke Date: Wed, 28 Jun 2023 12:03:28 -0400 Subject: [PATCH] Multiclass-jaccard: fix off-by-one issue when ignore_index = num_classes + 1 (#1860) fix off-by-one issue when ignore_index = num_classes + 1 --- CHANGELOG.md | 3 +++ src/torchmetrics/functional/classification/jaccard.py | 2 +- tests/unittests/classification/test_jaccard.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e355c900ad..997938ee05b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,6 +221,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed number of bugs related to `average="macro"` in classification metrics ([#1821](https://github.com/Lightning-AI/torchmetrics/pull/1821)) +- Fixed off-by-one issue when `ignore_index = num_classes + 1` in Multiclass-jaccard ([#1860](https://github.com/Lightning-AI/torchmetrics/pull/1860)) + + ## [0.11.4] - 2023-03-10 ### Fixed diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 59c946ffe27..560e28996ca 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -65,7 +65,7 @@ def _jaccard_index_reduce( if average == "binary": return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) - ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0] + ignore_index_cond = ignore_index is not None and 0 <= ignore_index < confmat.shape[0] multilabel = confmat.ndim == 3 if multilabel: num = confmat[:, 1, 1] diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 5a2ccc1fa43..c6133acb435 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -131,7 +131,7 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average= preds = preds.flatten() target = target.flatten() target, preds = remove_ignore_index(target, preds, ignore_index) - if ignore_index is not None and 0 <= ignore_index <= NUM_CLASSES: + if ignore_index is not None and 0 <= ignore_index < NUM_CLASSES: labels = [i for i in range(NUM_CLASSES) if i != ignore_index] res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels) return np.insert(res, ignore_index, 0.0) if average is None else res