diff --git a/CHANGELOG.md b/CHANGELOG.md index 361f564dbb7..b449c16417c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Deprecated `metric._update_called` ([#2141](https://github.com/Lightning-AI/torchmetrics/pull/2141)) +- Changed x-/y-axis order for `PrecisionRecallCurve` to be consistent with scikit-learn ([#2183](https://github.com/Lightning-AI/torchmetrics/pull/2183)) + + ### Removed - diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index c9e42d41b71..66265e5a6bf 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -210,13 +210,16 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) + score = ( _auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ ) @@ -408,11 +411,13 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ ) @@ -598,11 +603,13 @@ def plot( """ curve_computed = curve or self.compute() + # switch order as the standard way is recall along x-axis and precision along y-axis + curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None ) return plot_curve( - curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__ + curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ )