Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

expose scheduler to Task #189

Closed
wants to merge 12 commits into from
39 changes: 24 additions & 15 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import functools
import os
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union, Tuple

import pytorch_lightning as pl
import torch
Expand All @@ -22,6 +22,8 @@
from flash.core.data import DataModule, DataPipeline
from flash.core.utils import get_callable_dict

SchedulerType = Union[Tuple[torch.optim.lr_scheduler._LRScheduler], torch.optim.lr_scheduler._LRScheduler]


def predict_context(func: Callable) -> Callable:
"""
Expand Down Expand Up @@ -55,18 +57,20 @@ class Task(pl.LightningModule):
"""

def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
schedulers: Optional[SchedulerType] = None,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
):
super().__init__()
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.schedulers = schedulers
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
Expand Down Expand Up @@ -115,12 +119,12 @@ def test_step(self, batch: Any, batch_idx: int) -> None:

@predict_context
def predict(
self,
x: Any,
batch_idx: Optional[int] = None,
skip_collate_fn: bool = False,
dataloader_idx: Optional[int] = None,
data_pipeline: Optional[DataPipeline] = None,
self,
x: Any,
batch_idx: Optional[int] = None,
skip_collate_fn: bool = False,
dataloader_idx: Optional[int] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Predict function for raw data or processed data
Expand Down Expand Up @@ -156,8 +160,13 @@ def predict(
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_optimizers(
self
) -> Union[Tuple[Tuple[torch.optim.Optimizer], SchedulerType], Tuple[torch.optim.Optimizer]]:
optimizers = self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate),
if self.schedulers:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
return optimizers, self.schedulers
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
return optimizers

@property
def data_pipeline(self) -> DataPipeline:
Expand Down