From a33c13a48292d797e58fbf54211b199e38470556 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 11 Jul 2023 12:09:58 +0200 Subject: [PATCH] Fix `Auroc` metric when `max_fpr` is set and a class is missing (#1895) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> (cherry picked from commit 3fed40f450194d39a0ef63ae2b3fe843c9a29634) --- CHANGELOG.md | 3 +++ .../functional/classification/auroc.py | 2 +- tests/unittests/classification/test_auroc.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ba3977f207..c5083315607 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895)) + + - Fixed bug related to empty predictions for `IntersectionOverUnion` metric ([#1892](https://github.com/Lightning-AI/torchmetrics/pull/1892)) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 94ca84877f0..8d0157fe7aa 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -86,7 +86,7 @@ def _binary_auroc_compute( pos_label: int = 1, ) -> Tensor: fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) - if max_fpr is None or max_fpr == 1: + if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0: return _auc_compute_without_check(fpr, tpr, 1.0) _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 0fdf7834da0..e51f564b02f 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -393,3 +393,17 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) +def test_corner_case_max_fpr(max_fpr): + """Check that metric returns 0 when one class is missing and `max_fpr` is set.""" + preds = torch.tensor([0.1, 0.2, 0.3, 0.4]) + target = torch.tensor([0, 0, 0, 0]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0 + + preds = torch.tensor([0.5, 0.6, 0.7, 0.8]) + target = torch.tensor([1, 1, 1, 1]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0