From ddd4c941113c41e0795891f09e0434c25458230a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 16:52:48 +0100 Subject: [PATCH] Fixed a bug where validation metrics could be aggregated together with test metrics (#900) --- CHANGELOG.md | 4 ++++ flash/core/model.py | 3 ++- tests/core/test_model.py | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8b177eb11..989bfcc60e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where test metrics were not logged correctly with active learning ([#879](https://github.com/PyTorchLightning/lightning-flash/pull/879)) + +- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900)) + + ## [0.5.1] - 2021-10-26 ### Added diff --git a/flash/core/model.py b/flash/core/model.py index 20da95d285..6f87bcb4c3 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -354,6 +354,7 @@ def __init__( self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) + self.test_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") @@ -454,7 +455,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None: ) def test_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx, self.val_metrics) + output = self.step(batch, batch_idx, self.test_metrics) self.log_dict( {f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()}, on_step=False, diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 0e68344bb5..f31bba3e70 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -437,6 +437,7 @@ def i_will_create_a_misconfiguration_exception(optimizer): def test_classification_task_metrics(): train_dataset = FixedDataset([0, 1]) val_dataset = FixedDataset([1, 1]) + test_dataset = FixedDataset([0, 0]) model = OnesModel() @@ -444,6 +445,13 @@ class CheckAccuracy(Callback): def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5) + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["val_accuracy"], 1.0) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["test_accuracy"], 0.0) + task = ClassificationTask(model) trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count()) trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset)) + trainer.test(task, dataloaders=DataLoader(test_dataset))