Skip to content

Commit

Permalink
Restore log step during restart (#13467)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
rohitgr7 and carmocca committed Jul 12, 2022
1 parent 9e6997c commit b286088
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))
- Fixed bug with Python version check that prevented use with development versions of Python ([#13420](https://github.com/PyTorchLightning/pytorch-lightning/pull/13420))
- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396))
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))


## [1.6.4] - 2022-06-01
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def teardown(self) -> None:

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped

if (
self.trainer is not None
Expand All @@ -300,6 +301,7 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_loops_state_dict_structure():
expected = {
"fit_loop": {
"state_dict": {},
"epoch_loop.state_dict": {},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def on_train_start(self) -> None:
trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * train_batches
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


def test_fit_twice(tmpdir):
Expand Down

0 comments on commit b286088

Please sign in to comment.