From 946b5ba41f902af558541d2fe317e54089e8c85c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 15:59:06 +0200 Subject: [PATCH 01/18] update docs --- docs/source-pytorch/common/checkpointing.rst | 26 +++++++---- .../common/checkpointing_advanced.rst | 6 +-- .../common/checkpointing_basic.rst | 6 +-- .../common/checkpointing_expert.rst | 6 +-- .../common/checkpointing_intermediate.rst | 6 +-- .../common/checkpointing_migration.rst | 45 +++++++++++++++++++ 6 files changed, 74 insertions(+), 21 deletions(-) create mode 100644 docs/source-pytorch/common/checkpointing_migration.rst diff --git a/docs/source-pytorch/common/checkpointing.rst b/docs/source-pytorch/common/checkpointing.rst index 765d5cce6c36e..c1552fa153263 100644 --- a/docs/source-pytorch/common/checkpointing.rst +++ b/docs/source-pytorch/common/checkpointing.rst @@ -12,33 +12,41 @@ Checkpointing .. Add callout items below this line .. displayitem:: - :header: Basic + :header: Saving and loading checkpoints :description: Learn to save and load checkpoints - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_basic.html :height: 150 :tag: basic .. displayitem:: - :header: Intermediate - :description: Customize checkpointing behavior - :col_css: col-md-3 + :header: Customize checkpointing behavior + :description: Learn how to change the behavior of checkpointing + :col_css: col-md-4 :button_link: checkpointing_intermediate.html :height: 150 :tag: intermediate .. displayitem:: - :header: Advanced + :header: Upgrading checkpoints + :description: Learn how to upgrade old checkpoints to the newest Lightning version + :col_css: col-md-4 + :button_link: checkpointing_migration.html + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Cloud-based checkpoints :description: Enable cloud-based checkpointing and composable checkpoints. - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_advanced.html :height: 150 :tag: advanced .. displayitem:: - :header: Expert + :header: Distributed checkpoints :description: Customize checkpointing for custom distributed strategies and accelerators. - :col_css: col-md-3 + :col_css: col-md-4 :button_link: checkpointing_expert.html :height: 150 :tag: expert diff --git a/docs/source-pytorch/common/checkpointing_advanced.rst b/docs/source-pytorch/common/checkpointing_advanced.rst index eb4ff4dc0e607..3ef5bf6b778f1 100644 --- a/docs/source-pytorch/common/checkpointing_advanced.rst +++ b/docs/source-pytorch/common/checkpointing_advanced.rst @@ -1,8 +1,8 @@ .. _checkpointing_advanced: -######################## -Checkpointing (advanced) -######################## +################################## +Cloud-based checkpoints (advanced) +################################## ***************** diff --git a/docs/source-pytorch/common/checkpointing_basic.rst b/docs/source-pytorch/common/checkpointing_basic.rst index 8a4834096c44d..85292b0a7085d 100644 --- a/docs/source-pytorch/common/checkpointing_basic.rst +++ b/docs/source-pytorch/common/checkpointing_basic.rst @@ -2,9 +2,9 @@ .. _checkpointing_basic: -##################### -Checkpointing (basic) -##################### +###################################### +Saving and loading checkpoints (basic) +###################################### **Audience:** All users ---- diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index f800d822aa1c5..20511d3a3c97c 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -2,9 +2,9 @@ .. _checkpointing_expert: -###################### -Checkpointing (expert) -###################### +################################ +Distributed checkpoints (expert) +################################ ********************************* Writing your own Checkpoint class diff --git a/docs/source-pytorch/common/checkpointing_intermediate.rst b/docs/source-pytorch/common/checkpointing_intermediate.rst index 7e17d8abc808f..02293effcd767 100644 --- a/docs/source-pytorch/common/checkpointing_intermediate.rst +++ b/docs/source-pytorch/common/checkpointing_intermediate.rst @@ -2,9 +2,9 @@ .. _checkpointing_intermediate: -############################ -Checkpointing (intermediate) -############################ +############################################### +Customize checkpointing behavior (intermediate) +############################################### **Audience:** Users looking to customize the checkpointing behavior ---- diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst new file mode 100644 index 0000000000000..22812aa900c53 --- /dev/null +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -0,0 +1,45 @@ +:orphan: + +.. _checkpointing_basic: + +#################################### +Upgrading checkpoints (intermediate) +#################################### +**Audience:** Users who are upgrading Lightning and their code and want to reuse their old checkpoints. + +---- + +************************************** +Resume training from an old checkpoint +************************************** + +Next to the model weights and trainer state, a Lightning checkpoint contains the version number of Lightning with which the checkpoint was saved. +When you load a checkpoint file, either by resuming training + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt") + +or by loading the state directly into your model, + +.. code-block:: python + + model = LitModel.load_from_checkpoint("path/to/checkpoint.ckpt") + +Lightning will automatically recognize that it is from an older version and migrates the internal structure so it can be loaded properly. +This is done without any action required by the user. + +---- + +************************************ +Upgrade checkpoint files permanently +************************************ + +When Lightning loads a checkpoint, it applies the version migration on-the-fly as explained above, but it does not modify your checkpoint files. +You can upgrade checkpoint files permanently with the following command: + +.. code-block:: python + + python -m lightning.pytorch.utilities.upgrade_checkpoint --file model.ckpt + From d12c1fe6da6fe95c1327d612c2749c46c2e998eb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 17:34:03 +0200 Subject: [PATCH 02/18] upgrade checkpoint --- .../utilities/upgrade_checkpoint.py | 59 ++++++++++++++++--- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 6f4dd5ca938dd..a06a721952ddd 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import glob import logging +from pathlib import Path from shutil import copyfile +from typing import List import torch +from tqdm import tqdm from lightning_lite.utilities.types import _PATH from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint @@ -48,17 +52,58 @@ def upgrade_checkpoint(filepath: _PATH) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser( description=( - "Upgrade an old checkpoint to the current schema. This will also save a backup of the original file." + "A utility to upgrade old checkpoints to the format of the current Lightning version." + " By default, this will also save a backup of the original file." ) ) - parser.add_argument("--file", help="filepath for a checkpoint to upgrade") + parser.add_argument("path", "--path", "--file", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") + parser.add_argument("--recursive", "-r", type=bool, action="store_true", help="If the specified path is a directory, recursively search for checkpoint files to upgrade") + parser.add_argument("--extension", "-e", type=str, default=".ckpt", help="The file extension to look for when searching for checkpoint files in a directory.") + parser.add_argument("--no-backup", type=bool, action="store_true", help="Do not save backup files. This will overwrite your existing files") args = parser.parse_args() + path = Path(args.path).absolute() + extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" + files: List[Path] = [] + + if not path.exists(): + raise FileNotFoundError() + + if path.is_file(): + files = [path] + if path.is_dir(): + files = [Path(p) for p in glob.glob(f"{str(path)}**{extension}", recursive=True)] + if not files: + raise FileNotFoundError( + f"No files were found in {path}." + f" HINT: Try setting the `--extension` option to specify the right file extension to look for." + ) + + # existing_backup_files = [file.relative_to(path.parent) for file in files if file.suffix == ".bak"] + # existing_backup_files_str = "\n".join((str(f) for f in existing_backup_files)) + # if existing_backup_files: + # answer = input( + # "It looks like there was already an upgrade on these files. The following backup files already exist:\n\n" + # f"{existing_backup_files_str}\n\n" + # f"Are you sure you want to continue and overwrite the file(s)? The backup files are the original" + # f" checkpoints before the upgrade" + # ) + + + # if not args.no_backup: + # input("") + + log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") + for file in files: + backup_file = file.with_suffix(".bak") + if backup_file.exists(): + # never overwrite backup files - they are the original, untouched checkpoints + continue + copyfile(file, backup_file) - log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") - copyfile(args.file, args.file + ".bak") - with pl_legacy_patch(): - upgrade_checkpoint(args.file) + log.info("Upgrading checkpionts ...") + for file in tqdm(files): + with pl_legacy_patch(): + upgrade_checkpoint(file) From 96dc2bce6ad6cf539a74f1b7b4403a3e240af0e8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 17:49:30 +0200 Subject: [PATCH 03/18] parse --- .../utilities/upgrade_checkpoint.py | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index a06a721952ddd..0ba8322ddfe78 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse import glob import logging +from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile from typing import List @@ -51,49 +51,28 @@ def upgrade_checkpoint(filepath: _PATH) -> None: torch.save(checkpoint, filepath) -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "A utility to upgrade old checkpoints to the format of the current Lightning version." - " By default, this will also save a backup of the original file." - ) - ) - parser.add_argument("path", "--path", "--file", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") - parser.add_argument("--recursive", "-r", type=bool, action="store_true", help="If the specified path is a directory, recursively search for checkpoint files to upgrade") - parser.add_argument("--extension", "-e", type=str, default=".ckpt", help="The file extension to look for when searching for checkpoint files in a directory.") - parser.add_argument("--no-backup", type=bool, action="store_true", help="Do not save backup files. This will overwrite your existing files") - - args = parser.parse_args() +def main(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" files: List[Path] = [] if not path.exists(): - raise FileNotFoundError() + log.error( + f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory" + " containing checkpoints." + ) + exit(1) if path.is_file(): files = [path] if path.is_dir(): - files = [Path(p) for p in glob.glob(f"{str(path)}**{extension}", recursive=True)] + files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)] if not files: - raise FileNotFoundError( - f"No files were found in {path}." + log.error( + f"No checkpoint files with extension {extension} were found in {path}." f" HINT: Try setting the `--extension` option to specify the right file extension to look for." ) - - # existing_backup_files = [file.relative_to(path.parent) for file in files if file.suffix == ".bak"] - # existing_backup_files_str = "\n".join((str(f) for f in existing_backup_files)) - # if existing_backup_files: - # answer = input( - # "It looks like there was already an upgrade on these files. The following backup files already exist:\n\n" - # f"{existing_backup_files_str}\n\n" - # f"Are you sure you want to continue and overwrite the file(s)? The backup files are the original" - # f" checkpoints before the upgrade" - # ) - - - # if not args.no_backup: - # input("") + exit(1) log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") for file in files: @@ -107,3 +86,28 @@ def upgrade_checkpoint(filepath: _PATH) -> None: for file in tqdm(files): with pl_legacy_patch(): upgrade_checkpoint(file) + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "A utility to upgrade old checkpoints to the format of the current Lightning version." + " By default, this will also save a backup of the original file." + ) + ) + parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") + parser.add_argument( + "--recursive", + "-r", + action="store_true", + help="If the specified path is a directory, recursively search for checkpoint files to upgrade", + ) + parser.add_argument( + "--extension", + "-e", + type=str, + default=".ckpt", + help="The file extension to look for when searching for checkpoint files in a directory.", + ) + args = parser.parse_args() + main(args) From beb6bd3e966f66b2fa6c6346d2955eb99d9b7eab Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Oct 2022 18:01:21 +0200 Subject: [PATCH 04/18] done --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 0ba8322ddfe78..4e440e8dba5fe 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -82,11 +82,14 @@ def main(args: Namespace) -> None: continue copyfile(file, backup_file) - log.info("Upgrading checkpionts ...") - for file in tqdm(files): + log.info("Upgrading checkpoints ...") + progress = tqdm if len(files) > 1 else lambda x: x + for file in progress(files): with pl_legacy_patch(): upgrade_checkpoint(file) + log.info("Done.") + if __name__ == "__main__": parser = ArgumentParser( From b079dd4790e5504352bdf11712e163ef3b3e95a0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 27 Oct 2022 13:31:56 +0200 Subject: [PATCH 05/18] docs --- .../common/checkpointing_migration.rst | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst index 22812aa900c53..2b0ee82950402 100644 --- a/docs/source-pytorch/common/checkpointing_migration.rst +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -37,9 +37,15 @@ Upgrade checkpoint files permanently ************************************ When Lightning loads a checkpoint, it applies the version migration on-the-fly as explained above, but it does not modify your checkpoint files. -You can upgrade checkpoint files permanently with the following command: +You can upgrade checkpoint files permanently with the following command -.. code-block:: python +.. code-block:: + + python -m lightning.pytorch.utilities.upgrade_checkpoint path/to/model.ckpt + + +or a folder with multiple files: - python -m lightning.pytorch.utilities.upgrade_checkpoint --file model.ckpt +.. code-block:: + python -m lightning.pytorch.utilities.upgrade_checkpoint /path/to/checkpoints/folder From e60bf5c64b3face7ec2a522ebb0b7e9e0b32d1bc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 18:27:57 +0100 Subject: [PATCH 06/18] fix --- .../utilities/upgrade_checkpoint.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 8e38eecb6943f..f38bc2129510b 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -23,22 +23,7 @@ from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch -log = logging.getLogger(__name__) - - -def upgrade_checkpoint(filepath: _PATH) -> None: - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) +_log = logging.getLogger(__name__) def main(args: Namespace) -> None: @@ -47,7 +32,7 @@ def main(args: Namespace) -> None: files: List[Path] = [] if not path.exists(): - log.error( + _log.error( f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory" " containing checkpoints." ) @@ -58,13 +43,13 @@ def main(args: Namespace) -> None: if path.is_dir(): files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)] if not files: - log.error( + _log.error( f"No checkpoint files with extension {extension} were found in {path}." f" HINT: Try setting the `--extension` option to specify the right file extension to look for." ) exit(1) - log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") + _log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") for file in files: backup_file = file.with_suffix(".bak") if backup_file.exists(): @@ -72,13 +57,15 @@ def main(args: Namespace) -> None: continue copyfile(file, backup_file) - log.info("Upgrading checkpoints ...") + _log.info("Upgrading checkpoints ...") progress = tqdm if len(files) > 1 else lambda x: x for file in progress(files): with pl_legacy_patch(): - upgrade_checkpoint(file) + checkpoint = torch.load(file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, file) - log.info("Done.") + _log.info("Done.") if __name__ == "__main__": From f72eb2f50e3b8a0f60a10f6cff3349fe6d567523 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 18:30:00 +0100 Subject: [PATCH 07/18] update description --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index f38bc2129510b..24ff405ee41a7 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -72,7 +72,7 @@ def main(args: Namespace) -> None: parser = ArgumentParser( description=( "A utility to upgrade old checkpoints to the format of the current Lightning version." - " By default, this will also save a backup of the original file." + " This will also save a backup of the original files." ) ) parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") From c5236a1889329889c48e86debcd32dde3eab3ea7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 19:12:23 +0100 Subject: [PATCH 08/18] test --- .../utilities/upgrade_checkpoint.py | 18 ++--- .../utilities/test_upgrade_checkpoint.py | 75 ++++++++++++++++++- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 24ff405ee41a7..11d2e0254ecc9 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -26,7 +26,7 @@ _log = logging.getLogger(__name__) -def main(args: Namespace) -> None: +def _upgrade(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" files: List[Path] = [] @@ -34,7 +34,7 @@ def main(args: Namespace) -> None: if not path.exists(): _log.error( f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory" - " containing checkpoints." + f" containing checkpoints ending in {extension}." ) exit(1) @@ -68,7 +68,7 @@ def main(args: Namespace) -> None: _log.info("Done.") -if __name__ == "__main__": +def main() -> None: parser = ArgumentParser( description=( "A utility to upgrade old checkpoints to the format of the current Lightning version." @@ -76,12 +76,6 @@ def main(args: Namespace) -> None: ) ) parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade") - parser.add_argument( - "--recursive", - "-r", - action="store_true", - help="If the specified path is a directory, recursively search for checkpoint files to upgrade", - ) parser.add_argument( "--extension", "-e", @@ -90,4 +84,8 @@ def main(args: Namespace) -> None: help="The file extension to look for when searching for checkpoint files in a directory.", ) args = parser.parse_args() - main(args) + _upgrade(args) + + +if __name__ == "__main__": + main() diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index c8866829581cb..fab3e19fdcb6d 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from pathlib import Path +from unittest import mock +from unittest.mock import ANY, call + import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version - +from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main @pytest.mark.parametrize( "old_checkpoint, new_checkpoint", @@ -47,3 +52,71 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == old_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ + + +def test_upgrade_checkpoint_file_missing(tmp_path, caplog): + # path to single file (missing) + file = tmp_path / "checkpoint.ckpt" + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): + with caplog.at_level(logging.ERROR): + with pytest.raises(SystemExit): + upgrade_main() + assert f"The path {file} does not exist" in caplog.text + + caplog.clear() + + # path to non-empty directory, but no checkpoints with matching extension + file.touch() + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]): + with caplog.at_level(logging.ERROR): + with pytest.raises(SystemExit): + upgrade_main() + assert "No checkpoint files with extension .other were found" in caplog.text + + +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") +def test_upgrade_checkpoint_single_file(migrate_mock, load_mock, save_mock, tmp_path): + file = tmp_path / "checkpoint.ckpt" + file.touch() + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): + upgrade_main() + + load_mock.assert_called_once_with(Path(file)) + migrate_mock.assert_called_once() + save_mock.assert_called_once_with(ANY, Path(file)) + + +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") +def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_path): + top_files = [tmp_path / "top0.ckpt", tmp_path / "top1.ckpt"] + nested_files = [ + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir0" / "nested1.other", + tmp_path / "subdir1" / "nested2.ckpt", + ] + + for file in top_files + nested_files: + file.parent.mkdir(exist_ok=True, parents=True) + file.touch() + + # directory with recursion + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]): + upgrade_main() + + assert load_mock.call_args_list == [ + call(tmp_path / "top1.ckpt"), + call(tmp_path / "top0.ckpt"), + call(tmp_path / "subdir0" / "nested0.ckpt"), + call(tmp_path / "subdir1" / "nested2.ckpt"), + ] + assert migrate_mock.call_count == 4 + assert save_mock.call_args_list == [ + call(ANY, tmp_path / "top1.ckpt"), + call(ANY, tmp_path / "top0.ckpt"), + call(ANY, tmp_path / "subdir0" / "nested0.ckpt"), + call(ANY, tmp_path / "subdir1" / "nested2.ckpt"), + ] From b75ff6eb88cfa08dc54a8397ad69d90e794f16fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Nov 2022 18:14:36 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/utilities/test_upgrade_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index fab3e19fdcb6d..5a5ff3433d157 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main + @pytest.mark.parametrize( "old_checkpoint, new_checkpoint", [ From 7da658fa89ef514d74f34f75835a2406ea021488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Nov 2022 19:14:48 +0100 Subject: [PATCH 10/18] notebook --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 6d5634b794218..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d5634b7942180e6ba4a30bfbd74926d1c22f1eb From d116bb8f7eb202e41ea24286fa900c69bf6a5090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Nov 2022 19:14:54 +0100 Subject: [PATCH 11/18] notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0ad097a6fec2b --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0ad097a6fec2b2c3f8ddd5d2263e178c41d614f5 From 8a328ece859e6b91bdf091da57f71cb9acabd079 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 19:20:01 +0100 Subject: [PATCH 12/18] changelog --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9c03c4cf6715c..b4a6415b148cf 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) -- +- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) - From aaf82685aaa4aef1e9ebd9aa0011bf93e0cd3cd0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 19:25:58 +0100 Subject: [PATCH 13/18] use tqdm regardless of how many files --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 11d2e0254ecc9..e02b89457608e 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -58,8 +58,7 @@ def _upgrade(args: Namespace) -> None: copyfile(file, backup_file) _log.info("Upgrading checkpoints ...") - progress = tqdm if len(files) > 1 else lambda x: x - for file in progress(files): + for file in tqdm(files): with pl_legacy_patch(): checkpoint = torch.load(file) migrate_checkpoint(checkpoint) From 4ba1c22e0ea3ded48308b0f36a2665768ec2d654 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 19:31:36 +0100 Subject: [PATCH 14/18] update lightnig.pytorch to pytorch_lightning --- docs/source-pytorch/common/checkpointing_migration.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst index 2b0ee82950402..5487e6f34aab0 100644 --- a/docs/source-pytorch/common/checkpointing_migration.rst +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -41,11 +41,11 @@ You can upgrade checkpoint files permanently with the following command .. code-block:: - python -m lightning.pytorch.utilities.upgrade_checkpoint path/to/model.ckpt + python -m pytorch_lightning.utilities.upgrade_checkpoint path/to/model.ckpt or a folder with multiple files: .. code-block:: - python -m lightning.pytorch.utilities.upgrade_checkpoint /path/to/checkpoints/folder + python -m pytorch_lightning.utilities.upgrade_checkpoint /path/to/checkpoints/folder From 70e6461cf0aadc13f93cde13ae73de75cd693d7e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 20:02:09 +0100 Subject: [PATCH 15/18] fix duplicate label --- docs/source-pytorch/common/checkpointing_migration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst index 5487e6f34aab0..2926e583a3305 100644 --- a/docs/source-pytorch/common/checkpointing_migration.rst +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -1,6 +1,6 @@ :orphan: -.. _checkpointing_basic: +.. _checkpointing_intermediate: #################################### Upgrading checkpoints (intermediate) From 953d6fbfa0e85f1ee492a4502ca726c2854e2ca4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 20:06:46 +0100 Subject: [PATCH 16/18] fix deterministic order in file paths --- .../utilities/test_upgrade_checkpoint.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 5a5ff3433d157..8ecd1b689ffa9 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -108,16 +108,16 @@ def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_pa with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]): upgrade_main() - assert load_mock.call_args_list == [ - call(tmp_path / "top1.ckpt"), - call(tmp_path / "top0.ckpt"), - call(tmp_path / "subdir0" / "nested0.ckpt"), - call(tmp_path / "subdir1" / "nested2.ckpt"), - ] + assert {c[0][0] for c in load_mock.call_args_list} == { + tmp_path / "top0.ckpt", + tmp_path / "top1.ckpt", + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir1" / "nested2.ckpt", + } assert migrate_mock.call_count == 4 - assert save_mock.call_args_list == [ - call(ANY, tmp_path / "top1.ckpt"), - call(ANY, tmp_path / "top0.ckpt"), - call(ANY, tmp_path / "subdir0" / "nested0.ckpt"), - call(ANY, tmp_path / "subdir1" / "nested2.ckpt"), - ] + assert {c[0][1] for c in save_mock.call_args_list} == { + tmp_path / "top0.ckpt", + tmp_path / "top1.ckpt", + tmp_path / "subdir0" / "nested0.ckpt", + tmp_path / "subdir1" / "nested2.ckpt", + } From 3a6ee0c580e052b447ee9cde64b8a325ec5558bf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 2 Nov 2022 21:22:10 +0100 Subject: [PATCH 17/18] unused import --- tests/tests_pytorch/utilities/test_upgrade_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 8ecd1b689ffa9..2a53448f5189c 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -14,7 +14,7 @@ import logging from pathlib import Path from unittest import mock -from unittest.mock import ANY, call +from unittest.mock import ANY import pytest From 80a5205265485b916299f919121353b5b8757667 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 4 Nov 2022 22:01:41 +0100 Subject: [PATCH 18/18] fix duplicate labels --- docs/source-pytorch/common/checkpointing_intermediate.rst | 2 +- docs/source-pytorch/common/checkpointing_migration.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/checkpointing_intermediate.rst b/docs/source-pytorch/common/checkpointing_intermediate.rst index 02293effcd767..e20ad6884618c 100644 --- a/docs/source-pytorch/common/checkpointing_intermediate.rst +++ b/docs/source-pytorch/common/checkpointing_intermediate.rst @@ -1,6 +1,6 @@ :orphan: -.. _checkpointing_intermediate: +.. _checkpointing_intermediate_1: ############################################### Customize checkpointing behavior (intermediate) diff --git a/docs/source-pytorch/common/checkpointing_migration.rst b/docs/source-pytorch/common/checkpointing_migration.rst index 2926e583a3305..d04b24f4c60b9 100644 --- a/docs/source-pytorch/common/checkpointing_migration.rst +++ b/docs/source-pytorch/common/checkpointing_migration.rst @@ -1,6 +1,6 @@ :orphan: -.. _checkpointing_intermediate: +.. _checkpointing_intermediate_2: #################################### Upgrading checkpoints (intermediate)