Skip to content

Commit

Permalink
Use global_step while restoring logging step for old checkpoints (#…
Browse files Browse the repository at this point in the history
…13645)

Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
rohitgr7 and akihironitta authored Jul 19, 2022
1 parent 6cbd9d7 commit c67b075
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))


- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))


## [1.6.4] - 2022-06-01

### Added
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
13 changes: 11 additions & 2 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,19 @@ def restore_loops(self) -> None:
return

fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None

# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
Expand Down
26 changes: 25 additions & 1 deletion tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -255,13 +255,37 @@ class TestModel(BoringModel):
def on_train_start(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
assert self.trainer.global_step == first_max_epochs * train_batches
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches

trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * train_batches
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(ckpt_path)
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2


def test_fit_twice(tmpdir):
epochs = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_loops_restore(tmpdir):
ckpt_path = str(tmpdir / "last.ckpt")

trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)

for fn in TrainerFn:
if fn != TrainerFn.TUNING:
trainer_fn = getattr(trainer, f"{fn}_loop")
Expand Down

0 comments on commit c67b075

Please sign in to comment.