Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix testing loop in Active Learning (#879)
Browse files Browse the repository at this point in the history
Co-authored-by: fr.branchaud-charron <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
5 people authored Oct 29, 2021
1 parent e63a9a1 commit 194c8d9
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where test metrics were not logged correctly with active learning ([#879](https://github.com/PyTorchLightning/lightning-flash/pull/879))

## [0.5.1] - 2021-10-26

### Added
Expand Down
5 changes: 1 addition & 4 deletions flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,7 @@ def label(self, probabilities: List[torch.Tensor] = None, indices=None):
uncertainties = self.heuristic.get_uncertainties(torch.cat(probabilities, dim=0))
indices = np.argsort(uncertainties)
if self._dataset is not None:
unlabelled_mask = self._dataset.labelled == False # noqa E712
unlabelled = self._dataset.labelled[unlabelled_mask]
unlabelled[indices[-self.query_size :]] = True
self._dataset.labelled[unlabelled_mask] = unlabelled
self._dataset.label(indices[-self.query_size :])

def state_dict(self) -> Dict[str, torch.Tensor]:
return self._dataset.state_dict()
Expand Down
34 changes: 25 additions & 9 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus

import flash
from flash.core.data.utils import _STAGES_PREFIX
Expand Down Expand Up @@ -83,6 +83,8 @@ def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
if self.trainer.datamodule.has_labelled_data:
self._reset_dataloader_for_stage(RunningStage.TRAINING)
self._reset_dataloader_for_stage(RunningStage.VALIDATING)
if self.trainer.datamodule.has_test:
self._reset_dataloader_for_stage(RunningStage.TESTING)
if self.trainer.datamodule.has_unlabelled_data:
self._reset_dataloader_for_stage(RunningStage.PREDICTING)
self.progress.increment_ready()
Expand All @@ -94,7 +96,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self.fit_loop.run()

if self.trainer.datamodule.has_test:
self.trainer.test_loop.run()
self._reset_testing()
metrics = self.trainer.test_loop.run()
if metrics:
self.trainer.logger.log_metrics(metrics[0], step=self.trainer.global_step)

if self.trainer.datamodule.has_unlabelled_data:
self._reset_predicting()
Expand Down Expand Up @@ -133,19 +138,30 @@ def _reset_fitting(self):
self.trainer.training = True
self.trainer.lightning_module.on_train_dataloader()
self.trainer.accelerator.connect(self._lightning_module)
self.fit_loop.epoch_progress = Progress()

def _reset_predicting(self):
self.trainer.state.fn = TrainerFn.PREDICTING
self.trainer.predicting = True
self.trainer.lightning_module.on_predict_dataloader()
self.trainer.accelerator.connect(self.inference_model)

def _reset_testing(self):
self.trainer.state.fn = TrainerFn.TESTING
self.trainer.state.status = TrainerStatus.RUNNING
self.trainer.testing = True
self.trainer.lightning_module.on_test_dataloader()
self.trainer.accelerator.connect(self._lightning_module)

def _reset_dataloader_for_stage(self, running_state: RunningStage):
dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader"
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(getattr(self.trainer.datamodule, dataloader_name)(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
# If the dataloader exists, we reset it.
dataloader = getattr(self.trainer.datamodule, dataloader_name, None)
if dataloader:
setattr(
self.trainer.lightning_module,
dataloader_name,
_PatchDataLoader(dataloader(), running_state),
)
setattr(self.trainer, dataloader_name, None)
getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module)
42 changes: 39 additions & 3 deletions tests/image/classification/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pytest
import torch
from pytorch_lightning import seed_everything
from torch import nn
from torch.utils.data import SequentialSampler
Expand Down Expand Up @@ -94,12 +95,21 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s
backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities()
)
trainer = flash.Trainer(max_epochs=3)

active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

trainer.finetune(model, datamodule=active_learning_dm, strategy="freeze")
trainer.finetune(model, datamodule=active_learning_dm, strategy="no_freeze")
# Check that all metrics are logged
assert all(
any(m in log_met for log_met in active_learning_loop.trainer.logged_metrics) for m in ("train", "val", "test")
)

# Check that the weights has changed for both module.
classifier = active_learning_loop._lightning_module.adapter.parameters()
mc_inference = active_learning_loop.inference_model.parent_module.parameters()
assert all(torch.equal(p1, p2) for p1, p2 in zip(classifier, mc_inference))

if initial_num_labels == 0:
assert len(active_learning_dm._dataset) == 15
else:
Expand All @@ -117,3 +127,29 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s
else:
# in the second scenario we have more labelled data!
assert len(active_learning_dm.val_dataloader()) == 5


@pytest.mark.skipif(not (_IMAGE_TESTING and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed.")
def test_no_validation_loop(simple_datamodule):
active_learning_dm = ActiveLearningDataModule(
simple_datamodule,
initial_num_labels=2,
query_size=100,
val_split=0.0,
)
assert active_learning_dm.val_dataloader is None
head = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(512, active_learning_dm.num_classes),
)

model = ImageClassifier(
backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, serializer=Probabilities()
)
trainer = flash.Trainer(max_epochs=3)
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# Check that we can finetune without val_set
trainer.finetune(model, datamodule=active_learning_dm, strategy="no_freeze")

0 comments on commit 194c8d9

Please sign in to comment.