diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst index 8e312b7233a6d0..0917289e059f14 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst @@ -178,7 +178,7 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t ******************** Custom LR schedulers ******************** -Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler: +Any subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler: .. code:: python diff --git a/src/lightning_lite/utilities/types.py b/src/lightning_lite/utilities/types.py index a3ee70ea68ea1f..de834212ecebd9 100644 --- a/src/lightning_lite/utilities/types.py +++ b/src/lightning_lite/utilities/types.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer from typing_extensions import Protocol, runtime_checkable -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13 +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14 _PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] @@ -63,7 +63,7 @@ def rank(self) -> int: # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class _LRScheduler(_Stateful[str], Protocol): +class LRScheduler(_Stateful[str], Protocol): optimizer: Optimizer base_lrs: List[float] @@ -74,6 +74,11 @@ def step(self, epoch: Optional[int] = None) -> None: ... +_TORCH_LRSCHEDULER = ( + torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler +) + + # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index ccf5051d5fd398..3476de66b37bd2 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -23,7 +23,7 @@ from torch.optim.swa_utils import SWALR import pytorch_lightning as pl -from lightning_lite.utilities.types import _LRScheduler +from lightning_lite.utilities.types import LRScheduler from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy @@ -125,7 +125,7 @@ def __init__( self._model_contains_batch_norm: Optional[bool] = None self._average_model: Optional["pl.LightningModule"] = None self._initialized = False - self._swa_scheduler: Optional[_LRScheduler] = None + self._swa_scheduler: Optional[LRScheduler] = None self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 @@ -192,7 +192,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo assert trainer.max_epochs is not None self._swa_scheduler = cast( - _LRScheduler, + LRScheduler, SWALR( optimizer, swa_lr=self._swa_lrs, # type: ignore[arg-type] diff --git a/src/pytorch_lightning/core/optimizer.py b/src/pytorch_lightning/core/optimizer.py index e1a834f8c87ef3..c18c1e2697b030 100644 --- a/src/pytorch_lightning/core/optimizer.py +++ b/src/pytorch_lightning/core/optimizer.py @@ -253,8 +253,8 @@ def _configure_optimizers( " Output from `model.configure_optimizers()` should be one of:\n" " * `Optimizer`\n" " * [`Optimizer`]\n" - " * ([`Optimizer`], [`_LRScheduler`])\n" - ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n' + " * ([`Optimizer`], [`LRScheduler`])\n" + ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n' ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) return optimizers, lr_schedulers, optimizer_frequencies, monitor diff --git a/src/pytorch_lightning/demos/boring_classes.py b/src/pytorch_lightning/demos/boring_classes.py index 31483db9bd53b6..a67048c670e9b6 100644 --- a/src/pytorch_lightning/demos/boring_classes.py +++ b/src/pytorch_lightning/demos/boring_classes.py @@ -18,9 +18,9 @@ import torch.nn.functional as F from torch import Tensor from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning import LightningDataModule, LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -137,7 +137,7 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No outputs = cast(List[Dict[str, Tensor]], outputs) torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]: + def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]: optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e7df64c2acc10d..d62df864ee9513 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -34,7 +34,7 @@ from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.optimizer import _optimizers_to_device from lightning_lite.utilities.seed import reset_seed -from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau +from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau from pytorch_lightning.accelerators.cuda import CUDAAccelerator from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @@ -426,7 +426,7 @@ def _setup_model_and_optimizer( self, model: Module, optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, + lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index 7cad027ac6aefb..61d39367d0563d 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -9,7 +9,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.enums import PrecisionType -from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau +from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau from pytorch_lightning.strategies.strategy import Strategy, TBroadcast from pytorch_lightning.utilities.data import extract_batch_size from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -312,7 +312,7 @@ class HiveMindScheduler: base_lrs: List[float] - def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None: + def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None: # copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has # implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler` self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")} diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index fa55d062320bbc..996940eb9c79b5 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -21,14 +21,14 @@ import numpy as np import torch from lightning_utilities.core.imports import RequirementCache -from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT +from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -124,7 +124,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) - scheduler = cast(pl.utilities.types._LRScheduler, scheduler) + scheduler = cast(LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] @@ -394,7 +394,7 @@ def on_train_batch_end( self.losses.append(smoothed_loss) -class _LinearLR(_LRScheduler): +class _LinearLR(_TORCH_LRSCHEDULER): """Linearly increases the learning rate between two boundaries over a number of iterations. Args: @@ -413,7 +413,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: # type: ignore[override] + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -429,7 +429,7 @@ def lr(self) -> Union[float, List[float]]: return self._lr -class _ExponentialLR(_LRScheduler): +class _ExponentialLR(_TORCH_LRSCHEDULER): """Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: @@ -448,7 +448,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: # type: ignore[override] + def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index d766e4fdb7519a..db736e9cc2f9b4 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -27,14 +27,7 @@ from torchmetrics import Metric from typing_extensions import Protocol, runtime_checkable -try: - from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler -except ImportError: - # For torch <= 1.13.x - # TODO: Remove once minimum torch version is 1.14 (or 2.0) - from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler - -from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] @@ -118,15 +111,15 @@ def no_sync(self) -> Generator: # todo: improve LRSchedulerType naming/typing -LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] -LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] -LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau] +LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau] +LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] @dataclass class LRSchedulerConfig: - scheduler: Union[_LRScheduler, ReduceLROnPlateau] + scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau] # no custom name name: Optional[str] = None # after epoch is over diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 0fcacf080a4d7c..2224ed85697091 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -965,7 +965,7 @@ def training_step(self, batch, batch_idx): with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step: trainer.fit(model) - # If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`, + # If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`, # `.step()` is called once during its instantiation. # Thus, the call count should be 1, not 0. assert lr_step.call_count == 1 diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 52fb6ba5028ae5..ed821b0d6ff4ce 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -761,7 +761,7 @@ def configure_optimizers(self): trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] - # first step is called by PyTorch _LRScheduler + # first step is called by PyTorch LRScheduler assert mock_method_step.call_count == max_epochs * limit_train_batches + 1