diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index f77b4567aa263..6b31ebb58057a 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -13,12 +13,14 @@ # limitations under the License. from typing import Any, List, Sequence, Union +import torch from deprecate.utils import void from torch.utils.data.dataloader import DataLoader from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -31,6 +33,7 @@ def __init__(self) -> None: self._results = ResultCollection(training=False) self._outputs: List[EPOCH_OUTPUT] = [] + self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[int] = [] self._has_run: bool = False @@ -75,6 +78,7 @@ def reset(self) -> None: self._max_batches = self._get_max_batches() # bookkeeping self._outputs = [] + self._logged_outputs = [] if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) @@ -117,18 +121,34 @@ def advance(self, *args: Any, **kwargs: Any) -> None: # indicate the loop has run self._has_run = True + def on_advance_end(self) -> None: + self.trainer.logger_connector.epoch_end_reached() + + self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics()) + + super().on_advance_end() + def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" - outputs, self._outputs = self._outputs, [] # free memory + # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` + self.trainer.logger_connector.epoch_end_reached() - # lightning module method - self._evaluation_epoch_end(outputs) + # hook + self._evaluation_epoch_end(self._outputs) + self._outputs = [] # free memory # hook self._on_evaluation_epoch_end() - # log epoch metrics - eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics() + logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory + # include any logged outputs on epoch_end + if self.num_dataloaders < 2: # TODO: remove this check + epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics() + for dl_outputs in logged_outputs: + dl_outputs.update(epoch_end_logged_outputs) + + # log metrics + self.trainer.logger_connector.log_eval_end_metrics() # hook self._on_evaluation_end() @@ -136,7 +156,17 @@ def on_run_end(self) -> List[_OUT_DICT]: # enable train mode again self._on_evaluation_model_train() - return eval_loop_results + if ( + self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING) + and not self.trainer.sanity_checking + and self.trainer.is_global_zero + # TODO: this should be defined in this loop, not the Trainer + and self.trainer.verbose_evaluate + ): + assert self.trainer.state.stage is not None + self._print_results(logged_outputs, self.trainer.state.stage) + + return logged_outputs def teardown(self) -> None: self._results.cpu() @@ -220,8 +250,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None: """Runs ``{validation/test}_epoch_end``""" - # inform logger the batch loop has finished - self.trainer.logger_connector.epoch_end_reached() + self.trainer.logger_connector._evaluation_epoch_end() # with a single dataloader don't pass a 2D list output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = ( @@ -243,3 +272,18 @@ def _on_evaluation_epoch_end(self) -> None: self.trainer._call_callback_hooks("on_epoch_end") self.trainer._call_lightning_module_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() + + def _print_results(self, results: List[_OUT_DICT], stage: RunningStage) -> None: + # TODO: this could be updated to look nicer + from pprint import pprint + + print("-" * 80) + for i, metrics_dict in enumerate(results): + print(f"DATALOADER:{i} {stage.upper()} RESULTS") + pprint( + { + k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v + for k, v in metrics_dict.items() + } + ) + print("-" * 80) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 20a59fb440357..487188662e2dc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pprint import pprint -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union import torch @@ -20,7 +19,7 @@ from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT -from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _AcceleratorType, memory from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -36,7 +35,6 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None) "Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead." ) self.log_gpu_memory = log_gpu_memory - self.eval_loop_results: List[_OUT_DICT] = [] self._val_log_step: int = 0 self._test_log_step: int = 0 self._progress_bar_metrics: _PBAR_DICT = {} @@ -139,6 +137,11 @@ def _increment_eval_log_step(self) -> None: elif self.trainer.state.stage is RunningStage.TESTING: self._test_log_step += 1 + def _evaluation_epoch_end(self) -> None: + results = self.trainer._results + assert results is not None + results.dataloader_idx = None + def update_eval_step_metrics(self) -> None: assert not self._epoch_end_reached if self.trainer.sanity_checking: @@ -150,58 +153,19 @@ def update_eval_step_metrics(self) -> None: # increment the step even if nothing was logged self._increment_eval_log_step() - def _prepare_eval_loop_results(self) -> None: + def update_eval_epoch_metrics(self) -> _OUT_DICT: + assert self._epoch_end_reached if self.trainer.sanity_checking: - return - - on_step = not self._epoch_end_reached - num_dataloaders = self.trainer._evaluation_loop.num_dataloaders - has_been_initialized = len(self.eval_loop_results) == num_dataloaders - assert self.trainer._evaluation_loop._results is not None - for dl_idx in range(num_dataloaders): - metrics = self.trainer._evaluation_loop._results.metrics( - on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None - ) - callback_metrics = metrics["callback"] - - if has_been_initialized: - self.eval_loop_results[dl_idx].update(callback_metrics) - else: - self.eval_loop_results.append(callback_metrics) + return {} + return self.metrics["callback"] - def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: + def log_eval_end_metrics(self) -> None: assert self._epoch_end_reached - metrics = self.metrics - - if not self.trainer.sanity_checking: - # log all the metrics as a single dict - self.log_metrics(metrics["log"]) - - self._prepare_eval_loop_results() - - # log results of evaluation - if ( - self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING) - and self.trainer.evaluating - and self.trainer.is_global_zero - and self.trainer.verbose_evaluate - ): - print("-" * 80) - assert self.trainer.state.stage is not None - for i, metrics_dict in enumerate(self.eval_loop_results): - print(f"DATALOADER:{i} {self.trainer.state.stage.upper()} RESULTS") - pprint( - { - k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v - for k, v in metrics_dict.items() - } - ) - print("-" * 80) + if self.trainer.sanity_checking: + return - results = self.eval_loop_results - # clear mem - self.eval_loop_results = [] - return results + # log all the metrics as a single dict + self.log_metrics(self.metrics["log"]) """ Train metric updates @@ -272,11 +236,6 @@ def epoch_end_reached(self) -> None: def on_epoch_end(self) -> None: assert self._epoch_end_reached - results = self.trainer._results - assert results is not None - # we need to reset this index before the `self.metrics` call below - results.dataloader_idx = None - metrics = self.metrics self._progress_bar_metrics.update(metrics["pbar"]) self._callback_metrics.update(metrics["callback"]) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 4878099afc524..644bc6f855f3c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -525,12 +525,13 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten return cache.detach() return cache - def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator: + def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ( (k, v) for k, v in self.items() - if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx)) + if not (isinstance(v, ResultMetric) and v.has_reset) + and self.dataloader_idx in (None, v.meta.dataloader_idx) ) def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: @@ -544,10 +545,10 @@ def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, forked_name += dataloader_suffix return name, forked_name - def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS: + def metrics(self, on_step: bool) -> _METRICS: metrics = _METRICS(callback={}, log={}, pbar={}) - for _, result_metric in self.valid_items(dataloader_idx): + for _, result_metric in self.valid_items(): # extract forward_cache or computed from the ResultMetric. ignore when the output is None value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 1d4b0ea4cd6b1..30095beca228b 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -43,10 +43,10 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): @mock.patch( - "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.update_eval_epoch_metrics" + "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.log_eval_end_metrics" ) def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir): - """Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired.""" + """Test that the epoch metrics are logged before the `on_evaluation_end` hook is fired.""" order = [] update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics") diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 04ef7568e13ca..566db08f58a72 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -738,3 +738,21 @@ def test_dataloader(self): "test_log_no_dl_idx_1": 321 * 2, "test_log_b_class": 456.0, } + + +def test_logging_multi_dataloader_on_epoch_end(tmpdir): + class CustomBoringModel(BoringModel): + def test_step(self, batch, batch_idx, dataloader_idx): + self.log("foo", 12.0) + + def test_epoch_end(self, outputs) -> None: + self.log("foobar", 23.0) + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(2)] + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + logged_results = trainer.test(model) + # TODO: what's logged in `test_epoch_end` should be included in the results of each dataloader + assert logged_results == [{"foo/dataloader_idx_0": 12.0}, {"foo/dataloader_idx_1": 12.0}]