Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix metric computation #559

Merged
merged 5 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))

- Fixed a bug where train and validation metrics weren't being correctly computed ([#559](https://github.com/PyTorchLightning/lightning-flash/pull/559))

## [0.4.0] - 2021-06-22

Expand Down
13 changes: 7 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def __init__(
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}

self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
Expand All @@ -157,7 +158,7 @@ def __init__(
self.deserializer = deserializer
self.serializer = serializer

def step(self, batch: Any, batch_idx: int) -> Any:
def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
Expand All @@ -168,7 +169,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(output["y_hat"])
for name, metric in self.metrics.items():
for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
Expand All @@ -195,16 +196,16 @@ def forward(self, x: Any) -> Any:
return self.model(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.train_metrics)
self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True)
return output["loss"]

def validation_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.val_metrics)
self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.val_metrics)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

@predict_context
Expand Down
4 changes: 2 additions & 2 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def to_metrics_format(self, x) -> torch.Tensor:
x = x.logits
return super().to_metrics_format(x)

def step(self, batch, batch_idx) -> dict:
def step(self, batch, batch_idx, metrics) -> dict:
target = batch.pop("labels")
batch = (batch, target)
return super().step(batch, batch_idx)
return super().step(batch, batch_idx, metrics)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
Expand Down
4 changes: 2 additions & 2 deletions flash/video/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def on_train_epoch_start(self) -> None:
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()

def step(self, batch: Any, batch_idx: int) -> Any:
return super().step((batch["video"], batch["label"]), batch_idx)
def step(self, batch: Any, batch_idx: int, metrics) -> Any:
return super().step((batch["video"], batch["label"]), batch_idx, metrics)

def forward(self, x: Any) -> Any:
x = self.backbone(x)
Expand Down
46 changes: 46 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
Expand All @@ -20,6 +21,7 @@
import pytest
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -68,6 +70,34 @@ class DummyPostprocess(Postprocess):
pass


class FixedDataset(torch.utils.data.Dataset):

def __init__(self, targets):
super().__init__()

self.targets = targets

def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
return torch.rand(1), self.targets[index]

def __len__(self) -> int:
return len(self.targets)


class OnesModel(nn.Module):

def __init__(self):
super().__init__()

self.layer = nn.Linear(1, 2)
self.register_buffer('zeros', torch.zeros(2))
self.register_buffer('zero_one', torch.tensor([0.0, 1.0]))

def forward(self, x):
x = self.layer(x)
return x * self.zeros + self.zero_one


# ================================


Expand Down Expand Up @@ -249,3 +279,19 @@ def test_optimization(tmpdir):
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = get_linear_schedule_with_warmup.__name__
assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected


def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
val_dataset = FixedDataset([1, 1])

model = OnesModel()

class CheckAccuracy(Callback):

def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
assert math.isclose(trainer.callback_metrics['train_accuracy_epoch'], 0.5)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))