Skip to content

Commit

Permalink
LightningCLI support for optimizers and schedulers via dependency inj…
Browse files Browse the repository at this point in the history
…ection (#15869)

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2022
1 parent 38acba0 commit ed52823
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 46 deletions.
96 changes: 63 additions & 33 deletions docs/source-pytorch/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.


Optimizers
^^^^^^^^^^
Fixed optimizer and scheduler
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
Expand Down Expand Up @@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example
can be when someone wants to add support for multiple optimizers:
Multiple optimizers and schedulers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on
instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple
optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
`dependency injection <https://en.wikipedia.org/wiki/Dependency_injection>`__. Unlike the submodules, it is not possible
to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
instance. For these cases, dependency injection involves providing a function that instantiates the respective class
when called.

An example of a model that uses two optimizers is the following:

.. code-block:: python
from pytorch_lightning.cli import instantiate_class
from typing import Iterable
from torch.optim import Optimizer
OptimizerCallable = Callable[[Iterable], Optimizer]
class MyModel(LightningModule):
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
super().__init__()
self.optimizer1_init = optimizer1_init
self.optimizer2_init = optimizer2_init
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
def configure_optimizers(self):
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
optimizer1 = self.optimizer1(self.parameters())
optimizer2 = self.optimizer2(self.parameters())
return [optimizer1, optimizer2]
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some
learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
class and init arguments for each of the optimizers, as follows:

cli = MyLightningCLI(MyModel)
.. code-block:: bash
The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries.
The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in
``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the
``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``.
$ python trainer.py fit \
--model.optimizer1=Adam \
--model.optimizer1.lr=0.01 \
--model.optimizer2=AdamW \
--model.optimizer2.lr=0.0001
With shorthand notation:
In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For
convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model
that uses dependency injection for an optimizer and a learning scheduler is:

.. code-block:: bash
.. code-block:: python
$ python trainer.py fit \
--optimizer1=Adam \
--optimizer1.lr=0.01 \
--optimizer2=AdamW \
--optimizer2.lr=0.0001
from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
.. code-block:: bash
class MyModel(LightningModule):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optimizer = optimizer
self.scheduler = scheduler
$ python trainer.py fit \
--optimizer1=torch.optim.Adam \
--optimizer1.lr=0.01 \
--optimizer2=torch.optim.AdamW \
--optimizer2.lr=0.0001
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(self.parameters())
return {"optimizer": optimizer, "lr_scheduler": scheduler}
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
callables that receive the same first argument and return an instance of the class. Classes that have more than one
required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``.


Run from Python
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
matplotlib>3.1, <3.6.2
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.17.0, <4.18.0
jsonargparse[signatures]>=4.18.0, <4.19.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
4 changes: 3 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))


- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))


- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))

- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))

Expand Down
32 changes: 21 additions & 11 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -24,6 +24,7 @@

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -49,19 +50,22 @@
locals()["Namespace"] = object


ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
super().__init__(optimizer, *args, **kwargs)
self.monitor = monitor


# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]


# Type aliases intended for convenience of CLI developers
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
OptimizerCallable = Callable[[Iterable], Optimizer]
LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]]


class LightningArgumentParser(ArgumentParser):
Expand Down Expand Up @@ -274,6 +278,7 @@ def __init__(
subclass_mode_data: bool = False,
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
auto_registry: bool = False,
**kwargs: Any, # Remove with deprecations of v1.10
) -> None:
Expand Down Expand Up @@ -326,6 +331,7 @@ def __init__(
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463
self.auto_configure_optimizers = auto_configure_optimizers

self._handle_deprecated_params(kwargs)

Expand Down Expand Up @@ -447,10 +453,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
if self.auto_configure_optimizers:
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
self.link_optimizers_and_lr_schedulers(parser)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Expand Down Expand Up @@ -602,6 +609,9 @@ def configure_optimizers(
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
if not self.auto_configure_optimizers:
return

parser = self._parser(subcommand)

def get_automatic(
Expand Down
52 changes: 52 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
instantiate_class,
LightningArgumentParser,
LightningCLI,
LRSchedulerCallable,
LRSchedulerTypeTuple,
OptimizerCallable,
SaveConfigCallback,
)
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
Expand Down Expand Up @@ -706,6 +708,56 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
self,
optim1: OptimizerCallable = torch.optim.Adam,
optim2: OptimizerCallable = torch.optim.Adagrad,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optim1 = optim1
self.optim2 = optim2
self.scheduler = scheduler

def configure_optimizers(self):
optim1 = self.optim1(self.parameters())
optim2 = self.optim2(self.parameters())
scheduler = self.scheduler(optim2)
return (
{"optimizer": optim1},
{"optimizer": optim2, "lr_scheduler": scheduler},
)

out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
out = out.getvalue()
assert "--optimizer" not in out
assert "--lr_scheduler" not in out
assert "--model.optim1" in out
assert "--model.optim2" in out
assert "--model.scheduler" in out

cli_args = [
"--model.optim1=Adagrad",
"--model.optim2=SGD",
"--model.optim2.lr=0.007",
"--model.scheduler=ExponentialLR",
"--model.scheduler.gamma=0.3",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)

init = cli.model.configure_optimizers()
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
assert init[1]["lr_scheduler"].gamma == 0.3


@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
Expand Down

0 comments on commit ed52823

Please sign in to comment.