From fbc9b1922ac08f84fddf43f13c8e53f678b16f12 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 7 Dec 2022 14:34:32 +0100 Subject: [PATCH 1/2] Fix LRScheduler import for PyTorch 2.0 --- src/pytorch_lightning/utilities/types.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f7a8942f503bd..f2869ca7bc91a 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,6 +27,11 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +try: + from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler + from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] @@ -111,9 +116,9 @@ def no_sync(self) -> Generator: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau] From 13df8c2ac47d24c5c8213a99aba060dbc9287930 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 7 Dec 2022 16:14:16 +0100 Subject: [PATCH 2/2] Add comment for posterity --- src/pytorch_lightning/utilities/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f2869ca7bc91a..d766e4fdb7519 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -30,6 +30,8 @@ try: from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler except ImportError: + # For torch <= 1.13.x + # TODO: Remove once minimum torch version is 1.14 (or 2.0) from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau