From 277b0b811fb1419d6c06e7953941d6f6076eaf6d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 13:44:35 +0200 Subject: [PATCH] migration --- src/pytorch_lightning/core/saving.py | 7 +- .../connectors/checkpoint_connector.py | 28 +-- src/pytorch_lightning/utilities/migration.py | 159 +++++++++++++++++- .../utilities/upgrade_checkpoint.py | 33 +--- .../utilities/test_upgrade_checkpoint.py | 14 +- 5 files changed, 169 insertions(+), 72 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 46f1663ed705c..247d864438d87 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -31,7 +31,7 @@ from lightning_lite.utilities.cloud_io import load as pl_load from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -156,6 +156,9 @@ def _load_from_checkpoint( with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) + # convert legacy checkpoints to the new format + checkpoint = migrate_checkpoint(checkpoint) + if hparams_file is not None: extension = str(hparams_file).split(".")[-1] if extension.lower() == "csv": @@ -168,6 +171,7 @@ def _load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) # override the hparams with values that were passed in @@ -197,6 +201,7 @@ def _load_state( if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: if issubclass(cls, pl.LightningModule): + # TODO: make this a migration: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 50480e769b4ce..01d5a6a7e14ed 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -32,9 +32,8 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -86,13 +85,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) + loaded_checkpoint = migrate_checkpoint(loaded_checkpoint) return loaded_checkpoint def _set_ckpt_path( @@ -348,23 +341,6 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop - pl_module = self.trainer.lightning_module - assert pl_module is not None - - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] - assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py index ed71f25a571f7..7f400fc56efa4 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. + +When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, +see :func:`migrate_checkpoint`. +""" import sys -import threading +from distutils.version import LooseVersion from types import ModuleType, TracebackType +from typing import Any, Dict, Optional, Type +import pytorch_lightning as pl import pytorch_lightning.utilities.argparse -# Create a global lock to ensure no race condition with deleting sys modules -_lock = threading.Lock() +_CHECKPOINT = Dict[str, Any] class pl_legacy_patch: @@ -28,7 +33,7 @@ class pl_legacy_patch: unpickling old checkpoints. The following patches apply. 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898 + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, but still needs to be available for import for legacy checkpoints. @@ -38,8 +43,7 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ - def __enter__(self) -> None: - _lock.acquire() + def __enter__(self) -> "pl_legacy_patch": # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -47,11 +51,148 @@ def __enter__(self) -> None: # `_gpus_arg_default` used to be imported from these locations legacy_argparse_module._gpus_arg_default = lambda x: x pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x + return self def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], ) -> None: if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] - _lock.release() + + +def get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies all migrations below in order.""" + if should_upgrade(checkpoint, "0.10.0"): + _migrate_model_checkpoint_early_stopping(checkpoint) + if should_upgrade(checkpoint, "1.6.0"): + _migrate_loop_global_step_to_progress_tracking(checkpoint) + _migrate_loop_current_epoch_to_progress_tracking(checkpoint) + + set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint + + +def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """The checkpoint and early stopping keys were renamed. + + Version: 0.10.0 + Commit: + """ + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + + keys_mapping = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), + } + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in keys_mapping.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + return checkpoint + + +def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + global_step = checkpoint["global_step"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + # for automatic optimization + optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] + optim_progress["optimizer"]["step"]["total"]["completed"] = global_step + # for manual optimization + optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] + optim_step_progress["total"]["completed"] = global_step + return checkpoint + + +def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + epoch = checkpoint["epoch"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + + +_FIT_LOOP_INITIAL_STATE_1_6_0 = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, +} diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 6f4dd5ca938dd..46038701b1e52 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,36 +17,11 @@ import torch -from lightning_lite.utilities.types import _PATH -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import pl_legacy_patch - -KEYS_MAPPING = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), -} +from pytorch_lightning.utilities.migration.base import pl_legacy_patch +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint log = logging.getLogger(__name__) - -def upgrade_checkpoint(filepath: _PATH) -> None: - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) - - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -61,4 +36,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None: log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") with pl_legacy_patch(): - upgrade_checkpoint(args.file) + checkpoint = torch.load(args.file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, args.file) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index a58bdb5721bc7..c01fcf7eb249d 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint +from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version @pytest.mark.parametrize( @@ -42,8 +40,8 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - filepath = os.path.join(tmpdir, "model.ckpt") - torch.save(old_checkpoint, filepath) - upgrade_checkpoint(filepath) - updated_checkpoint = torch.load(filepath) + set_version(old_checkpoint, "0.9.0") + set_version(new_checkpoint, pl.__version__) + updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint + assert get_version(updated_checkpoint) == pl.__version__