diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 48ef6e25dc465..6aed707726079 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) +- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) + + ## [1.6.4] - 2022-06-01 ### Added diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index de07acdc90568..33ee2ce484d8b 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -287,7 +287,8 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) - self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) + # restore global step instead to make sure logging works correctly if checkpoints None: # reload dataloaders diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5bc591c7c6719..22f61c845360d 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -264,10 +264,19 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop + pl_module = self.trainer.lightning_module + assert pl_module is not None + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved - optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop - optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"] + batch_loop = fit_loop.epoch_loop.batch_loop + if pl_module.automatic_optimization: + batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ + "global_step" + ] + else: + batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 4f167c08e8a05..39f623c73736d 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -28,7 +28,7 @@ import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -255,6 +255,7 @@ class TestModel(BoringModel): def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs assert self.trainer.global_step == first_max_epochs * train_batches + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs @@ -262,6 +263,29 @@ def on_train_start(self) -> None: assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches +@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) +def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): + trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) + model = model_class() + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + ckpt = torch.load(ckpt_path) + # the key "_batches_that_stepped" doesn't exist in checkpoints generated with None: + assert self.trainer.global_step == 1 + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 + + trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) + model = TestModel() + trainer.fit(model, ckpt_path=ckpt_path) + new_loop = trainer.fit_loop.epoch_loop + assert new_loop.global_step == new_loop._batches_that_stepped == 2 + + def test_fit_twice(tmpdir): epochs = [] diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index cf640f794f259..02e3221f4dfd4 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -172,6 +172,8 @@ def test_loops_restore(tmpdir): ckpt_path = str(tmpdir / "last.ckpt") trainer = Trainer(**trainer_args) + trainer.strategy.connect(model) + for fn in TrainerFn: if fn != TrainerFn.TUNING: trainer_fn = getattr(trainer, f"{fn}_loop")