From beae9289ace323e175d32d2cf0fd683d6c7f8062 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 18 Nov 2023 06:45:46 +0100 Subject: [PATCH] Switch plotting order for pr curves (#2183) --- CHANGELOG.md | 3 +++ .../classification/precision_recall_curve.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 361f564dbb7..b449c16417c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141)) +- Changed x-/y-axis order for `PrecisionRecallCurve` to be consistent with scikit-learn ([#2183](https://github.com/Lightning-AI/torchmetrics/pull/2183)) + + ### Removed - diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index c9e42d41b71..66265e5a6bf 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -210,13 +210,16 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) + score = ( _auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ ) @@ -408,11 +411,13 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ ) @@ -598,11 +603,13 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ )