Skip to content

Commit

Permalink
Introduce checkpoint migration (#15237)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
awaelchli and carmocca authored Nov 2, 2022
1 parent 6aa6423 commit 94f7d23
Show file tree
Hide file tree
Showing 15 changed files with 399 additions and 150 deletions.
4 changes: 2 additions & 2 deletions src/lightning_lite/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import torch
from torch.nn import Module
from typing_extensions import Self


class _DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
__jit_unused_properties__: List[str] = ["device", "dtype"]

def __init__(self) -> None:
super().__init__()
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-

Expand All @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-

Expand Down Expand Up @@ -57,7 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [1.8.0] - 2022-11-01


### Added

- Added support for requeueing slurm array jobs ([#15040](https://github.com/Lightning-AI/lightning/pull/15040))
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import inspect
import types
from argparse import Namespace
from typing import Any, MutableMapping, Optional, Sequence, Union
from typing import Any, List, MutableMapping, Optional, Sequence, Union

from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities.parsing import AttributeDict, save_hyperparameters


class HyperparametersMixin:

__jit_unused_properties__ = ["hparams", "hparams_initial"]
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]

def __init__(self) -> None:
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LightningModule(
):
# Below is for property support of JIT
# since none of these are important when using JIT, we are going to ignore them.
__jit_unused_properties__ = (
__jit_unused_properties__: List[str] = (
[
"example_input_array",
"on_gpu",
Expand Down
9 changes: 9 additions & 0 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from argparse import Namespace
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn

Expand All @@ -32,6 +33,7 @@
from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

Expand Down Expand Up @@ -156,6 +158,11 @@ def _load_from_checkpoint(
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)

# convert legacy checkpoints to the new format
checkpoint = _pl_migrate_checkpoint(
checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
)

if hparams_file is not None:
extension = str(hparams_file).split(".")[-1]
if extension.lower() == "csv":
Expand All @@ -168,6 +175,7 @@ def _load_from_checkpoint(
# overwrite hparams by the given file
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

# TODO: make this a migration:
# for past checkpoint need to add the new key
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
# override the hparams with values that were passed in
Expand Down Expand Up @@ -198,6 +206,7 @@ def _load_state(
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:

if issubclass(cls, pl.LightningModule):
# TODO: make this a migration:
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
Expand Down
14 changes: 2 additions & 12 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
from omegaconf import Container
Expand Down Expand Up @@ -81,19 +81,9 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
return

rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
with pl_legacy_patch():
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
return loaded_checkpoint
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)

def _set_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool
Expand Down
57 changes: 0 additions & 57 deletions src/pytorch_lightning/utilities/migration.py

This file was deleted.

16 changes: 16 additions & 0 deletions src/pytorch_lightning/utilities/migration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401
from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401
69 changes: 69 additions & 0 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.
When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`.
For the Lightning developer: How to add a new migration?
1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include
version information as well as the specific commit or PR where the breaking change happened.
2. Add the function to the `_migration_index()` below. The key in the index is the version of Lightning in which the
change happened. Any checkpoint with a version greater or equal to that version will apply the given function.
Multiple migrations per version get executed in the provided list order.
3. You can test the migration on a checkpoint (backup your files first) by running:
cp model.ckpt model.ckpt.backup
python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt
"""

from typing import Any, Callable, Dict, List

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

_CHECKPOINT = Dict[str, Any]


def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
"""Migration functions returned here will get executed in the order they are listed."""
return {
"0.10.0": [_migrate_model_checkpoint_early_stopping],
}


def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""The checkpoint and early stopping keys were renamed.
Version: 0.10.0
Commit: a5d1176
"""
keys_mapping = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
"early_stop_callback_patience": (EarlyStopping, "patience"),
}
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}

for key, new_path in keys_mapping.items():
if key in checkpoint:
value = checkpoint[key]
callback_type, callback_key = new_path
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
return checkpoint
Loading

0 comments on commit 94f7d23

Please sign in to comment.