Skip to content

Commit

Permalink
Call LightningDataModule.load_state_dict hook while restoring check…
Browse files Browse the repository at this point in the history
…point using `LightningDataModule.load_from_checkpoint` (#14883)
  • Loading branch information
rohitgr7 authored Sep 29, 2022
1 parent 93e802a commit 3a70e5d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))


- Called `LightningDataModule.load_state_dict` hook while restoring checkpoint using `LightningDataModule.load_from_checkpoint` ([#14883](https://github.com/Lightning-AI/lightning/pull/14883))


- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904))


Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def _load_state(
obj.on_load_checkpoint(checkpoint)

if isinstance(obj, pl.LightningDataModule):
if obj.__class__.__qualname__ in checkpoint:
obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
return obj

# load the state_dict on the model automatically
Expand Down
81 changes: 61 additions & 20 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@
from tests_pytorch.helpers.runif import RunIf


class HookedDataModule(BoringDataModule):
def __init__(self, called):
super().__init__()

def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
d = {"name": hook}
if args:
d["args"] = args
if kwargs:
d["kwargs"] = kwargs
called.append(d)
return out

for h in get_members(LightningDataModule):
attr = getattr(self, h)
partial_h = partial(call, h, attr)
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)


@pytest.mark.parametrize("max_steps", [1, 2, 3])
def test_on_before_zero_grad_called(tmpdir, max_steps):
class CurrentTestModel(BoringModel):
Expand Down Expand Up @@ -911,26 +932,6 @@ def predict_dataloader(self):
def test_trainer_datamodule_hook_system(tmpdir):
"""Test the LightningDataModule hook system."""

class HookedDataModule(BoringDataModule):
def __init__(self, called):
super().__init__()

def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
d = {"name": hook}
if args:
d["args"] = args
if kwargs:
d["kwargs"] = kwargs
called.append(d)
return out

for h in get_members(LightningDataModule):
attr = getattr(self, h)
partial_h = partial(call, h, attr)
update_wrapper(partial_h, attr)
setattr(self, h, partial_h)

model = BoringModel()
batches = 2
trainer = Trainer(
Expand Down Expand Up @@ -991,3 +992,43 @@ def call(hook, fn, *args, **kwargs):
dict(name="teardown", kwargs=dict(stage="predict")),
]
assert called == expected


def test_load_from_checkpoint_hook_calls(tmpdir):
class CustomHookedDataModule(HookedDataModule):
def state_dict(self):
return {"foo": "bar"}

lm_called, ldm_called = [], []
model = HookedModel(lm_called)
datamodule = CustomHookedDataModule(ldm_called)
trainer = Trainer()
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model, datamodule=datamodule)
ckpt_path = str(tmpdir / "file.ckpt")
trainer.save_checkpoint(ckpt_path)

datamodule_state_dict_key = datamodule.__class__.__qualname__
saved_ckpt = {
"callbacks": ANY,
"epoch": 0,
"global_step": 0,
"lr_schedulers": ANY,
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
datamodule_state_dict_key: {"foo": "bar"},
}

assert lm_called == [dict(name="on_save_checkpoint", args=(saved_ckpt,))]
assert ldm_called == [dict(name="state_dict"), dict(name="on_save_checkpoint", args=(saved_ckpt,))]

lm_called, ldm_called = [], []
model = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
datamodule = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
assert lm_called == [dict(name="on_load_checkpoint", args=({**saved_ckpt, "hyper_parameters": ANY},))]
assert ldm_called == [
dict(name="on_load_checkpoint", args=({**saved_ckpt, "datamodule_hyper_parameters": ANY},)),
dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],)),
]

0 comments on commit 3a70e5d

Please sign in to comment.