Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

PoC: Revamp optimizer and scheduler experience using registries #777

Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2f46b93
Change optimizer Callables alone and scheduler to support Callables a…
karthikrangasai Sep 15, 2021
caefe68
Add Optimizer Registry and Update __init__ for all tasks.
karthikrangasai Sep 15, 2021
93bc1b5
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 20, 2021
7ea53a2
Revamp scheduler parameter to use str, Callable, str with params.
karthikrangasai Sep 22, 2021
e95a209
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 26, 2021
4cf6cdd
Updated _instantiate_scheduler method to handle providers. Added supp…
karthikrangasai Sep 26, 2021
440aef2
wip
tchaton Sep 27, 2021
094b690
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
06e7722
Updated scheduler parameter to take input as type Tuple[str, Dict[str…
karthikrangasai Sep 29, 2021
8ab54bd
Update naming of scheduler parameter to lr_scheduler.
karthikrangasai Sep 29, 2021
617e53a
Update optimizer and lr_scheduler parameter across all tasks.
karthikrangasai Sep 29, 2021
dd5615e
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
7a3029b
Updated optimizer registration code to compare with optimizer types a…
karthikrangasai Sep 29, 2021
d36c451
Added tests for Errors and Exceptions.
karthikrangasai Sep 29, 2021
061454b
Update README with examples on using the API.
karthikrangasai Sep 30, 2021
c611aa8
Update skipif condition only to check for transformers library instea…
karthikrangasai Sep 30, 2021
64cedf3
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 1, 2021
e158802
Update newly added Face Detection Task.
karthikrangasai Oct 1, 2021
c8cb598
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 4, 2021
fcb3916
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 7, 2021
eda81ae
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
20eacaf
Changes from code review, Add new input method to lr_scheduler parame…
karthikrangasai Oct 13, 2021
87cf563
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
ddb5d1f
Fix pre-commit ci review.
karthikrangasai Oct 13, 2021
eb3aaec
Add documentation for using the modified API and update CHANGELOG.
karthikrangasai Oct 14, 2021
50c936a
Update docstrings for all tasks.
karthikrangasai Oct 14, 2021
5dfbeae
Fix mistake in my CHANGELOG update.
karthikrangasai Oct 14, 2021
93dbe67
Removed optimizer old that was commented code.
karthikrangasai Oct 14, 2021
42e3bf4
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 14, 2021
5e76ea3
Fix dependency version for failing tests on text type data, module - …
karthikrangasai Oct 14, 2021
ec348bf
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 15, 2021
c49d70a
Changes from review - Fix docs, Add test, Clean up certian parts of t…
karthikrangasai Oct 15, 2021
66c30bc
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 18, 2021
35f3834
Remove debug print statements.
karthikrangasai Oct 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))


- Changed `PreprocessTransform` to `InputTransform` ([#868](https://github.com/PyTorchLightning/lightning-flash/pull/868))

- Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). ([#777](https://github.com/PyTorchLightning/lightning-flash/pull/777))


### Fixed

Expand Down
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,53 @@ In detail, the following methods are currently implemented:
* **[metaoptnet](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_metaoptnet.py)** : from Lee *et al.* 2019, [Meta-Learning with Differentiable Convex Optimization](https://arxiv.org/abs/1904.03758)
* **[anil](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_anil.py)** : from Raghu *et al.* 2020, [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML](https://arxiv.org/abs/1909.09157)


### Flash Optimizers / Schedulers

With Flash, swapping among 40+ optimizers and 15 + schedulers recipes are simple. Find the list of available optimizers, schedulers as follows:

```py
ImageClassifier.available_optimizers()
# ['A2GradExp', ..., 'Yogi']

ImageClassifier.available_schedulers()
# ['CosineAnnealingLR', 'CosineAnnealingWarmRestarts', ..., 'polynomial_decay_schedule_with_warmup']
```

Once you've chosen, create the model:

```py
#### The optimizer of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), lr_scheduler=None)

# - Tuple[string, dict]: (The dict takes in the optimizer kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("Adadelta", {"epa": 0.5}), lr_scheduler=None)

#### The scheduler of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))

# - Tuple[string, dict]: (The dict takes in the scheduler kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10]))
```

You can also register you own custom scheduler recipes beforeahand and use them shown as above:

```py
@ImageClassifier.lr_schedulers
def my_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")
```

### Flash Transforms


Expand Down
197 changes: 197 additions & 0 deletions docs/source/general/optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@

.. _optimization:

########################################
Optimization (Optimizers and Schedulers)
########################################

Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever.

With the use of :ref:`registry`, instantiation of an optimzer or a learning rate scheduler can done with just a string.

Setting an optimizer to a task
==============================

Each task has a built-in method :func:`~flash.core.model.Task.available_optimizers` which will list all the optimizers
registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()
['adadelta', ..., 'sgd']

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
An alternative to customizing an optimizer using a tuple is to pass it as a callable.

.. code-block:: python
from functools import partial
from torch.optim import Adam
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
Setting a Learning Rate Scheduler
=================================

Each task has a built-in method :func:`~flash.core.model.Task.available_lr_schedulers` which will list all the learning
rate schedulers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()
['lambdalr', ..., 'cosineannealingwarmrestarts']

To train / finetune a :class:`~flash.core.model.Task` of your choice, just pass on a string.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
.. note:: ``"constant_schedule"`` and a few other lr schedulers will be available only if you have installed the ``transformers`` library from Hugging Face.


In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("StepLR", {"step_size": 10}),
)
An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.

.. code-block:: python
from functools import partial
from torch.optim.lr_scheduler import CyclicLR
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5),
)
Additionally, the ``lr_scheduler`` parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.

The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

.. code-block:: python
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the ``torch.optim.lr_scheduler.ReduceLROnPlateau`` scheduler,
Flash requires that the Lightning Scheduler configuration contains the keyword ``"monitor"`` set to the metric name that the scheduler should be conditioned on.
Below is an example for this:

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
)
.. note:: Do not set the ``"scheduler"`` key in the Lightning Scheduler configuration, it will overriden with an instance of the provided scheduler key.


Pre-Registering optimizers and scheduler recipes
================================================

Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry.

Using the ``optimizers`` and ``lr_schedulers`` decorator pertaining to each :class:`~flash.core.model.Task`, custom optimizer and LR scheduler recipes can be pre-registered.

.. code-block:: python
import torch
from flash.image import ImageClassifier
@ImageClassifier.lr_schedulers
def my_flash_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
Provider specific requirements
==============================

Schedulers
**********

Certain LR Schedulers provided by Hugging Face require both ``num_training_steps`` and ``num_warmup_steps``.

In order to use them in Flash, just provide ``num_warmup_steps`` as float between 0 and 1 which indicates the fraction of the training steps
that will be used as warmup steps. Flash's :class:`~flash.core.trainer.Trainer` will take care of computing the number of training steps and
number of warmup steps based on the flags that are set in the Trainer.

.. code-block:: python
from flash.image import ImageClassifier
model = ImageClassifier(
backbone="resnet18",
num_classes=2,
optimizer="Adam",
lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
)
5 changes: 3 additions & 2 deletions docs/source/general/registry.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@

.. _registry:

########
Registry
########

.. _registry:

********************
Available Registries
********************
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Lightning Flash
general/registry
general/serve
general/backbones
general/optimization

.. toctree::
:maxdepth: 1
Expand Down
23 changes: 8 additions & 15 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,19 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict, Mapping, Optional, Type, Union
from typing import Any, Dict

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState
from flash.core.data.process import Serializer
from flash.core.data.states import CollateFn
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE

if _AUDIO_AVAILABLE:
from transformers import Wav2Vec2Processor
Expand All @@ -39,11 +38,9 @@ class SpeechRecognition(Task):
Args:
backbone: Any speech recognition model from `HuggingFace/transformers
<https://huggingface.co/models?pipeline_tag=automatic-speech-recognition>`_.
learning_rate: Learning rate to use for training, defaults to ``1e-5``.
optimizer: Optimizer to use for training.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
lr_scheduler: The LR scheduler to use during training.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
"""

Expand All @@ -54,12 +51,10 @@ class SpeechRecognition(Task):
def __init__(
self,
backbone: str = "facebook/wav2vec2-base-960h",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 1e-5,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
serializer: SERIALIZER_TYPE = None,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
Expand All @@ -71,9 +66,7 @@ def __init__(
super().__init__(
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
serializer=serializer,
)
Expand Down
Loading