diff --git a/src/lightning_lite/utilities/device_dtype_mixin.py b/src/lightning_lite/utilities/device_dtype_mixin.py index 1f5164f0cda37..7f3a0acfe2432 100644 --- a/src/lightning_lite/utilities/device_dtype_mixin.py +++ b/src/lightning_lite/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch.nn import Module @@ -20,7 +20,7 @@ class _DeviceDtypeModuleMixin(Module): - __jit_unused_properties__ = ["device", "dtype"] + __jit_unused_properties__: List[str] = ["device", "dtype"] def __init__(self) -> None: super().__init__() diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 168f03d998057..9c03c4cf6715c 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) - @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) - @@ -57,7 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.8.0] - 2022-11-01 - ### Added - Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040)) diff --git a/src/pytorch_lightning/core/mixins/hparams_mixin.py b/src/pytorch_lightning/core/mixins/hparams_mixin.py index b5b4a9b312312..2cdb785403a68 100644 --- a/src/pytorch_lightning/core/mixins/hparams_mixin.py +++ b/src/pytorch_lightning/core/mixins/hparams_mixin.py @@ -15,7 +15,7 @@ import inspect import types from argparse import Namespace -from typing import Any, MutableMapping, Optional, Sequence, Union +from typing import Any, List, MutableMapping, Optional, Sequence, Union from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters @@ -23,7 +23,7 @@ class HyperparametersMixin: - __jit_unused_properties__ = ["hparams", "hparams_initial"] + __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 7710e4d5c6b86..4314035c6f8a7 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -75,7 +75,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__ = ( + __jit_unused_properties__: List[str] = ( [ "example_input_array", "on_gpu", diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index fa464698314a1..fefdb122fc5d2 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,6 +20,7 @@ from argparse import Namespace from copy import deepcopy from enum import Enum +from pathlib import Path from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn @@ -32,6 +33,7 @@ from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -156,6 +158,11 @@ def _load_from_checkpoint( with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) + # convert legacy checkpoints to the new format + checkpoint = _pl_migrate_checkpoint( + checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None) + ) + if hparams_file is not None: extension = str(hparams_file).split(".")[-1] if extension.lower() == "csv": @@ -168,6 +175,7 @@ def _load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) # override the hparams with values that were passed in @@ -198,6 +206,7 @@ def _load_state( if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: if issubclass(cls, pl.LightningModule): + # TODO: make this a migration: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 65d0b365d8c2d..2cb285167aeb0 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -33,8 +33,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -81,19 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: return rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") - self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - - def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) - return loaded_checkpoint + self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _set_ckpt_path( self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py deleted file mode 100644 index ed71f25a571f7..0000000000000 --- a/src/pytorch_lightning/utilities/migration.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import sys -import threading -from types import ModuleType, TracebackType - -import pytorch_lightning.utilities.argparse - -# Create a global lock to ensure no race condition with deleting sys modules -_lock = threading.Lock() - - -class pl_legacy_patch: - """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for - unpickling old checkpoints. The following patches apply. - - 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898 - 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, - but still needs to be available for import for legacy checkpoints. - - Example: - - with pl_legacy_patch(): - torch.load("path/to/legacy/checkpoint.ckpt") - """ - - def __enter__(self) -> None: - _lock.acquire() - # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` - legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") - sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module - - # `_gpus_arg_default` used to be imported from these locations - legacy_argparse_module._gpus_arg_default = lambda x: x - pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x - - def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None - ) -> None: - if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): - delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - del sys.modules["pytorch_lightning.utilities.argparse_utils"] - _lock.release() diff --git a/src/pytorch_lightning/utilities/migration/__init__.py b/src/pytorch_lightning/utilities/migration/__init__.py new file mode 100644 index 0000000000000..199541c19034f --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401 +from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py new file mode 100644 index 0000000000000..3431c01709b89 --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -0,0 +1,69 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. + +When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, +see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`. + +For the Lightning developer: How to add a new migration? + +1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include + version information as well as the specific commit or PR where the breaking change happened. +2. Add the function to the `_migration_index()` below. The key in the index is the version of Lightning in which the + change happened. Any checkpoint with a version greater or equal to that version will apply the given function. + Multiple migrations per version get executed in the provided list order. +3. You can test the migration on a checkpoint (backup your files first) by running: + + cp model.ckpt model.ckpt.backup + python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt +""" + +from typing import Any, Callable, Dict, List + +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +_CHECKPOINT = Dict[str, Any] + + +def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: + """Migration functions returned here will get executed in the order they are listed.""" + return { + "0.10.0": [_migrate_model_checkpoint_early_stopping], + } + + +def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """The checkpoint and early stopping keys were renamed. + + Version: 0.10.0 + Commit: a5d1176 + """ + keys_mapping = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), + } + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in keys_mapping.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + return checkpoint diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py new file mode 100644 index 0000000000000..209262057b990 --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -0,0 +1,140 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import sys +from distutils.version import LooseVersion +from types import ModuleType, TracebackType +from typing import Any, Dict, List, Optional, Tuple, Type + +from lightning_utilities.core.rank_zero import rank_zero_warn + +import pytorch_lightning as pl +from lightning_lite.utilities.types import _PATH +from lightning_lite.utilities.warnings import PossibleUserWarning +from pytorch_lightning.utilities.migration.migration import _migration_index + +_log = logging.getLogger(__name__) +_CHECKPOINT = Dict[str, Any] + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: + """Applies Lightning version migrations to a checkpoint dictionary. + + Note: + The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large + checkpoints and objects that do not support being deep-copied. + """ + ckpt_version = _get_version(checkpoint) + if LooseVersion(ckpt_version) > LooseVersion(pl.__version__): + rank_zero_warn( + f"The loaded checkpoint was produced with Lightning v{ckpt_version}, which is newer than your current" + f" Lightning version: v{pl.__version__}", + category=PossibleUserWarning, + ) + return checkpoint, {} + + index = _migration_index() + applied_migrations = {} + for migration_version, migration_functions in index.items(): + if not _should_upgrade(checkpoint, migration_version): + continue + for migration_function in migration_functions: + checkpoint = migration_function(checkpoint) + + applied_migrations[migration_version] = [fn.__name__ for fn in migration_functions] + + if ckpt_version != pl.__version__: + _set_legacy_version(checkpoint, ckpt_version) + _set_version(checkpoint, pl.__version__) + return checkpoint, applied_migrations + + +class pl_legacy_patch: + """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for + unpickling old checkpoints. The following patches apply. + + 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, + but still needs to be available for import for legacy checkpoints. + + Example: + + with pl_legacy_patch(): + torch.load("path/to/legacy/checkpoint.ckpt") + """ + + def __enter__(self) -> "pl_legacy_patch": + # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` + legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") + sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module + + # `_gpus_arg_default` used to be imported from these locations + legacy_argparse_module._gpus_arg_default = lambda x: x + pl.utilities.argparse._gpus_arg_default = lambda x: x + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + if hasattr(pl.utilities.argparse, "_gpus_arg_default"): + delattr(pl.utilities.argparse, "_gpus_arg_default") + del sys.modules["pytorch_lightning.utilities.argparse_utils"] + + +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: + """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. + + This function is used by the Lightning Trainer when resuming from a checkpoint. + """ + old_version = _get_version(checkpoint) + checkpoint, migrations = migrate_checkpoint(checkpoint) + new_version = _get_version(checkpoint) + if not migrations or checkpoint_path is None: + # the checkpoint was already a new one, no migrations were needed + return checkpoint + + # include the full upgrade command, including the path to the loaded file in the error message, + # so user can copy-paste and run if they want + path_hint = os.path.relpath(checkpoint_path, os.getcwd()) + _log.info( + f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}." + " To apply the upgrade to your files permanently, run" + f" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file {str(path_hint)}`" + ) + return checkpoint + + +def _get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the legacy version of a Lightning checkpoint if a legacy version is not already set.""" + checkpoint.setdefault("legacy_pytorch-lightning_version", version) + + +def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 6f4dd5ca938dd..4bcfb4a86f5bd 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,36 +17,10 @@ import torch -from lightning_lite.utilities.types import _PATH -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import pl_legacy_patch - -KEYS_MAPPING = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), -} +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch log = logging.getLogger(__name__) - -def upgrade_checkpoint(filepath: _PATH) -> None: - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) - - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -61,4 +35,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None: log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") with pl_legacy_patch(): - upgrade_checkpoint(args.file) + checkpoint = torch.load(args.file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, args.file) diff --git a/tests/tests_pytorch/utilities/migration/__init__.py b/tests/tests_pytorch/utilities/migration/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py new file mode 100644 index 0000000000000..d662cf5e89833 --- /dev/null +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -0,0 +1,143 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys + +import pytest + +import pytorch_lightning as pl +from lightning_lite.utilities.warnings import PossibleUserWarning +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch +from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint + + +def test_patch_legacy_argparse_utils(): + with pl_legacy_patch(): + from pytorch_lightning.utilities import argparse_utils + + assert callable(argparse_utils._gpus_arg_default) + assert "pytorch_lightning.utilities.argparse_utils" in sys.modules + + assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules + + +def test_patch_legacy_gpus_arg_default(): + with pl_legacy_patch(): + from pytorch_lightning.utilities.argparse import _gpus_arg_default + + assert callable(_gpus_arg_default) + assert not hasattr(pl.utilities.argparse, "_gpus_arg_default") + assert not hasattr(pl.utilities.argparse, "_gpus_arg_default") + + +def test_migrate_checkpoint(monkeypatch): + """Test that the correct migration function gets executed given the current version of the checkpoint.""" + + # A checkpoint that is older than any migration point in the index + old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123} + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["one", "two", "three", "four"] + assert ( + new_checkpoint + == old_checkpoint + == {"legacy_pytorch-lightning_version": "0.0.0", "pytorch-lightning_version": pl.__version__, "content": 123} + ) + + # A checkpoint that is newer, but not the newest + old_checkpoint = {"pytorch-lightning_version": "1.0.3", "content": 123} + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["four"] + assert ( + new_checkpoint + == old_checkpoint + == {"legacy_pytorch-lightning_version": "1.0.3", "pytorch-lightning_version": pl.__version__, "content": 123} + ) + + # A checkpoint newer than any migration point in the index + old_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == [] + assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + + +def _run_simple_migration(monkeypatch, old_checkpoint): + call_order = [] + + def dummy_upgrade(tag): + def upgrade(ckpt): + call_order.append(tag) + return ckpt + + return upgrade + + index = { + "0.0.1": [dummy_upgrade("one")], + "0.0.2": [dummy_upgrade("two"), dummy_upgrade("three")], + "1.2.3": [dummy_upgrade("four")], + } + monkeypatch.setattr(pl.utilities.migration.utils, "_migration_index", lambda: index) + new_checkpoint, _ = migrate_checkpoint(old_checkpoint) + return new_checkpoint, call_order + + +def test_migrate_checkpoint_too_new(): + """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer + version of Lightning than installed.""" + super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} + with pytest.warns( + PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" + ): + new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy()) + + # no version modification + assert not migrations + assert new_checkpoint == super_new_checkpoint + + +def test_migrate_checkpoint_for_pl(caplog): + """Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent.""" + + # simulate a very recent checkpoint, no migrations needed + loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} + new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + + # simulate an old checkpoint that needed an upgrade + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): + new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") + assert new_checkpoint == { + "legacy_pytorch-lightning_version": "0.0.1", + "pytorch-lightning_version": pl.__version__, + "callbacks": {}, + "content": 123, + } + assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text + + +def test_migrate_checkpoint_legacy_version(monkeypatch): + """Test that the legacy version gets set and does not change if migration is applied multiple times.""" + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + + # pretend the current pl version is 2.0 + monkeypatch.setattr(pl, "__version__", "2.0.0") + new_checkpoint, _ = migrate_checkpoint(loaded_checkpoint) + assert new_checkpoint["pytorch-lightning_version"] == "2.0.0" + assert new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1" + + # pretend the current pl version is even newer, we are migrating a second time + monkeypatch.setattr(pl, "__version__", "3.0.0") + new_new_checkpoint, _ = migrate_checkpoint(new_checkpoint) + assert new_new_checkpoint["pytorch-lightning_version"] == "3.0.0" + assert new_new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1" # remains the same diff --git a/tests/tests_pytorch/utilities/test_migration.py b/tests/tests_pytorch/utilities/test_migration.py deleted file mode 100644 index ee94ee690e798..0000000000000 --- a/tests/tests_pytorch/utilities/test_migration.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -import pytorch_lightning -from pytorch_lightning.utilities.migration import pl_legacy_patch - - -def test_patch_legacy_argparse_utils(): - with pl_legacy_patch(): - from pytorch_lightning.utilities import argparse_utils - - assert callable(argparse_utils._gpus_arg_default) - assert "pytorch_lightning.utilities.argparse_utils" in sys.modules - - assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules - - -def test_patch_legacy_gpus_arg_default(): - with pl_legacy_patch(): - from pytorch_lightning.utilities.argparse import _gpus_arg_default - - assert callable(_gpus_arg_default) - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index a58bdb5721bc7..c8866829581cb 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint +from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version @pytest.mark.parametrize( @@ -42,8 +41,9 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - filepath = os.path.join(tmpdir, "model.ckpt") - torch.save(old_checkpoint, filepath) - upgrade_checkpoint(filepath) - updated_checkpoint = torch.load(filepath) - assert updated_checkpoint == new_checkpoint + _set_version(old_checkpoint, "0.9.0") + _set_legacy_version(new_checkpoint, "0.9.0") + _set_version(new_checkpoint, pl.__version__) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + assert updated_checkpoint == old_checkpoint == new_checkpoint + assert _get_version(updated_checkpoint) == pl.__version__