diff --git a/flash/core/model.py b/flash/core/model.py index 1020663340..233c9fa5fb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -14,7 +14,7 @@ import functools import inspect from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union, Tuple import torch import torchmetrics @@ -27,6 +27,8 @@ from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess +SchedulerType = Union[Tuple[torch.optim.lr_scheduler._LRScheduler], torch.optim.lr_scheduler._LRScheduler] + def predict_context(func: Callable) -> Callable: """ @@ -63,18 +65,20 @@ class Task(LightningModule): """ def __init__( - self, - model: Optional[nn.Module] = None, - loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 5e-5, + self, + model: Optional[nn.Module] = None, + loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 5e-5, ): super().__init__() if model is not None: self.model = model self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn) self.optimizer_cls = optimizer + self.scheduler = scheduler self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics @@ -126,9 +130,9 @@ def test_step(self, batch: Any, batch_idx: int) -> None: @predict_context def predict( - self, - x: Any, - data_pipeline: Optional[DataPipeline] = None, + self, + x: Any, + data_pipeline: Optional[DataPipeline] = None, ) -> Any: """ Predict function for raw data or processed data @@ -159,8 +163,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A batch = torch.stack(batch) return self(batch) - def configure_optimizers(self) -> torch.optim.Optimizer: - return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + def configure_optimizers( + self + ) -> Union[Tuple[Tuple[torch.optim.Optimizer], Any], Tuple[torch.optim.Optimizer]]: + optimizers = self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) + if self.scheduler: + return optimizers, self.scheduler(optimizer=optimizers) + return optimizers def configure_finetune_callback(self) -> List[Callback]: return []