Skip to content

Commit

Permalink
Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler (#15768)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>

(cherry picked from commit 2577285)
  • Loading branch information
qmaruf authored and Borda committed Dec 14, 2022
1 parent 6fddd82 commit fbb5f5f
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/cli/lightning_cli_intermediate_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t
********************
Custom LR schedulers
********************
Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
Any subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:

.. code:: python
Expand Down
9 changes: 7 additions & 2 deletions src/lightning_lite/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable

from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14

_PATH = Union[str, Path]
_DEVICE = Union[torch.device, str, int]
Expand Down Expand Up @@ -63,7 +63,7 @@ def rank(self) -> int:
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class _LRScheduler(_Stateful[str], Protocol):
class LRScheduler(_Stateful[str], Protocol):
optimizer: Optimizer
base_lrs: List[float]

Expand All @@ -74,6 +74,11 @@ def step(self, epoch: Optional[int] = None) -> None:
...


_TORCH_LRSCHEDULER = (
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler
)


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from lightning_lite.utilities.types import _LRScheduler
from lightning_lite.utilities.types import LRScheduler
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: Optional["pl.LightningModule"] = None
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._swa_scheduler: Optional[LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
Expand Down Expand Up @@ -192,7 +192,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

assert trainer.max_epochs is not None
self._swa_scheduler = cast(
_LRScheduler,
LRScheduler,
SWALR(
optimizer,
swa_lr=self._swa_lrs, # type: ignore[arg-type]
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def _configure_optimizers(
" Output from `model.configure_optimizers()` should be one of:\n"
" * `Optimizer`\n"
" * [`Optimizer`]\n"
" * ([`Optimizer`], [`_LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
" * ([`Optimizer`], [`LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
Expand Down Expand Up @@ -426,7 +426,7 @@ def _setup_model_and_optimizer(
self,
model: Module,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/hivemind.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytorch_lightning as pl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -312,7 +312,7 @@ class HiveMindScheduler:

base_lrs: List[float]

def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}
Expand Down
14 changes: 7 additions & 7 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
Expand Down Expand Up @@ -124,7 +124,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
scheduler = cast(LRScheduler, scheduler)

trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
Expand Down Expand Up @@ -394,7 +394,7 @@ def on_train_batch_end(
self.losses.append(smoothed_loss)


class _LinearLR(_LRScheduler):
class _LinearLR(_TORCH_LRSCHEDULER):
"""Linearly increases the learning rate between two boundaries over a number of iterations.
Args:
Expand All @@ -413,7 +413,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)

def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand All @@ -429,7 +429,7 @@ def lr(self) -> Union[float, List[float]]:
return self._lr


class _ExponentialLR(_LRScheduler):
class _ExponentialLR(_TORCH_LRSCHEDULER):
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
Arguments:
Expand All @@ -448,7 +448,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)

def get_lr(self) -> List[float]: # type: ignore[override]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand Down
19 changes: 6 additions & 13 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,7 @@
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable

try:
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
except ImportError:
# For torch <= 1.13.x
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler

from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau

_NUMBER = Union[int, float]
_METRIC = Union[Metric, Tensor, _NUMBER]
Expand Down Expand Up @@ -118,15 +111,15 @@ def no_sync(self) -> Generator:


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]


@dataclass
class LRSchedulerConfig:
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
# no custom name
name: Optional[str] = None
# after epoch is over
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def training_step(self, batch, batch_idx):
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
trainer.fit(model)

# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
# If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`,
# `.step()` is called once during its instantiation.
# Thus, the call count should be 1, not 0.
assert lr_step.call_count == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def configure_optimizers(self):
trainer.fit(model)

assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)]
# first step is called by PyTorch _LRScheduler
# first step is called by PyTorch LRScheduler
assert mock_method_step.call_count == max_epochs * limit_train_batches + 1


Expand Down

0 comments on commit fbb5f5f

Please sign in to comment.