Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the checkpoint upgrade utility script #15333

Merged
merged 23 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions docs/source-pytorch/common/checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,41 @@ Checkpointing
.. Add callout items below this line

.. displayitem::
:header: Basic
:header: Saving and loading checkpoints
:description: Learn to save and load checkpoints
:col_css: col-md-3
:col_css: col-md-4
:button_link: checkpointing_basic.html
:height: 150
:tag: basic

.. displayitem::
:header: Intermediate
:description: Customize checkpointing behavior
:col_css: col-md-3
:header: Customize checkpointing behavior
:description: Learn how to change the behavior of checkpointing
:col_css: col-md-4
:button_link: checkpointing_intermediate.html
:height: 150
:tag: intermediate

.. displayitem::
:header: Advanced
:header: Upgrading checkpoints
:description: Learn how to upgrade old checkpoints to the newest Lightning version
:col_css: col-md-4
:button_link: checkpointing_migration.html
:height: 150
:tag: intermediate

.. displayitem::
:header: Cloud-based checkpoints
:description: Enable cloud-based checkpointing and composable checkpoints.
:col_css: col-md-3
:col_css: col-md-4
:button_link: checkpointing_advanced.html
:height: 150
:tag: advanced

.. displayitem::
:header: Expert
:header: Distributed checkpoints
:description: Customize checkpointing for custom distributed strategies and accelerators.
:col_css: col-md-3
:col_css: col-md-4
:button_link: checkpointing_expert.html
:height: 150
:tag: expert
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/common/checkpointing_advanced.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.. _checkpointing_advanced:

########################
Checkpointing (advanced)
########################
##################################
Cloud-based checkpoints (advanced)
##################################


*****************
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/common/checkpointing_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

.. _checkpointing_basic:

#####################
Checkpointing (basic)
#####################
######################################
Saving and loading checkpoints (basic)
######################################
**Audience:** All users

----
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/common/checkpointing_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

.. _checkpointing_expert:

######################
Checkpointing (expert)
######################
################################
Distributed checkpoints (expert)
################################

*********************************
Writing your own Checkpoint class
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/common/checkpointing_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

.. _checkpointing_intermediate:

############################
Checkpointing (intermediate)
############################
###############################################
Customize checkpointing behavior (intermediate)
###############################################
**Audience:** Users looking to customize the checkpointing behavior

----
Expand Down
51 changes: 51 additions & 0 deletions docs/source-pytorch/common/checkpointing_migration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
:orphan:

.. _checkpointing_basic:

####################################
Upgrading checkpoints (intermediate)
####################################
**Audience:** Users who are upgrading Lightning and their code and want to reuse their old checkpoints.

----

**************************************
Resume training from an old checkpoint
**************************************

Next to the model weights and trainer state, a Lightning checkpoint contains the version number of Lightning with which the checkpoint was saved.
When you load a checkpoint file, either by resuming training

.. code-block:: python

trainer = Trainer(...)
trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt")

or by loading the state directly into your model,

.. code-block:: python

model = LitModel.load_from_checkpoint("path/to/checkpoint.ckpt")

Lightning will automatically recognize that it is from an older version and migrates the internal structure so it can be loaded properly.
This is done without any action required by the user.

----

************************************
Upgrade checkpoint files permanently
************************************

When Lightning loads a checkpoint, it applies the version migration on-the-fly as explained above, but it does not modify your checkpoint files.
You can upgrade checkpoint files permanently with the following command

.. code-block::

python -m lightning.pytorch.utilities.upgrade_checkpoint path/to/model.ckpt
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


or a folder with multiple files:

.. code-block::

python -m lightning.pytorch.utilities.upgrade_checkpoint /path/to/checkpoints/folder
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

-
- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333))

-

Expand Down
77 changes: 64 additions & 13 deletions src/pytorch_lightning/utilities/upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,81 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
from shutil import copyfile
from typing import List

import torch
from tqdm import tqdm

from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch

log = logging.getLogger(__name__)
_log = logging.getLogger(__name__)

if __name__ == "__main__":

parser = argparse.ArgumentParser(
def _upgrade(args: Namespace) -> None:
path = Path(args.path).absolute()
extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}"
files: List[Path] = []

if not path.exists():
_log.error(
f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory"
f" containing checkpoints ending in {extension}."
)
exit(1)

if path.is_file():
files = [path]
if path.is_dir():
files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)]
if not files:
_log.error(
f"No checkpoint files with extension {extension} were found in {path}."
f" HINT: Try setting the `--extension` option to specify the right file extension to look for."
)
exit(1)

_log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.")
for file in files:
backup_file = file.with_suffix(".bak")
if backup_file.exists():
# never overwrite backup files - they are the original, untouched checkpoints
continue
copyfile(file, backup_file)

_log.info("Upgrading checkpoints ...")
progress = tqdm if len(files) > 1 else lambda x: x
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
for file in progress(files):
with pl_legacy_patch():
checkpoint = torch.load(file)
migrate_checkpoint(checkpoint)
torch.save(checkpoint, file)

_log.info("Done.")


def main() -> None:
parser = ArgumentParser(
description=(
"Upgrade an old checkpoint to the current schema. This will also save a backup of the original file."
"A utility to upgrade old checkpoints to the format of the current Lightning version."
" This will also save a backup of the original files."
)
)
parser.add_argument("--file", help="filepath for a checkpoint to upgrade")

parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade")
parser.add_argument(
"--extension",
"-e",
type=str,
default=".ckpt",
help="The file extension to look for when searching for checkpoint files in a directory.",
)
args = parser.parse_args()
_upgrade(args)


log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
copyfile(args.file, args.file + ".bak")
with pl_legacy_patch():
checkpoint = torch.load(args.file)
migrate_checkpoint(checkpoint)
torch.save(checkpoint, args.file)
if __name__ == "__main__":
main()
74 changes: 74 additions & 0 deletions tests/tests_pytorch/utilities/test_upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, call

import pytest

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.migration import migrate_checkpoint
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main


@pytest.mark.parametrize(
Expand Down Expand Up @@ -47,3 +53,71 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__


def test_upgrade_checkpoint_file_missing(tmp_path, caplog):
# path to single file (missing)
file = tmp_path / "checkpoint.ckpt"
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]):
with caplog.at_level(logging.ERROR):
with pytest.raises(SystemExit):
upgrade_main()
assert f"The path {file} does not exist" in caplog.text

caplog.clear()

# path to non-empty directory, but no checkpoints with matching extension
file.touch()
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]):
with caplog.at_level(logging.ERROR):
with pytest.raises(SystemExit):
upgrade_main()
assert "No checkpoint files with extension .other were found" in caplog.text


@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save")
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load")
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint")
def test_upgrade_checkpoint_single_file(migrate_mock, load_mock, save_mock, tmp_path):
file = tmp_path / "checkpoint.ckpt"
file.touch()
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]):
upgrade_main()

load_mock.assert_called_once_with(Path(file))
migrate_mock.assert_called_once()
save_mock.assert_called_once_with(ANY, Path(file))


@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save")
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load")
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint")
def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_path):
top_files = [tmp_path / "top0.ckpt", tmp_path / "top1.ckpt"]
nested_files = [
tmp_path / "subdir0" / "nested0.ckpt",
tmp_path / "subdir0" / "nested1.other",
tmp_path / "subdir1" / "nested2.ckpt",
]

for file in top_files + nested_files:
file.parent.mkdir(exist_ok=True, parents=True)
file.touch()

# directory with recursion
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]):
upgrade_main()

assert load_mock.call_args_list == [
call(tmp_path / "top1.ckpt"),
call(tmp_path / "top0.ckpt"),
call(tmp_path / "subdir0" / "nested0.ckpt"),
call(tmp_path / "subdir1" / "nested2.ckpt"),
]
assert migrate_mock.call_count == 4
assert save_mock.call_args_list == [
call(ANY, tmp_path / "top1.ckpt"),
call(ANY, tmp_path / "top0.ckpt"),
call(ANY, tmp_path / "subdir0" / "nested0.ckpt"),
call(ANY, tmp_path / "subdir1" / "nested2.ckpt"),
]