Skip to content

Commit

Permalink
Maintain float32 precision at minimum in ResultMetric (#18686)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 3, 2023
1 parent 9d9220c commit b69f3c6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed redundant input-type casting in FSDP precision ([#18630](https://github.com/Lightning-AI/lightning/pull/18630))


- Fixed numerical issues when reducing values in low precision with `self.log` ([#18686](https://github.com/Lightning-AI/lightning/pull/18686))



## [2.0.9] - 2023-09-14

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
default = float("inf")
else:
default = 0.0
# do not set a dtype in case the default dtype was changed
self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum)
# the logged value will be stored in float32 or higher to maintain accuracy
self.add_state("value", torch.tensor(default, dtype=_get_default_dtype()), dist_reduce_fx=torch.sum)
if self.meta.is_mean_reduction:
self.cumulated_batch_size: Tensor
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
Expand All @@ -205,14 +205,16 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
def update(self, value: _VALUE, batch_size: int) -> None:
if self.is_tensor:
value = cast(Tensor, value)
dtype = _get_default_dtype()
if not torch.is_floating_point(value):
dtype = torch.get_default_dtype()
warning_cache.warn(
# do not include the value to avoid cache misses
f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"
f" be floating point. Converting it to {dtype}."
)
value = value.to(dtype)
if value.dtype not in (torch.float32, torch.float64):
value = value.to(dtype)

if self.meta.on_step:
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place
Expand Down Expand Up @@ -517,3 +519,9 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return f"{{{self.training}, {super().__repr__()}}}"


def _get_default_dtype() -> torch.dtype:
"""The default dtype for new tensors, but no lower than float32."""
dtype = torch.get_default_dtype()
return dtype if dtype in (torch.float32, torch.float64) else torch.float32
38 changes: 32 additions & 6 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,27 +479,34 @@ def test_metric_result_computed_check():
assert cache is computed_value


@pytest.mark.parametrize("floating_dtype", [torch.float, torch.double])
def test_metric_result_respects_dtype(floating_dtype):
@pytest.mark.parametrize(
("default_type", "converted_type"),
[
(torch.half, torch.float),
(torch.float, torch.float),
(torch.double, torch.double),
],
)
def test_metric_result_respects_dtype(default_type, converted_type):
from lightning.pytorch.trainer.connectors.logger_connector.result import warning_cache

warning_cache.clear()

torch.set_default_dtype(floating_dtype)
torch.set_default_dtype(default_type)
fixed_dtype = torch.long # default by PyTorch

metadata = _Metadata("foo", "bar")
metadata.sync = _Sync()
rm = _ResultMetric(metadata, is_tensor=True)

assert rm.value.dtype == floating_dtype
assert rm.value.dtype == converted_type
assert rm.cumulated_batch_size.dtype == fixed_dtype

# two fixed point numbers - should be converted
value, batch_size = tensor(2), 3
assert value.dtype == fixed_dtype
with pytest.warns(
UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {floating_dtype}"
UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {converted_type}"
):
rm.update(value, batch_size)
# floating and fixed
Expand All @@ -508,7 +515,7 @@ def test_metric_result_respects_dtype(floating_dtype):
total = rm.compute()

assert total == (2 * 3 + 4 * 5) / (5 + 3)
assert total.dtype == floating_dtype
assert total.dtype == converted_type

# restore to avoid impacting other tests
torch.set_default_dtype(torch.float)
Expand All @@ -534,6 +541,25 @@ def test_metric_result_dtype_promotion(reduce_fx):
assert total.dtype == torch.double


@pytest.mark.parametrize("input_dtype", [torch.int8, torch.float16, torch.bfloat16])
def test_metric_result_precision_no_lower_than_float32(input_dtype):
"""Test that the ResultMetric only stores values in float32 or higher precision for numerical stability."""
metadata = _Metadata("foo", "bar", reduce_fx="sum")
metadata.sync = _Sync()
metric = _ResultMetric(metadata, is_tensor=True)
assert metric.value.dtype == torch.float

# in bfloat16, truncation would occur at 256 (8 bit exponent)
# in int8, overflow would occur at 128
for i in range(1000):
metric.update(tensor(1.0, dtype=input_dtype), 1)
assert metric.value.dtype == torch.float32

total = metric.compute()
assert total.item() == 1000.0
assert total.dtype == torch.float32


@pytest.mark.parametrize(("reduce_fx", "expected"), [(max, -2), (min, 2)])
def test_result_metric_max_min(reduce_fx, expected):
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
Expand Down

0 comments on commit b69f3c6

Please sign in to comment.