diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index f1478ecbf9cbe..367b702375eb4 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -62,7 +62,9 @@ class _LogOptions(TypedDict): "on_train_start": _LogOptions( allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True ), - "on_train_end": None, + "on_train_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "on_validation_start": _LogOptions( allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True ), diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 1ffe7ffe9defb..173319c9ef134 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -109,7 +109,7 @@ def test_fx_validator(): is_stage = "train" in func_name or "test" in func_name or "validation" in func_name is_start = "start" in func_name or "batch" in func_name is_epoch = "epoch" in func_name - on_step = is_stage and not is_start and not is_epoch + on_step = is_stage and not is_start and not is_epoch and func_name not in ["on_train_end"] on_epoch = True # creating allowed condition allowed = ( @@ -124,7 +124,7 @@ def test_fx_validator(): allowed and "pretrain" not in func_name and "predict" not in func_name - and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] + and func_name not in ["on_test_end", "on_validation_end"] ) if allowed: validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) @@ -198,7 +198,6 @@ def test_fx_validator_integration(tmpdir): "transfer_batch_to_device": "You can't", "on_after_batch_transfer": "You can't", "on_validation_end": "You can't", - "on_train_end": "You can't", "on_fit_end": "You can't", "teardown": "You can't", "on_sanity_check_start": "You can't", diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index 3251d4d2aa5ef..d08b6c13432f9 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -68,6 +68,7 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): hooks = [ "on_train_start", + "on_train_end", "on_train_epoch_start", "on_train_epoch_end", "training_epoch_end",