From 81c808e97b68899511f3b2fe528ca3f4a6444856 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Sat, 20 May 2023 15:29:48 +0900 Subject: [PATCH 01/13] Avoid deleting old argparse_utils module multiple times --- src/lightning/pytorch/utilities/migration/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index f24ad3d035b53..e5da3178f1387 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -102,7 +102,8 @@ def __exit__( ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") - del sys.modules["lightning.pytorch.utilities.argparse_utils"] + if "lightning.pytorch.utilities.argparse_utils" in sys.modules: + del sys.modules["lightning.pytorch.utilities.argparse_utils"] def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: From 9773fe862922817c949dfbe0e13f751b3d01f3e5 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 00:23:43 +0900 Subject: [PATCH 02/13] Fix test test_legacy_ckpt_threading `test_legacy_ckpt_threading` test is actually failing, but it always passes because exception from threads are not handled correctly. --- .../checkpointing/test_legacy_checkpoints.py | 15 ++++++--- tests/tests_pytorch/helpers/threading.py | 33 +++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 tests/tests_pytorch/helpers/threading.py diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index a171e92e8bd59..3df6ad0dac583 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -26,6 +26,7 @@ from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel +from tests_pytorch.helpers.threading import ThreadExceptionHandler LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" @@ -68,18 +69,22 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) def test_legacy_ckpt_threading(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(PATH_LEGACY) - - PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + _ = torch.load(path_ckpt) + with patch("sys.path", [PATH_LEGACY] + sys.path): - t1 = threading.Thread(target=load_model) - t2 = threading.Thread(target=load_model) + t1 = ThreadExceptionHandler(target=load_model) + t2 = ThreadExceptionHandler(target=load_model) t1.start() t2.start() diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py new file mode 100644 index 0000000000000..f943c50d01b94 --- /dev/null +++ b/tests/tests_pytorch/helpers/threading.py @@ -0,0 +1,33 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from threading import Thread + +class ThreadExceptionHandler(Thread): + """ + Adopted from https://stackoverflow.com/a/67022927 + """ + def __init__(self, target, args=(), kwargs={}): + Thread.__init__(self, target=target, args=args, kwargs=kwargs) + self.exception = None + + def run(self): + try: + self._target(*self._args, **self._kwargs) + except Exception as e: + self.exception = e + + def join(self): + super().join() + if self.exception: + raise self.exception \ No newline at end of file From 093bc11fffc9b1935b9447b394ceb407cbcd7c88 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 00:29:29 +0900 Subject: [PATCH 03/13] Add `load_legacy_checkpoint` function --- .../pytorch/utilities/migration/pickle.py | 44 +++++++++++++++++++ .../pytorch/utilities/migration/utils.py | 10 ++++- .../checkpointing/test_legacy_checkpoints.py | 24 ++++++++++ tests/tests_pytorch/helpers/threading.py | 10 ++--- 4 files changed, 81 insertions(+), 7 deletions(-) create mode 100644 src/lightning/pytorch/utilities/migration/pickle.py diff --git a/src/lightning/pytorch/utilities/migration/pickle.py b/src/lightning/pytorch/utilities/migration/pickle.py new file mode 100644 index 0000000000000..97657aa4f0557 --- /dev/null +++ b/src/lightning/pytorch/utilities/migration/pickle.py @@ -0,0 +1,44 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle +from types import TracebackType +from typing import Any, Optional, Type + +import lightning.pytorch as pl + + +class LegacyModulePatcher: + def __enter__(self) -> "LegacyModulePatcher": + pl.utilities.argparse._gpus_arg_default = lambda x: x + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + if hasattr(pl.utilities.argparse, "_gpus_arg_default"): + delattr(pl.utilities.argparse, "_gpus_arg_default") + + +class Unpickler(pickle.Unpickler): + def load(self) -> Any: + with LegacyModulePatcher(): + return super().load() + + def find_class(self, module: str, name: str): + if module == "lightning.pytorch.utilities.argparse_utils": + module = "lightning.pytorch.utilities.argparse" + return super().find_class(module, name) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index e5da3178f1387..6bb05658e99a2 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -17,12 +17,15 @@ from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type +import torch from packaging.version import Version +from torch.serialization import FILE_LIKE import lightning.pytorch as pl from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.types import _PATH from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.utilities.migration import pickle as legacy_checkpoint_pickle from lightning.pytorch.utilities.migration.migration import _migration_index from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -69,6 +72,10 @@ def migrate_checkpoint( return checkpoint, applied_migrations +def load_legacy_checkpoint(f: FILE_LIKE, **torch_load_kwargs): + return torch.load(f, pickle_module=legacy_checkpoint_pickle, **torch_load_kwargs) + + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -102,8 +109,7 @@ def __exit__( ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") - if "lightning.pytorch.utilities.argparse_utils" in sys.modules: - del sys.modules["lightning.pytorch.utilities.argparse_utils"] + del sys.modules["lightning.pytorch.utilities.argparse_utils"] def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 3df6ad0dac583..36dddf42f7d87 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -93,6 +93,30 @@ def load_model(): t2.join() +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +@RunIf(sklearn=True) +def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + + def load_model(): + from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint + + _ = load_legacy_checkpoint(path_ckpt) + + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = ThreadExceptionHandler(target=load_model) + t2 = ThreadExceptionHandler(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py index f943c50d01b94..38c43c858807c 100644 --- a/tests/tests_pytorch/helpers/threading.py +++ b/tests/tests_pytorch/helpers/threading.py @@ -13,14 +13,14 @@ # limitations under the License. from threading import Thread + class ThreadExceptionHandler(Thread): - """ - Adopted from https://stackoverflow.com/a/67022927 - """ + """Adopted from https://stackoverflow.com/a/67022927.""" + def __init__(self, target, args=(), kwargs={}): Thread.__init__(self, target=target, args=args, kwargs=kwargs) self.exception = None - + def run(self): try: self._target(*self._args, **self._kwargs) @@ -30,4 +30,4 @@ def run(self): def join(self): super().join() if self.exception: - raise self.exception \ No newline at end of file + raise self.exception From 6749f4102f991a0d7e3e0f77102a278e28755698 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 00:29:29 +0900 Subject: [PATCH 04/13] Add `load_legacy_checkpoint` function --- .../checkpointing/test_legacy_checkpoints.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 36dddf42f7d87..1061778843a4d 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -14,7 +14,6 @@ import glob import os import sys -import threading from unittest.mock import patch import pytest @@ -80,8 +79,32 @@ def load_model(): from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt) - + _ = torch.load(path_ckpt) + + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = ThreadExceptionHandler(target=load_model) + t2 = ThreadExceptionHandler(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +@RunIf(sklearn=True) +def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + + def load_model(): + from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint + + _ = load_legacy_checkpoint(path_ckpt) + with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) t2 = ThreadExceptionHandler(target=load_model) From 2fdef0df4804655d7a2f3c12d4779f54cb321d4f Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 01:02:57 +0900 Subject: [PATCH 05/13] Revert "Add `load_legacy_checkpoint` function" This reverts commit 6749f4102f991a0d7e3e0f77102a278e28755698. --- .../checkpointing/test_legacy_checkpoints.py | 29 ++----------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 1061778843a4d..36dddf42f7d87 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -14,6 +14,7 @@ import glob import os import sys +import threading from unittest.mock import patch import pytest @@ -79,32 +80,8 @@ def load_model(): from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt) - - with patch("sys.path", [PATH_LEGACY] + sys.path): - t1 = ThreadExceptionHandler(target=load_model) - t2 = ThreadExceptionHandler(target=load_model) - - t1.start() - t2.start() - - t1.join() - t2.join() - - -@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) -@RunIf(sklearn=True) -def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str): - PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) - path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) - assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' - path_ckpt = path_ckpts[-1] - - def load_model(): - from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint - - _ = load_legacy_checkpoint(path_ckpt) - + _ = torch.load(path_ckpt) + with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) t2 = ThreadExceptionHandler(target=load_model) From 21e221c5b5445eccbf1abef1f7f32deb449f496e Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 01:03:02 +0900 Subject: [PATCH 06/13] Revert "Add `load_legacy_checkpoint` function" This reverts commit 093bc11fffc9b1935b9447b394ceb407cbcd7c88. --- .../pytorch/utilities/migration/pickle.py | 44 ------------------- .../pytorch/utilities/migration/utils.py | 10 +---- .../checkpointing/test_legacy_checkpoints.py | 24 ---------- tests/tests_pytorch/helpers/threading.py | 10 ++--- 4 files changed, 7 insertions(+), 81 deletions(-) delete mode 100644 src/lightning/pytorch/utilities/migration/pickle.py diff --git a/src/lightning/pytorch/utilities/migration/pickle.py b/src/lightning/pytorch/utilities/migration/pickle.py deleted file mode 100644 index 97657aa4f0557..0000000000000 --- a/src/lightning/pytorch/utilities/migration/pickle.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pickle -from types import TracebackType -from typing import Any, Optional, Type - -import lightning.pytorch as pl - - -class LegacyModulePatcher: - def __enter__(self) -> "LegacyModulePatcher": - pl.utilities.argparse._gpus_arg_default = lambda x: x - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], - ) -> None: - if hasattr(pl.utilities.argparse, "_gpus_arg_default"): - delattr(pl.utilities.argparse, "_gpus_arg_default") - - -class Unpickler(pickle.Unpickler): - def load(self) -> Any: - with LegacyModulePatcher(): - return super().load() - - def find_class(self, module: str, name: str): - if module == "lightning.pytorch.utilities.argparse_utils": - module = "lightning.pytorch.utilities.argparse" - return super().find_class(module, name) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 6bb05658e99a2..e5da3178f1387 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -17,15 +17,12 @@ from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type -import torch from packaging.version import Version -from torch.serialization import FILE_LIKE import lightning.pytorch as pl from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.types import _PATH from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.utilities.migration import pickle as legacy_checkpoint_pickle from lightning.pytorch.utilities.migration.migration import _migration_index from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -72,10 +69,6 @@ def migrate_checkpoint( return checkpoint, applied_migrations -def load_legacy_checkpoint(f: FILE_LIKE, **torch_load_kwargs): - return torch.load(f, pickle_module=legacy_checkpoint_pickle, **torch_load_kwargs) - - class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -109,7 +102,8 @@ def __exit__( ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") - del sys.modules["lightning.pytorch.utilities.argparse_utils"] + if "lightning.pytorch.utilities.argparse_utils" in sys.modules: + del sys.modules["lightning.pytorch.utilities.argparse_utils"] def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 36dddf42f7d87..3df6ad0dac583 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -93,30 +93,6 @@ def load_model(): t2.join() -@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) -@RunIf(sklearn=True) -def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str): - PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) - path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) - assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' - path_ckpt = path_ckpts[-1] - - def load_model(): - from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint - - _ = load_legacy_checkpoint(path_ckpt) - - with patch("sys.path", [PATH_LEGACY] + sys.path): - t1 = ThreadExceptionHandler(target=load_model) - t2 = ThreadExceptionHandler(target=load_model) - - t1.start() - t2.start() - - t1.join() - t2.join() - - @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py index 38c43c858807c..f943c50d01b94 100644 --- a/tests/tests_pytorch/helpers/threading.py +++ b/tests/tests_pytorch/helpers/threading.py @@ -13,14 +13,14 @@ # limitations under the License. from threading import Thread - class ThreadExceptionHandler(Thread): - """Adopted from https://stackoverflow.com/a/67022927.""" - + """ + Adopted from https://stackoverflow.com/a/67022927 + """ def __init__(self, target, args=(), kwargs={}): Thread.__init__(self, target=target, args=args, kwargs=kwargs) self.exception = None - + def run(self): try: self._target(*self._args, **self._kwargs) @@ -30,4 +30,4 @@ def run(self): def join(self): super().join() if self.exception: - raise self.exception + raise self.exception \ No newline at end of file From 28d904394637e01cd93da0b1b45f7f88bf536cc5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 May 2023 16:17:55 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../checkpointing/test_legacy_checkpoints.py | 5 ++--- tests/tests_pytorch/helpers/threading.py | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 3df6ad0dac583..3e69659461ac0 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -14,7 +14,6 @@ import glob import os import sys -import threading from unittest.mock import patch import pytest @@ -80,8 +79,8 @@ def load_model(): from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt) - + _ = torch.load(path_ckpt) + with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) t2 = ThreadExceptionHandler(target=load_model) diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py index f943c50d01b94..38c43c858807c 100644 --- a/tests/tests_pytorch/helpers/threading.py +++ b/tests/tests_pytorch/helpers/threading.py @@ -13,14 +13,14 @@ # limitations under the License. from threading import Thread + class ThreadExceptionHandler(Thread): - """ - Adopted from https://stackoverflow.com/a/67022927 - """ + """Adopted from https://stackoverflow.com/a/67022927.""" + def __init__(self, target, args=(), kwargs={}): Thread.__init__(self, target=target, args=args, kwargs=kwargs) self.exception = None - + def run(self): try: self._target(*self._args, **self._kwargs) @@ -30,4 +30,4 @@ def run(self): def join(self): super().join() if self.exception: - raise self.exception \ No newline at end of file + raise self.exception From 7a402ad53ff78310adba8a619348d913919d634f Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 01:21:23 +0900 Subject: [PATCH 08/13] Remove unnecessary arguments from ThreadExceptionHandler --- tests/tests_pytorch/helpers/threading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py index 38c43c858807c..c80468dff8d20 100644 --- a/tests/tests_pytorch/helpers/threading.py +++ b/tests/tests_pytorch/helpers/threading.py @@ -17,13 +17,13 @@ class ThreadExceptionHandler(Thread): """Adopted from https://stackoverflow.com/a/67022927.""" - def __init__(self, target, args=(), kwargs={}): - Thread.__init__(self, target=target, args=args, kwargs=kwargs) + def __init__(self, target): + super().__init__(target=target) self.exception = None def run(self): try: - self._target(*self._args, **self._kwargs) + self._target() except Exception as e: self.exception = e From 6227c43629b8bf39f0c755ca6ad09fa828d98090 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 30 May 2023 22:35:09 +0900 Subject: [PATCH 09/13] Re-introduce thread locking --- src/lightning/pytorch/utilities/migration/utils.py | 9 +++++++-- tests/tests_pytorch/helpers/threading.py | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index e5da3178f1387..a5da66fd8de2c 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -14,6 +14,7 @@ import logging import os import sys +import threading from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type @@ -69,6 +70,9 @@ def migrate_checkpoint( return checkpoint, applied_migrations +_lock = threading.Lock() + + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -85,6 +89,7 @@ class pl_legacy_patch: """ def __enter__(self) -> "pl_legacy_patch": + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils") sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module @@ -102,8 +107,8 @@ def __exit__( ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") - if "lightning.pytorch.utilities.argparse_utils" in sys.modules: - del sys.modules["lightning.pytorch.utilities.argparse_utils"] + del sys.modules["lightning.pytorch.utilities.argparse_utils"] + _lock.release() def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py index c80468dff8d20..6447bec303c91 100644 --- a/tests/tests_pytorch/helpers/threading.py +++ b/tests/tests_pytorch/helpers/threading.py @@ -17,13 +17,13 @@ class ThreadExceptionHandler(Thread): """Adopted from https://stackoverflow.com/a/67022927.""" - def __init__(self, target): - super().__init__(target=target) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.exception = None def run(self): try: - self._target() + super().run() except Exception as e: self.exception = e From c38fcb17e9f1e95ae0d8e182105dcfb071ff7df3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 31 May 2023 14:58:13 +0200 Subject: [PATCH 10/13] check if test reproduces issue if fix is removed --- src/lightning/pytorch/utilities/migration/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index a5da66fd8de2c..385cc6b266fc0 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -70,7 +70,7 @@ def migrate_checkpoint( return checkpoint, applied_migrations -_lock = threading.Lock() +# _lock = threading.Lock() class pl_legacy_patch: @@ -89,7 +89,7 @@ class pl_legacy_patch: """ def __enter__(self) -> "pl_legacy_patch": - _lock.acquire() + # _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils") sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module @@ -108,7 +108,7 @@ def __exit__( if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") del sys.modules["lightning.pytorch.utilities.argparse_utils"] - _lock.release() + # _lock.release() def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: From fdb6a70119ac9d9f9c76be91470173684d8b68a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 May 2023 13:00:20 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/utilities/migration/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 385cc6b266fc0..2a5e693ae94d8 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -14,7 +14,6 @@ import logging import os import sys -import threading from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type From 587f38aea201aceb6a257a064e430148cdebdcc1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 31 May 2023 16:16:37 +0200 Subject: [PATCH 12/13] Revert "check if test reproduces issue if fix is removed" This reverts commit c38fcb17e9f1e95ae0d8e182105dcfb071ff7df3. --- src/lightning/pytorch/utilities/migration/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 2a5e693ae94d8..3e6dd3ba45f6a 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -16,6 +16,7 @@ import sys from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type +import threading from packaging.version import Version @@ -28,6 +29,7 @@ _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] +_lock = threading.Lock() def migrate_checkpoint( @@ -69,9 +71,6 @@ def migrate_checkpoint( return checkpoint, applied_migrations -# _lock = threading.Lock() - - class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -88,7 +87,7 @@ class pl_legacy_patch: """ def __enter__(self) -> "pl_legacy_patch": - # _lock.acquire() + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils") sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module @@ -107,7 +106,7 @@ def __exit__( if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") del sys.modules["lightning.pytorch.utilities.argparse_utils"] - # _lock.release() + _lock.release() def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: From 06f90354f0a2ac3a5eab598fd2be5423a5c16d62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 May 2023 14:18:07 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/utilities/migration/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 3e6dd3ba45f6a..56f018a9fe900 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -14,9 +14,9 @@ import logging import os import sys +import threading from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type -import threading from packaging.version import Version