Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
PoC: Revamp optimizer and scheduler experience using registries (#777)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
  • Loading branch information
karthikrangasai and tchaton authored Oct 18, 2021
1 parent a94ed6c commit b41722a
Show file tree
Hide file tree
Showing 35 changed files with 890 additions and 442 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))


- Changed `PreprocessTransform` to `InputTransform` ([#868](https://github.com/PyTorchLightning/lightning-flash/pull/868))

- Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). ([#777](https://github.com/PyTorchLightning/lightning-flash/pull/777))


### Fixed

Expand Down
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,53 @@ In detail, the following methods are currently implemented:
* **[metaoptnet](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_metaoptnet.py)** : from Lee *et al.* 2019, [Meta-Learning with Differentiable Convex Optimization](https://arxiv.org/abs/1904.03758)
* **[anil](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_anil.py)** : from Raghu *et al.* 2020, [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML](https://arxiv.org/abs/1909.09157)


### Flash Optimizers / Schedulers

With Flash, swapping among 40+ optimizers and 15 + schedulers recipes are simple. Find the list of available optimizers, schedulers as follows:

```py
ImageClassifier.available_optimizers()
# ['A2GradExp', ..., 'Yogi']

ImageClassifier.available_schedulers()
# ['CosineAnnealingLR', 'CosineAnnealingWarmRestarts', ..., 'polynomial_decay_schedule_with_warmup']
```

Once you've chosen, create the model:

```py
#### The optimizer of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), lr_scheduler=None)

# - Tuple[string, dict]: (The dict takes in the optimizer kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("Adadelta", {"epa": 0.5}), lr_scheduler=None)

#### The scheduler of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))

# - Tuple[string, dict]: (The dict takes in the scheduler kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10]))
```

You can also register you own custom scheduler recipes beforeahand and use them shown as above:

```py
@ImageClassifier.lr_schedulers
def my_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")
```

### Flash Transforms


Expand Down
197 changes: 197 additions & 0 deletions docs/source/general/optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@

.. _optimization:

########################################
Optimization (Optimizers and Schedulers)
########################################

Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever.

With the use of :ref:`registry`, instantiation of an optimzer or a learning rate scheduler can done with just a string.

Setting an optimizer to a task
==============================

Each task has a built-in method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers
registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()
['adadelta', ..., 'sgd']

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
An alternative to customizing an optimizer using a tuple is to pass it as a callable.

.. code-block:: python
from functools import partial
from torch.optim import Adam
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
Setting a Learning Rate Scheduler
=================================

Each task has a built-in method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning
rate schedulers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()
['lambdalr', ..., 'cosineannealingwarmrestarts']

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
)
.. note:: ``"constant_schedule"`` and a few other lr schedulers will be available only if you have installed the ``transformers`` library from Hugging Face.


In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("StepLR", {"step_size": 10}),
)
An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.

.. code-block:: python
from functools import partial
from torch.optim.lr_scheduler import CyclicLR
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5),
)
Additionally, the ``lr_scheduler`` parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.

The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

.. code-block:: python
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the ``torch.optim.lr_scheduler.ReduceLROnPlateau`` scheduler,
Flash requires that the Lightning Scheduler configuration contains the keyword ``"monitor"`` set to the metric name that the scheduler should be conditioned on.
Below is an example for this:

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
)
.. note:: Do not set the ``"scheduler"`` key in the Lightning Scheduler configuration, it will overriden with an instance of the provided scheduler key.


Pre-Registering optimizers and scheduler recipes
================================================

Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry.

Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :class:`~flash.core.model.Task`, custom optimizer and LR scheduler recipes can be pre-registered.

.. code-block:: python
import torch
from flash.image import ImageClassifier
@ImageClassifier.lr_schedulers
def my_flash_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
Provider specific requirements
==============================

Schedulers
**********

Certain LR Schedulers provided by Hugging Face require both ``num_training_steps`` and ``num_warmup_steps``.

In order to use them in Flash, just provide ``num_warmup_steps`` as float between 0 and 1 which indicates the fraction of the training steps
that will be used as warmup steps. Flash's :class:`~flash.core.trainer.Trainer` will take care of computing the number of training steps and
number of warmup steps based on the flags that are set in the Trainer.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
backbone="resnet18",
num_classes=2,
optimizer="Adam",
lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
)
5 changes: 3 additions & 2 deletions docs/source/general/registry.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@

.. _registry:

########
Registry
########

.. _registry:

********************
Available Registries
********************
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Lightning Flash
general/registry
general/serve
general/backbones
general/optimization

.. toctree::
:maxdepth: 1
Expand Down
23 changes: 8 additions & 15 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,19 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict, Mapping, Optional, Type, Union
from typing import Any, Dict

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState
from flash.core.data.process import Serializer
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
Expand All @@ -39,11 +38,9 @@ class SpeechRecognition(Task):
Args:
backbone: Any speech recognition model from `HuggingFace/transformers
<https://huggingface.co/models?pipeline_tag=automatic-speech-recognition>`_.
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
lr_scheduler: The LR scheduler to use during training.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
"""

Expand All @@ -54,12 +51,10 @@ class SpeechRecognition(Task):
def __init__(
self,
backbone: str = "facebook/wav2vec2-base-960h",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 1e-5,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
serializer: SERIALIZER_TYPE = None,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -71,9 +66,7 @@ def __init__(
super().__init__(
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
serializer=serializer,
)
Expand Down
Loading

0 comments on commit b41722a

Please sign in to comment.