Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch MLFlowLogger requests #15915

Merged
merged 28 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3db5d60
use `log_batch`
Dec 5, 2022
62db2b9
fix tests
Dec 5, 2022
5feb2e9
Merge branch 'master' into batch-mlflowlogger-requests
Dec 5, 2022
9344870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2022
f9bd6a1
fix test
Dec 6, 2022
cbd34c8
fix test
Dec 6, 2022
229fb69
Merge branch 'master' into batch-mlflowlogger-requests
Dec 6, 2022
c1d5ee3
hide imports
Dec 6, 2022
b0fe898
Merge branch 'master' into batch-mlflowlogger-requests
Dec 6, 2022
316cb93
Merge branch 'master' into batch-mlflowlogger-requests
Dec 7, 2022
29ae0bb
test mocking
awaelchli Dec 7, 2022
9c72071
use mock in test
Dec 7, 2022
ccd0877
rm commented code
Dec 7, 2022
4a81f13
update `CHANGELOG.md`
Dec 7, 2022
09533d3
add `mlflow` version floor to reference file
Dec 7, 2022
3428e5d
add mlflow version guard
Dec 8, 2022
65fa659
Apply suggestions from code review
Dec 8, 2022
ceae948
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
98e061d
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
93c7715
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
d82f03b
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
b132a93
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
be9b6fa
Update src/pytorch_lightning/loggers/mlflow.py
Dec 9, 2022
188e359
Merge branch 'master' into batch-mlflowlogger-requests
Dec 9, 2022
5562654
RequirementCache
carmocca Dec 12, 2022
8131209
Merge branch 'master' into batch-mlflowlogger-requests
carmocca Dec 12, 2022
183b7b4
Fix mocks
carmocca Dec 12, 2022
907b67a
Merge branch 'master' into batch-mlflowlogger-requests
carmocca Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.info
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI
neptune-client
comet-ml
mlflow
mlflow>=1.0.0
wandb
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))


- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))


### Deprecated

- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))
Expand Down
33 changes: 18 additions & 15 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import yaml
from lightning_utilities.core.imports import module_available
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -36,15 +36,14 @@

log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
_MLFLOW_AVAILABLE = module_available("mlflow")
try:
import mlflow
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
if _MLFLOW_AVAILABLE:
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
# todo: there seems to be still some remaining import error with Conda env
except ModuleNotFoundError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None
else:
MlflowClient, context = None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
Expand Down Expand Up @@ -147,10 +146,8 @@ def __init__(
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
):
if mlflow is None:
raise ModuleNotFoundError(
"You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`."
)
if not _MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
super().__init__()
if not tracking_uri:
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
Expand Down Expand Up @@ -240,20 +237,25 @@ def experiment_id(self) -> Optional[str]:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params_list: List[Param] = []

for k, v in params.items():
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=v, value=v))

self.experiment.log_param(self.run_id, k, v)
self.experiment.log_batch(run_id=self.run_id, params=params_list)

@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics_list: List[Metric] = []

timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
Expand All @@ -269,8 +271,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
category=RuntimeWarning,
)
k = new_k
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
schmidt-jake marked this conversation as resolved.
Show resolved Hide resolved

@rank_zero_only
def finalize(self, status: str = "success") -> None:
Expand Down
15 changes: 9 additions & 6 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
LOGGER_CTX_MANAGERS = (
mock.patch("pytorch_lightning.loggers.comet.comet_ml"),
mock.patch("pytorch_lightning.loggers.comet.CometOfflineExperiment"),
mock.patch("pytorch_lightning.loggers.mlflow.mlflow"),
mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
mock.patch("pytorch_lightning.loggers.mlflow.Metric"),
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
Expand Down Expand Up @@ -282,12 +283,14 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
"pytorch_lightning.loggers.mlflow.MlflowClient"
):
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"pytorch_lightning.loggers.mlflow.Metric"
) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"):
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0)
logger.experiment.log_batch.assert_called_once_with(
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
)

# Neptune
with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch(
Expand Down Expand Up @@ -340,7 +343,7 @@ def test_logger_default_name(tmpdir, monkeypatch):
assert logger.name == "lightning_logs"

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"pytorch_lightning.loggers.mlflow.MlflowClient"
) as mlflow_client:
mlflow_client().get_experiment_by_name.return_value = None
Expand Down
49 changes: 28 additions & 21 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
return logger


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_exists(client, mlflow, tmpdir):
def test_mlflow_logger_exists(client, _, tmpdir):
"""Test launching three independent loggers with either same or different experiment name."""

run1 = MagicMock()
Expand Down Expand Up @@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
assert logger3.run_id == "run-id-3"


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
def test_mlflow_run_name_setting(client, _, tmpdir):
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""

tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
Expand All @@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
def test_mlflow_run_id_setting(client, _, tmpdir):
"""Test that the run_id argument uses the provided run_id."""

run = MagicMock()
Expand All @@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
client.reset_mock(return_value=True)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
def test_mlflow_log_dir(client, _, tmpdir):
"""Test that the trainer saves checkpoints in the logger's save dir."""

# simulate experiment creation with mlflow client mock
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
def test_mlflow_logger_dirs_creation(tmpdir):
"""Test that the logger creates the folders and files in the right place."""
if not _MLFLOW_AVAILABLE:
pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.")
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")

assert not os.listdir(tmpdir)
logger = MLFlowLogger("test", save_dir=tmpdir)
Expand Down Expand Up @@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
"""Test that the logger experiment_id retrieved only once."""
logger = MLFlowLogger("test", save_dir=tmpdir)
_ = logger.experiment
Expand All @@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
assert logger.experiment.get_experiment_by_name.call_count == 1


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
metrics = {"[some_metric]": 10}
Expand All @@ -223,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
logger.log_metrics(metrics)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
value = "test" * 100
Expand All @@ -236,10 +237,12 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
logger.log_hyperparams(params)


@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow.time")
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time.return_value = 1

Expand All @@ -249,19 +252,23 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
params = {"test": "test_param"}
logger.log_hyperparams(params)

logger.experiment.log_param.assert_called_once_with(logger.run_id, "test", "test_param")
logger.experiment.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test_param", value="test_param")]
)

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

logger.experiment.log_metric.assert_called_once_with(logger.run_id, "some_metric", 10, 1000, None)
logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
)

logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize_when_exception(*_):
logger = MLFlowLogger("test")
Expand All @@ -279,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
@pytest.mark.parametrize("log_model", ["all", True, False])
def test_mlflow_log_model(client, _, tmpdir, log_model):
Expand Down