Skip to content

Commit

Permalink
Fix LRScheduler import for PyTorch 2.0 (#15940)
Browse files Browse the repository at this point in the history
* Fix LRScheduler import for PyTorch 2.0
* Add comment for posterity
  • Loading branch information
lantiga authored Dec 7, 2022
1 parent 2041908 commit de93167
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable

try:
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
except ImportError:
# For torch <= 1.13.x
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler

from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau

_NUMBER = Union[int, float]
Expand Down Expand Up @@ -111,9 +118,9 @@ def no_sync(self) -> Generator:


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]


Expand Down

0 comments on commit de93167

Please sign in to comment.