Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

PoC: Revamp optimizer and scheduler experience using registries #777

Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2f46b93
Change optimizer Callables alone and scheduler to support Callables a…
karthikrangasai Sep 15, 2021
caefe68
Add Optimizer Registry and Update __init__ for all tasks.
karthikrangasai Sep 15, 2021
93bc1b5
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 20, 2021
7ea53a2
Revamp scheduler parameter to use str, Callable, str with params.
karthikrangasai Sep 22, 2021
e95a209
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 26, 2021
4cf6cdd
Updated _instantiate_scheduler method to handle providers. Added supp…
karthikrangasai Sep 26, 2021
440aef2
wip
tchaton Sep 27, 2021
094b690
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
06e7722
Updated scheduler parameter to take input as type Tuple[str, Dict[str…
karthikrangasai Sep 29, 2021
8ab54bd
Update naming of scheduler parameter to lr_scheduler.
karthikrangasai Sep 29, 2021
617e53a
Update optimizer and lr_scheduler parameter across all tasks.
karthikrangasai Sep 29, 2021
dd5615e
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Sep 29, 2021
7a3029b
Updated optimizer registration code to compare with optimizer types a…
karthikrangasai Sep 29, 2021
d36c451
Added tests for Errors and Exceptions.
karthikrangasai Sep 29, 2021
061454b
Update README with examples on using the API.
karthikrangasai Sep 30, 2021
c611aa8
Update skipif condition only to check for transformers library instea…
karthikrangasai Sep 30, 2021
64cedf3
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 1, 2021
e158802
Update newly added Face Detection Task.
karthikrangasai Oct 1, 2021
c8cb598
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 4, 2021
fcb3916
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 7, 2021
eda81ae
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
20eacaf
Changes from code review, Add new input method to lr_scheduler parame…
karthikrangasai Oct 13, 2021
87cf563
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 13, 2021
ddb5d1f
Fix pre-commit ci review.
karthikrangasai Oct 13, 2021
eb3aaec
Add documentation for using the modified API and update CHANGELOG.
karthikrangasai Oct 14, 2021
50c936a
Update docstrings for all tasks.
karthikrangasai Oct 14, 2021
5dfbeae
Fix mistake in my CHANGELOG update.
karthikrangasai Oct 14, 2021
93dbe67
Removed optimizer old that was commented code.
karthikrangasai Oct 14, 2021
42e3bf4
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 14, 2021
5e76ea3
Fix dependency version for failing tests on text type data, module - …
karthikrangasai Oct 14, 2021
ec348bf
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 15, 2021
c49d70a
Changes from review - Fix docs, Add test, Clean up certian parts of t…
karthikrangasai Oct 15, 2021
66c30bc
Merge branch 'master' into refactor/revamp_optimizer_and_scheduler
karthikrangasai Oct 18, 2021
35f3834
Remove debug print statements.
karthikrangasai Oct 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,53 @@ In detail, the following methods are currently implemented:
* **[metaoptnet](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_metaoptnet.py)** : from Lee *et al.* 2019, [Meta-Learning with Differentiable Convex Optimization](https://arxiv.org/abs/1904.03758)
* **[anil](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/lightning/lightning_anil.py)** : from Raghu *et al.* 2020, [Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML](https://arxiv.org/abs/1909.09157)


### Flash Optimizers / Schedulers

With Flash, swapping among 40+ optimizers and 15 + schedulers recipes are simple. Find the list of available optimizers, schedulers as follows:

```py
ImageClassifier.available_optimizers()
# ['A2GradExp', ..., 'Yogi']

ImageClassifier.available_schedulers()
# ['CosineAnnealingLR', 'CosineAnnealingWarmRestarts', ..., 'polynomial_decay_schedule_with_warmup']
```

Once you've chosen, create the model:

```py
#### The optimizer of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.AdaDelta, eps=0.5), lr_scheduler=None)
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

# - Tuple[string, dict]: (The dict takes in the optimizer kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("AdaDelta", {"epa": 0.5}), lr_scheduler=None)

#### The scheduler of choice can be passed as a
# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))

# - Tuple[string, dict]: (The dict takes in the scheduler kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10]))
```

You can also register you own custom scheduler recipes beforeahand and use them shown as above:

```py
@ImageClassifier.lr_schedulers
def my_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")
```

### Flash Transforms


Expand Down
17 changes: 5 additions & 12 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import os
import warnings
from typing import Any, Dict, Mapping, Optional, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler

from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES
from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding
Expand All @@ -40,9 +39,7 @@ class SpeechRecognition(Task):
backbone: Any speech recognition model from `HuggingFace/transformers
<https://huggingface.co/models?pipeline_tag=automatic-speech-recognition>`_.
optimizer: Optimizer to use for training.
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
lr_scheduler: The scheduler or scheduler class to use.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the learning rate which is obselete now.

serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
"""
Expand All @@ -54,10 +51,8 @@ class SpeechRecognition(Task):
def __init__(
self,
backbone: str = "facebook/wav2vec2-base-960h",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam",
lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None,
learning_rate: float = 1e-5,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
Expand All @@ -71,9 +66,7 @@ def __init__(
super().__init__(
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr_scheduler=lr_scheduler,
learning_rate=learning_rate,
serializer=serializer,
)
Expand Down
172 changes: 137 additions & 35 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum
Expand All @@ -46,9 +47,10 @@
SerializerMapping,
)
from flash.core.data.properties import ProcessState
from flash.core.optimizers import _OPTIMIZERS_REGISTRY, _SCHEDULERS_REGISTRY
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.serve import Composition
from flash.core.utilities import providers
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires

Expand Down Expand Up @@ -297,26 +299,26 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=Check
model: Model to use for the task.
loss_fn: Loss function for training
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`.
lr_scheduler: The scheduler or scheduler class to use.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to ``5e-5``.
preprocess: :class:`~flash.core.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.core.data.process.Postprocess` to use as the default for this task.
"""

schedulers: FlashRegistry = _SCHEDULERS_REGISTRY
optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY
lr_schedulers: FlashRegistry = _SCHEDULERS_REGISTRY

required_extras: Optional[Union[str, List[str]]] = None

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
optimizer: Union[str, Callable, Tuple[str, Dict[str, Any]]] = "Adam",
lr_scheduler: Optional[Union[str, Callable, Tuple[str, Dict[str, Any]]]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
deserializer: Optional[Union[Deserializer, Mapping[str, Deserializer]]] = None,
preprocess: Optional[Preprocess] = None,
postprocess: Optional[Postprocess] = None,
Expand All @@ -327,9 +329,7 @@ def __init__(
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer = optimizer
self.scheduler = scheduler
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}
self.lr_scheduler = lr_scheduler

self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
Expand Down Expand Up @@ -474,12 +474,38 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return self(batch)

def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]:
optimizer = self.optimizer
if not isinstance(self.optimizer, Optimizer):
self.optimizer_kwargs["lr"] = self.learning_rate
optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_kwargs)
if self.scheduler:
return [optimizer], [self._instantiate_scheduler(optimizer)]
if isinstance(self.optimizer, str):
if self.optimizer.lower() not in self.available_optimizers():
raise KeyError(
f"""Please provide a valid optimizer name and make sure it is registerd with the Optimizer registry.
Use `{self.__class__.__name__}.available_optimizers`."""
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
)
optimizer_fn = self.optimizers.get(self.optimizer.lower())
_optimizers_kwargs: Dict[str, Any] = {}
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(self.optimizer, Callable):
optimizer_fn = self.optimizer
_optimizers_kwargs: Dict[str, Any] = {}
elif isinstance(self.optimizer, Tuple):
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
optimizer_fn: Callable = None
optimizer_key: str = self.optimizer[0]

if not isinstance(optimizer_key, str):
raise TypeError(
f"PThe first value in scheduler argument tuple should be a string but got {type(optimizer_key)}."
)

optimizer_fn = self.optimizers.get(optimizer_key.lower())
_optimizers_kwargs: Dict[str, Any] = self.optimizer[1]
else:
raise TypeError(
f"""Optimizer should be of type string or callable or tuple(string, dictionary)
but got {type(self.optimizer)}."""
)

model_parameters = filter(lambda p: p.requires_grad, self.parameters())
optimizer: Optimizer = optimizer_fn(model_parameters, lr=self.learning_rate, **_optimizers_kwargs)
if self.lr_scheduler is not None:
return [optimizer], [self._instantiate_lr_scheduler(optimizer)]
return optimizer

@staticmethod
Expand Down Expand Up @@ -775,8 +801,15 @@ def get_backbone_details(cls, key) -> List[str]:
return list(inspect.signature(registry.get(key)).parameters.items())

@classmethod
def available_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "schedulers", None)
def available_optimizers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "optimizers", None)
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_lr_schedulers(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "lr_schedulers", None)
if registry is None:
return []
return registry.available_keys()
Expand Down Expand Up @@ -816,24 +849,93 @@ def _compute_warmup(num_training_steps: int, num_warmup_steps: Union[int, float]
num_warmup_steps *= num_training_steps
return round(num_warmup_steps)

def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler:
scheduler = self.scheduler
if isinstance(scheduler, _LRScheduler):
return scheduler
if isinstance(scheduler, str):
scheduler_fn = self.schedulers.get(self.scheduler)
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=self.scheduler_kwargs.get("num_warmup_steps"),
def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
if isinstance(self.lr_scheduler, str) or isinstance(self.lr_scheduler, Callable):
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self.lr_scheduler, str) and self.lr_scheduler.lower() not in self.available_lr_schedulers():
raise KeyError(
f"""Please provide a valid key and make sure it is registerd with the Scheduler registry.
Use `{self.__class__.__name__}.available_schedulers`."""
)

# Get values based in type.
if isinstance(self.lr_scheduler, str):
_lr_scheduler = self.lr_schedulers.get(self.lr_scheduler.lower(), with_metadata=True)
lr_scheduler_fn: Callable = _lr_scheduler["fn"]
lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"]
else:
lr_scheduler_fn: Callable = self.lr_scheduler

# Generate the output: could be a lr_scheduler object or a lr_scheduler config.
sched_output: Union[_LRScheduler, Dict[str, Any]] = lr_scheduler_fn(optimizer)

# Create and/or update a lr_scheduler configuration
lr_scheduler_config = _get_default_scheduler_config()
if isinstance(sched_output, _LRScheduler):
lr_scheduler_config["scheduler"] = sched_output
if isinstance(self.lr_scheduler, str) and "interval" in lr_scheduler_metadata.keys():
lr_scheduler_config["interval"] = lr_scheduler_metadata["interval"]
elif isinstance(sched_output, dict):
for key, value in sched_output.items():
lr_scheduler_config[key] = value
else:
if isinstance(self.lr_scheduler, str):
message = "register a custom callable"
else:
message = "provide a callable"
raise MisconfigurationException(
f"Please {message} that outputs either an LR Scheduler or a scheduler condifguration."
)

return lr_scheduler_config

if not isinstance(self.lr_scheduler, Tuple):
raise TypeError("The scheduler arguments should be provided as a tuple.")

if not isinstance(self.lr_scheduler[0], str):
raise TypeError(
f"""The first value in scheduler argument tuple should be a string but got
{type(self.lr_scheduler[0])}."""
)
return scheduler_fn(optimizer, num_warmup_steps, num_training_steps)
if issubclass(scheduler, _LRScheduler):
return scheduler(optimizer, **self.scheduler_kwargs)
raise MisconfigurationException(
"scheduler can be a scheduler, a scheduler type with `scheduler_kwargs` "
f"or a built-in scheduler in {self.available_schedulers()}"
)

# Separate the key and the kwargs.
lr_scheduler_key: str = self.lr_scheduler[0]
lr_scheduler_kwargs_and_config: Dict[str, Any] = self.lr_scheduler[1]
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

# Get the default scheduler config.
lr_scheduler_config: Dict[str, Any] = _get_default_scheduler_config()
lr_scheduler_config["interval"] = None

# Update scheduler config from the kwargs and pop the keys from the kwargs at the same time.
for config_key, config_value in lr_scheduler_config.items():
lr_scheduler_config[config_key] = lr_scheduler_kwargs_and_config.pop(config_key, None) or config_value

# Create a new copy of the kwargs.
lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs_and_config)
assert all(config_key not in lr_scheduler_kwargs.keys() for config_key in lr_scheduler_config.keys())

# Retreive the scheduler callable with metadata from the registry.
_lr_scheduler = self.lr_schedulers.get(lr_scheduler_key.lower(), with_metadata=True)
lr_scheduler_fn: Callable = _lr_scheduler["fn"]
lr_scheduler_metadata: Dict[str, Any] = _lr_scheduler["metadata"]

# Make necessary adjustment to the kwargs based on the provider of the scheduler.
if "providers" in lr_scheduler_metadata.keys():
if lr_scheduler_metadata["providers"] == providers._HUGGINGFACE:
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
num_training_steps=num_training_steps,
num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"],
)
lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps
lr_scheduler_kwargs["num_training_steps"] = num_training_steps

# Set the scheduler in the config.
lr_scheduler_config["scheduler"] = lr_scheduler_fn(optimizer, **lr_scheduler_kwargs)

# Update the interval in sched config just in case it has NoneType.
if "interval" in lr_scheduler_metadata.keys():
lr_scheduler_config["interval"] = lr_scheduler_config["interval"] or lr_scheduler_metadata["interval"]
return lr_scheduler_config

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
Expand Down
2 changes: 2 additions & 0 deletions flash/core/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from flash.core.optimizers.lamb import LAMB # noqa: F401
from flash.core.optimizers.lars import LARS # noqa: F401
from flash.core.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY # noqa: F401
from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY # noqa: F401
35 changes: 35 additions & 0 deletions flash/core/optimizers/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from inspect import isclass
from typing import Callable, List

from torch import optim

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE

_OPTIMIZERS_REGISTRY = FlashRegistry("optimizer")

_optimizers: List[Callable] = []
for n in dir(optim):
_optimizer = getattr(optim, n)

if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer):
_optimizers.append(_optimizer)

for fn in _optimizers:
_OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower())


if _TORCH_OPTIMIZER_AVAILABLE:
import torch_optimizer

_torch_optimizers: List[Callable] = []
for n in dir(torch_optimizer):
_optimizer = getattr(torch_optimizer, n)

if isclass(_optimizer) and issubclass(_optimizer, optim.Optimizer):
_torch_optimizers.append(_optimizer)

for fn in _torch_optimizers:
name = fn.__name__.lower()
if name not in _OPTIMIZERS_REGISTRY:
_OPTIMIZERS_REGISTRY(fn, name=name)
Loading