From d0698433758818f0e6f15d6c9fd682f613fc3ae2 Mon Sep 17 00:00:00 2001 From: Peutlefaire <42559439+BrianPulfer@users.noreply.github.com> Date: Fri, 20 Jan 2023 01:15:32 +0100 Subject: [PATCH] Solved minor bug with MLFlow logger (#16418) Resolves https://github.com/Lightning-AI/lightning/issues/16411 (cherry picked from commit 6fd914f40b12c8a133e015df1f9aa520f9082cb0) --- src/pytorch_lightning/CHANGELOG.md | 3 +++ src/pytorch_lightning/loggers/mlflow.py | 2 +- tests/tests_pytorch/loggers/test_mlflow.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0da62aafc4334..84dad03f6744a 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369)) +- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418)) + + ## [1.9.0] - 2023-01-17 diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index bbed562283326..980d4e4bccb9e 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -247,7 +247,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning ) continue - params_list.append(Param(key=v, value=v)) + params_list.append(Param(key=k, value=v)) self.experiment.log_batch(run_id=self.run_id, params=params_list) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index d6828901a9961..23de563270cfe 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -253,8 +253,9 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): logger.log_hyperparams(params) logger.experiment.log_batch.assert_called_once_with( - run_id=logger.run_id, params=[param(key="test_param", value="test_param")] + run_id=logger.run_id, params=[param(key="test", value="test_param")] ) + param.assert_called_with(key="test", value="test_param") metrics = {"some_metric": 10} logger.log_metrics(metrics) @@ -262,6 +263,7 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): logger.experiment.log_batch.assert_called_with( run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)] ) + metric.assert_called_with(key="some_metric", value=10, timestamp=1000, step=0) logger._mlflow_client.create_experiment.assert_called_once_with( name="test", artifact_location="my_artifact_location"