diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index d4bc362a539f7..081c567135fd6 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836)) +- Called `LightningDataModule.load_state_dict` hook while restoring checkpoint using `LightningDataModule.load_from_checkpoint` ([#14883](https://github.com/Lightning-AI/lightning/pull/14883)) + + - Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904)) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 9cc465184898c..da1d166add3b4 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -226,6 +226,8 @@ def _load_state( obj.on_load_checkpoint(checkpoint) if isinstance(obj, pl.LightningDataModule): + if obj.__class__.__qualname__ in checkpoint: + obj.load_state_dict(checkpoint[obj.__class__.__qualname__]) return obj # load the state_dict on the model automatically diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 158371a3097c5..95ea61240c030 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -25,6 +25,27 @@ from tests_pytorch.helpers.runif import RunIf +class HookedDataModule(BoringDataModule): + def __init__(self, called): + super().__init__() + + def call(hook, fn, *args, **kwargs): + out = fn(*args, **kwargs) + d = {"name": hook} + if args: + d["args"] = args + if kwargs: + d["kwargs"] = kwargs + called.append(d) + return out + + for h in get_members(LightningDataModule): + attr = getattr(self, h) + partial_h = partial(call, h, attr) + update_wrapper(partial_h, attr) + setattr(self, h, partial_h) + + @pytest.mark.parametrize("max_steps", [1, 2, 3]) def test_on_before_zero_grad_called(tmpdir, max_steps): class CurrentTestModel(BoringModel): @@ -911,26 +932,6 @@ def predict_dataloader(self): def test_trainer_datamodule_hook_system(tmpdir): """Test the LightningDataModule hook system.""" - class HookedDataModule(BoringDataModule): - def __init__(self, called): - super().__init__() - - def call(hook, fn, *args, **kwargs): - out = fn(*args, **kwargs) - d = {"name": hook} - if args: - d["args"] = args - if kwargs: - d["kwargs"] = kwargs - called.append(d) - return out - - for h in get_members(LightningDataModule): - attr = getattr(self, h) - partial_h = partial(call, h, attr) - update_wrapper(partial_h, attr) - setattr(self, h, partial_h) - model = BoringModel() batches = 2 trainer = Trainer( @@ -991,3 +992,43 @@ def call(hook, fn, *args, **kwargs): dict(name="teardown", kwargs=dict(stage="predict")), ] assert called == expected + + +def test_load_from_checkpoint_hook_calls(tmpdir): + class CustomHookedDataModule(HookedDataModule): + def state_dict(self): + return {"foo": "bar"} + + lm_called, ldm_called = [], [] + model = HookedModel(lm_called) + datamodule = CustomHookedDataModule(ldm_called) + trainer = Trainer() + trainer.strategy.connect(model) + trainer._data_connector.attach_data(model, datamodule=datamodule) + ckpt_path = str(tmpdir / "file.ckpt") + trainer.save_checkpoint(ckpt_path) + + datamodule_state_dict_key = datamodule.__class__.__qualname__ + saved_ckpt = { + "callbacks": ANY, + "epoch": 0, + "global_step": 0, + "lr_schedulers": ANY, + "optimizer_states": ANY, + "pytorch-lightning_version": __version__, + "state_dict": ANY, + "loops": ANY, + datamodule_state_dict_key: {"foo": "bar"}, + } + + assert lm_called == [dict(name="on_save_checkpoint", args=(saved_ckpt,))] + assert ldm_called == [dict(name="state_dict"), dict(name="on_save_checkpoint", args=(saved_ckpt,))] + + lm_called, ldm_called = [], [] + model = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called) + datamodule = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called) + assert lm_called == [dict(name="on_load_checkpoint", args=({**saved_ckpt, "hyper_parameters": ANY},))] + assert ldm_called == [ + dict(name="on_load_checkpoint", args=({**saved_ckpt, "datamodule_hyper_parameters": ANY},)), + dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],)), + ]