Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch MLFlowLogger requests #15915

Merged
merged 28 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3db5d60
use `log_batch`
Dec 5, 2022
62db2b9
fix tests
Dec 5, 2022
5feb2e9
Merge branch 'master' into batch-mlflowlogger-requests
Dec 5, 2022
9344870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2022
f9bd6a1
fix test
Dec 6, 2022
cbd34c8
fix test
Dec 6, 2022
229fb69
Merge branch 'master' into batch-mlflowlogger-requests
Dec 6, 2022
c1d5ee3
hide imports
Dec 6, 2022
b0fe898
Merge branch 'master' into batch-mlflowlogger-requests
Dec 6, 2022
316cb93
Merge branch 'master' into batch-mlflowlogger-requests
Dec 7, 2022
29ae0bb
test mocking
awaelchli Dec 7, 2022
9c72071
use mock in test
Dec 7, 2022
ccd0877
rm commented code
Dec 7, 2022
4a81f13
update `CHANGELOG.md`
Dec 7, 2022
09533d3
add `mlflow` version floor to reference file
Dec 7, 2022
3428e5d
add mlflow version guard
Dec 8, 2022
65fa659
Apply suggestions from code review
Dec 8, 2022
ceae948
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
98e061d
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
93c7715
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
d82f03b
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
b132a93
Merge branch 'master' into batch-mlflowlogger-requests
Dec 8, 2022
be9b6fa
Update src/pytorch_lightning/loggers/mlflow.py
Dec 9, 2022
188e359
Merge branch 'master' into batch-mlflowlogger-requests
Dec 9, 2022
5562654
RequirementCache
carmocca Dec 12, 2022
8131209
Merge branch 'master' into batch-mlflowlogger-requests
carmocca Dec 12, 2022
183b7b4
Fix mocks
carmocca Dec 12, 2022
907b67a
Merge branch 'master' into batch-mlflowlogger-requests
carmocca Dec 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.info
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# all supported loggers. this list is here as a reference, but they are not installed in CI
neptune-client
comet-ml
mlflow
mlflow>=1.0.0
wandb
4 changes: 3 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))



- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))


Expand All @@ -60,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))


- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))


### Deprecated

- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))
Expand Down
18 changes: 15 additions & 3 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argparse import Namespace
from pathlib import Path
from time import time
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import yaml
from lightning_utilities.core.imports import module_available
Expand All @@ -31,6 +31,7 @@

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import RequirementCache
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _scan_checkpoints
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

Expand All @@ -39,12 +40,14 @@
_MLFLOW_AVAILABLE = module_available("mlflow")
try:
import mlflow
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
# todo: there seems to be still some remaining import error with Conda env
except ModuleNotFoundError:
_MLFLOW_AVAILABLE = False
mlflow, MlflowClient, context = None, None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"

# before v1.1.0
Expand Down Expand Up @@ -151,6 +154,9 @@ def __init__(
raise ModuleNotFoundError(
"You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`."
)
if not RequirementCache("mlflow>=1.0.0"):
# we require the log_batch APIs that were introduced in mlflow 1.0.0
raise RuntimeError("Incompatible mlflow version")
schmidt-jake marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
if not tracking_uri:
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
Expand Down Expand Up @@ -240,20 +246,25 @@ def experiment_id(self) -> Optional[str]:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params_list: List[Param] = []

for k, v in params.items():
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=v, value=v))

self.experiment.log_param(self.run_id, k, v)
self.experiment.log_batch(run_id=self.run_id, params=params_list)

@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
metrics_list: List[Metric] = []

timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
Expand All @@ -269,8 +280,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
category=RuntimeWarning,
)
k = new_k
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))

self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
schmidt-jake marked this conversation as resolved.
Show resolved Hide resolved

@rank_zero_only
def finalize(self, status: str = "success") -> None:
Expand Down
8 changes: 5 additions & 3 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,13 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):

# MLflow
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
"pytorch_lightning.loggers.mlflow.MlflowClient"
):
"pytorch_lightning.loggers.mlflow.Metric"
) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"):
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0)
logger.experiment.log_batch.assert_called_once_with(
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
)

# Neptune
with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch(
Expand Down
15 changes: 11 additions & 4 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
assert logger.experiment.get_experiment_by_name.call_count == 1


@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
metrics = {"[some_metric]": 10}
Expand All @@ -236,10 +237,12 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
logger.log_hyperparams(params)


@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow.time")
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmpdir):
"""Test that the logger calls methods on the mlflow experiment correctly."""
time.return_value = 1

Expand All @@ -249,12 +252,16 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
params = {"test": "test_param"}
logger.log_hyperparams(params)

logger.experiment.log_param.assert_called_once_with(logger.run_id, "test", "test_param")
logger.experiment.log_batch.assert_called_once_with(
run_id=logger.run_id, params=[param(key="test_param", value="test_param")]
)

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

logger.experiment.log_metric.assert_called_once_with(logger.run_id, "some_metric", 10, 1000, None)
logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
)

logger._mlflow_client.create_experiment.assert_called_once_with(
name="test", artifact_location="my_artifact_location"
Expand Down