diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f7a8942f503bd..d766e4fdb7519 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,6 +27,13 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable +try: + from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler +except ImportError: + # For torch <= 1.13.x + # TODO: Remove once minimum torch version is 1.14 (or 2.0) + from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler + from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] @@ -111,9 +118,9 @@ def no_sync(self) -> Generator: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]