Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Fix converting only float type tensors in Lite (Lightning-AI#10429)
Browse files Browse the repository at this point in the history
* fix

* less code

* add test case

* add test cases

* update input

* add test cases

* add type hint

* add changelog note

Co-authored-by: Kaushik B <[email protected]>
  • Loading branch information
2 people authored and Raalsky committed Nov 23, 2021
1 parent f6ea60f commit 2fcb55d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409))


- - Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
- Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429))


-
- - Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))


-
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,17 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
}
# TODO (@awaelchli): let the precision plugin handle the conversion
to_type = precision_to_type[precision]
args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor)

def _convert_float_tensor(t: Tensor) -> Tensor:
return t.to(to_type) if torch.is_floating_point(t) else t

args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)

with self._precision_plugin.forward_context():
output = self.module(*args, **kwargs)

output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor)
to_type = torch.get_default_dtype()
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
return output


Expand Down
11 changes: 8 additions & 3 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ def test_lite_module_wraps():
(32, torch.float16, torch.float32),
(32, torch.float32, torch.float32),
(32, torch.float64, torch.float32),
(32, torch.int, torch.int),
(16, torch.float32, torch.float16),
(16, torch.float64, torch.float16),
(16, torch.long, torch.long),
pytest.param("bf16", torch.float32, torch.bfloat16, marks=RunIf(min_torch="1.10")),
pytest.param("bf16", torch.float64, torch.bfloat16, marks=RunIf(min_torch="1.10")),
pytest.param("bf16", torch.bool, torch.bool, marks=RunIf(min_torch="1.10")),
],
)
def test_lite_module_forward_conversion(precision, input_type, expected_type):
Expand All @@ -53,11 +58,11 @@ def check_autocast(forward_input):
assert precision != 16 or torch.is_autocast_enabled()
return forward_input

module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast)
module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
lite_module = _LiteModule(module, lite._precision_plugin).to(device)
out = lite_module(torch.rand(1, dtype=input_type, device=device))
out = lite_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
assert module.call_args[0][0].dtype == expected_type
assert out.dtype == torch.get_default_dtype()
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()


def test_lite_dataloader_iterator():
Expand Down

0 comments on commit 2fcb55d

Please sign in to comment.