Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix metric computation #559

Merged
merged 5 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493))

- Fixed a bug where train and validation metrics weren't being correctly computed ([#559](https://github.com/PyTorchLightning/lightning-flash/pull/559))

## [0.4.0] - 2021-06-22

Expand Down
13 changes: 7 additions & 6 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def __init__(
self.optimizer_kwargs = optimizer_kwargs or {}
self.scheduler_kwargs = scheduler_kwargs or {}

self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
Expand All @@ -157,7 +158,7 @@ def __init__(
self.deserializer = deserializer
self.serializer = serializer

def step(self, batch: Any, batch_idx: int) -> Any:
def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
Expand All @@ -168,7 +169,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(output["y_hat"])
for name, metric in self.metrics.items():
for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
Expand All @@ -195,16 +196,16 @@ def forward(self, x: Any) -> Any:
return self.model(x)

def training_step(self, batch: Any, batch_idx: int) -> Any:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.train_metrics)
self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True)
return output["loss"]

def validation_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.val_metrics)
self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
output = self.step(batch, batch_idx, self.val_metrics)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

@predict_context
Expand Down
4 changes: 2 additions & 2 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def to_metrics_format(self, x) -> torch.Tensor:
x = x.logits
return super().to_metrics_format(x)

def step(self, batch, batch_idx) -> dict:
def step(self, batch, batch_idx, metrics) -> dict:
target = batch.pop("labels")
batch = (batch, target)
return super().step(batch, batch_idx)
return super().step(batch, batch_idx, metrics)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
Expand Down
46 changes: 46 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
Expand All @@ -20,6 +21,7 @@
import pytest
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
Expand Down Expand Up @@ -68,6 +70,34 @@ class DummyPostprocess(Postprocess):
pass


class FixedDataset(torch.utils.data.Dataset):

def __init__(self, targets):
super().__init__()

self.targets = targets

def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
return torch.rand(1), self.targets[index]

def __len__(self) -> int:
return len(self.targets)


class OnesModel(nn.Module):

def __init__(self):
super().__init__()

self.layer = nn.Linear(1, 2)
self.register_buffer('zeros', torch.zeros(2))
self.register_buffer('zero_one', torch.tensor([0.0, 1.0]))

def forward(self, x):
x = self.layer(x)
return x * self.zeros + self.zero_one


# ================================


Expand Down Expand Up @@ -249,3 +279,19 @@ def test_optimization(tmpdir):
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = get_linear_schedule_with_warmup.__name__
assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected


def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
val_dataset = FixedDataset([1, 1])

model = OnesModel()

class CheckAccuracy(Callback):

def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
assert math.isclose(trainer.callback_metrics['train_accuracy_epoch'], 0.5)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))