Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint migration for loop's internal state #15500

Merged
merged 34 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
277b0b8
migration
awaelchli Oct 21, 2022
cc110a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2022
e13fcc6
import
awaelchli Oct 21, 2022
7cb1d24
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 21, 2022
9838f00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2022
ed1ab3f
refactor
awaelchli Oct 24, 2022
756b2e7
protected
awaelchli Oct 24, 2022
d64b5ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2022
954b4d3
typo
awaelchli Oct 24, 2022
0febf26
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 24, 2022
5223dff
Merge branch 'master' into feature/migration-utils
awaelchli Oct 24, 2022
0f988d1
tests
awaelchli Oct 24, 2022
d2f213a
Merge branch 'master' into feature/migration-functions
awaelchli Nov 3, 2022
b3069a9
update
awaelchli Nov 3, 2022
94cbba9
x
awaelchli Nov 3, 2022
dd51bb0
update
awaelchli Nov 3, 2022
a2d7972
tests
awaelchli Nov 3, 2022
698fc70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2022
cca1c2f
update test
awaelchli Nov 3, 2022
91f0c07
Merge remote-tracking branch 'origin/feature/migration-functions' int…
awaelchli Nov 3, 2022
85e588a
format
awaelchli Nov 3, 2022
4e6013c
notebook
awaelchli Nov 3, 2022
093676d
notebook
awaelchli Nov 3, 2022
d899bf1
reset
awaelchli Nov 3, 2022
d6e94ec
Merge branch 'master' into feature/migration-functions
awaelchli Nov 3, 2022
878d34f
Merge branch 'master' into feature/migration-functions
awaelchli Nov 3, 2022
aa6eca6
remove unused import
awaelchli Nov 3, 2022
1fe623b
add missing keys
awaelchli Nov 3, 2022
827b0a3
Merge branch 'master' into feature/migration-functions
awaelchli Nov 4, 2022
3792f35
notebook
awaelchli Nov 4, 2022
1244c8b
notebook
awaelchli Nov 4, 2022
2a2e35d
Merge branch 'master' into feature/migration-functions
awaelchli Nov 5, 2022
a16ada6
fix merge
awaelchli Nov 5, 2022
61aec28
Merge branch 'master' into feature/migration-functions
awaelchli Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def on_save_checkpoint(self) -> Dict:
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
# restore global step instead to make sure logging works correctly if checkpoints <v1.6.5 used to resume
self._batches_that_stepped = state_dict.get("_batches_that_stepped", self.global_step)
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)

def _run_validation(self) -> None:
# reload dataloaders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,6 @@ def restore_loops(self) -> None:
return

fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None

if self.trainer.state.fn == TrainerFn.FITTING:
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
Expand Down
93 changes: 93 additions & 0 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
"""Migration functions returned here will get executed in the order they are listed."""
return {
"0.10.0": [_migrate_model_checkpoint_early_stopping],
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
"1.6.5": [_migrate_loop_batches_that_stepped],
}


Expand All @@ -67,3 +69,94 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
return checkpoint


def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be
overwritten by the loop's state if it was also saved.

Version: 1.6.0
Commit: c67b075
PR: #13645, #11805
"""
global_step = checkpoint["global_step"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
# for automatic optimization
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
# for manual optimization
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]
optim_step_progress["total"]["completed"] = global_step
return checkpoint


def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be
overwritten by the loop's state if it was also saved.

Version: 1.6.0
Commit: aea96e4
PR: #11805
"""
epoch = checkpoint["epoch"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch
return checkpoint


def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Sets the `_batches_that_stepped` default value for checkpoints before v1.6.5 which don't have this key.

Version: 1.6.5
Commit: c67b075
PR: #13645
"""
global_step = checkpoint["global_step"]
checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step)
return checkpoint


_FIT_LOOP_INITIAL_STATE_1_6_0 = {
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"zero_grad": {
"current": {"completed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "ready": 0, "started": 0},
},
},
"optimizer_position": 0,
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.val_loop.dataloader_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.state_dict": {},
"epoch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"state_dict": {},
}
25 changes: 1 addition & 24 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lightning_lite import seed_everything
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -254,29 +254,6 @@ def on_train_start(self) -> None:
assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(ckpt_path)
# the key "_batches_that_stepped" doesn't exist in checkpoints generated with <v1.6.5
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2


def test_fit_twice(tmpdir):
epochs = []

Expand Down
111 changes: 111 additions & 0 deletions tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY

import pytest
import torch

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.utilities.migration import migrate_checkpoint
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version


@pytest.mark.parametrize(
"old_checkpoint, new_checkpoint",
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY},
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{
"epoch": 1,
"global_step": 23,
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
"loops": ANY,
},
),
],
)
def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint):
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__


def test_migrate_loop_global_step_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
# automatic optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
"step"
]["total"]["completed"]
== 15
)
# for manual optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][
"completed"
]
== 15
)


def test_migrate_loop_current_epoch_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2


@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_migrate_loop_batches_that_stepped(tmpdir, model_class):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path

# pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then
ckpt = torch.load(ckpt_path)
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
_set_version(ckpt, "1.6.4")
torch.save(ckpt, ckpt_path)

class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1

trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2
13 changes: 8 additions & 5 deletions tests/tests_pytorch/utilities/migration/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import logging
import sys
from unittest.mock import ANY

import pytest

Expand Down Expand Up @@ -109,26 +110,28 @@ def test_migrate_checkpoint_for_pl(caplog):
"""Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent."""

# simulate a very recent checkpoint, no migrations needed
loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0}
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123}
assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0}

# simulate an old checkpoint that needed an upgrade
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0}
with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"):
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
assert new_checkpoint == {
"legacy_pytorch-lightning_version": "0.0.1",
"pytorch-lightning_version": pl.__version__,
"callbacks": {},
"content": 123,
"global_step": 2,
"epoch": 0,
"loops": ANY,
}
assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text


def test_migrate_checkpoint_legacy_version(monkeypatch):
"""Test that the legacy version gets set and does not change if migration is applied multiple times."""
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0}

# pretend the current pl version is 2.0
monkeypatch.setattr(pl, "__version__", "2.0.0")
Expand Down
34 changes: 0 additions & 34 deletions tests/tests_pytorch/utilities/test_upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,9 @@

import pytest

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.migration import migrate_checkpoint
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main


@pytest.mark.parametrize(
"old_checkpoint, new_checkpoint",
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
),
],
)
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__


def test_upgrade_checkpoint_file_missing(tmp_path, caplog):
# path to single file (missing)
file = tmp_path / "checkpoint.ckpt"
Expand Down