Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce checkpoint migration #15237

Merged
merged 49 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
277b0b8
migration
awaelchli Oct 21, 2022
cc110a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2022
e13fcc6
import
awaelchli Oct 21, 2022
7cb1d24
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 21, 2022
9838f00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2022
ed1ab3f
refactor
awaelchli Oct 24, 2022
756b2e7
protected
awaelchli Oct 24, 2022
d64b5ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2022
954b4d3
typo
awaelchli Oct 24, 2022
0febf26
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 24, 2022
5223dff
Merge branch 'master' into feature/migration-utils
awaelchli Oct 24, 2022
0f988d1
tests
awaelchli Oct 24, 2022
3a6aa87
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2022
ac0d2fc
prune
awaelchli Oct 24, 2022
aec0257
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 24, 2022
f7f1250
reset
awaelchli Oct 24, 2022
24a5d60
wip
awaelchli Oct 24, 2022
c349ec7
messaging
awaelchli Oct 24, 2022
0e6f867
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2022
cc333c6
checkpint path
awaelchli Oct 24, 2022
c1afc87
Merge remote-tracking branch 'origin/feature/migration-utils' into fe…
awaelchli Oct 24, 2022
3734981
type
awaelchli Oct 24, 2022
da53672
warn
awaelchli Oct 24, 2022
f745259
tests
awaelchli Oct 24, 2022
372dfaf
tests
awaelchli Oct 24, 2022
ec9b0f8
changelog
awaelchli Oct 24, 2022
978b0f0
add commit info
awaelchli Oct 24, 2022
815f7f6
mark index as protected
awaelchli Oct 24, 2022
d890069
test fix
awaelchli Oct 24, 2022
62ae611
extend test to check inplace modification
awaelchli Oct 24, 2022
7124604
Merge branch 'master' into feature/migration-utils
awaelchli Oct 25, 2022
9ff9350
set legacy version if upgrade happened
awaelchli Oct 25, 2022
947bbbe
Merge branch 'master' into feature/migration-utils
awaelchli Oct 26, 2022
3e6fd84
Merge branch 'master' into feature/migration-utils
awaelchli Oct 27, 2022
bc032c5
Merge branch 'master' into feature/migration-utils
awaelchli Nov 1, 2022
7419eca
changelog
awaelchli Nov 1, 2022
ccc9ba3
Merge branch 'master' into feature/migration-utils
awaelchli Nov 1, 2022
173c2da
fix typing
awaelchli Nov 1, 2022
70e1110
Merge branch 'master' into feature/migration-utils
awaelchli Nov 1, 2022
e3debf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2022
47d56ca
notebook
awaelchli Nov 1, 2022
4d2767e
notebook
awaelchli Nov 1, 2022
4967134
Merge branch 'master' into feature/migration-utils
awaelchli Nov 2, 2022
15dbfdf
fix mypy error
awaelchli Nov 2, 2022
59e05b0
Update src/pytorch_lightning/utilities/migration/utils.py
awaelchli Nov 2, 2022
1743f64
rename
awaelchli Nov 2, 2022
82bb977
Merge branch 'feature/migration-utils' of github.com:Lightning-AI/lig…
awaelchli Nov 2, 2022
40353c3
add test for legacy migration twice
awaelchli Nov 2, 2022
a73e89b
clarify the instructions are for Lightning developers
awaelchli Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-

Expand All @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-

Expand Down Expand Up @@ -55,7 +55,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.8.0] - 2022-11-01


### Added

- Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040))
Expand Down
9 changes: 9 additions & 0 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from argparse import Namespace
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn

Expand All @@ -32,6 +33,7 @@
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

Expand Down Expand Up @@ -156,6 +158,11 @@ def _load_from_checkpoint(
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)

# convert legacy checkpoints to the new format
checkpoint = _pl_migrate_checkpoint(
checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)

if hparams_file is not None:
extension = str(hparams_file).split(".")[-1]
if extension.lower() == "csv":
Expand All @@ -168,6 +175,7 @@ def _load_from_checkpoint(
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

# TODO: make this a migration:
# for past checkpoint need to add the new key
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
# override the hparams with values that were passed in
Expand Down Expand Up @@ -198,6 +206,7 @@ def _load_state(
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:

if issubclass(cls, pl.LightningModule):
# TODO: make this a migration:
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
from omegaconf import Container
Expand Down Expand Up @@ -81,19 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
return

rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
with pl_legacy_patch():
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
return loaded_checkpoint
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)

def _set_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool
Expand Down
57 changes: 0 additions & 57 deletions src/pytorch_lightning/utilities/migration.py

This file was deleted.

16 changes: 16 additions & 0 deletions src/pytorch_lightning/utilities/migration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401
from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401
69 changes: 69 additions & 0 deletions src/pytorch_lightning/utilities/migration/migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright The PyTorch Lightning team.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.

When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`.

How to add a new migration?

1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include
version information as well as the specific commit or PR where the breaking change happened.
2. Add the function to the `_migration_index()` below. The key in the index is the version of Lightning in which the
change happened. Any checkpoint with a version greater or equal to that version will apply the given function.
Multiple migrations per version get executed in the provided list order.
3. You can test the migration on a checkpoint (backup your files first) by running:

cp model.ckpt model.ckpt.backup
python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt
"""

from typing import Any, Callable, Dict, List

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

_CHECKPOINT = Dict[str, Any]


def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Migration functions returned here will get executed in the order they are listed."""
return {
"0.10.0": [_migrate_model_checkpoint_early_stopping],
}


def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""The checkpoint and early stopping keys were renamed.

Version: 0.10.0
Commit: a5d1176
"""
keys_mapping = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
"early_stop_callback_patience": (EarlyStopping, "patience"),
}
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

for key, new_path in keys_mapping.items():
if key in checkpoint:
value = checkpoint[key]
callback_type, callback_key = new_path
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
return checkpoint
140 changes: 140 additions & 0 deletions src/pytorch_lightning/utilities/migration/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from distutils.version import LooseVersion
from types import ModuleType, TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type

from lightning_utilities.core.rank_zero import rank_zero_warn

import pytorch_lightning as pl
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.utilities.migration.migrations import _migration_index

_log = logging.getLogger(__name__)
_CHECKPOINT = Dict[str, Any]


def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]:
"""Applies Lightning version migrations to a checkpoint dictionary.

Note:
The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large
checkpoints and objects that do not support being deep-copied.
"""
ckpt_version = _get_version(checkpoint)
if LooseVersion(ckpt_version) > LooseVersion(pl.__version__):
rank_zero_warn(
f"The loaded checkpoint was produced with Lightning v{ckpt_version}, which is newer than your current"
f" Lightning version: v{pl.__version__}",
category=PossibleUserWarning,
)
return checkpoint, {}

index = _migration_index()
applied_migrations = {}
for migration_version, migration_functions in index.items():
if not _should_upgrade(checkpoint, migration_version):
continue
for migration_function in migration_functions:
checkpoint = migration_function(checkpoint)

applied_migrations[migration_version] = [fn.__name__ for fn in migration_functions]

if ckpt_version != pl.__version__:
_set_legacy_version(checkpoint, _get_version(checkpoint))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
_set_version(checkpoint, pl.__version__)
return checkpoint, applied_migrations


class pl_legacy_patch:
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
unpickling old checkpoints. The following patches apply.

1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
but still needs to be available for import for legacy checkpoints.

Example:

with pl_legacy_patch():
torch.load("path/to/legacy/checkpoint.ckpt")
"""

def __enter__(self) -> "pl_legacy_patch":
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module

# `_gpus_arg_default` used to be imported from these locations
legacy_argparse_module._gpus_arg_default = lambda x: x
pl.utilities.argparse._gpus_arg_default = lambda x: x
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
) -> None:
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
delattr(pl.utilities.argparse, "_gpus_arg_default")
del sys.modules["pytorch_lightning.utilities.argparse_utils"]


def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT:
"""Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user.

This function is used by the Lightning Trainer when resuming from a checkpoint.
"""
old_version = _get_version(checkpoint)
checkpoint, migrations = migrate_checkpoint(checkpoint)
new_version = _get_version(checkpoint)
if not migrations or checkpoint_path is None:
# the checkpoint was already a new one, no migrations were needed
return checkpoint

# include the full upgrade command, including the path to the loaded file in the error message,
# so user can copy-paste and run if they want
path_hint = os.path.relpath(checkpoint_path, os.getcwd())
_log.info(
f"Lightning automatically upgraded your loaded checkpoint from v{old_version} to v{new_version}."
" To apply the upgrade to your files permanently, run"
f" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file {str(path_hint)}`"
)
return checkpoint


def _get_version(checkpoint: _CHECKPOINT) -> str:
"""Get the version of a Lightning checkpoint."""
return checkpoint["pytorch-lightning_version"]


def _set_version(checkpoint: _CHECKPOINT, version: str) -> None:
"""Set the version of a Lightning checkpoint."""
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
checkpoint["pytorch-lightning_version"] = version


def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None:
"""Set the legacy version of a Lightning checkpoint."""
checkpoint["legacy_pytorch-lightning_version"] = version
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool:
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
return LooseVersion(_get_version(checkpoint)) < LooseVersion(target)
Loading