Skip to content

Commit

Permalink
Batch MLFlowLogger requests (#15915)
Browse files Browse the repository at this point in the history
Co-authored-by: Jake Schmidt <[email protected]>
Co-authored-by: awaelchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
4 people authored Dec 12, 2022
1 parent 2577285 commit 38acba0
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 43 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.info
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI
neptune-client
comet-ml
mlflow
mlflow>=1.0.0
wandb
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))


- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))


### Deprecated

- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))
Expand Down
33 changes: 18 additions & 15 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import yaml
from lightning_utilities.core.imports import module_available
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -36,15 +36,14 @@

log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = module_available("mlflow")
try:
import mlflow
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
if _MLFLOW_AVAILABLE:
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
# todo: there seems to be still some remaining import error with Conda env
except ModuleNotFoundError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None
else:
MlflowClient, context = None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
Expand Down Expand Up @@ -147,10 +146,8 @@ def __init__(
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
):
if mlflow is None:
raise ModuleNotFoundError(
"You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`."
)
if not _MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
super().__init__()
if not tracking_uri:
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
Expand Down Expand Up @@ -240,20 +237,25 @@ def experiment_id(self) -> Optional[str]:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params_list: List[Param] = []

for k, v in params.items():
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=v, value=v))

self.experiment.log_param(self.run_id, k, v)
self.experiment.log_batch(run_id=self.run_id, params=params_list)

@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics_list: List[Metric] = []

timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
Expand All @@ -269,8 +271,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
category=RuntimeWarning,
)
k = new_k
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)

@rank_zero_only
def finalize(self, status: str = "success") -> None:
Expand Down
15 changes: 9 additions & 6 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
LOGGER_CTX_MANAGERS = (
mock.patch("pytorch_lightning.loggers.comet.comet_ml"),
mock.patch("pytorch_lightning.loggers.comet.CometOfflineExperiment"),
mock.patch("pytorch_lightning.loggers.mlflow.mlflow"),
mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
mock.patch("pytorch_lightning.loggers.mlflow.Metric"),
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
Expand Down Expand Up @@ -282,12 +283,14 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
"pytorch_lightning.loggers.mlflow.MlflowClient"
):
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"pytorch_lightning.loggers.mlflow.Metric"
) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"):
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0)
logger.experiment.log_batch.assert_called_once_with(
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
)

# Neptune
with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch(
Expand Down Expand Up @@ -340,7 +343,7 @@ def test_logger_default_name(tmpdir, monkeypatch):
assert logger.name == "lightning_logs"

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"pytorch_lightning.loggers.mlflow.MlflowClient"
) as mlflow_client:
mlflow_client().get_experiment_by_name.return_value = None
Expand Down
49 changes: 28 additions & 21 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
return logger


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_exists(client, mlflow, tmpdir):
def test_mlflow_logger_exists(client, _, tmpdir):
"""Test launching three independent loggers with either same or different experiment name."""

run1 = MagicMock()
Expand Down Expand Up @@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
assert logger3.run_id == "run-id-3"


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
def test_mlflow_run_name_setting(client, _, tmpdir):
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""

tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
Expand All @@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
def test_mlflow_run_id_setting(client, _, tmpdir):
"""Test that the run_id argument uses the provided run_id."""

run = MagicMock()
Expand All @@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
client.reset_mock(return_value=True)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
def test_mlflow_log_dir(client, _, tmpdir):
"""Test that the trainer saves checkpoints in the logger's save dir."""

# simulate experiment creation with mlflow client mock
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
def test_mlflow_logger_dirs_creation(tmpdir):
"""Test that the logger creates the folders and files in the right place."""
if not _MLFLOW_AVAILABLE:
pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.")
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")

assert not os.listdir(tmpdir)
logger = MLFlowLogger("test", save_dir=tmpdir)
Expand Down Expand Up @@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
"""Test that the logger experiment_id retrieved only once."""
logger = MLFlowLogger("test", save_dir=tmpdir)
_ = logger.experiment
Expand All @@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
assert logger.experiment.get_experiment_by_name.call_count == 1


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
metrics = {"[some_metric]": 10}
Expand All @@ -223,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
logger.log_metrics(metrics)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
value = "test" * 100
Expand All @@ -236,10 +237,12 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
logger.log_hyperparams(params)


@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow.time")
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time.return_value = 1

Expand All @@ -249,19 +252,23 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
params = {"test": "test_param"}
logger.log_hyperparams(params)

logger.experiment.log_param.assert_called_once_with(logger.run_id, "test", "test_param")
logger.experiment.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test_param", value="test_param")]
)

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

logger.experiment.log_metric.assert_called_once_with(logger.run_id, "some_metric", 10, 1000, None)
logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
)

logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize_when_exception(*_):
logger = MLFlowLogger("test")
Expand All @@ -279,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
@pytest.mark.parametrize("log_model", ["all", True, False])
def test_mlflow_log_model(client, _, tmpdir, log_model):
Expand Down

0 comments on commit 38acba0

Please sign in to comment.