From 6e7fed8af1f7e5e3ffe8b9d1df43abfcaa6a3a79 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 26 Aug 2023 19:42:54 +0200 Subject: [PATCH 1/6] classification --- src/torchmetrics/classification/accuracy.py | 12 ++++++------ src/torchmetrics/classification/auroc.py | 6 +++--- src/torchmetrics/classification/average_precision.py | 6 +++--- src/torchmetrics/classification/exact_match.py | 8 ++++---- src/torchmetrics/classification/group_fairness.py | 8 ++++---- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 8e7e6eb3455..8deb4fe5544 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -92,8 +92,8 @@ class BinaryAccuracy(BinaryStatScores): tensor([0.3333, 0.1667]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -240,8 +240,8 @@ class MulticlassAccuracy(MulticlassStatScores): [0.0000, 0.3333, 0.5000]]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -389,8 +389,8 @@ class MultilabelAccuracy(MultilabelStatScores): [0.0000, 0.0000, 0.5000]]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index f42c4b942f9..fb3465f4010 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -99,7 +99,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -241,7 +241,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -390,7 +390,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 8ade1884d73..2d29c30e05b 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -106,7 +106,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -239,7 +239,7 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -393,7 +393,7 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): """ is_differentiable: bool = False - higher_is_better: Optional[bool] = None + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 075c3430524..3d9bf724ccc 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -94,8 +94,8 @@ class MulticlassExactMatch(Metric): tensor([1., 0.]) """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -257,8 +257,8 @@ class MultilabelExactMatch(Metric): """ - is_differentiable = False - higher_is_better = True + is_differentiable: bool = False + higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 33ea8950fb7..9e46dd36891 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -101,8 +101,8 @@ class BinaryGroupStatRates(_AbstractGroupStatScores): {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])} """ - is_differentiable = False - higher_is_better = False + is_differentiable: bool = False + higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -207,8 +207,8 @@ class BinaryFairness(_AbstractGroupStatScores): {'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)} """ - is_differentiable = False - higher_is_better = False + is_differentiable: bool = False + higher_is_better: bool = False full_state_update: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 From 7025a9caa5e0ecafde3424aba871a083b8899d5b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 26 Aug 2023 19:43:29 +0200 Subject: [PATCH 2/6] add remaining --- src/torchmetrics/clustering/mutual_info_score.py | 4 ++-- src/torchmetrics/detection/ciou.py | 4 ++++ src/torchmetrics/detection/diou.py | 4 ++++ src/torchmetrics/detection/giou.py | 4 ++++ src/torchmetrics/image/perceptual_path_length.py | 3 +++ src/torchmetrics/regression/concordance.py | 4 ++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 86118daf41c..504d12c2718 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -62,8 +62,8 @@ class MutualInfoScore(Metric): """ - is_differentiable = True - higher_is_better = None + is_differentiable: bool = True + higher_is_better: bool = True full_state_update: bool = True plot_lower_bound: float = 0.0 preds: List[Tensor] diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index 0adc57af4f6..5b62679a396 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -93,6 +93,10 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.13.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "ciou" _invalid_val: float = -2.0 # unsure, min val could be just -1.5 as well diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 3508d80fee1..6778979b1c0 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -93,6 +93,10 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.13.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "diou" _invalid_val: float = -1.0 diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index d53d3e88777..e4ec9aee65c 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -93,6 +93,10 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): If torchvision is not installed with version 0.8.0 or newer. """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True + _iou_type: str = "giou" _invalid_val: float = -1.0 diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index 829a5249bef..bb7561f7c4b 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -118,6 +118,9 @@ class PerceptualPathLength(Metric): tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921])) """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = True + full_state_update: bool = True def __init__( self, diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index d45a06cb943..be9e2c1989c 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -67,6 +67,10 @@ class ConcordanceCorrCoef(PearsonCorrCoef): tensor([0.7273, 0.9887]) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 From 0497e53f506ea113101d160f81726ba905264c75 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 19:46:28 +0200 Subject: [PATCH 3/6] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b23ce58d355..ef9681f1e79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) +- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028) + ## [1.1.0] - 2023-08-22 ### Added From 28fec30218535e3a69957624850ff53c9c3c6a4b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 28 Aug 2023 09:21:06 +0200 Subject: [PATCH 4/6] fix --- src/torchmetrics/regression/concordance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index be9e2c1989c..8276c922b1a 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -69,7 +69,7 @@ class ConcordanceCorrCoef(PearsonCorrCoef): """ is_differentiable: bool = False higher_is_better: bool = True - full_state_update: bool = False + full_state_update: bool = True plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 From b0d9c42488f113fee954dd444db8171eabf29c6d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 28 Aug 2023 09:21:58 +0200 Subject: [PATCH 5/6] typing --- src/torchmetrics/regression/pearson.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 0fc575f7449..a25fa72ff7e 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -110,8 +110,8 @@ class PearsonCorrCoef(Metric): tensor([1., 1.]) """ - is_differentiable = True - higher_is_better = None # both -1 and 1 are optimal + is_differentiable: bool = True + higher_is_better: Optional[bool] = None # both -1 and 1 are optimal full_state_update: bool = True plot_lower_bound: float = -1.0 plot_upper_bound: float = 1.0 From b21f75da90cca54eadf54d1833d979d0091577fe Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 28 Aug 2023 11:13:52 +0200 Subject: [PATCH 6/6] fix wrongly set attr --- src/torchmetrics/regression/concordance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/regression/concordance.py b/src/torchmetrics/regression/concordance.py index 8276c922b1a..2a52e8a8036 100644 --- a/src/torchmetrics/regression/concordance.py +++ b/src/torchmetrics/regression/concordance.py @@ -67,7 +67,7 @@ class ConcordanceCorrCoef(PearsonCorrCoef): tensor([0.7273, 0.9887]) """ - is_differentiable: bool = False + is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = True