diff --git a/CHANGELOG.md b/CHANGELOG.md index 50bd9fcaaa783..e1fb57ece2e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue with pickling `CSVLogger` after a call to `CSVLogger.save` ([#10388](https://github.com/PyTorchLightning/pytorch-lightning/pull/10388)) + +- Fixed the logging with `on_step=True` in epoch-level hooks causing unintended side-effects. Logging with `on_step=True` in epoch-level hooks will now correctly raise an error ([#10409](https://github.com/PyTorchLightning/pytorch-lightning/pull/10409)) + + - diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index a928122a2053a..cc91476518565 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -46,15 +46,15 @@ class _LogOptions(TypedDict): "on_predict_end": None, "on_pretrain_routine_start": None, "on_pretrain_routine_end": None, - "on_train_epoch_start": _LogOptions(on_step=(False, True), on_epoch=(True,)), + "on_train_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_train_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_validation_epoch_start": _LogOptions(on_step=(False, True), on_epoch=(True,)), + "on_validation_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_test_epoch_start": _LogOptions(on_step=(False, True), on_epoch=(True,)), + "on_test_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_predict_epoch_start": None, "on_predict_epoch_end": None, - "on_epoch_start": _LogOptions(on_step=(False, True), on_epoch=(True,)), + "on_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), "on_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)), "on_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index c3cc4f972b06b..6ed40b5f03082 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -423,6 +423,12 @@ def make_logging(self, pl_module, func_name, on_steps, on_epochs, prob_bars): def on_test_start(self, _, pl_module): self.make_logging(pl_module, "on_test_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices) + def on_epoch_start(self, trainer, pl_module): + if trainer.testing: + self.make_logging( + pl_module, "on_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices + ) + def on_test_epoch_start(self, _, pl_module): self.make_logging( pl_module, "on_test_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 6cad94017177e..5b775b9968d99 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -272,13 +272,11 @@ def on_train_start(self, _, pl_module): self.make_logging(pl_module, "on_train_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices) def on_epoch_start(self, _, pl_module): - self.make_logging( - pl_module, "on_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices - ) + self.make_logging(pl_module, "on_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices) def on_train_epoch_start(self, _, pl_module): self.make_logging( - pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices + pl_module, "on_train_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_batch_start(self, _, pl_module, *__):