From de9316760fa450742c3a1584ba392dcde78a8f51 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 7 Dec 2022 17:11:07 +0100 Subject: [PATCH] Fix LRScheduler import for PyTorch 2.0 (#15940) * Fix LRScheduler import for PyTorch 2.0 * Add comment for posterity --- src/pytorch_lightning/utilities/types.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f7a8942f503bd..d766e4fdb7519 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,6 +27,13 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +try: + from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler +except ImportError: + # For torch <= 1.13.x + # TODO: Remove once minimum torch version is 1.14 (or 2.0) + from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler + from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] @@ -111,9 +118,9 @@ def no_sync(self) -> Generator: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]