Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler #15768

Merged
merged 15 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-pytorch/cli/lightning_cli_intermediate_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t

python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch

Furthermore, any custom subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:

.. code:: python

Expand Down
5 changes: 3 additions & 2 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
Expand Down Expand Up @@ -418,7 +418,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
def _initialize_engine(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

Expand Down
7 changes: 4 additions & 3 deletions src/lightning_lite/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable

from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14

_PATH = Union[str, Path]
_DEVICE = Union[torch.device, str, int]
Expand Down Expand Up @@ -60,8 +60,6 @@ def rank(self) -> int:
...


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
@runtime_checkable
class _LRScheduler(_Stateful[str], Protocol):
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
optimizer: Optimizer
Expand All @@ -74,6 +72,9 @@ def step(self, epoch: Optional[int] = None) -> None:
...


_TORCH_LRSCHEDULER = torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else _LRScheduler


# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from lightning_lite.utilities.types import _LRScheduler
from lightning_lite.utilities.types import LRScheduler
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: Optional["pl.LightningModule"] = None
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._swa_scheduler: Optional[LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
Expand Down Expand Up @@ -192,7 +192,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo

assert trainer.max_epochs is not None
self._swa_scheduler = cast(
_LRScheduler,
LRScheduler,
SWALR(
optimizer,
swa_lr=self._swa_lrs, # type: ignore[arg-type]
Expand Down
9 changes: 5 additions & 4 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import LRScheduler
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -59,9 +60,9 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any


# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]]
qmaruf marked this conversation as resolved.
Show resolved Hide resolved


class LightningArgumentParser(ArgumentParser):
Expand Down Expand Up @@ -162,7 +163,7 @@ def add_lr_scheduler_args(
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.

Args:
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{LRScheduler, ReduceLROnPlateau}``. Use
tuple to allow subclasses.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def _configure_optimizers(
" Output from `model.configure_optimizers()` should be one of:\n"
" * `Optimizer`\n"
" * [`Optimizer`]\n"
" * ([`Optimizer`], [`_LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
" * ([`Optimizer`], [`LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, LRScheduler, STEP_OUTPUT


class RandomDictDataset(Dataset):
Expand Down Expand Up @@ -137,7 +136,7 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
Expand Down Expand Up @@ -426,7 +426,7 @@ def _setup_model_and_optimizer(
self,
model: Module,
optimizer: Optional[Optimizer],
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler.

Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/hivemind.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytorch_lightning as pl
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -312,7 +312,7 @@ class HiveMindScheduler:

base_lrs: List[float]

def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}
Expand Down
9 changes: 4 additions & 5 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
Expand Down Expand Up @@ -124,7 +123,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
scheduler = cast(LRScheduler, scheduler)

trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
Expand Down Expand Up @@ -401,7 +400,7 @@ def on_train_batch_end(
self.losses.append(smoothed_loss)


class _LinearLR(_LRScheduler):
class _LinearLR(LRScheduler):
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
"""Linearly increases the learning rate between two boundaries over a number of iterations.

Args:
Expand Down Expand Up @@ -436,7 +435,7 @@ def lr(self) -> Union[float, List[float]]:
return self._lr


class _ExponentialLR(_LRScheduler):
class _ExponentialLR(LRScheduler):
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
"""Exponentially increases the learning rate between two boundaries over a number of iterations.

Arguments:
Expand Down
9 changes: 4 additions & 5 deletions src/pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytorch_lightning as pl
import pytorch_lightning.cli as new_cli
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.types import LRScheduler

_deprecate_registry_message = (
"`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
Expand Down Expand Up @@ -110,7 +111,7 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
# this will register any subclasses from all loaded modules including userland
for cls in get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
for cls in get_all_subclasses(LRScheduler):
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls, show_deprecation=False)
Expand All @@ -123,12 +124,10 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
else:
# manually register torch's subclasses and our subclasses
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False)
LR_SCHEDULER_REGISTRY.register_classes(
torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False
)
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, LRScheduler, show_deprecation=False)
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
# `ReduceLROnPlateau` does not subclass `LRScheduler`
LR_SCHEDULER_REGISTRY(cls=new_cli.ReduceLROnPlateau, show_deprecation=False)


Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable

from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
from lightning_lite.utilities.types import LRScheduler, ProcessGroup, ReduceLROnPlateau

_NUMBER = Union[int, float]
_METRIC = Union[Metric, Tensor, _NUMBER]
Expand Down Expand Up @@ -111,15 +111,15 @@ def no_sync(self) -> Generator:


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerTypeTuple = (LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
qmaruf marked this conversation as resolved.
Show resolved Hide resolved
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]


@dataclass
class LRSchedulerConfig:
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
scheduler: Union[LRScheduler, ReduceLROnPlateau]
# no custom name
name: Optional[str] = None
# after epoch is over
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def training_step(self, batch, batch_idx):
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
trainer.fit(model)

# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
# If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`,
# `.step()` is called once during its instantiation.
# Thus, the call count should be 1, not 0.
assert lr_step.call_count == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def configure_optimizers(self):
trainer.fit(model)

assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)]
# first step is called by PyTorch _LRScheduler
# first step is called by PyTorch LRScheduler
assert mock_method_step.call_count == max_epochs * limit_train_batches + 1


Expand Down