Skip to content

Commit

Permalink
Make LightningModule torch.jit.script-able again (#15947)
Browse files Browse the repository at this point in the history
* Make LightningModule torch.jit.script-able again
* remove skip

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Dec 8, 2022
1 parent 67a47d4 commit b5fa896
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 39 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620))


- Fixed `torch.jit.script`-ing a LightningModule causing an unintended error message about deprecated `use_amp` property ([#15947](https://github.com/Lightning-AI/lightning/pull/15947))


## [1.8.3] - 2022-11-22

### Changed
Expand Down
24 changes: 1 addition & 23 deletions src/pytorch_lightning/_graveyard/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import LightningDataModule


def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None:
Expand All @@ -32,28 +32,6 @@ def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None:
)


def _use_amp(_: LightningModule) -> None:
# Remove in v2.0.0 and the skip in `__jit_unused_properties__`
if not LightningModule._jit_is_scripting:
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
raise RuntimeError(
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
" Please use `Trainer.amp_backend`.",
)


def _use_amp_setter(_: LightningModule, __: bool) -> None:
# Remove in v2.0.0
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
raise RuntimeError(
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
" Please use `Trainer.amp_backend`.",
)


# Properties
LightningModule.use_amp = property(fget=_use_amp, fset=_use_amp_setter)

# Methods
LightningDataModule.on_save_checkpoint = _on_save_checkpoint
LightningDataModule.on_load_checkpoint = _on_load_checkpoint
1 change: 0 additions & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class LightningModule(
"automatic_optimization",
"truncated_bptt_steps",
"trainer",
"use_amp", # from graveyard
]
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,17 @@ def test_proper_refcount():
assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)


def test_lightning_module_scriptable():
"""Test that the LightningModule is `torch.jit.script`-able.
Regression test for #15917.
"""
model = BoringModel()
trainer = Trainer()
model.trainer = trainer
torch.jit.script(model)


def test_trainer_reference_recursively():
ensemble = LightningModule()
inner = LightningModule()
Expand Down
15 changes: 0 additions & 15 deletions tests/tests_pytorch/graveyard/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,3 @@ def on_load_checkpoint(self, checkpoint):
match="`LightningDataModule.on_load_checkpoint`.*no longer supported as of v1.8.",
):
trainer.fit(model, OnLoadDataModule())


def test_v2_0_0_lightning_module_unsupported_use_amp():
model = BoringModel()
with pytest.raises(
RuntimeError,
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
):
model.use_amp

with pytest.raises(
RuntimeError,
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
):
model.use_amp = False

0 comments on commit b5fa896

Please sign in to comment.