Skip to content

Commit

Permalink
Track the evaluation loop outputs in the loop (#10928)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Dec 17, 2021
1 parent 210ff84 commit 5956a07
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 71 deletions.
60 changes: 52 additions & 8 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
from typing import Any, List, Sequence, Union

import torch
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand All @@ -31,6 +33,7 @@ def __init__(self) -> None:

self._results = ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._has_run: bool = False

Expand Down Expand Up @@ -75,6 +78,7 @@ def reset(self) -> None:
self._max_batches = self._get_max_batches()
# bookkeeping
self._outputs = []
self._logged_outputs = []

if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
Expand Down Expand Up @@ -117,26 +121,52 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
# indicate the loop has run
self._has_run = True

def on_advance_end(self) -> None:
self.trainer.logger_connector.epoch_end_reached()

self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics())

super().on_advance_end()

def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
outputs, self._outputs = self._outputs, [] # free memory
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer.logger_connector.epoch_end_reached()

# lightning module method
self._evaluation_epoch_end(outputs)
# hook
self._evaluation_epoch_end(self._outputs)
self._outputs = [] # free memory

# hook
self._on_evaluation_epoch_end()

# log epoch metrics
eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
# include any logged outputs on epoch_end
if self.num_dataloaders < 2: # TODO: remove this check
epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics()
for dl_outputs in logged_outputs:
dl_outputs.update(epoch_end_logged_outputs)

# log metrics
self.trainer.logger_connector.log_eval_end_metrics()

# hook
self._on_evaluation_end()

# enable train mode again
self._on_evaluation_model_train()

return eval_loop_results
if (
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
and not self.trainer.sanity_checking
and self.trainer.is_global_zero
# TODO: this should be defined in this loop, not the Trainer
and self.trainer.verbose_evaluate
):
assert self.trainer.state.stage is not None
self._print_results(logged_outputs, self.trainer.state.stage)

return logged_outputs

def teardown(self) -> None:
self._results.cpu()
Expand Down Expand Up @@ -220,8 +250,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:

def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
self.trainer.logger_connector._evaluation_epoch_end()

# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
Expand All @@ -243,3 +272,18 @@ def _on_evaluation_epoch_end(self) -> None:
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def _print_results(self, results: List[_OUT_DICT], stage: RunningStage) -> None:
# TODO: this could be updated to look nicer
from pprint import pprint

print("-" * 80)
for i, metrics_dict in enumerate(results):
print(f"DATALOADER:{i} {stage.upper()} RESULTS")
pprint(
{
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in metrics_dict.items()
}
)
print("-" * 80)
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pprint
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _AcceleratorType, memory
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
Expand All @@ -36,7 +35,6 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
"Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
)
self.log_gpu_memory = log_gpu_memory
self.eval_loop_results: List[_OUT_DICT] = []
self._val_log_step: int = 0
self._test_log_step: int = 0
self._progress_bar_metrics: _PBAR_DICT = {}
Expand Down Expand Up @@ -139,6 +137,11 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def _evaluation_epoch_end(self) -> None:
results = self.trainer._results
assert results is not None
results.dataloader_idx = None

def update_eval_step_metrics(self) -> None:
assert not self._epoch_end_reached
if self.trainer.sanity_checking:
Expand All @@ -150,58 +153,19 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

def _prepare_eval_loop_results(self) -> None:
def update_eval_epoch_metrics(self) -> _OUT_DICT:
assert self._epoch_end_reached
if self.trainer.sanity_checking:
return

on_step = not self._epoch_end_reached
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
assert self.trainer._evaluation_loop._results is not None
for dl_idx in range(num_dataloaders):
metrics = self.trainer._evaluation_loop._results.metrics(
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
)
callback_metrics = metrics["callback"]

if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)
return {}
return self.metrics["callback"]

def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
def log_eval_end_metrics(self) -> None:
assert self._epoch_end_reached
metrics = self.metrics

if not self.trainer.sanity_checking:
# log all the metrics as a single dict
self.log_metrics(metrics["log"])

self._prepare_eval_loop_results()

# log results of evaluation
if (
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
and self.trainer.evaluating
and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print("-" * 80)
assert self.trainer.state.stage is not None
for i, metrics_dict in enumerate(self.eval_loop_results):
print(f"DATALOADER:{i} {self.trainer.state.stage.upper()} RESULTS")
pprint(
{
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in metrics_dict.items()
}
)
print("-" * 80)
if self.trainer.sanity_checking:
return

results = self.eval_loop_results
# clear mem
self.eval_loop_results = []
return results
# log all the metrics as a single dict
self.log_metrics(self.metrics["log"])

"""
Train metric updates
Expand Down Expand Up @@ -272,11 +236,6 @@ def epoch_end_reached(self) -> None:

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
results = self.trainer._results
assert results is not None
# we need to reset this index before the `self.metrics` call below
results.dataloader_idx = None

metrics = self.metrics
self._progress_bar_metrics.update(metrics["pbar"])
self._callback_metrics.update(metrics["callback"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,13 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
return cache.detach()
return cache

def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator:
def valid_items(self) -> Generator:
"""This function is used to iterate over current valid metrics."""
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx))
if not (isinstance(v, ResultMetric) and v.has_reset)
and self.dataloader_idx in (None, v.meta.dataloader_idx)
)

def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
Expand All @@ -544,10 +545,10 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str,
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
def metrics(self, on_step: bool) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items(dataloader_idx):
for _, result_metric in self.valid_items():

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
Expand Down
4 changes: 2 additions & 2 deletions tests/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):


@mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.update_eval_epoch_metrics"
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.log_eval_end_metrics"
)
def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir):
"""Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired."""
"""Test that the epoch metrics are logged before the `on_evaluation_end` hook is fired."""
order = []
update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics")

Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,3 +738,21 @@ def test_dataloader(self):
"test_log_no_dl_idx_1": 321 * 2,
"test_log_b_class": 456.0,
}


def test_logging_multi_dataloader_on_epoch_end(tmpdir):
class CustomBoringModel(BoringModel):
def test_step(self, batch, batch_idx, dataloader_idx):
self.log("foo", 12.0)

def test_epoch_end(self, outputs) -> None:
self.log("foobar", 23.0)

def test_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(2)]

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
logged_results = trainer.test(model)
# TODO: what's logged in `test_epoch_end` should be included in the results of each dataloader
assert logged_results == [{"foo/dataloader_idx_0": 12.0}, {"foo/dataloader_idx_1": 12.0}]

0 comments on commit 5956a07

Please sign in to comment.