Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

expose scheduler to Task #189

Closed
wants to merge 12 commits into from
33 changes: 21 additions & 12 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import functools
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union, Tuple

import torch
import torchmetrics
Expand All @@ -27,6 +27,8 @@
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

SchedulerType = Union[Tuple[torch.optim.lr_scheduler._LRScheduler], torch.optim.lr_scheduler._LRScheduler]


def predict_context(func: Callable) -> Callable:
"""
Expand Down Expand Up @@ -63,18 +65,20 @@ class Task(LightningModule):
"""

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
):
super().__init__()
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.scheduler = scheduler
self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
Expand Down Expand Up @@ -126,9 +130,9 @@ def test_step(self, batch: Any, batch_idx: int) -> None:

@predict_context
def predict(
self,
x: Any,
data_pipeline: Optional[DataPipeline] = None,
self,
x: Any,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Predict function for raw data or processed data
Expand Down Expand Up @@ -159,8 +163,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
batch = torch.stack(batch)
return self(batch)

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_optimizers(
self
) -> Union[Tuple[Tuple[torch.optim.Optimizer], Any], Tuple[torch.optim.Optimizer]]:
optimizers = self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
if self.scheduler:
return optimizers, self.scheduler(optimizer=optimizers)
return optimizers

def configure_finetune_callback(self) -> List[Callback]:
return []
Expand Down