From 7c03e164f2821977688723c6a3a32d1ed6bf7744 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 30 Jun 2022 15:54:17 +0530 Subject: [PATCH 1/4] Restore log step during restart --- src/pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 ++ tests/tests_pytorch/loops/test_loop_state_dict.py | 2 +- tests/tests_pytorch/models/test_restore.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 04e9d070a6d8e..e75f3c4b0e2cf 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -273,6 +273,7 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() + state_dict["batches_that_stepped"] = self._batches_that_stepped if ( self.trainer is not None @@ -292,6 +293,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict") + self._batches_that_stepped = state_dict.get("batches_that_stepped") def _run_validation(self) -> None: # reload dataloaders diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 1e67fcc0ed8db..68571a183b868 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -47,7 +47,7 @@ def test_loops_state_dict_structure(): expected = { "fit_loop": { "state_dict": {}, - "epoch_loop.state_dict": {}, + "epoch_loop.state_dict": {"batches_that_stepped": 0}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 77f45928dd907..4f167c08e8a05 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -259,6 +259,7 @@ def on_train_start(self) -> None: trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs assert trainer.global_step == max_epochs * train_batches + assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches def test_fit_twice(tmpdir): From d2df4b6bb0a19f3d9f91d354fda04085a74b6bc0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 30 Jun 2022 15:57:52 +0530 Subject: [PATCH 2/4] chlog --- src/pytorch_lightning/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 28695785c367c..44c4096b9b204 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -302,6 +302,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the input validation for the accelerator Trainer argument when passed as a string ([#13417](https://github.com/PyTorchLightning/pytorch-lightning/pull/13417)) +- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) + ## [1.6.4] - 2022-06-01 From f8546c29ba9bdeefcfa8829df6f926a944e8d108 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 30 Jun 2022 16:48:42 +0530 Subject: [PATCH 3/4] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- src/pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ++-- tests/tests_pytorch/loops/test_loop_state_dict.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index e75f3c4b0e2cf..9f7988b8019e0 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -273,7 +273,7 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() - state_dict["batches_that_stepped"] = self._batches_that_stepped + state_dict["_batches_that_stepped"] = self._batches_that_stepped if ( self.trainer is not None @@ -293,7 +293,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict") - self._batches_that_stepped = state_dict.get("batches_that_stepped") + self._batches_that_stepped = state_dict.get("_batches_that_stepped") def _run_validation(self) -> None: # reload dataloaders diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 68571a183b868..f9630095502d1 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -47,7 +47,7 @@ def test_loops_state_dict_structure(): expected = { "fit_loop": { "state_dict": {}, - "epoch_loop.state_dict": {"batches_that_stepped": 0}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, From 8873b4b8c9cc60e9d8b9301ed1a682c32bfee7a9 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 1 Jul 2022 12:59:53 +0530 Subject: [PATCH 4/4] Update src/pytorch_lightning/loops/epoch/training_epoch_loop.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/loops/epoch/training_epoch_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9f7988b8019e0..36a594b45ae6f 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -293,7 +293,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict") - self._batches_that_stepped = state_dict.get("_batches_that_stepped") + self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _run_validation(self) -> None: # reload dataloaders