diff --git a/CHANGELOG.md b/CHANGELOG.md index fb1f0b5ce05..05fb07fbd12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830)) +- Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926)) + ### Changed - diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index 547b03ccf92..8c8a4896a38 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Tuple import torch from torch import Tensor from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape @@ -100,6 +102,15 @@ def _pearson_corrcoef_compute( var_x = var_x.bfloat16() var_y = var_y.bfloat16() + bound = math.sqrt(torch.finfo(var_x.dtype).eps) + if (var_x < bound).any() or (var_y < bound).any(): + rank_zero_warn( + "The variance of predictions or target is close to zero. This can cause instability in Pearson correlation" + "coefficient, leading to wrong results. Consider re-scaling the input if possible or computing using a" + f"larger dtype (currently using {var_x.dtype}).", + UserWarning, + ) + corrcoef = (corr_xy / (var_x * var_y).sqrt()).squeeze() return torch.clamp(corrcoef, -1.0, 1.0) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 4decf8f16fb..043ca470d5c 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -140,3 +140,12 @@ def test_final_aggregation_function(shapes): output = _final_aggregation(input_fn(), input_fn(), input_fn(), input_fn(), input_fn(), torch.randint(10, shapes)) assert all(isinstance(out, torch.Tensor) for out in output) assert all(out.ndim == input_fn().ndim - 1 for out in output) + + +@pytest.mark.parametrize(("dtype", "scale"), [(torch.float16, 1e-4), (torch.float32, 1e-8), (torch.float64, 1e-16)]) +def test_pearsons_warning_on_small_input(dtype, scale): + """Check that a user warning is raised for small input.""" + preds = scale * torch.randn(100, dtype=dtype) + target = scale * torch.randn(100, dtype=dtype) + with pytest.warns(UserWarning, match="The variance of predictions or target is close to zero.*"): + pearson_corrcoef(preds, target)