diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c4b852c2ad568..73ec649435337 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -277,6 +277,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed redundant input-type casting in FSDP precision ([#18630](https://github.com/Lightning-AI/lightning/pull/18630)) +- Fixed numerical issues when reducing values in low precision with `self.log` ([#18686](https://github.com/Lightning-AI/lightning/pull/18686)) + + ## [2.0.9] - 2023-09-14 diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index e0668da873373..dc6467f4d3167 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -194,8 +194,8 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: default = float("inf") else: default = 0.0 - # do not set a dtype in case the default dtype was changed - self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum) + # the logged value will be stored in float32 or higher to maintain accuracy + self.add_state("value", torch.tensor(default, dtype=_get_default_dtype()), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: self.cumulated_batch_size: Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) @@ -205,14 +205,16 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: def update(self, value: _VALUE, batch_size: int) -> None: if self.is_tensor: value = cast(Tensor, value) + dtype = _get_default_dtype() if not torch.is_floating_point(value): - dtype = torch.get_default_dtype() warning_cache.warn( # do not include the value to avoid cache misses f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to" f" be floating point. Converting it to {dtype}." ) value = value.to(dtype) + if value.dtype not in (torch.float32, torch.float64): + value = value.to(dtype) if self.meta.on_step: self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place @@ -517,3 +519,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"{{{self.training}, {super().__repr__()}}}" + + +def _get_default_dtype() -> torch.dtype: + """The default dtype for new tensors, but no lower than float32.""" + dtype = torch.get_default_dtype() + return dtype if dtype in (torch.float32, torch.float64) else torch.float32 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index ce529b22a2df6..26488307e134f 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -479,27 +479,34 @@ def test_metric_result_computed_check(): assert cache is computed_value -@pytest.mark.parametrize("floating_dtype", [torch.float, torch.double]) -def test_metric_result_respects_dtype(floating_dtype): +@pytest.mark.parametrize( + ("default_type", "converted_type"), + [ + (torch.half, torch.float), + (torch.float, torch.float), + (torch.double, torch.double), + ], +) +def test_metric_result_respects_dtype(default_type, converted_type): from lightning.pytorch.trainer.connectors.logger_connector.result import warning_cache warning_cache.clear() - torch.set_default_dtype(floating_dtype) + torch.set_default_dtype(default_type) fixed_dtype = torch.long # default by PyTorch metadata = _Metadata("foo", "bar") metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) - assert rm.value.dtype == floating_dtype + assert rm.value.dtype == converted_type assert rm.cumulated_batch_size.dtype == fixed_dtype # two fixed point numbers - should be converted value, batch_size = tensor(2), 3 assert value.dtype == fixed_dtype with pytest.warns( - UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {floating_dtype}" + UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {converted_type}" ): rm.update(value, batch_size) # floating and fixed @@ -508,7 +515,7 @@ def test_metric_result_respects_dtype(floating_dtype): total = rm.compute() assert total == (2 * 3 + 4 * 5) / (5 + 3) - assert total.dtype == floating_dtype + assert total.dtype == converted_type # restore to avoid impacting other tests torch.set_default_dtype(torch.float) @@ -534,6 +541,25 @@ def test_metric_result_dtype_promotion(reduce_fx): assert total.dtype == torch.double +@pytest.mark.parametrize("input_dtype", [torch.int8, torch.float16, torch.bfloat16]) +def test_metric_result_precision_no_lower_than_float32(input_dtype): + """Test that the ResultMetric only stores values in float32 or higher precision for numerical stability.""" + metadata = _Metadata("foo", "bar", reduce_fx="sum") + metadata.sync = _Sync() + metric = _ResultMetric(metadata, is_tensor=True) + assert metric.value.dtype == torch.float + + # in bfloat16, truncation would occur at 256 (8 bit exponent) + # in int8, overflow would occur at 128 + for i in range(1000): + metric.update(tensor(1.0, dtype=input_dtype), 1) + assert metric.value.dtype == torch.float32 + + total = metric.compute() + assert total.item() == 1000.0 + assert total.dtype == torch.float32 + + @pytest.mark.parametrize(("reduce_fx", "expected"), [(max, -2), (min, 2)]) def test_result_metric_max_min(reduce_fx, expected): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)