Skip to content

Commit

Permalink
Fix multithreading checkpoint loading (#17678)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
3 people authored May 31, 2023
1 parent fd296e0 commit 1307b60
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/lightning/pytorch/utilities/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import sys
import threading
from types import ModuleType, TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type

Expand All @@ -28,6 +29,7 @@

_log = logging.getLogger(__name__)
_CHECKPOINT = Dict[str, Any]
_lock = threading.Lock()


def migrate_checkpoint(
Expand Down Expand Up @@ -85,6 +87,7 @@ class pl_legacy_patch:
"""

def __enter__(self) -> "pl_legacy_patch":
_lock.acquire()
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils")
sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module
Expand All @@ -103,6 +106,7 @@ def __exit__(
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
delattr(pl.utilities.argparse, "_gpus_arg_default")
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
_lock.release()


def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT:
Expand Down
14 changes: 9 additions & 5 deletions tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import glob
import os
import sys
import threading
from unittest.mock import patch

import pytest
Expand All @@ -26,6 +25,7 @@
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
from tests_pytorch.helpers.threading import ThreadExceptionHandler

LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
CHECKPOINT_EXTENSION = ".ckpt"
Expand Down Expand Up @@ -68,18 +68,22 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
@RunIf(sklearn=True)
def test_legacy_ckpt_threading(tmpdir, pl_version: str):
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
path_ckpt = path_ckpts[-1]

def load_model():
import torch

from lightning.pytorch.utilities.migration import pl_legacy_patch

with pl_legacy_patch():
_ = torch.load(PATH_LEGACY)
_ = torch.load(path_ckpt)

PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = threading.Thread(target=load_model)
t2 = threading.Thread(target=load_model)
t1 = ThreadExceptionHandler(target=load_model)
t2 = ThreadExceptionHandler(target=load_model)

t1.start()
t2.start()
Expand Down
33 changes: 33 additions & 0 deletions tests/tests_pytorch/helpers/threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from threading import Thread


class ThreadExceptionHandler(Thread):
"""Adopted from https://stackoverflow.com/a/67022927."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exception = None

def run(self):
try:
super().run()
except Exception as e:
self.exception = e

def join(self):
super().join()
if self.exception:
raise self.exception

0 comments on commit 1307b60

Please sign in to comment.