Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use global_step while restoring logging step for old checkpoints #13645

Merged
merged 13 commits into from
Jul 19, 2022
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))


- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))


## [1.6.4] - 2022-06-01

### Added
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,19 @@ def restore_loops(self) -> None:
return

fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None

# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
Expand Down
26 changes: 25 additions & 1 deletion tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -255,13 +255,37 @@ class TestModel(BoringModel):
def on_train_start(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
assert self.trainer.global_step == first_max_epochs * train_batches
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches

trainer.fit(TestModel(), ckpt_path=ckpt_path)
assert trainer.current_epoch == max_epochs
assert trainer.global_step == max_epochs * train_batches
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(ckpt_path)
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2


def test_fit_twice(tmpdir):
epochs = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_loops_restore(tmpdir):
ckpt_path = str(tmpdir / "last.ckpt")

trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)

for fn in TrainerFn:
if fn != TrainerFn.TUNING:
trainer_fn = getattr(trainer, f"{fn}_loop")
Expand Down