From 277b0b811fb1419d6c06e7953941d6f6076eaf6d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 13:44:35 +0200 Subject: [PATCH 01/36] migration --- src/pytorch_lightning/core/saving.py | 7 +- .../connectors/checkpoint_connector.py | 28 +-- src/pytorch_lightning/utilities/migration.py | 159 +++++++++++++++++- .../utilities/upgrade_checkpoint.py | 33 +--- .../utilities/test_upgrade_checkpoint.py | 14 +- 5 files changed, 169 insertions(+), 72 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 46f1663ed705c..247d864438d87 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -31,7 +31,7 @@ from lightning_lite.utilities.cloud_io import load as pl_load from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -156,6 +156,9 @@ def _load_from_checkpoint( with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) + # convert legacy checkpoints to the new format + checkpoint = migrate_checkpoint(checkpoint) + if hparams_file is not None: extension = str(hparams_file).split(".")[-1] if extension.lower() == "csv": @@ -168,6 +171,7 @@ def _load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) # override the hparams with values that were passed in @@ -197,6 +201,7 @@ def _load_state( if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: if issubclass(cls, pl.LightningModule): + # TODO: make this a migration: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 50480e769b4ce..01d5a6a7e14ed 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -32,9 +32,8 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -86,13 +85,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) + loaded_checkpoint = migrate_checkpoint(loaded_checkpoint) return loaded_checkpoint def _set_ckpt_path( @@ -348,23 +341,6 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop - pl_module = self.trainer.lightning_module - assert pl_module is not None - - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] - assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py index ed71f25a571f7..7f400fc56efa4 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. + +When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, +see :func:`migrate_checkpoint`. +""" import sys -import threading +from distutils.version import LooseVersion from types import ModuleType, TracebackType +from typing import Any, Dict, Optional, Type +import pytorch_lightning as pl import pytorch_lightning.utilities.argparse -# Create a global lock to ensure no race condition with deleting sys modules -_lock = threading.Lock() +_CHECKPOINT = Dict[str, Any] class pl_legacy_patch: @@ -28,7 +33,7 @@ class pl_legacy_patch: unpickling old checkpoints. The following patches apply. 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898 + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, but still needs to be available for import for legacy checkpoints. @@ -38,8 +43,7 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ - def __enter__(self) -> None: - _lock.acquire() + def __enter__(self) -> "pl_legacy_patch": # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -47,11 +51,148 @@ def __enter__(self) -> None: # `_gpus_arg_default` used to be imported from these locations legacy_argparse_module._gpus_arg_default = lambda x: x pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x + return self def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], ) -> None: if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] - _lock.release() + + +def get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies all migrations below in order.""" + if should_upgrade(checkpoint, "0.10.0"): + _migrate_model_checkpoint_early_stopping(checkpoint) + if should_upgrade(checkpoint, "1.6.0"): + _migrate_loop_global_step_to_progress_tracking(checkpoint) + _migrate_loop_current_epoch_to_progress_tracking(checkpoint) + + set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint + + +def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """The checkpoint and early stopping keys were renamed. + + Version: 0.10.0 + Commit: + """ + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + + keys_mapping = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), + } + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in keys_mapping.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + return checkpoint + + +def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + global_step = checkpoint["global_step"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + # for automatic optimization + optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] + optim_progress["optimizer"]["step"]["total"]["completed"] = global_step + # for manual optimization + optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] + optim_step_progress["total"]["completed"] = global_step + return checkpoint + + +def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + epoch = checkpoint["epoch"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + + +_FIT_LOOP_INITIAL_STATE_1_6_0 = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, +} diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 6f4dd5ca938dd..46038701b1e52 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,36 +17,11 @@ import torch -from lightning_lite.utilities.types import _PATH -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import pl_legacy_patch - -KEYS_MAPPING = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), -} +from pytorch_lightning.utilities.migration.base import pl_legacy_patch +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint log = logging.getLogger(__name__) - -def upgrade_checkpoint(filepath: _PATH) -> None: - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) - - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -61,4 +36,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None: log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") with pl_legacy_patch(): - upgrade_checkpoint(args.file) + checkpoint = torch.load(args.file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, args.file) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index a58bdb5721bc7..c01fcf7eb249d 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint +from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version @pytest.mark.parametrize( @@ -42,8 +40,8 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - filepath = os.path.join(tmpdir, "model.ckpt") - torch.save(old_checkpoint, filepath) - upgrade_checkpoint(filepath) - updated_checkpoint = torch.load(filepath) + set_version(old_checkpoint, "0.9.0") + set_version(new_checkpoint, pl.__version__) + updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint + assert get_version(updated_checkpoint) == pl.__version__ From cc110a3d1c7f21728ff74938317101667544d948 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Oct 2022 11:49:28 +0000 Subject: [PATCH 02/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py index 7f400fc56efa4..8df43b83dad48 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration.py @@ -123,8 +123,8 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - It will be overwritten by the loop's state if it was also saved. + """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. Version: 1.6.0 Commit: @@ -142,8 +142,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - It will be overwritten by the loop's state if it was also saved. + """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. Version: 1.6.0 Commit: From e13fcc61ddc9dfa5b01c7e3d3d70066be5cd65d5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 14:13:58 +0200 Subject: [PATCH 03/36] import --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 46038701b1e52..03705c5287e8d 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,8 +17,8 @@ import torch -from pytorch_lightning.utilities.migration.base import pl_legacy_patch -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint log = logging.getLogger(__name__) From 9838f008c61ad7a50d9ba7e7344ec16bd67111b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Oct 2022 12:16:07 +0000 Subject: [PATCH 04/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 03705c5287e8d..4bcfb4a86f5bd 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,8 +17,7 @@ import torch -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch log = logging.getLogger(__name__) From ed1ab3fce307e3b7f8e9dbdf5e79157ec821c98d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 12:58:36 +0200 Subject: [PATCH 05/36] refactor --- .../utilities/migration/__init__.py | 16 +++ .../{migration.py => migration/migrations.py} | 98 +++++-------------- .../utilities/migration/utils.py | 91 +++++++++++++++++ 3 files changed, 131 insertions(+), 74 deletions(-) create mode 100644 src/pytorch_lightning/utilities/migration/__init__.py rename src/pytorch_lightning/utilities/{migration.py => migration/migrations.py} (59%) create mode 100644 src/pytorch_lightning/utilities/migration/utils.py diff --git a/src/pytorch_lightning/utilities/migration/__init__.py b/src/pytorch_lightning/utilities/migration/__init__.py new file mode 100644 index 0000000000000..8e0f79f6904cb --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 +from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401 diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration/migrations.py similarity index 59% rename from src/pytorch_lightning/utilities/migration.py rename to src/pytorch_lightning/utilities/migration/migrations.py index 8df43b83dad48..a148f79ca5d7c 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -14,84 +14,36 @@ """Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, -see :func:`migrate_checkpoint`. -""" - -import sys -from distutils.version import LooseVersion -from types import ModuleType, TracebackType -from typing import Any, Dict, Optional, Type - -import pytorch_lightning as pl -import pytorch_lightning.utilities.argparse - -_CHECKPOINT = Dict[str, Any] - - -class pl_legacy_patch: - """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for - unpickling old checkpoints. The following patches apply. - - 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 - 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, - but still needs to be available for import for legacy checkpoints. - - Example: - - with pl_legacy_patch(): - torch.load("path/to/legacy/checkpoint.ckpt") - """ +see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`. - def __enter__(self) -> "pl_legacy_patch": - # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` - legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") - sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module +How to add a new migration? - # `_gpus_arg_default` used to be imported from these locations - legacy_argparse_module._gpus_arg_default = lambda x: x - pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x - return self +1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include + version informatin as well as the specific commit or PR where the breaking change happened. +2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the + change happened. Any checkpoint with a version greater or equal to that version will apply the given function. + Multiple migrations per version get executed in the provided list order. +3. You can test the migration on a checkpoint (backup your files first) by running: - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], - ) -> None: - if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): - delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - del sys.modules["pytorch_lightning.utilities.argparse_utils"] - - -def get_version(checkpoint: _CHECKPOINT) -> str: - """Get the version of a Lightning checkpoint.""" - return checkpoint["pytorch-lightning_version"] - - -def set_version(checkpoint: _CHECKPOINT, version: str) -> None: - """Set the version of a Lightning checkpoint.""" - checkpoint["pytorch-lightning_version"] = version + cp model.ckpt model.ckpt.backup + python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt +""" -def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: - """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return LooseVersion(get_version(checkpoint)) < LooseVersion(target) +from typing import Any, Dict, Callable, List +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies all migrations below in order.""" - if should_upgrade(checkpoint, "0.10.0"): - _migrate_model_checkpoint_early_stopping(checkpoint) - if should_upgrade(checkpoint, "1.6.0"): - _migrate_loop_global_step_to_progress_tracking(checkpoint) - _migrate_loop_current_epoch_to_progress_tracking(checkpoint) +_CHECKPOINT = Dict[str, Any] - set_version(checkpoint, pl.__version__) - # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert - # checkpoints permanently - return checkpoint +def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: + """Migration functions returned here will get executed in the order they are listed.""" + return { + "0.10.0": [_migrate_model_checkpoint_early_stopping], + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking] + } def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: @@ -100,9 +52,6 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP Version: 0.10.0 Commit: """ - from pytorch_lightning.callbacks.early_stopping import EarlyStopping - from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - keys_mapping = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), @@ -123,7 +72,7 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be overwritten by the loop's state if it was also saved. Version: 1.6.0 @@ -142,7 +91,7 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be overwritten by the loop's state if it was also saved. Version: 1.6.0 @@ -152,6 +101,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + return checkpoint _FIT_LOOP_INITIAL_STATE_1_6_0 = { diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py new file mode 100644 index 0000000000000..a7e8443c78467 --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import pytorch_lightning as pl +from distutils.version import LooseVersion +from types import ModuleType, TracebackType +from typing import Optional, Type, Dict, Any + +from pytorch_lightning.utilities.migration.migrations import _migrate_model_checkpoint_early_stopping, \ + _migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking, migration_index + +_CHECKPOINT = Dict[str, Any] + + +class pl_legacy_patch: + """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for + unpickling old checkpoints. The following patches apply. + + 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, + but still needs to be available for import for legacy checkpoints. + + Example: + + with pl_legacy_patch(): + torch.load("path/to/legacy/checkpoint.ckpt") + """ + + def __enter__(self) -> "pl_legacy_patch": + # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` + legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") + sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module + + # `_gpus_arg_default` used to be imported from these locations + legacy_argparse_module._gpus_arg_default = lambda x: x + pl.utilities.argparse._gpus_arg_default = lambda x: x + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + if hasattr(pl.utilities.argparse, "_gpus_arg_default"): + delattr(pl.utilities.argparse, "_gpus_arg_default") + del sys.modules["pytorch_lightning.utilities.argparse_utils"] + + +def get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies all migrations below in order.""" + index = migration_index() + for migration_version, migration_functions in index.items(): + if not should_upgrade(checkpoint, migration_version): + continue + for migration_function in migration_functions: + checkpoint = migration_function(checkpoint) + + set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint From 756b2e7ee153c39957c70bf5eb4d986eabb064d2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:05:21 +0200 Subject: [PATCH 06/36] protected --- .../utilities/migration/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index a7e8443c78467..15769a3efa267 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -18,8 +18,7 @@ from types import ModuleType, TracebackType from typing import Optional, Type, Dict, Any -from pytorch_lightning.utilities.migration.migrations import _migrate_model_checkpoint_early_stopping, \ - _migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking, migration_index +from pytorch_lightning.utilities.migration.migrations import migration_index _CHECKPOINT = Dict[str, Any] @@ -60,31 +59,31 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] -def get_version(checkpoint: _CHECKPOINT) -> str: +def _get_version(checkpoint: _CHECKPOINT) -> str: """Get the version of a Lightning checkpoint.""" return checkpoint["pytorch-lightning_version"] -def set_version(checkpoint: _CHECKPOINT, version: str) -> None: +def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: """Set the version of a Lightning checkpoint.""" checkpoint["pytorch-lightning_version"] = version -def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: +def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies all migrations below in order.""" + """Applies Lightning version migrations to a checkpoint.""" index = migration_index() for migration_version, migration_functions in index.items(): - if not should_upgrade(checkpoint, migration_version): + if not _should_upgrade(checkpoint, migration_version): continue for migration_function in migration_functions: checkpoint = migration_function(checkpoint) - set_version(checkpoint, pl.__version__) + _set_version(checkpoint, pl.__version__) # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert # checkpoints permanently From d64b5edd467049bb5e6bd29b8fbd6e204e8adba9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:07:07 +0000 Subject: [PATCH 07/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration/__init__.py | 2 +- src/pytorch_lightning/utilities/migration/migrations.py | 5 ++--- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/__init__.py b/src/pytorch_lightning/utilities/migration/__init__.py index 8e0f79f6904cb..199541c19034f 100644 --- a/src/pytorch_lightning/utilities/migration/__init__.py +++ b/src/pytorch_lightning/utilities/migration/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401 +from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index a148f79ca5d7c..e9ff88cc36f01 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -27,10 +27,9 @@ cp model.ckpt model.ckpt.backup python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt - """ -from typing import Any, Dict, Callable, List +from typing import Any, Callable, Dict, List from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -42,7 +41,7 @@ def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], - "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking] + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], } diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 15769a3efa267..fc580471d463a 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -13,11 +13,11 @@ # limitations under the License. import sys -import pytorch_lightning as pl from distutils.version import LooseVersion from types import ModuleType, TracebackType -from typing import Optional, Type, Dict, Any +from typing import Any, Dict, Optional, Type +import pytorch_lightning as pl from pytorch_lightning.utilities.migration.migrations import migration_index _CHECKPOINT = Dict[str, Any] From 954b4d3c9cc1308ddff8b8a51713f47d04800e13 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:08:22 +0200 Subject: [PATCH 08/36] typo --- src/pytorch_lightning/utilities/migration/migrations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index a148f79ca5d7c..a3dc604a8294d 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -19,7 +19,7 @@ How to add a new migration? 1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include - version informatin as well as the specific commit or PR where the breaking change happened. + version information as well as the specific commit or PR where the breaking change happened. 2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the change happened. Any checkpoint with a version greater or equal to that version will apply the given function. Multiple migrations per version get executed in the provided list order. From 0f988d1746be1672e3c94f77e6329cd830a3b40c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:50:04 +0200 Subject: [PATCH 09/36] tests --- .../utilities/migration/utils.py | 32 ++++---- .../utilities/migration/__init__.py | 0 .../utilities/migration/test_utils.py | 75 +++++++++++++++++++ .../utilities/test_upgrade_checkpoint.py | 19 +++-- 4 files changed, 102 insertions(+), 24 deletions(-) create mode 100644 tests/tests_pytorch/utilities/migration/__init__.py create mode 100644 tests/tests_pytorch/utilities/migration/test_utils.py diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index fc580471d463a..a52aca9944e2c 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -23,6 +23,22 @@ _CHECKPOINT = Dict[str, Any] +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies Lightning version migrations to a checkpoint dictionary.""" + index = migration_index() + for migration_version, migration_functions in index.items(): + if not _should_upgrade(checkpoint, migration_version): + continue + for migration_function in migration_functions: + checkpoint = migration_function(checkpoint) + + _set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint + + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -72,19 +88,3 @@ def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) - - -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies Lightning version migrations to a checkpoint.""" - index = migration_index() - for migration_version, migration_functions in index.items(): - if not _should_upgrade(checkpoint, migration_version): - continue - for migration_function in migration_functions: - checkpoint = migration_function(checkpoint) - - _set_version(checkpoint, pl.__version__) - - # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert - # checkpoints permanently - return checkpoint diff --git a/tests/tests_pytorch/utilities/migration/__init__.py b/tests/tests_pytorch/utilities/migration/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py new file mode 100644 index 0000000000000..227e6e0b590cd --- /dev/null +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from pytorch_lightning.utilities.migration import migrate_checkpoint + + +def test_migrate_checkpoint(monkeypatch): + """Test that the correct migration function gets executed given the current version of the checkpoint.""" + # A checkpoint that is older than any migration point in the index + old_checkpoint = { + "pytorch-lightning_version": "0.0.0", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["one", "two", "three", "four"] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + # A checkpoint that is newer, but not the newest + old_checkpoint = { + "pytorch-lightning_version": "1.0.3", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["four"] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + # A checkpoint newer than any migration point in the index + old_checkpoint = { + "pytorch-lightning_version": "2.0", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == [] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + +def _run_simple_migration(monkeypatch, old_checkpoint): + call_order = [] + + def dummy_upgrade(tag): + def upgrade(ckpt): + call_order.append(tag) + return ckpt + + return upgrade + + index = { + "0.0.1": [dummy_upgrade("one")], + "0.0.2": [dummy_upgrade("two"), dummy_upgrade("three")], + "1.2.3": [dummy_upgrade("four")], + } + monkeypatch.setattr(pl.utilities.migration.utils, "migration_index", lambda: index) + new_checkpoint = migrate_checkpoint(old_checkpoint) + return new_checkpoint, call_order diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index c01fcf7eb249d..2429067ed0557 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY + import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version +from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration.utils import _set_version, _get_version @pytest.mark.parametrize( @@ -23,25 +26,25 @@ [ ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, + {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, "loops": ANY}, ), ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - set_version(old_checkpoint, "0.9.0") - set_version(new_checkpoint, pl.__version__) + _set_version(old_checkpoint, "0.9.0") + _set_version(new_checkpoint, pl.__version__) updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint - assert get_version(updated_checkpoint) == pl.__version__ + assert _get_version(updated_checkpoint) == pl.__version__ From 3a6aa87ab48f246866cf6ece5502acfa50ec4a6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:52:57 +0000 Subject: [PATCH 10/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../utilities/migration/test_utils.py | 30 ++++--------------- .../utilities/test_upgrade_checkpoint.py | 9 ++++-- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 227e6e0b590cd..2c28ed5755b2a 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -19,40 +19,22 @@ def test_migrate_checkpoint(monkeypatch): """Test that the correct migration function gets executed given the current version of the checkpoint.""" # A checkpoint that is older than any migration point in the index - old_checkpoint = { - "pytorch-lightning_version": "0.0.0", - "content": 123 - } + old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["one", "two", "three", "four"] - assert new_checkpoint == { - "pytorch-lightning_version": pl.__version__, - "content": 123 - } + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} # A checkpoint that is newer, but not the newest - old_checkpoint = { - "pytorch-lightning_version": "1.0.3", - "content": 123 - } + old_checkpoint = {"pytorch-lightning_version": "1.0.3", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["four"] - assert new_checkpoint == { - "pytorch-lightning_version": pl.__version__, - "content": 123 - } + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} # A checkpoint newer than any migration point in the index - old_checkpoint = { - "pytorch-lightning_version": "2.0", - "content": 123 - } + old_checkpoint = {"pytorch-lightning_version": "2.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == [] - assert new_checkpoint == { - "pytorch-lightning_version": pl.__version__, - "content": 123 - } + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} def _run_simple_migration(monkeypatch, old_checkpoint): diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2429067ed0557..5beaa9b606721 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _set_version, _get_version +from pytorch_lightning.utilities.migration.utils import _get_version, _set_version @pytest.mark.parametrize( @@ -38,7 +38,12 @@ ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, "loops": ANY}, + { + "epoch": 1, + "global_step": 23, + "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, + "loops": ANY, + }, ), ], ) From ac0d2fc46a92c28a00478698762f8298f28c1424 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:56:22 +0200 Subject: [PATCH 11/36] prune --- .../utilities/migration/migrations.py | 78 ------------------- .../utilities/migration/utils.py | 3 - .../utilities/test_upgrade_checkpoint.py | 10 +-- 3 files changed, 4 insertions(+), 87 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index 199337ec540be..6769f8ce0006b 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -41,7 +41,6 @@ def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], - "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], } @@ -68,80 +67,3 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP checkpoint["callbacks"][callback_type][callback_key] = value del checkpoint[key] return checkpoint - - -def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be - overwritten by the loop's state if it was also saved. - - Version: 1.6.0 - Commit: - """ - global_step = checkpoint["global_step"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) - # for automatic optimization - optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] - optim_progress["optimizer"]["step"]["total"]["completed"] = global_step - # for manual optimization - optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] - optim_step_progress["total"]["completed"] = global_step - return checkpoint - - -def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be - overwritten by the loop's state if it was also saved. - - Version: 1.6.0 - Commit: - """ - epoch = checkpoint["epoch"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) - checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch - return checkpoint - - -_FIT_LOOP_INITIAL_STATE_1_6_0 = { - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.batch_loop.manual_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { - "optimizer": { - "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "zero_grad": { - "current": {"completed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "ready": 0, "started": 0}, - }, - }, - "optimizer_position": 0, - }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "epoch_loop.state_dict": {"_batches_that_stepped": 0}, - "epoch_loop.val_loop.dataloader_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.state_dict": {}, - "epoch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "state_dict": {}, -} diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index a52aca9944e2c..adda962af08dd 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -33,9 +33,6 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: checkpoint = migration_function(checkpoint) _set_version(checkpoint, pl.__version__) - - # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert - # checkpoints permanently return checkpoint diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2429067ed0557..804ca2a04ee70 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import ANY - import pytest import pytorch_lightning as pl @@ -26,19 +24,19 @@ [ ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, ), ], ) From f7f12502728b96b3b295345d03b6e165ce1a756b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:59:40 +0200 Subject: [PATCH 12/36] reset --- .../connectors/checkpoint_connector.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 01d5a6a7e14ed..50480e769b4ce 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -32,8 +32,9 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch +from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -85,7 +86,13 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - loaded_checkpoint = migrate_checkpoint(loaded_checkpoint) + if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): + raise ValueError( + "The checkpoint you're attempting to load follows an" + " outdated schema. You can upgrade to the current schema by running" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) return loaded_checkpoint def _set_ckpt_path( @@ -341,6 +348,23 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop + pl_module = self.trainer.lightning_module + assert pl_module is not None + + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + batch_loop = fit_loop.epoch_loop.batch_loop + if pl_module.automatic_optimization: + batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ + "global_step" + ] + else: + batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] + + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] + assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: From 24a5d60931804a4acf2f012067f0cf9a54e83600 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 14:22:54 +0200 Subject: [PATCH 13/36] wip --- src/pytorch_lightning/core/saving.py | 2 +- .../connectors/checkpoint_connector.py | 14 ++-------- .../utilities/migration/utils.py | 27 ++++++++++++++++--- .../utilities/migration/test_utils.py | 2 +- .../utilities/test_upgrade_checkpoint.py | 2 +- 5 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 42bf676481abb..015368b2c9de8 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -157,7 +157,7 @@ def _load_from_checkpoint( checkpoint = pl_load(checkpoint_path, map_location=map_location) # convert legacy checkpoints to the new format - checkpoint = migrate_checkpoint(checkpoint) + checkpoint, _ = migrate_checkpoint(checkpoint) if hparams_file is not None: extension = str(hparams_file).split(".")[-1] diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 50480e769b4ce..1b6bf4689afb6 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -33,8 +33,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -81,19 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: return rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") - self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - - def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) - return loaded_checkpoint + self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint) def _set_ckpt_path( self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index adda962af08dd..3ef507c6f20e5 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -11,29 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging import sys from distutils.version import LooseVersion from types import ModuleType, TracebackType -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, Tuple, List import pytorch_lightning as pl from pytorch_lightning.utilities.migration.migrations import migration_index +_log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: """Applies Lightning version migrations to a checkpoint dictionary.""" index = migration_index() + applied_migrations = {} for migration_version, migration_functions in index.items(): if not _should_upgrade(checkpoint, migration_version): continue for migration_function in migration_functions: checkpoint = migration_function(checkpoint) + applied_migrations[migration_version] = [fn.__name__ for fn in migration_functions] + _set_version(checkpoint, pl.__version__) - return checkpoint + return checkpoint, applied_migrations class pl_legacy_patch: @@ -72,6 +76,21 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user.""" + old_version = _get_version(checkpoint) + checkpoint, migrations = migrate_checkpoint(checkpoint) + new_version = _get_version(checkpoint) + if migrations: + _log.info( + f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." + " To apply the upgrade to your files permanently, run" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) + return checkpoint + + def _get_version(checkpoint: _CHECKPOINT) -> str: """Get the version of a Lightning checkpoint.""" return checkpoint["pytorch-lightning_version"] diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 2c28ed5755b2a..304e7fffdc682 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -53,5 +53,5 @@ def upgrade(ckpt): "1.2.3": [dummy_upgrade("four")], } monkeypatch.setattr(pl.utilities.migration.utils, "migration_index", lambda: index) - new_checkpoint = migrate_checkpoint(old_checkpoint) + new_checkpoint, _ = migrate_checkpoint(old_checkpoint) return new_checkpoint, call_order diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 076fe2d630232..cb1f86e1e3135 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -43,6 +43,6 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): _set_version(old_checkpoint, "0.9.0") _set_version(new_checkpoint, pl.__version__) - updated_checkpoint = migrate_checkpoint(old_checkpoint) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ From c349ec7e495266821daea54e5d89f8ed2862f708 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 14:33:21 +0200 Subject: [PATCH 14/36] messaging --- src/pytorch_lightning/core/saving.py | 5 +-- .../connectors/checkpoint_connector.py | 2 +- .../utilities/migration/utils.py | 34 ++++++++++++++----- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 015368b2c9de8..36813f32ff7eb 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -31,7 +31,8 @@ from lightning_lite.utilities.cloud_io import load as pl_load from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -157,7 +158,7 @@ def _load_from_checkpoint( checkpoint = pl_load(checkpoint_path, map_location=map_location) # convert legacy checkpoints to the new format - checkpoint, _ = migrate_checkpoint(checkpoint) + checkpoint = _pl_migrate_checkpoint(checkpoint) if hparams_file is not None: extension = str(hparams_file).split(".")[-1] diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1b6bf4689afb6..cb7e80253e502 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -83,7 +83,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint) + self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint) def _set_ckpt_path( self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 3ef507c6f20e5..8488e3ad371f3 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -14,11 +14,13 @@ import logging import sys from distutils.version import LooseVersion +from pathlib import Path from types import ModuleType, TracebackType from typing import Any, Dict, Optional, Type, Tuple, List import pytorch_lightning as pl from pytorch_lightning.utilities.migration.migrations import migration_index +from lightning_lite.utilities.types import _PATH _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] @@ -76,18 +78,32 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] -def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user.""" +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: + """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. + + This function is used by the Lightning Trainer when resuming from a checkpoint. + """ old_version = _get_version(checkpoint) checkpoint, migrations = migrate_checkpoint(checkpoint) new_version = _get_version(checkpoint) - if migrations: - _log.info( - f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." - " To apply the upgrade to your files permanently, run" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) + if not migrations: + # the checkpoint was already a new one, no migrations were needed + return checkpoint + + # include the full command including the checkpoint path to the checkpoint in the error message, + # so user can copy-paste if they want + path_hint = Path(checkpoint_path if checkpoint_path is not None else "model.ckpt") + path_hint = path_hint.relative_to(Path.cwd()) + + msg = ( + f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." + " To apply the upgrade to your files permanently, run" + f" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file {str(path_hint)}`" + ) + if checkpoint_path: + msg += f" where `{path_hint}` is your checkpoint file." + + _log.info(msg) return checkpoint From 0e6f867b989fc9fa2542148a56631d0eb3bec67e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Oct 2022 12:35:19 +0000 Subject: [PATCH 15/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 8488e3ad371f3..4c8b0e9898b0c 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -16,11 +16,11 @@ from distutils.version import LooseVersion from pathlib import Path from types import ModuleType, TracebackType -from typing import Any, Dict, Optional, Type, Tuple, List +from typing import Any, Dict, List, Optional, Tuple, Type import pytorch_lightning as pl -from pytorch_lightning.utilities.migration.migrations import migration_index from lightning_lite.utilities.types import _PATH +from pytorch_lightning.utilities.migration.migrations import migration_index _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] From cc333c66e0ae421bb8d74709b025fdb493d70ffd Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 14:38:59 +0200 Subject: [PATCH 16/36] checkpint path --- src/pytorch_lightning/core/saving.py | 2 +- .../trainer/connectors/checkpoint_connector.py | 2 +- .../utilities/migration/utils.py | 16 +++++----------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 36813f32ff7eb..4087bcdf3a935 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -158,7 +158,7 @@ def _load_from_checkpoint( checkpoint = pl_load(checkpoint_path, map_location=map_location) # convert legacy checkpoints to the new format - checkpoint = _pl_migrate_checkpoint(checkpoint) + checkpoint = _pl_migrate_checkpoint(checkpoint, checkpoint_path) if hparams_file is not None: extension = str(hparams_file).split(".")[-1] diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index cb7e80253e502..209940597ba58 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -83,7 +83,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint) + self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _set_ckpt_path( self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 8488e3ad371f3..41c8216a64067 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -78,7 +78,7 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] -def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH) -> _CHECKPOINT: """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. This function is used by the Lightning Trainer when resuming from a checkpoint. @@ -90,20 +90,14 @@ def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_P # the checkpoint was already a new one, no migrations were needed return checkpoint - # include the full command including the checkpoint path to the checkpoint in the error message, - # so user can copy-paste if they want - path_hint = Path(checkpoint_path if checkpoint_path is not None else "model.ckpt") - path_hint = path_hint.relative_to(Path.cwd()) - - msg = ( + # include the full upgrade command, including the path to the loaded file in the error message, + # so user can copy-paste and run if they want + path_hint = Path(checkpoint_path).relative_to(Path.cwd()) + _log.info( f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." " To apply the upgrade to your files permanently, run" f" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file {str(path_hint)}`" ) - if checkpoint_path: - msg += f" where `{path_hint}` is your checkpoint file." - - _log.info(msg) return checkpoint From 373498116ab56b89c425f54619722cf450c6fb7e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 15:00:21 +0200 Subject: [PATCH 17/36] type --- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 8cd4fec9e2dee..4f4c064f469de 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os import sys from distutils.version import LooseVersion -from pathlib import Path from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type @@ -92,7 +92,7 @@ def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH) -> _ # include the full upgrade command, including the path to the loaded file in the error message, # so user can copy-paste and run if they want - path_hint = Path(checkpoint_path).relative_to(Path.cwd()) + path_hint = os.path.relpath(checkpoint_path, os.getcwd()) _log.info( f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." " To apply the upgrade to your files permanently, run" From da53672b458dcb239c2c49d93827c8eb70b6552d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 16:27:17 +0200 Subject: [PATCH 18/36] warn --- .../utilities/migration/utils.py | 19 ++++++- .../utilities/migration/test_utils.py | 49 +++++++++++++++++-- .../tests_pytorch/utilities/test_migration.py | 36 -------------- 3 files changed, 63 insertions(+), 41 deletions(-) delete mode 100644 tests/tests_pytorch/utilities/test_migration.py diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 4f4c064f469de..1829983fd34b1 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -18,8 +18,11 @@ from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type +from lightning_utilities.core.rank_zero import rank_zero_warn + import pytorch_lightning as pl from lightning_lite.utilities.types import _PATH +from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.migration.migrations import migration_index _log = logging.getLogger(__name__) @@ -27,7 +30,21 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: - """Applies Lightning version migrations to a checkpoint dictionary.""" + """Applies Lightning version migrations to a checkpoint dictionary. + + Note: + The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large + checkpoints and objects that do not support being deep-copied. + """ + ckpt_version = _get_version(checkpoint) + if LooseVersion(ckpt_version) > LooseVersion(pl.__version__): + rank_zero_warn( + f"The loaded checkpoint was produced with Lightning v{ckpt_version}, which is newer than your current" + f" Lightning version: v{pl.__version__}", + category=PossibleUserWarning, + ) + return checkpoint, {} + index = migration_index() applied_migrations = {} for migration_version, migration_functions in index.items(): diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 304e7fffdc682..12f97bfae9451 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -11,9 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + +import pytest import pytorch_lightning as pl -from pytorch_lightning.utilities.migration import migrate_checkpoint +from lightning_lite.utilities.warnings import PossibleUserWarning +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch +import sys + +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint + + +def test_patch_legacy_argparse_utils(): + with pl_legacy_patch(): + from pytorch_lightning.utilities import argparse_utils + + assert callable(argparse_utils._gpus_arg_default) + assert "pytorch_lightning.utilities.argparse_utils" in sys.modules + + assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules + + +def test_patch_legacy_gpus_arg_default(): + with pl_legacy_patch(): + from pytorch_lightning.utilities.argparse import _gpus_arg_default + + assert callable(_gpus_arg_default) + assert not hasattr(pl.utilities.argparse, "_gpus_arg_default") + assert not hasattr(pl.utilities.argparse, "_gpus_arg_default") def test_migrate_checkpoint(monkeypatch): @@ -22,19 +48,34 @@ def test_migrate_checkpoint(monkeypatch): old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["one", "two", "three", "four"] - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} # A checkpoint that is newer, but not the newest old_checkpoint = {"pytorch-lightning_version": "1.0.3", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["four"] - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} # A checkpoint newer than any migration point in the index old_checkpoint = {"pytorch-lightning_version": "2.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == [] - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + + +def test_migrate_checkpoint_too_new(): + """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer + version of Lightning than installed.""" + super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} + with pytest.warns( + PossibleUserWarning, + match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" + ): + new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy()) + + # no version modification + assert not migrations + assert new_checkpoint == super_new_checkpoint def _run_simple_migration(monkeypatch, old_checkpoint): diff --git a/tests/tests_pytorch/utilities/test_migration.py b/tests/tests_pytorch/utilities/test_migration.py deleted file mode 100644 index ee94ee690e798..0000000000000 --- a/tests/tests_pytorch/utilities/test_migration.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -import pytorch_lightning -from pytorch_lightning.utilities.migration import pl_legacy_patch - - -def test_patch_legacy_argparse_utils(): - with pl_legacy_patch(): - from pytorch_lightning.utilities import argparse_utils - - assert callable(argparse_utils._gpus_arg_default) - assert "pytorch_lightning.utilities.argparse_utils" in sys.modules - - assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules - - -def test_patch_legacy_gpus_arg_default(): - with pl_legacy_patch(): - from pytorch_lightning.utilities.argparse import _gpus_arg_default - - assert callable(_gpus_arg_default) - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") From f745259edc1226b90fffd124c46e2e2eff3cf1b3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 16:37:54 +0200 Subject: [PATCH 19/36] tests --- .../utilities/migration/test_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 12f97bfae9451..244d453ff1782 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from copy import deepcopy import pytest @@ -44,6 +45,7 @@ def test_patch_legacy_gpus_arg_default(): def test_migrate_checkpoint(monkeypatch): """Test that the correct migration function gets executed given the current version of the checkpoint.""" + # A checkpoint that is older than any migration point in the index old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) @@ -96,3 +98,19 @@ def upgrade(ckpt): monkeypatch.setattr(pl.utilities.migration.utils, "migration_index", lambda: index) new_checkpoint, _ = migrate_checkpoint(old_checkpoint) return new_checkpoint, call_order + + +def test_migrate_checkpoint_for_pl(caplog): + """Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent.""" + + # simulate a very recent checkpoint, no migrations needed + loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} + new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + + # simulate an old checkpoint that needed an upgrade + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): + new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, 'callbacks': {}, "content": 123} + assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text From 372dfaf9291ae2fe17831b23423fd6182bcace46 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 16:38:12 +0200 Subject: [PATCH 20/36] tests --- tests/tests_pytorch/utilities/migration/test_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 244d453ff1782..47664ffdf2994 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys from copy import deepcopy import pytest @@ -19,8 +20,6 @@ import pytorch_lightning as pl from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch -import sys - from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint @@ -70,8 +69,7 @@ def test_migrate_checkpoint_too_new(): version of Lightning than installed.""" super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} with pytest.warns( - PossibleUserWarning, - match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" + PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" ): new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy()) @@ -112,5 +110,5 @@ def test_migrate_checkpoint_for_pl(caplog): loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, 'callbacks': {}, "content": 123} + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "callbacks": {}, "content": 123} assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text From ec9b0f8b68bfe30f1f0503ebb51e4032efee19ab Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 16:55:05 +0200 Subject: [PATCH 21/36] changelog --- src/pytorch_lightning/CHANGELOG.md | 4 +++- tests/tests_pytorch/utilities/migration/test_utils.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5f687de11a8da..58893eb445be9 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,10 +8,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -### Changed +### Changed +- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) ## [1.8.0] - 2022-MM-DD diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 47664ffdf2994..46acc5b3e975a 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import logging import sys -from copy import deepcopy import pytest From 978b0f0f0f526805f4d8107397d16e4cfceac605 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 17:54:47 +0200 Subject: [PATCH 22/36] add commit info --- src/pytorch_lightning/utilities/migration/migrations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index 6769f8ce0006b..43bb77a8fe2c4 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -48,7 +48,7 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP """The checkpoint and early stopping keys were renamed. Version: 0.10.0 - Commit: + Commit: a5d1176 """ keys_mapping = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), From 815f7f6ddbbd8afe75eb7a4c4b3cc454e6f1b7d5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 18:33:10 +0200 Subject: [PATCH 23/36] mark index as protected --- src/pytorch_lightning/utilities/migration/migrations.py | 4 ++-- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- tests/tests_pytorch/utilities/migration/test_utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index 43bb77a8fe2c4..10694501262fa 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -20,7 +20,7 @@ 1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include version information as well as the specific commit or PR where the breaking change happened. -2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the +2. Add the function to the `_migration_index()` below. The key in the index is the version of Lightning in which the change happened. Any checkpoint with a version greater or equal to that version will apply the given function. Multiple migrations per version get executed in the provided list order. 3. You can test the migration on a checkpoint (backup your files first) by running: @@ -37,7 +37,7 @@ _CHECKPOINT = Dict[str, Any] -def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: +def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 1829983fd34b1..f45e23a3d8bac 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -23,7 +23,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.types import _PATH from lightning_lite.utilities.warnings import PossibleUserWarning -from pytorch_lightning.utilities.migration.migrations import migration_index +from pytorch_lightning.utilities.migration.migrations import _migration_index _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] @@ -45,7 +45,7 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, ) return checkpoint, {} - index = migration_index() + index = _migration_index() applied_migrations = {} for migration_version, migration_functions in index.items(): if not _should_upgrade(checkpoint, migration_version): diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 46acc5b3e975a..18992034b7f27 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -92,7 +92,7 @@ def upgrade(ckpt): "0.0.2": [dummy_upgrade("two"), dummy_upgrade("three")], "1.2.3": [dummy_upgrade("four")], } - monkeypatch.setattr(pl.utilities.migration.utils, "migration_index", lambda: index) + monkeypatch.setattr(pl.utilities.migration.utils, "_migration_index", lambda: index) new_checkpoint, _ = migrate_checkpoint(old_checkpoint) return new_checkpoint, call_order From d8900691e34accce77c56344391be97ced473663 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 18:37:14 +0200 Subject: [PATCH 24/36] test fix --- .../utilities/migration/test_utils.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 18992034b7f27..08c36348c2400 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -57,26 +57,12 @@ def test_migrate_checkpoint(monkeypatch): assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} # A checkpoint newer than any migration point in the index - old_checkpoint = {"pytorch-lightning_version": "2.0", "content": 123} + old_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == [] assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} -def test_migrate_checkpoint_too_new(): - """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer - version of Lightning than installed.""" - super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} - with pytest.warns( - PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" - ): - new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy()) - - # no version modification - assert not migrations - assert new_checkpoint == super_new_checkpoint - - def _run_simple_migration(monkeypatch, old_checkpoint): call_order = [] @@ -97,6 +83,20 @@ def upgrade(ckpt): return new_checkpoint, call_order +def test_migrate_checkpoint_too_new(): + """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer + version of Lightning than installed.""" + super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} + with pytest.warns( + PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" + ): + new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy()) + + # no version modification + assert not migrations + assert new_checkpoint == super_new_checkpoint + + def test_migrate_checkpoint_for_pl(caplog): """Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent.""" From 62ae6111122510586386e36afb028d882576fa0c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 18:40:36 +0200 Subject: [PATCH 25/36] extend test to check inplace modification --- tests/tests_pytorch/utilities/test_upgrade_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index cb1f86e1e3135..fc8bf9edf7cdd 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -44,5 +44,5 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): _set_version(old_checkpoint, "0.9.0") _set_version(new_checkpoint, pl.__version__) updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) - assert updated_checkpoint == new_checkpoint + assert updated_checkpoint == old_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ From 9ff93506efe304ccd1a02456ce244d822913ad3d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Oct 2022 14:08:13 +0200 Subject: [PATCH 26/36] set legacy version if upgrade happened [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci x --- .../utilities/migration/utils.py | 7 +++++++ .../utilities/migration/test_utils.py | 19 ++++++++++++++++--- .../utilities/test_upgrade_checkpoint.py | 3 ++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index f45e23a3d8bac..7d8c1c7038575 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -55,6 +55,8 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, applied_migrations[migration_version] = [fn.__name__ for fn in migration_functions] + if ckpt_version != pl.__version__: + _set_legacy_version(checkpoint, _get_version(checkpoint)) _set_version(checkpoint, pl.__version__) return checkpoint, applied_migrations @@ -128,6 +130,11 @@ def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: checkpoint["pytorch-lightning_version"] = version +def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the legacy version of a Lightning checkpoint.""" + checkpoint["legacy_pytorch-lightning_version"] = version + + def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 08c36348c2400..83fe240551540 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -48,13 +48,21 @@ def test_migrate_checkpoint(monkeypatch): old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["one", "two", "three", "four"] - assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert ( + new_checkpoint + == old_checkpoint + == {"legacy_pytorch-lightning_version": "0.0.0", "pytorch-lightning_version": pl.__version__, "content": 123} + ) # A checkpoint that is newer, but not the newest old_checkpoint = {"pytorch-lightning_version": "1.0.3", "content": 123} new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) assert call_order == ["four"] - assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert ( + new_checkpoint + == old_checkpoint + == {"legacy_pytorch-lightning_version": "1.0.3", "pytorch-lightning_version": pl.__version__, "content": 123} + ) # A checkpoint newer than any migration point in the index old_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} @@ -109,5 +117,10 @@ def test_migrate_checkpoint_for_pl(caplog): loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "callbacks": {}, "content": 123} + assert new_checkpoint == { + "legacy_pytorch-lightning_version": "0.0.1", + "pytorch-lightning_version": pl.__version__, + "callbacks": {}, + "content": 123, + } assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index fc8bf9edf7cdd..c8866829581cb 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -16,7 +16,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _get_version, _set_version +from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version @pytest.mark.parametrize( @@ -42,6 +42,7 @@ ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): _set_version(old_checkpoint, "0.9.0") + _set_legacy_version(new_checkpoint, "0.9.0") _set_version(new_checkpoint, pl.__version__) updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == old_checkpoint == new_checkpoint From 7419ecac62a7befb4e1ef7274600f2a941dca567 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 1 Nov 2022 11:25:00 +0100 Subject: [PATCH 27/36] changelog --- src/pytorch_lightning/CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index bb115e1d4855a..5b4d659c9b973 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -5,13 +5,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.8.0] - 2022-11-01 +## [unreleased] - 2023-XX-XX + +### Added - Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) + ### Changed - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) + + +## [1.8.0] - 2022-11-01 + ### Added - Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040)) From 173c2da3217ea58a7f6a623f0b96e660a54ce54c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 1 Nov 2022 15:36:25 +0100 Subject: [PATCH 28/36] fix typing --- src/pytorch_lightning/core/saving.py | 5 ++++- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 383b07817a088..fefdb122fc5d2 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,6 +20,7 @@ from argparse import Namespace from copy import deepcopy from enum import Enum +from pathlib import Path from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn @@ -158,7 +159,9 @@ def _load_from_checkpoint( checkpoint = pl_load(checkpoint_path, map_location=map_location) # convert legacy checkpoints to the new format - checkpoint = _pl_migrate_checkpoint(checkpoint, checkpoint_path) + checkpoint = _pl_migrate_checkpoint( + checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None) + ) if hparams_file is not None: extension = str(hparams_file).split(".")[-1] diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 7d8c1c7038575..6d5005a36986d 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -97,7 +97,7 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] -def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH) -> _CHECKPOINT: +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. This function is used by the Lightning Trainer when resuming from a checkpoint. @@ -105,7 +105,7 @@ def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH) -> _ old_version = _get_version(checkpoint) checkpoint, migrations = migrate_checkpoint(checkpoint) new_version = _get_version(checkpoint) - if not migrations: + if not migrations or checkpoint_path is None: # the checkpoint was already a new one, no migrations were needed return checkpoint From e3debf979351a81cec05449619fb3dafdfd0712b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Nov 2022 14:44:07 +0000 Subject: [PATCH 29/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 653787f542d9f..c635389948d98 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,9 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- +- -- +- ### Deprecated From 47d56ca774f8ff6fbfb19de2d85845de6aab7bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Nov 2022 16:13:58 +0100 Subject: [PATCH 30/36] notebook --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 6d5634b794218..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d5634b7942180e6ba4a30bfbd74926d1c22f1eb From 4d2767ec200582926f84bdecfbd00730bfc42b42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Nov 2022 16:14:02 +0100 Subject: [PATCH 31/36] notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0ad097a6fec2b --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0ad097a6fec2b2c3f8ddd5d2263e178c41d614f5 From 15dbfdf109d3405509c141c0f0cda0304eb3732a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 14:39:09 +0100 Subject: [PATCH 32/36] fix mypy error --- src/lightning_lite/utilities/device_dtype_mixin.py | 4 ++-- src/pytorch_lightning/core/mixins/hparams_mixin.py | 4 ++-- src/pytorch_lightning/core/module.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning_lite/utilities/device_dtype_mixin.py b/src/lightning_lite/utilities/device_dtype_mixin.py index 1f5164f0cda37..7f3a0acfe2432 100644 --- a/src/lightning_lite/utilities/device_dtype_mixin.py +++ b/src/lightning_lite/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch.nn import Module @@ -20,7 +20,7 @@ class _DeviceDtypeModuleMixin(Module): - __jit_unused_properties__ = ["device", "dtype"] + __jit_unused_properties__: List[str] = ["device", "dtype"] def __init__(self) -> None: super().__init__() diff --git a/src/pytorch_lightning/core/mixins/hparams_mixin.py b/src/pytorch_lightning/core/mixins/hparams_mixin.py index b5b4a9b312312..2cdb785403a68 100644 --- a/src/pytorch_lightning/core/mixins/hparams_mixin.py +++ b/src/pytorch_lightning/core/mixins/hparams_mixin.py @@ -15,7 +15,7 @@ import inspect import types from argparse import Namespace -from typing import Any, MutableMapping, Optional, Sequence, Union +from typing import Any, List, MutableMapping, Optional, Sequence, Union from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters @@ -23,7 +23,7 @@ class HyperparametersMixin: - __jit_unused_properties__ = ["hparams", "hparams_initial"] + __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 7710e4d5c6b86..4314035c6f8a7 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -75,7 +75,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__ = ( + __jit_unused_properties__: List[str] = ( [ "example_input_array", "on_gpu", From 59e05b0fc89b94115eff999b048a7ae8a51c4262 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Nov 2022 09:44:16 -0400 Subject: [PATCH 33/36] Update src/pytorch_lightning/utilities/migration/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/utilities/migration/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 6d5005a36986d..af0c420f01074 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -56,7 +56,7 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, applied_migrations[migration_version] = [fn.__name__ for fn in migration_functions] if ckpt_version != pl.__version__: - _set_legacy_version(checkpoint, _get_version(checkpoint)) + _set_legacy_version(checkpoint, ckpt_version) _set_version(checkpoint, pl.__version__) return checkpoint, applied_migrations From 1743f640f3fec35eab3ca0a7fe1a66fce71aefc6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 14:46:01 +0100 Subject: [PATCH 34/36] rename --- .../utilities/migration/{migrations.py => migration.py} | 0 src/pytorch_lightning/utilities/migration/utils.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/pytorch_lightning/utilities/migration/{migrations.py => migration.py} (100%) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migration.py similarity index 100% rename from src/pytorch_lightning/utilities/migration/migrations.py rename to src/pytorch_lightning/utilities/migration/migration.py diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 6d5005a36986d..e6b80418d7db4 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -23,7 +23,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.types import _PATH from lightning_lite.utilities.warnings import PossibleUserWarning -from pytorch_lightning.utilities.migration.migrations import _migration_index +from pytorch_lightning.utilities.migration.migration import _migration_index _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] From 40353c30688e6cf83b561a28a8b813387bde3c37 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 14:56:12 +0100 Subject: [PATCH 35/36] add test for legacy migration twice --- .../utilities/migration/utils.py | 4 ++-- .../utilities/migration/test_utils.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index f168c825bb854..209262057b990 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -131,8 +131,8 @@ def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: - """Set the legacy version of a Lightning checkpoint.""" - checkpoint["legacy_pytorch-lightning_version"] = version + """Set the legacy version of a Lightning checkpoint if a legacy version is not already set.""" + checkpoint.setdefault("legacy_pytorch-lightning_version", version) def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 83fe240551540..d662cf5e89833 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -124,3 +124,20 @@ def test_migrate_checkpoint_for_pl(caplog): "content": 123, } assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text + + +def test_migrate_checkpoint_legacy_version(monkeypatch): + """Test that the legacy version gets set and does not change if migration is applied multiple times.""" + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + + # pretend the current pl version is 2.0 + monkeypatch.setattr(pl, "__version__", "2.0.0") + new_checkpoint, _ = migrate_checkpoint(loaded_checkpoint) + assert new_checkpoint["pytorch-lightning_version"] == "2.0.0" + assert new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1" + + # pretend the current pl version is even newer, we are migrating a second time + monkeypatch.setattr(pl, "__version__", "3.0.0") + new_new_checkpoint, _ = migrate_checkpoint(new_checkpoint) + assert new_new_checkpoint["pytorch-lightning_version"] == "3.0.0" + assert new_new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1" # remains the same From a73e89b2a5a86d3ebdf1659c0fc421009f6d1f96 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 15:25:19 +0100 Subject: [PATCH 36/36] clarify the instructions are for Lightning developers --- src/pytorch_lightning/utilities/migration/migration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index 10694501262fa..3431c01709b89 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -16,7 +16,7 @@ When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`. -How to add a new migration? +For the Lightning developer: How to add a new migration? 1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include version information as well as the specific commit or PR where the breaking change happened.