From a813be852905fd9a15a70a2921b29ae2e21d76a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 9 Nov 2021 15:33:48 +0100 Subject: [PATCH] Fix converting only float type tensors in Lite (#10429) * fix * less code * add test case * add test cases * update input * add test cases * add type hint * add changelog note Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- CHANGELOG.md | 1 + pytorch_lightning/lite/wrappers.py | 9 +++++++-- tests/lite/test_wrappers.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28e609813e0f2..22615f329dbfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an import error being caused by `PostLocalSGD` when `torch.distributed` not available ([#10359](https://github.com/PyTorchLightning/pytorch-lightning/pull/10359)) - Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409)) - Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428)) +- Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429)) ## [1.5.0] - 2021-11-02 diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 881a663fdb9e5..615f461055204 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -95,12 +95,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: } # TODO (@awaelchli): let the precision plugin handle the conversion to_type = precision_to_type[precision] - args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor) + + def _convert_float_tensor(t: Tensor) -> Tensor: + return t.to(to_type) if torch.is_floating_point(t) else t + + args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor) with self._precision_plugin.forward_context(): output = self.module(*args, **kwargs) - output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor) + to_type = torch.get_default_dtype() + output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor) return output diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 6741bf59b4dca..4993a10c8dbc2 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -40,8 +40,13 @@ def test_lite_module_wraps(): (32, torch.float16, torch.float32), (32, torch.float32, torch.float32), (32, torch.float64, torch.float32), + (32, torch.int, torch.int), (16, torch.float32, torch.float16), (16, torch.float64, torch.float16), + (16, torch.long, torch.long), + pytest.param("bf16", torch.float32, torch.bfloat16, marks=RunIf(min_torch="1.10")), + pytest.param("bf16", torch.float64, torch.bfloat16, marks=RunIf(min_torch="1.10")), + pytest.param("bf16", torch.bool, torch.bool, marks=RunIf(min_torch="1.10")), ], ) def test_lite_module_forward_conversion(precision, input_type, expected_type): @@ -53,11 +58,11 @@ def check_autocast(forward_input): assert precision != 16 or torch.is_autocast_enabled() return forward_input - module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast) + module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast) lite_module = _LiteModule(module, lite._precision_plugin).to(device) - out = lite_module(torch.rand(1, dtype=input_type, device=device)) + out = lite_module(torch.tensor([1, 2, 3], dtype=input_type, device=device)) assert module.call_args[0][0].dtype == expected_type - assert out.dtype == torch.get_default_dtype() + assert out.dtype == input_type or out.dtype == torch.get_default_dtype() def test_lite_dataloader_iterator():