From 01e8f9b0be448ebddf36b93bf516346d3064dfba Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 14 Jul 2022 00:59:39 +0530 Subject: [PATCH 1/7] Use global_step while restoring logging step for old checkpoints --- .../loops/epoch/training_epoch_loop.py | 3 ++- tests/tests_pytorch/models/test_restore.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 36a594b45ae6f..4916d428fc8aa 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -293,7 +293,8 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict") - self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) + # restore global step instead to make sure logging works correctly if checkpoints None: # reload dataloaders diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 4f167c08e8a05..805fbd5f87e52 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -255,6 +255,7 @@ class TestModel(BoringModel): def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs assert self.trainer.global_step == first_max_epochs * train_batches + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs @@ -262,6 +263,27 @@ def on_train_start(self) -> None: assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches +def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir): + trainer = Trainer(max_steps=1, default_root_dir=tmpdir) + model = BoringModel() + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + ckpt = torch.load(ckpt_path) + del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"] + torch.save(ckpt, ckpt_path) + + class TestModel(BoringModel): + def on_train_start(self) -> None: + assert self.trainer.global_step == 1 + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 + + trainer = Trainer(max_steps=2, default_root_dir=tmpdir) + model = TestModel() + trainer.fit(model, ckpt_path=ckpt_path) + new_loop = trainer.fit_loop.epoch_loop + assert new_loop.global_step == new_loop._batches_that_stepped == 2 + + def test_fit_twice(tmpdir): epochs = [] From dbeead922fb8a831b958298bff603767c50f6cf5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 14 Jul 2022 01:13:54 +0530 Subject: [PATCH 2/7] chlog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0d42f819d9d1d..76f9305ea7eb1 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -314,6 +314,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) +- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) + + ## [1.6.4] - 2022-06-01 ### Added From f70d7bd98d5b322fe83ef789fa99106225efd1d6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 14 Jul 2022 14:14:08 +0530 Subject: [PATCH 3/7] disable validation --- tests/tests_pytorch/models/test_restore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 805fbd5f87e52..831f34888906f 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -264,7 +264,7 @@ def on_train_start(self) -> None: def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir): - trainer = Trainer(max_steps=1, default_root_dir=tmpdir) + trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) model = BoringModel() trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path @@ -277,7 +277,7 @@ def on_train_start(self) -> None: assert self.trainer.global_step == 1 assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 - trainer = Trainer(max_steps=2, default_root_dir=tmpdir) + trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) model = TestModel() trainer.fit(model, ckpt_path=ckpt_path) new_loop = trainer.fit_loop.epoch_loop From fb00257b30030035a5dd48269fd9e886b64f1696 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 14 Jul 2022 17:24:46 +0530 Subject: [PATCH 4/7] Update tests/tests_pytorch/models/test_restore.py Co-authored-by: Akihiro Nitta --- tests/tests_pytorch/models/test_restore.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 831f34888906f..bad8f3575e9cc 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -269,6 +269,7 @@ def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir): trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path ckpt = torch.load(ckpt_path) + # the key "_batches_that_stepped" doesn't exist in checkpoints generated with Date: Fri, 15 Jul 2022 00:18:28 +0530 Subject: [PATCH 5/7] fix for manual opt --- .../trainer/connectors/checkpoint_connector.py | 12 ++++++++++-- tests/tests_pytorch/models/test_restore.py | 9 +++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5bc591c7c6719..67dd8768dd241 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -264,10 +264,18 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop + pl_module = self.trainer.lightning_module + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved - optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop - optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"] + batch_loop = fit_loop.epoch_loop.batch_loop + if pl_module is None or pl_module.automatic_optimization: + batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ + "global_step" + ] + else: + batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index bad8f3575e9cc..39f623c73736d 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -28,7 +28,7 @@ import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -263,9 +263,10 @@ def on_train_start(self) -> None: assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches -def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir): +@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) +def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) - model = BoringModel() + model = model_class() trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path ckpt = torch.load(ckpt_path) @@ -273,7 +274,7 @@ def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir): del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"] torch.save(ckpt, ckpt_path) - class TestModel(BoringModel): + class TestModel(model_class): def on_train_start(self) -> None: assert self.trainer.global_step == 1 assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 From 6ab0608c0f8901eed8872a5f7e5c6d165f3b0d6f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sun, 17 Jul 2022 02:49:21 +0530 Subject: [PATCH 6/7] add assertion --- .../trainer/connectors/checkpoint_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 67dd8768dd241..22f61c845360d 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -265,11 +265,12 @@ def restore_loops(self) -> None: fit_loop = self.trainer.fit_loop pl_module = self.trainer.lightning_module + assert pl_module is not None # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module is None or pl_module.automatic_optimization: + if pl_module.automatic_optimization: batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ "global_step" ] From 0c73b635c2820a9ea26333c014c286e7de637dd5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 19 Jul 2022 23:51:44 +0530 Subject: [PATCH 7/7] update test --- .../trainer/connectors/test_checkpoint_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index cf640f794f259..02e3221f4dfd4 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -172,6 +172,8 @@ def test_loops_restore(tmpdir): ckpt_path = str(tmpdir / "last.ckpt") trainer = Trainer(**trainer_args) + trainer.strategy.connect(model) + for fn in TrainerFn: if fn != TrainerFn.TUNING: trainer_fn = getattr(trainer, f"{fn}_loop")