diff --git a/docs/source-pytorch/common/checkpointing.rst b/docs/source-pytorch/common/checkpointing.rst index 765d5cce6c36e..c1552fa153263 100644 --- a/docs/source-pytorch/common/checkpointing.rst +++ b/docs/source-pytorch/common/checkpointing.rst @@ -12,33 +12,41 @@ Checkpointing .. Add callout items below this line .. displayitem:: - :header: Basic + :header: Saving and loading checkpoints :description: Learn to save and load checkpoints - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_basic.html :height: 150 :tag: basic .. displayitem:: - :header: Intermediate - :description: Customize checkpointing behavior - :col_css: col-md-3 + :header: Customize checkpointing behavior + :description: Learn how to change the behavior of checkpointing + :col_css: col-md-4 :button_link: checkpointing_intermediate.html :height: 150 :tag: intermediate .. displayitem:: - :header: Advanced + :header: Upgrading checkpoints + :description: Learn how to upgrade old checkpoints to the newest Lightning version + :col_css: col-md-4 + :button_link: checkpointing_migration.html + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Cloud-based checkpoints :description: Enable cloud-based checkpointing and composable checkpoints. - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_advanced.html :height: 150 :tag: advanced .. displayitem:: - :header: Expert + :header: Distributed checkpoints :description: Customize checkpointing for custom distributed strategies and accelerators. - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_expert.html :height: 150 :tag: expert diff --git a/docs/source-pytorch/common/checkpointing_advanced.rst b/docs/source-pytorch/common/checkpointing_advanced.rst index eb4ff4dc0e607..3ef5bf6b778f1 100644 --- a/docs/source-pytorch/common/checkpointing_advanced.rst +++ b/docs/source-pytorch/common/checkpointing_advanced.rst @@ -1,8 +1,8 @@ .. _checkpointing_advanced: -######################## -Checkpointing (advanced) -######################## +################################## +Cloud-based checkpoints (advanced) +################################## ***************** diff --git a/docs/source-pytorch/common/checkpointing_basic.rst b/docs/source-pytorch/common/checkpointing_basic.rst index 8a4834096c44d..85292b0a7085d 100644 --- a/docs/source-pytorch/common/checkpointing_basic.rst +++ b/docs/source-pytorch/common/checkpointing_basic.rst @@ -2,9 +2,9 @@ .. _checkpointing_basic: -##################### -Checkpointing (basic) -##################### +###################################### +Saving and loading checkpoints (basic) +###################################### **Audience:** All users ---- diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index f800d822aa1c5..20511d3a3c97c 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -2,9 +2,9 @@ .. _checkpointing_expert: -###################### -Checkpointing (expert) -###################### +################################ +Distributed checkpoints (expert) +################################ ********************************* Writing your own Checkpoint class diff --git a/docs/source-pytorch/common/checkpointing_intermediate.rst b/docs/source-pytorch/common/checkpointing_intermediate.rst index 7e17d8abc808f..e20ad6884618c 100644 --- a/docs/source-pytorch/common/checkpointing_intermediate.rst +++ b/docs/source-pytorch/common/checkpointing_intermediate.rst @@ -1,10 +1,10 @@ :orphan: -.. _checkpointing_intermediate: +.. _checkpointing_intermediate_1: -############################ -Checkpointing (intermediate) -############################ +############################################### +Customize checkpointing behavior (intermediate) +############################################### **Audience:** Users looking to customize the checkpointing behavior ---- diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst new file mode 100644 index 0000000000000..d04b24f4c60b9 --- /dev/null +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -0,0 +1,51 @@ +:orphan: + +.. _checkpointing_intermediate_2: + +#################################### +Upgrading checkpoints (intermediate) +#################################### +**Audience:** Users who are upgrading Lightning and their code and want to reuse their old checkpoints. + +---- + +************************************** +Resume training from an old checkpoint +************************************** + +Next to the model weights and trainer state, a Lightning checkpoint contains the version number of Lightning with which the checkpoint was saved. +When you load a checkpoint file, either by resuming training + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt") + +or by loading the state directly into your model, + +.. code-block:: python + + model = LitModel.load_from_checkpoint("path/to/checkpoint.ckpt") + +Lightning will automatically recognize that it is from an older version and migrates the internal structure so it can be loaded properly. +This is done without any action required by the user. + +---- + +************************************ +Upgrade checkpoint files permanently +************************************ + +When Lightning loads a checkpoint, it applies the version migration on-the-fly as explained above, but it does not modify your checkpoint files. +You can upgrade checkpoint files permanently with the following command + +.. code-block:: + + python -m pytorch_lightning.utilities.upgrade_checkpoint path/to/model.ckpt + + +or a folder with multiple files: + +.. code-block:: + + python -m pytorch_lightning.utilities.upgrade_checkpoint /path/to/checkpoints/folder diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6a600a058a90d..5bd34c2864f51 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- +- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) - diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 4bcfb4a86f5bd..e02b89457608e 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -11,30 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse +import glob import logging +from argparse import ArgumentParser, Namespace +from pathlib import Path from shutil import copyfile +from typing import List import torch +from tqdm import tqdm from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch -log = logging.getLogger(__name__) +_log = logging.getLogger(__name__) -if __name__ == "__main__": - parser = argparse.ArgumentParser( +def _upgrade(args: Namespace) -> None: + path = Path(args.path).absolute() + extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" + files: List[Path] = [] + + if not path.exists(): + _log.error( + f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory" + f" containing checkpoints ending in {extension}." + ) + exit(1) + + if path.is_file(): + files = [path] + if path.is_dir(): + files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)] + if not files: + _log.error( + f"No checkpoint files with extension {extension} were found in {path}." + f" HINT: Try setting the `--extension` option to specify the right file extension to look for." + ) + exit(1) + + _log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") + for file in files: + backup_file = file.with_suffix(".bak") + if backup_file.exists(): + # never overwrite backup files - they are the original, untouched checkpoints + continue + copyfile(file, backup_file) + + _log.info("Upgrading checkpoints ...") + for file in tqdm(files): + with pl_legacy_patch(): + checkpoint = torch.load(file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, file) + + _log.info("Done.") + + +def main() -> None: + parser = ArgumentParser( description=( - "Upgrade an old checkpoint to the current schema. This will also save a backup of the original file." + "A utility to upgrade old checkpoints to the format of the current Lightning version." + " This will also save a backup of the original files." ) ) - parser.add_argument("--file", help="filepath for a checkpoint to upgrade") - + parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") + parser.add_argument( + "--extension", + "-e", + type=str, + default=".ckpt", + help="The file extension to look for when searching for checkpoint files in a directory.", + ) args = parser.parse_args() + _upgrade(args) + - log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") - copyfile(args.file, args.file + ".bak") - with pl_legacy_patch(): - checkpoint = torch.load(args.file) - migrate_checkpoint(checkpoint) - torch.save(checkpoint, args.file) +if __name__ == "__main__": + main() diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index c8866829581cb..2a53448f5189c 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from pathlib import Path +from unittest import mock +from unittest.mock import ANY + import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version +from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main @pytest.mark.parametrize( @@ -47,3 +53,71 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == old_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ + + +def test_upgrade_checkpoint_file_missing(tmp_path, caplog): + # path to single file (missing) + file = tmp_path / "checkpoint.ckpt" + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): + with caplog.at_level(logging.ERROR): + with pytest.raises(SystemExit): + upgrade_main() + assert f"The path {file} does not exist" in caplog.text + + caplog.clear() + + # path to non-empty directory, but no checkpoints with matching extension + file.touch() + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]): + with caplog.at_level(logging.ERROR): + with pytest.raises(SystemExit): + upgrade_main() + assert "No checkpoint files with extension .other were found" in caplog.text + + +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") +def test_upgrade_checkpoint_single_file(migrate_mock, load_mock, save_mock, tmp_path): + file = tmp_path / "checkpoint.ckpt" + file.touch() + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): + upgrade_main() + + load_mock.assert_called_once_with(Path(file)) + migrate_mock.assert_called_once() + save_mock.assert_called_once_with(ANY, Path(file)) + + +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") +def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_path): + top_files = [tmp_path / "top0.ckpt", tmp_path / "top1.ckpt"] + nested_files = [ + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir0" / "nested1.other", + tmp_path / "subdir1" / "nested2.ckpt", + ] + + for file in top_files + nested_files: + file.parent.mkdir(exist_ok=True, parents=True) + file.touch() + + # directory with recursion + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]): + upgrade_main() + + assert {c[0][0] for c in load_mock.call_args_list} == { + tmp_path / "top0.ckpt", + tmp_path / "top1.ckpt", + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir1" / "nested2.ckpt", + } + assert migrate_mock.call_count == 4 + assert {c[0][1] for c in save_mock.call_args_list} == { + tmp_path / "top0.ckpt", + tmp_path / "top1.ckpt", + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir1" / "nested2.ckpt", + }