Skip to content

Commit

Permalink
Checkpoint migration for loop's internal state (#15500)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 8, 2022
1 parent 175603c commit bc2cf45
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 83 deletions.
3 changes: 1 addition & 2 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
18 changes: 0 additions & 18 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,6 @@ def restore_loops(self) -> None:
return

fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None

if self.trainer.state.fn == TrainerFn.FITTING:
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
Expand Down
93 changes: 93 additions & 0 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
"""Migration functions returned here will get executed in the order they are listed."""
return {
"0.10.0": [_migrate_model_checkpoint_early_stopping],
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
"1.6.5": [_migrate_loop_batches_that_stepped],
}


Expand All @@ -67,3 +69,94 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
return checkpoint


def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be
overwritten by the loop's state if it was also saved.
Version: 1.6.0
Commit: c67b075
PR: #13645, #11805
"""
global_step = checkpoint["global_step"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
# for automatic optimization
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
# for manual optimization
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]
optim_step_progress["total"]["completed"] = global_step
return checkpoint


def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be
overwritten by the loop's state if it was also saved.
Version: 1.6.0
Commit: aea96e4
PR: #11805
"""
epoch = checkpoint["epoch"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch
return checkpoint


def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `_batches_that_stepped` default value for checkpoints before v1.6.5 which don't have this key.
Version: 1.6.5
Commit: c67b075
PR: #13645
"""
global_step = checkpoint["global_step"]
checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step)
return checkpoint


_FIT_LOOP_INITIAL_STATE_1_6_0 = {
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"zero_grad": {
"current": {"completed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "ready": 0, "started": 0},
},
},
"optimizer_position": 0,
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.val_loop.dataloader_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.state_dict": {},
"epoch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"state_dict": {},
}
25 changes: 1 addition & 24 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lightning_lite import seed_everything
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -254,29 +254,6 @@ def on_train_start(self) -> None:
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(ckpt_path)
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2


def test_fit_twice(tmpdir):
epochs = []

Expand Down
111 changes: 111 additions & 0 deletions tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY

import pytest
import torch

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.utilities.migration import migrate_checkpoint
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version


@pytest.mark.parametrize(
"old_checkpoint, new_checkpoint",
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{
"epoch": 1,
"global_step": 23,
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
"loops": ANY,
},
),
],
)
def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint):
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__


def test_migrate_loop_global_step_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
# automatic optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
"step"
]["total"]["completed"]
== 15
)
# for manual optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][
"completed"
]
== 15
)


def test_migrate_loop_current_epoch_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_migrate_loop_batches_that_stepped(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path

# pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then
ckpt = torch.load(ckpt_path)
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
_set_version(ckpt, "1.6.4")
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2
13 changes: 8 additions & 5 deletions tests/tests_pytorch/utilities/migration/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import logging
import sys
from unittest.mock import ANY

import pytest

Expand Down Expand Up @@ -109,26 +110,28 @@ def test_migrate_checkpoint_for_pl(caplog):
"""Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent."""

# simulate a very recent checkpoint, no migrations needed
loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0}
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123}
assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0}

# simulate an old checkpoint that needed an upgrade
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0}
with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"):
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
assert new_checkpoint == {
"legacy_pytorch-lightning_version": "0.0.1",
"pytorch-lightning_version": pl.__version__,
"callbacks": {},
"content": 123,
"global_step": 2,
"epoch": 0,
"loops": ANY,
}
assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text


def test_migrate_checkpoint_legacy_version(monkeypatch):
"""Test that the legacy version gets set and does not change if migration is applied multiple times."""
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0}

# pretend the current pl version is 2.0
monkeypatch.setattr(pl, "__version__", "2.0.0")
Expand Down
34 changes: 0 additions & 34 deletions tests/tests_pytorch/utilities/test_upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,9 @@

import pytest

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.migration import migrate_checkpoint
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main


@pytest.mark.parametrize(
"old_checkpoint, new_checkpoint",
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
),
],
)
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__


def test_upgrade_checkpoint_file_missing(tmp_path, caplog):
# path to single file (missing)
file = tmp_path / "checkpoint.ckpt"
Expand Down

0 comments on commit bc2cf45

Please sign in to comment.