Skip to content

Commit

Permalink
migration
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Oct 21, 2022
1 parent 3da62ff commit 277b0b8
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 72 deletions.
7 changes: 6 additions & 1 deletion src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from lightning_lite.utilities.cloud_io import load as pl_load
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

Expand Down Expand Up @@ -156,6 +156,9 @@ def _load_from_checkpoint(
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)

# convert legacy checkpoints to the new format
checkpoint = migrate_checkpoint(checkpoint)

if hparams_file is not None:
extension = str(hparams_file).split(".")[-1]
if extension.lower() == "csv":
Expand All @@ -168,6 +171,7 @@ def _load_from_checkpoint(
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

# TODO: make this a migration:
# for past checkpoint need to add the new key
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
# override the hparams with values that were passed in
Expand Down Expand Up @@ -197,6 +201,7 @@ def _load_state(
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:

if issubclass(cls, pl.LightningModule):
# TODO: make this a migration:
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
Expand Down
28 changes: 2 additions & 26 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
from omegaconf import Container
Expand Down Expand Up @@ -86,13 +85,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
with pl_legacy_patch():
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
loaded_checkpoint = migrate_checkpoint(loaded_checkpoint)
return loaded_checkpoint

def _set_ckpt_path(
Expand Down Expand Up @@ -348,23 +341,6 @@ def restore_loops(self) -> None:
return

fit_loop = self.trainer.fit_loop
pl_module = self.trainer.lightning_module
assert pl_module is not None

# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
batch_loop = fit_loop.epoch_loop.batch_loop
if pl_module.automatic_optimization:
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
"global_step"
]
else:
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]

# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]

assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
Expand Down
159 changes: 150 additions & 9 deletions src/pytorch_lightning/utilities/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.
When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
see :func:`migrate_checkpoint`.
"""

import sys
import threading
from distutils.version import LooseVersion
from types import ModuleType, TracebackType
from typing import Any, Dict, Optional, Type

import pytorch_lightning as pl
import pytorch_lightning.utilities.argparse

# Create a global lock to ensure no race condition with deleting sys modules
_lock = threading.Lock()
_CHECKPOINT = Dict[str, Any]


class pl_legacy_patch:
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
unpickling old checkpoints. The following patches apply.
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
but still needs to be available for import for legacy checkpoints.
Expand All @@ -38,20 +43,156 @@ class pl_legacy_patch:
torch.load("path/to/legacy/checkpoint.ckpt")
"""

def __enter__(self) -> None:
_lock.acquire()
def __enter__(self) -> "pl_legacy_patch":
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module

# `_gpus_arg_default` used to be imported from these locations
legacy_argparse_module._gpus_arg_default = lambda x: x
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
return self

def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
) -> None:
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
del sys.modules["pytorch_lightning.utilities.argparse_utils"]
_lock.release()


def get_version(checkpoint: _CHECKPOINT) -> str:
"""Get the version of a Lightning checkpoint."""
return checkpoint["pytorch-lightning_version"]


def set_version(checkpoint: _CHECKPOINT, version: str) -> None:
"""Set the version of a Lightning checkpoint."""
checkpoint["pytorch-lightning_version"] = version


def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool:
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
return LooseVersion(get_version(checkpoint)) < LooseVersion(target)


def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Applies all migrations below in order."""
if should_upgrade(checkpoint, "0.10.0"):
_migrate_model_checkpoint_early_stopping(checkpoint)
if should_upgrade(checkpoint, "1.6.0"):
_migrate_loop_global_step_to_progress_tracking(checkpoint)
_migrate_loop_current_epoch_to_progress_tracking(checkpoint)

set_version(checkpoint, pl.__version__)

# TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert
# checkpoints permanently
return checkpoint


def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""The checkpoint and early stopping keys were renamed.
Version: 0.10.0
Commit:
"""
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

keys_mapping = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
"early_stop_callback_patience": (EarlyStopping, "patience"),
}
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}

for key, new_path in keys_mapping.items():
if key in checkpoint:
value = checkpoint[key]
callback_type, callback_key = new_path
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
return checkpoint


def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
It will be overwritten by the loop's state if it was also saved.
Version: 1.6.0
Commit:
"""
global_step = checkpoint["global_step"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
# for automatic optimization
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
# for manual optimization
optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]
optim_step_progress["total"]["completed"] = global_step
return checkpoint


def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
It will be overwritten by the loop's state if it was also saved.
Version: 1.6.0
Commit:
"""
epoch = checkpoint["epoch"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch


_FIT_LOOP_INITIAL_STATE_1_6_0 = {
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"zero_grad": {
"current": {"completed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "ready": 0, "started": 0},
},
},
"optimizer_position": 0,
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.val_loop.dataloader_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.state_dict": {},
"epoch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"state_dict": {},
}
33 changes: 5 additions & 28 deletions src/pytorch_lightning/utilities/upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,11 @@

import torch

from lightning_lite.utilities.types import _PATH
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.migration import pl_legacy_patch

KEYS_MAPPING = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
"early_stop_callback_patience": (EarlyStopping, "patience"),
}
from pytorch_lightning.utilities.migration.base import pl_legacy_patch
from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint

log = logging.getLogger(__name__)


def upgrade_checkpoint(filepath: _PATH) -> None:
checkpoint = torch.load(filepath)
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}

for key, new_path in KEYS_MAPPING.items():
if key in checkpoint:
value = checkpoint[key]
callback_type, callback_key = new_path
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]

torch.save(checkpoint, filepath)


if __name__ == "__main__":

parser = argparse.ArgumentParser(
Expand All @@ -61,4 +36,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None:
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
copyfile(args.file, args.file + ".bak")
with pl_legacy_patch():
upgrade_checkpoint(args.file)
checkpoint = torch.load(args.file)
migrate_checkpoint(checkpoint)
torch.save(checkpoint, args.file)
14 changes: 6 additions & 8 deletions tests/tests_pytorch/utilities/test_upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version


@pytest.mark.parametrize(
Expand All @@ -42,8 +40,8 @@
],
)
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
filepath = os.path.join(tmpdir, "model.ckpt")
torch.save(old_checkpoint, filepath)
upgrade_checkpoint(filepath)
updated_checkpoint = torch.load(filepath)
set_version(old_checkpoint, "0.9.0")
set_version(new_checkpoint, pl.__version__)
updated_checkpoint = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == new_checkpoint
assert get_version(updated_checkpoint) == pl.__version__

0 comments on commit 277b0b8

Please sign in to comment.