diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 226a69c869311..777fa01b04847 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -299,8 +299,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) - # restore global step instead to make sure logging works correctly if checkpoints None: # reload dataloaders diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a3aa02dd6e09f..22d6f7955e6df 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -343,24 +343,6 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop - pl_module = self.trainer.lightning_module - assert pl_module is not None - - if self.trainer.state.fn == TrainerFn.FITTING: - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] - assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index 3431c01709b89..ba1165288b949 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -41,6 +41,8 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], + "1.6.5": [_migrate_loop_batches_that_stepped], } @@ -67,3 +69,94 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP checkpoint["callbacks"][callback_type][callback_key] = value del checkpoint[key] return checkpoint + + +def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: c67b075 + PR: #13645, #11805 + """ + global_step = checkpoint["global_step"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + # for automatic optimization + optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] + optim_progress["optimizer"]["step"]["total"]["completed"] = global_step + # for manual optimization + optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] + optim_step_progress["total"]["completed"] = global_step + return checkpoint + + +def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: aea96e4 + PR: #11805 + """ + epoch = checkpoint["epoch"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + return checkpoint + + +def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `_batches_that_stepped` default value for checkpoints before v1.6.5 which don't have this key. + + Version: 1.6.5 + Commit: c67b075 + PR: #13645 + """ + global_step = checkpoint["global_step"] + checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step) + return checkpoint + + +_FIT_LOOP_INITIAL_STATE_1_6_0 = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, +} diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 36dd508ff92d0..3748264d42e9e 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -29,7 +29,7 @@ from lightning_lite import seed_everything from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel +from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -254,29 +254,6 @@ def on_train_start(self) -> None: assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches -@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) -def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): - trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) - model = model_class() - trainer.fit(model) - ckpt_path = trainer.checkpoint_callback.best_model_path - ckpt = torch.load(ckpt_path) - # the key "_batches_that_stepped" doesn't exist in checkpoints generated with None: - assert self.trainer.global_step == 1 - assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 - - trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) - model = TestModel() - trainer.fit(model, ckpt_path=ckpt_path) - new_loop = trainer.fit_loop.epoch_loop - assert new_loop.global_step == new_loop._batches_that_stepped == 2 - - def test_fit_twice(tmpdir): epochs = [] diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py new file mode 100644 index 0000000000000..d6a94c76720f3 --- /dev/null +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -0,0 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import ANY + +import pytest +import torch + +import pytorch_lightning as pl +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel +from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version + + +@pytest.mark.parametrize( + "old_checkpoint, new_checkpoint", + [ + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY}, + ), + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY}, + ), + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY}, + ), + ( + {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, + { + "epoch": 1, + "global_step": 23, + "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, + "loops": ANY, + }, + ), + ], +) +def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint): + _set_version(old_checkpoint, "0.9.0") + _set_legacy_version(new_checkpoint, "0.9.0") + _set_version(new_checkpoint, pl.__version__) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + assert updated_checkpoint == old_checkpoint == new_checkpoint + assert _get_version(updated_checkpoint) == pl.__version__ + + +def test_migrate_loop_global_step_to_progress_tracking(): + old_checkpoint = {"global_step": 15, "epoch": 2} + _set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + # automatic optimization + assert ( + updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][ + "step" + ]["total"]["completed"] + == 15 + ) + # for manual optimization + assert ( + updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][ + "completed" + ] + == 15 + ) + + +def test_migrate_loop_current_epoch_to_progress_tracking(): + old_checkpoint = {"global_step": 15, "epoch": 2} + _set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2 + + +@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) +def test_migrate_loop_batches_that_stepped(tmpdir, model_class): + trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) + model = model_class() + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + + # pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then + ckpt = torch.load(ckpt_path) + del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"] + _set_version(ckpt, "1.6.4") + torch.save(ckpt, ckpt_path) + + class TestModel(model_class): + def on_train_start(self) -> None: + assert self.trainer.global_step == 1 + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 + + trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) + model = TestModel() + trainer.fit(model, ckpt_path=ckpt_path) + new_loop = trainer.fit_loop.epoch_loop + assert new_loop.global_step == new_loop._batches_that_stepped == 2 diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index d662cf5e89833..80b337f71a9d0 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import sys +from unittest.mock import ANY import pytest @@ -109,26 +110,28 @@ def test_migrate_checkpoint_for_pl(caplog): """Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent.""" # simulate a very recent checkpoint, no migrations needed - loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0} new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0} # simulate an old checkpoint that needed an upgrade - loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0} with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") assert new_checkpoint == { "legacy_pytorch-lightning_version": "0.0.1", "pytorch-lightning_version": pl.__version__, "callbacks": {}, - "content": 123, + "global_step": 2, + "epoch": 0, + "loops": ANY, } assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text def test_migrate_checkpoint_legacy_version(monkeypatch): """Test that the legacy version gets set and does not change if migration is applied multiple times.""" - loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0} # pretend the current pl version is 2.0 monkeypatch.setattr(pl, "__version__", "2.0.0") diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2a53448f5189c..1777849e09ca1 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,43 +18,9 @@ import pytest -import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main -@pytest.mark.parametrize( - "old_checkpoint, new_checkpoint", - [ - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, - ), - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, - ), - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, - ), - ( - {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, - ), - ], -) -def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - _set_version(old_checkpoint, "0.9.0") - _set_legacy_version(new_checkpoint, "0.9.0") - _set_version(new_checkpoint, pl.__version__) - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) - assert updated_checkpoint == old_checkpoint == new_checkpoint - assert _get_version(updated_checkpoint) == pl.__version__ - - def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to single file (missing) file = tmp_path / "checkpoint.ckpt"