Skip to content

Commit

Permalink
Multiclass-jaccard: fix off-by-one issue when ignore_index = num_clas…
Browse files Browse the repository at this point in the history
…ses + 1 (#1860)

fix off-by-one issue when ignore_index = num_classes + 1
  • Loading branch information
martinmeinke authored Jun 28, 2023
1 parent cb13126 commit 8fef294
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed number of bugs related to `average="macro"` in classification metrics ([#1821](https://github.com/Lightning-AI/torchmetrics/pull/1821))


- Fixed off-by-one issue when `ignore_index = num_classes + 1` in Multiclass-jaccard ([#1860](https://github.com/Lightning-AI/torchmetrics/pull/1860))


## [0.11.4] - 2023-03-10

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _jaccard_index_reduce(
if average == "binary":
return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1])

ignore_index_cond = ignore_index is not None and 0 <= ignore_index <= confmat.shape[0]
ignore_index_cond = ignore_index is not None and 0 <= ignore_index < confmat.shape[0]
multilabel = confmat.ndim == 3
if multilabel:
num = confmat[:, 1, 1]
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _sklearn_jaccard_index_multiclass(preds, target, ignore_index=None, average=
preds = preds.flatten()
target = target.flatten()
target, preds = remove_ignore_index(target, preds, ignore_index)
if ignore_index is not None and 0 <= ignore_index <= NUM_CLASSES:
if ignore_index is not None and 0 <= ignore_index < NUM_CLASSES:
labels = [i for i in range(NUM_CLASSES) if i != ignore_index]
res = sk_jaccard_index(y_true=target, y_pred=preds, average=average, labels=labels)
return np.insert(res, ignore_index, 0.0) if average is None else res
Expand Down

0 comments on commit 8fef294

Please sign in to comment.