From 3db5d608ae2589c0dfc4ca9d43f863e743c1aa56 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Mon, 5 Dec 2022 16:31:35 -0700 Subject: [PATCH 01/16] use `log_batch` --- src/pytorch_lightning/loggers/mlflow.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 35f9b8396dd0f..f67a0354732a8 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -22,7 +22,7 @@ from argparse import Namespace from pathlib import Path from time import time -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import torch import yaml @@ -39,6 +39,7 @@ _MLFLOW_AVAILABLE = module_available("mlflow") try: import mlflow + from mlflow.entities import Metric, Param from mlflow.tracking import context, MlflowClient from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME # todo: there seems to be still some remaining import error with Conda env @@ -240,20 +241,26 @@ def experiment_id(self) -> Optional[str]: def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) + # params = [Param(k, v) for k, v in params.items()] + params_list: List[Param] = [] + for k, v in params.items(): + # FIXME: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 if len(str(v)) > 250: rank_zero_warn( f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning ) continue + params_list.append(Param(key=v, value=v)) - self.experiment.log_param(self.run_id, k, v) + self.experiment.log_batch(run_id=self.run_id, params=params_list) @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + metrics_list: List[Metric] = [] timestamp_ms = int(time() * 1000) for k, v in metrics.items(): @@ -269,8 +276,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) category=RuntimeWarning, ) k = new_k + metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0)) - self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step) + self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list) @rank_zero_only def finalize(self, status: str = "success") -> None: From 62db2b980b4c2deddc0ad5541ca9485fb13f667f Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Mon, 5 Dec 2022 16:31:49 -0700 Subject: [PATCH 02/16] fix tests --- tests/tests_pytorch/loggers/test_all.py | 2 +- tests/tests_pytorch/loggers/test_mlflow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 4477b13b5b2a9..a91313ab9b43a 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -287,7 +287,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): ): logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) + logger.experiment.log_batch.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) # Neptune with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch( diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index a395a980612ed..18c959b6cd4c5 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -249,7 +249,7 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): params = {"test": "test_param"} logger.log_hyperparams(params) - logger.experiment.log_param.assert_called_once_with(logger.run_id, "test", "test_param") + logger.experiment.log_batch.assert_called_once_with(logger.run_id, "test", "test_param") metrics = {"some_metric": 10} logger.log_metrics(metrics) From 93448704e55479f77fdb1d784fe7e31525646051 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Dec 2022 23:34:27 +0000 Subject: [PATCH 03/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/loggers/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index f67a0354732a8..24912bd618745 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -243,7 +243,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _flatten_dict(params) # params = [Param(k, v) for k, v in params.items()] params_list: List[Param] = [] - + for k, v in params.items(): # FIXME: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 if len(str(v)) > 250: From f9bd6a19fed1b8368a887602b8b2923f4b1815a7 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Mon, 5 Dec 2022 17:36:22 -0700 Subject: [PATCH 04/16] fix test --- tests/tests_pytorch/loggers/test_mlflow.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 18c959b6cd4c5..99c6fd3e48f47 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock import pytest +from mlflow.entities import Metric, Param from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -249,12 +250,16 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): params = {"test": "test_param"} logger.log_hyperparams(params) - logger.experiment.log_batch.assert_called_once_with(logger.run_id, "test", "test_param") + logger.experiment.log_batch.assert_called_once_with( + run_id=logger.run_id, params=[Param(key="test_param", value="test_param")] + ) metrics = {"some_metric": 10} logger.log_metrics(metrics) - logger.experiment.log_metric.assert_called_once_with(logger.run_id, "some_metric", 10, 1000, None) + logger.experiment.log_batch.assert_called_with( + run_id=logger.run_id, metrics=[Metric(key="some_metric", value=10, timestamp=1000, step=0)] + ) logger._mlflow_client.create_experiment.assert_called_once_with( name="test", artifact_location="my_artifact_location" From cbd34c85e364023cc1aa3e9cfd2d63b9f5f49538 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Mon, 5 Dec 2022 19:44:24 -0700 Subject: [PATCH 05/16] fix test --- tests/tests_pytorch/loggers/test_all.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index a91313ab9b43a..044e105d25b91 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -19,6 +19,7 @@ import pytest import torch +from mlflow.entities import Metric from pytorch_lightning import Callback, Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -287,7 +288,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): ): logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_batch.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0) + logger.experiment.log_batch.assert_called_once_with( + run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)] + ) # Neptune with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch( From c1d5ee3ed413449b45e7f0441c5d53b5f8f7e69b Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Tue, 6 Dec 2022 08:24:15 -0700 Subject: [PATCH 06/16] hide imports --- tests/tests_pytorch/loggers/test_all.py | 3 ++- tests/tests_pytorch/loggers/test_mlflow.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 044e105d25b91..4a28e4f7d07c9 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -19,7 +19,6 @@ import pytest import torch -from mlflow.entities import Metric from pytorch_lightning import Callback, Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -286,6 +285,8 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( "pytorch_lightning.loggers.mlflow.MlflowClient" ): + from mlflow.entities import Metric + logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log_batch.assert_called_once_with( diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 99c6fd3e48f47..cdaf17edf53e2 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,7 +16,6 @@ from unittest.mock import MagicMock import pytest -from mlflow.entities import Metric, Param from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel @@ -242,6 +241,8 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): """Test that the logger calls methods on the mlflow experiment correctly.""" + from mlflow.entities import Metric, Param + time.return_value = 1 logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location") From 29ae0bbbe187d80f29d9796d272c3e6c12fe3b4f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 7 Dec 2022 20:49:54 +0100 Subject: [PATCH 07/16] test mocking --- src/pytorch_lightning/loggers/mlflow.py | 1 + tests/tests_pytorch/loggers/test_mlflow.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 24912bd618745..ca4b5c5402c7b 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -46,6 +46,7 @@ except ModuleNotFoundError: _MLFLOW_AVAILABLE = False mlflow, MlflowClient, context = None, None, None + Metric, Param = None, None MLFLOW_RUN_NAME = "mlflow.runName" # before v1.1.0 diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index cdaf17edf53e2..f74187a7ff655 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): assert logger.experiment.get_experiment_by_name.call_count == 1 +@mock.patch("pytorch_lightning.loggers.mlflow.Metric") @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): +def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) metrics = {"[some_metric]": 10} @@ -236,13 +237,13 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): logger.log_hyperparams(params) +@mock.patch("pytorch_lightning.loggers.mlflow.Metric") +@mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow.time") @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): +def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmpdir): """Test that the logger calls methods on the mlflow experiment correctly.""" - from mlflow.entities import Metric, Param - time.return_value = 1 logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location") @@ -252,14 +253,14 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): logger.log_hyperparams(params) logger.experiment.log_batch.assert_called_once_with( - run_id=logger.run_id, params=[Param(key="test_param", value="test_param")] + run_id=logger.run_id, params=[param(key="test_param", value="test_param")] ) metrics = {"some_metric": 10} logger.log_metrics(metrics) logger.experiment.log_batch.assert_called_with( - run_id=logger.run_id, metrics=[Metric(key="some_metric", value=10, timestamp=1000, step=0)] + run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)] ) logger._mlflow_client.create_experiment.assert_called_once_with( From 9c7207185570e19a4e274cb218199ac9f121f3a0 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 13:35:42 -0700 Subject: [PATCH 08/16] use mock in test --- tests/tests_pytorch/loggers/test_all.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 4a28e4f7d07c9..574ef067d907a 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -283,10 +283,8 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): # MLflow with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( - "pytorch_lightning.loggers.mlflow.MlflowClient" - ): - from mlflow.entities import Metric - + "pytorch_lightning.loggers.mlflow.Metric" + ) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"): logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) logger.experiment.log_batch.assert_called_once_with( From ccd08777f615c3c0b4ab2754ee973fe266532d8f Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 13:35:52 -0700 Subject: [PATCH 09/16] rm commented code --- src/pytorch_lightning/loggers/mlflow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index ca4b5c5402c7b..ee28bd53b2cca 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -242,7 +242,6 @@ def experiment_id(self) -> Optional[str]: def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) - # params = [Param(k, v) for k, v in params.items()] params_list: List[Param] = [] for k, v in params.items(): From 4a81f13216c724162d0dbb970082f7d873466008 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 13:47:36 -0700 Subject: [PATCH 10/16] update `CHANGELOG.md` --- src/pytorch_lightning/CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 67197271d07ca..14c916cb6cd63 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -33,7 +33,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) - - Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) @@ -57,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634)) +- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915)) + + ### Deprecated - Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651)) From 09533d345dd918a5c801cc6449d5cf730733a589 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 13:48:01 -0700 Subject: [PATCH 11/16] add `mlflow` version floor to reference file --- requirements/pytorch/loggers.info | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/loggers.info b/requirements/pytorch/loggers.info index 590d3babcc8d2..247b1eb1792c9 100644 --- a/requirements/pytorch/loggers.info +++ b/requirements/pytorch/loggers.info @@ -1,5 +1,5 @@ # all supported loggers. this list is here as a reference, but they are not installed in CI neptune-client comet-ml -mlflow +mlflow>=1 wandb From 3428e5dbd6b2554dd2506b0d67c54f94569f5aa8 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 17:12:12 -0700 Subject: [PATCH 12/16] add mlflow version guard --- src/pytorch_lightning/loggers/mlflow.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index ee28bd53b2cca..aec54f3aa2c14 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -31,6 +31,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment +from pytorch_lightning.utilities.imports import RequirementCache from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _scan_checkpoints from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn @@ -153,6 +154,9 @@ def __init__( raise ModuleNotFoundError( "You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`." ) + if not RequirementCache("mlflow>=1.0.0"): + # we require the log_batch APIs that were introduced in mlflow 1.0.0 + raise RuntimeError("Incompatible mlflow version") super().__init__() if not tracking_uri: tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" From 65fa6591bcf351a6cbf2c02562db6aefa653609f Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Wed, 7 Dec 2022 18:14:15 -0600 Subject: [PATCH 13/16] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- requirements/pytorch/loggers.info | 2 +- src/pytorch_lightning/loggers/mlflow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/pytorch/loggers.info b/requirements/pytorch/loggers.info index 247b1eb1792c9..39d593df9c87f 100644 --- a/requirements/pytorch/loggers.info +++ b/requirements/pytorch/loggers.info @@ -1,5 +1,5 @@ # all supported loggers. this list is here as a reference, but they are not installed in CI neptune-client comet-ml -mlflow>=1 +mlflow>=1.0.0 wandb diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index aec54f3aa2c14..14f35c381bab3 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -249,7 +249,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params_list: List[Param] = [] for k, v in params.items(): - # FIXME: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 + # TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 if len(str(v)) > 250: rank_zero_warn( f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning From be9b6facb42599c6540c136c644272ef727a7135 Mon Sep 17 00:00:00 2001 From: Jake Schmidt Date: Fri, 9 Dec 2022 09:10:01 -0600 Subject: [PATCH 14/16] Update src/pytorch_lightning/loggers/mlflow.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/loggers/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 968ee9a556f1f..0f26e9f9200db 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -156,7 +156,7 @@ def __init__( ) if not RequirementCache("mlflow>=1.0.0"): # we require the log_batch APIs that were introduced in mlflow 1.0.0 - raise RuntimeError("Incompatible mlflow version") + raise RuntimeError("`MLFlowLogger` requires mlflow >= 1.0.0. Hint: Run `pip install -U mlflow`") super().__init__() if not tracking_uri: tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" From 556265471ae53fefff5403c682b6dc9e196541cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 12 Dec 2022 13:38:28 +0100 Subject: [PATCH 15/16] RequirementCache --- src/pytorch_lightning/loggers/mlflow.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 0f26e9f9200db..44abe2cda5335 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -25,27 +25,24 @@ from typing import Any, Dict, List, Mapping, Optional, Union import yaml -from lightning_utilities.core.imports import module_available +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from typing_extensions import Literal from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment -from pytorch_lightning.utilities.imports import RequirementCache from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _scan_checkpoints from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) LOCAL_FILE_URI_PREFIX = "file:" -_MLFLOW_AVAILABLE = module_available("mlflow") -try: +_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0") +if _MLFLOW_AVAILABLE: import mlflow from mlflow.entities import Metric, Param from mlflow.tracking import context, MlflowClient from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME -# todo: there seems to be still some remaining import error with Conda env -except ModuleNotFoundError: - _MLFLOW_AVAILABLE = False +else: mlflow, MlflowClient, context = None, None, None Metric, Param = None, None MLFLOW_RUN_NAME = "mlflow.runName" @@ -150,13 +147,8 @@ def __init__( artifact_location: Optional[str] = None, run_id: Optional[str] = None, ): - if mlflow is None: - raise ModuleNotFoundError( - "You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`." - ) - if not RequirementCache("mlflow>=1.0.0"): - # we require the log_batch APIs that were introduced in mlflow 1.0.0 - raise RuntimeError("`MLFlowLogger` requires mlflow >= 1.0.0. Hint: Run `pip install -U mlflow`") + if not _MLFLOW_AVAILABLE: + raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) super().__init__() if not tracking_uri: tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" From 183b7b4cf85914df13e37e7bafa26616f63e3c6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 12 Dec 2022 14:07:34 +0100 Subject: [PATCH 16/16] Fix mocks --- src/pytorch_lightning/loggers/mlflow.py | 3 +- tests/tests_pytorch/loggers/test_all.py | 7 ++-- tests/tests_pytorch/loggers/test_mlflow.py | 38 +++++++++++----------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 44abe2cda5335..70767eca5c1a6 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -38,12 +38,11 @@ LOCAL_FILE_URI_PREFIX = "file:" _MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0") if _MLFLOW_AVAILABLE: - import mlflow from mlflow.entities import Metric, Param from mlflow.tracking import context, MlflowClient from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME else: - mlflow, MlflowClient, context = None, None, None + MlflowClient, context = None, None Metric, Param = None, None MLFLOW_RUN_NAME = "mlflow.runName" diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 574ef067d907a..3c8dbe844b526 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -40,8 +40,9 @@ LOGGER_CTX_MANAGERS = ( mock.patch("pytorch_lightning.loggers.comet.comet_ml"), mock.patch("pytorch_lightning.loggers.comet.CometOfflineExperiment"), - mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), + mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"), + mock.patch("pytorch_lightning.loggers.mlflow.Metric"), mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock), mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True), mock.patch("pytorch_lightning.loggers.wandb.wandb"), @@ -282,7 +283,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) # MLflow - with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( + with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch( "pytorch_lightning.loggers.mlflow.Metric" ) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"): logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix) @@ -342,7 +343,7 @@ def test_logger_default_name(tmpdir, monkeypatch): assert logger.name == "lightning_logs" # MLflow - with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch( + with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch( "pytorch_lightning.loggers.mlflow.MlflowClient" ) as mlflow_client: mlflow_client().get_experiment_by_name.return_value = None diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index f74187a7ff655..17bd5389f4f96 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r return logger -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_exists(client, mlflow, tmpdir): +def test_mlflow_logger_exists(client, _, tmpdir): """Test launching three independent loggers with either same or different experiment name.""" run1 = MagicMock() @@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir): assert logger3.run_id == "run-id-3" -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_run_name_setting(client, mlflow, tmpdir): +def test_mlflow_run_name_setting(client, _, tmpdir): """Test that the run_name argument makes the MLFLOW_RUN_NAME tag.""" tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"}) @@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir): client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_run_id_setting(client, mlflow, tmpdir): +def test_mlflow_run_id_setting(client, _, tmpdir): """Test that the run_id argument uses the provided run_id.""" run = MagicMock() @@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir): client.reset_mock(return_value=True) -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_log_dir(client, mlflow, tmpdir): +def test_mlflow_log_dir(client, _, tmpdir): """Test that the trainer saves checkpoints in the logger's save dir.""" # simulate experiment creation with mlflow client mock @@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): def test_mlflow_logger_dirs_creation(tmpdir): """Test that the logger creates the folders and files in the right place.""" if not _MLFLOW_AVAILABLE: - pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.") + pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") assert not os.listdir(tmpdir) logger = MLFlowLogger("test", save_dir=tmpdir) @@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs): assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): +def test_mlflow_experiment_id_retrieved_once(client, tmpdir): """Test that the logger experiment_id retrieved only once.""" logger = MLFlowLogger("test", save_dir=tmpdir) _ = logger.experiment @@ -213,9 +213,9 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.Metric") -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir): +def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) metrics = {"[some_metric]": 10} @@ -224,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir): logger.log_metrics(metrics) -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): +def test_mlflow_logger_with_long_param_value(client, _, tmpdir): """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) value = "test" * 100 @@ -240,9 +240,9 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.Metric") @mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow.time") -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmpdir): +def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): """Test that the logger calls methods on the mlflow experiment correctly.""" time.return_value = 1 @@ -268,7 +268,7 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmp ) -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_finalize_when_exception(*_): logger = MLFlowLogger("test") @@ -286,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_): logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED") -@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") @pytest.mark.parametrize("log_model", ["all", True, False]) def test_mlflow_log_model(client, _, tmpdir, log_model):