Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track the evaluation loop outputs in the loop #10928

Merged
merged 15 commits into from
Dec 17, 2021
60 changes: 52 additions & 8 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
from typing import Any, List, Sequence, Union

import torch
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand All @@ -31,6 +33,7 @@ def __init__(self) -> None:

self._results = ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._max_batches: List[int] = []
self._has_run: bool = False

Expand Down Expand Up @@ -75,6 +78,7 @@ def reset(self) -> None:
self._max_batches = self._get_max_batches()
# bookkeeping
self._outputs = []
self._logged_outputs = []

if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
Expand Down Expand Up @@ -117,26 +121,52 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
# indicate the loop has run
self._has_run = True

def on_advance_end(self) -> None:
self.trainer.logger_connector.epoch_end_reached()

self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics())
carmocca marked this conversation as resolved.
Show resolved Hide resolved

super().on_advance_end()

def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
outputs, self._outputs = self._outputs, [] # free memory
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer.logger_connector.epoch_end_reached()

# lightning module method
self._evaluation_epoch_end(outputs)
# hook
self._evaluation_epoch_end(self._outputs)
self._outputs = [] # free memory

# hook
self._on_evaluation_epoch_end()

# log epoch metrics
eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
# include any logged outputs on epoch_end
if self.num_dataloaders < 2: # TODO: remove this check
carmocca marked this conversation as resolved.
Show resolved Hide resolved
epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics()
for dl_outputs in logged_outputs:
dl_outputs.update(epoch_end_logged_outputs)

# log metrics
self.trainer.logger_connector.log_eval_end_metrics()

# hook
self._on_evaluation_end()

# enable train mode again
self._on_evaluation_model_train()

return eval_loop_results
if (
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
and not self.trainer.sanity_checking
and self.trainer.is_global_zero
# TODO: this should be defined in this loop, not the Trainer
and self.trainer.verbose_evaluate
carmocca marked this conversation as resolved.
Show resolved Hide resolved
):
assert self.trainer.state.stage is not None
self._print_results(logged_outputs, self.trainer.state.stage)

return logged_outputs

def teardown(self) -> None:
self._results.cpu()
Expand Down Expand Up @@ -220,8 +250,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:

def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
self.trainer.logger_connector._evaluation_epoch_end()

# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
Expand All @@ -243,3 +272,18 @@ def _on_evaluation_epoch_end(self) -> None:
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def _print_results(self, results: List[_OUT_DICT], stage: RunningStage) -> None:
# TODO: this could be updated to look nicer
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pprint import pprint

print("-" * 80)
for i, metrics_dict in enumerate(results):
print(f"DATALOADER:{i} {stage.upper()} RESULTS")
pprint(
{
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in metrics_dict.items()
}
)
print("-" * 80)
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pprint
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _AcceleratorType, memory
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
Expand All @@ -36,7 +35,6 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
"Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
)
self.log_gpu_memory = log_gpu_memory
self.eval_loop_results: List[_OUT_DICT] = []
self._val_log_step: int = 0
self._test_log_step: int = 0
self._progress_bar_metrics: _PBAR_DICT = {}
Expand Down Expand Up @@ -139,6 +137,11 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def _evaluation_epoch_end(self) -> None:
results = self.trainer._results
assert results is not None
results.dataloader_idx = None

def update_eval_step_metrics(self) -> None:
assert not self._epoch_end_reached
if self.trainer.sanity_checking:
Expand All @@ -150,58 +153,19 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

def _prepare_eval_loop_results(self) -> None:
def update_eval_epoch_metrics(self) -> _OUT_DICT:
assert self._epoch_end_reached
if self.trainer.sanity_checking:
return

on_step = not self._epoch_end_reached
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
assert self.trainer._evaluation_loop._results is not None
for dl_idx in range(num_dataloaders):
metrics = self.trainer._evaluation_loop._results.metrics(
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
)
callback_metrics = metrics["callback"]
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)
return {}
return self.metrics["callback"]

def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
def log_eval_end_metrics(self) -> None:
assert self._epoch_end_reached
metrics = self.metrics

if not self.trainer.sanity_checking:
# log all the metrics as a single dict
self.log_metrics(metrics["log"])

self._prepare_eval_loop_results()

# log results of evaluation
if (
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
and self.trainer.evaluating
and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print("-" * 80)
assert self.trainer.state.stage is not None
for i, metrics_dict in enumerate(self.eval_loop_results):
print(f"DATALOADER:{i} {self.trainer.state.stage.upper()} RESULTS")
pprint(
{
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in metrics_dict.items()
}
)
print("-" * 80)
if self.trainer.sanity_checking:
return

results = self.eval_loop_results
# clear mem
self.eval_loop_results = []
return results
# log all the metrics as a single dict
self.log_metrics(self.metrics["log"])

"""
Train metric updates
Expand Down Expand Up @@ -272,11 +236,6 @@ def epoch_end_reached(self) -> None:

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
results = self.trainer._results
assert results is not None
# we need to reset this index before the `self.metrics` call below
results.dataloader_idx = None

metrics = self.metrics
self._progress_bar_metrics.update(metrics["pbar"])
self._callback_metrics.update(metrics["callback"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,13 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
return cache.detach()
return cache

def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator:
def valid_items(self) -> Generator:
"""This function is used to iterate over current valid metrics."""
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx))
if not (isinstance(v, ResultMetric) and v.has_reset)
and self.dataloader_idx in (None, v.meta.dataloader_idx)
)

def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
Expand All @@ -544,10 +545,10 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str,
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
def metrics(self, on_step: bool) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items(dataloader_idx):
for _, result_metric in self.valid_items():

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
Expand Down
4 changes: 2 additions & 2 deletions tests/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):


@mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.update_eval_epoch_metrics"
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.log_eval_end_metrics"
)
def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir):
"""Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired."""
"""Test that the epoch metrics are logged before the `on_evaluation_end` hook is fired."""
order = []
update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics")

Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,3 +738,21 @@ def test_dataloader(self):
"test_log_no_dl_idx_1": 321 * 2,
"test_log_b_class": 456.0,
}


def test_logging_multi_dataloader_on_epoch_end(tmpdir):
class CustomBoringModel(BoringModel):
def test_step(self, batch, batch_idx, dataloader_idx):
self.log("foo", 12.0)

def test_epoch_end(self, outputs) -> None:
self.log("foobar", 23.0)

def test_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(2)]

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
logged_results = trainer.test(model)
# TODO: what's logged in `test_epoch_end` should be included in the results of each dataloader
assert logged_results == [{"foo/dataloader_idx_0": 12.0}, {"foo/dataloader_idx_1": 12.0}]