Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/pytorch_forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Oct 29, 2021
2 parents c71bff7 + ddd4c94 commit d005bb1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where test metrics were not logged correctly with active learning ([#879](https://github.com/PyTorchLightning/lightning-flash/pull/879))


- Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900))


## [0.5.1] - 2021-10-26

### Added
Expand Down
3 changes: 2 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(

self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.test_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics)))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
Expand Down Expand Up @@ -454,7 +455,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
)

def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx, self.val_metrics)
output = self.step(batch, batch_idx, self.test_metrics)
self.log_dict(
{f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()},
on_step=False,
Expand Down
8 changes: 8 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,21 @@ def i_will_create_a_misconfiguration_exception(optimizer):
def test_classification_task_metrics():
train_dataset = FixedDataset([0, 1])
val_dataset = FixedDataset([1, 1])
test_dataset = FixedDataset([0, 0])

model = OnesModel()

class CheckAccuracy(Callback):
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["val_accuracy"], 1.0)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
assert math.isclose(trainer.callback_metrics["test_accuracy"], 0.0)

task = ClassificationTask(model)
trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))
trainer.test(task, dataloaders=DataLoader(test_dataset))

0 comments on commit d005bb1

Please sign in to comment.