From 898a236681def29aec9ce77e02fc5a15dd5d9cb3 Mon Sep 17 00:00:00 2001 From: Adrien B Date: Thu, 12 Jan 2023 16:04:14 +0100 Subject: [PATCH] Fix "finished" status code in MLFlowLogger (#16340) Co-authored-by: awaelchli --- src/pytorch_lightning/CHANGELOG.md | 3 +++ src/pytorch_lightning/loggers/mlflow.py | 2 ++ tests/tests_pytorch/loggers/test_mlflow.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 15b855dcd9776..c6ec5b9f8f2bc 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where the ``interval`` key of the scheduler would be ignored during manual optimization, making the LearningRateMonitor callback fail to log the learning rate ([#16308](https://github.com/Lightning-AI/lightning/pull/16308)) +- Fixed an issue with `MLFlowLogger` not finalizing correctly when status code 'finished' was passed ([#16340](https://github.com/Lightning-AI/lightning/pull/16340)) + + ## [1.8.6] - 2022-12-21 - minor cleaning diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 87f0d707e111b..bbed562283326 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -284,6 +284,8 @@ def finalize(self, status: str = "success") -> None: status = "FINISHED" elif status == "failed": status = "FAILED" + elif status == "finished": + status = "FINISHED" # log checkpoints as artifacts if self._checkpoint_callback: diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 17bd5389f4f96..d6828901a9961 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -268,6 +268,27 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): ) +@pytest.mark.parametrize( + "status,expected", + [ + ("success", "FINISHED"), + ("failed", "FAILED"), + ("finished", "FINISHED"), + ], +) +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_finalize(_, __, status, expected): + logger = MLFlowLogger("test") + + # Pretend we are in a worker process and finalizing + _ = logger.experiment + assert logger._initialized + + logger.finalize(status) + logger.experiment.set_terminated.assert_called_once_with(logger.run_id, expected) + + @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_finalize_when_exception(*_):