Skip to content

Commit

Permalink
Improve MLFlowLogger (#164)
Browse files Browse the repository at this point in the history
* Improve MLFlowLogger

* Add doc and fix PyLint

* Fix tests

* Deprecate experiment_name as positional argument

* Update CHANGELOG

* Fix docstring

* Add tests
  • Loading branch information
freud14-tm authored Apr 26, 2023
1 parent 1a9072c commit c56df09
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 40 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

-

# v1.16

- Add `run_id` and `terminate_on_end` arguments to [MLFlowLogger](https://poutyne.org/callbacks.html#poutyne.MLFlowLogger).

Breaking change:

- In [MLFlowLogger](https://poutyne.org/callbacks.html#poutyne.MLFlowLogger), except for `experiment_name`, all
arguments must now be passed as keyword arguments. Passing `experiment_name` as a positional argument is also
deprecated and will be removed in future versions.

# v1.15

- Remove support for Python 3.7
Expand Down
107 changes: 83 additions & 24 deletions poutyne/framework/callbacks/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# pylint: disable=line-too-long, pointless-string-statement
import os
import warnings
from typing import Dict, Mapping, Sequence, Union
from typing import Any, Dict, Mapping, Optional, Sequence, Union

from poutyne.framework.callbacks.logger import Logger

Expand All @@ -43,13 +43,18 @@ class MLFlowLogger(Logger):
logger will log all run into the same experiment.
Args:
experiment_name (str): The name of the experiment. The name must be unique and are case-sensitive.
tracking_uri (Union[str, None]): Either the URI tracking path (for server tracking) of the absolute path to
experiment_name (Optional[str]): The name of the experiment. The name is case-sensitive. An `experiment_id` must
not be passed if this is passed.
experiment_id (Optional[str]): The id of the experiment. An `experiment_name` must not be passed if this is
passed.
run_id (Optional[str]): The id of the run. An experiment name/id must not be passed if this is passed.
tracking_uri (Optional[str]): Either the URI tracking path (for server tracking) of the absolute path to
the directory to save the files (for file store). For example: ``http://<ip address>:<port>``
(remote server) or ``/home/<user>/mlflow-server`` (local server).
If None, will use the default MLflow file tracking URI ``"./mlruns"``.
batch_granularity (bool): Whether to also output the result of each batch in addition to the epochs.
(Default value = False)
terminate_on_end (bool): Wheter to end the run at the end of the training or testing. (Default value = True)
Example:
Using file store::
Expand Down Expand Up @@ -82,35 +87,83 @@ class MLFlowLogger(Logger):
"""

def __init__(
self, experiment_name: str, tracking_uri: Union[str, None] = None, batch_granularity: bool = False
self,
deprecated_experiment_name: Optional[str] = None,
*,
experiment_name: Optional[str] = None,
experiment_id: Optional[str] = None,
run_id: Optional[str] = None,
tracking_uri: Optional[str] = None,
batch_granularity: bool = False,
terminate_on_end=True,
) -> None:
super().__init__(batch_granularity=batch_granularity)
if mlflow is None:
raise ImportError("Mlflow needs to be installed to use this callback.")

if deprecated_experiment_name is not None and experiment_name is not None:
raise ValueError(
"`experiment_name` was passed as positional and keyword arguments. Make sure to only pass it once as a "
"keyword argument."
)

if deprecated_experiment_name is not None:
warnings.warn(
'Positional argument `experiment_name` is deprecated and will be removed in future versions. Please '
'use it as a keyword argument, i.e. experiment_name="my-experiment-name"'
)
experiment_name = deprecated_experiment_name

self.tracking = tracking_uri

self._working_directory = os.getcwd() # For Git hash monitoring.

self.ml_flow_client = MlflowClient(tracking_uri=self.tracking)

self._handle_experiment_id(experiment_name)
self.run_id = self.ml_flow_client.create_run(experiment_id=self.experiment_id).info.run_id
if run_id is not None and (experiment_name is not None or experiment_id is not None):
raise ValueError("Either provide an experiment name/id or a run id, not both.")

if run_id is None:
experiment_id = self._handle_experiment_id(experiment_name, experiment_id)
self.run_id = self.ml_flow_client.create_run(experiment_id=experiment_id).info.run_id
else:
self.run_id = run_id

self._log_git_version()

self.terminate_on_end = terminate_on_end
self._status = "FAILED" # Base case is a failure.

def log_config_params(self, config_params: Mapping) -> None:
def log_config_params(self, config_params: Mapping, **kwargs: Any) -> None:
"""
Args:
config_params (Mapping):
The config parameters of the training to log, such as number of epoch, loss function, optimizer etc.
"""
for param_name, element in config_params.items():
self._log_config_write(param_name, element)
self._log_config_write(param_name, element, **kwargs)

def log_params(self, params: Dict[str, Any], **kwargs: Any):
"""
Log the values of the parameters into the experiment.
Args:
params (Dict[str, float]): Dictionary of key-value to log.
"""
for k, v in params.items():
self.log_param(k, v, **kwargs)

def log_param(self, param_name: str, value: Union[str, float]) -> None:
def log_metrics(self, metrics: Dict[str, float], **kwargs: Any):
"""
Log the values of the metrics into the experiment.
Args:
metrics (Dict[str, float]): Dictionary of key-value to log.
"""
for k, v in metrics.items():
self.log_metric(k, v, **kwargs)

def log_param(self, param_name: str, value: Union[str, float], **kwargs: Any) -> None:
"""
Log the value of a parameter into the experiment.
Expand All @@ -119,9 +172,9 @@ def log_param(self, param_name: str, value: Union[str, float]) -> None:
value (Union[str, float]): The value of the parameter.
"""
self.ml_flow_client.log_param(run_id=self.run_id, key=param_name, value=value)
self.ml_flow_client.log_param(run_id=self.run_id, key=param_name, value=value, **kwargs)

def log_metric(self, metric_name: str, value: float, step: Union[int, None] = None) -> None:
def log_metric(self, metric_name: str, value: float, **kwargs: Any) -> None:
"""
Log the value of a metric into the experiment.
Expand All @@ -130,22 +183,24 @@ def log_metric(self, metric_name: str, value: float, step: Union[int, None] = No
value (float): The value of the metric.
step (Union[int, None]): The step when the metric was computed (Default = None).
"""
self.ml_flow_client.log_metric(run_id=self.run_id, key=metric_name, value=value, step=step)
self.ml_flow_client.log_metric(run_id=self.run_id, key=metric_name, value=value, **kwargs)

def _log_config_write(self, parent_name: str, element: Union[int, float, str, Mapping, Sequence]) -> None:
def _log_config_write(
self, parent_name: str, element: Union[int, float, str, Mapping, Sequence], **kwargs: Any
) -> None:
"""
Log the config parameters when it's a mapping or a sequence of elements.
"""
if isinstance(element, Mapping):
for key, value in element.items():
# We recursively open the element (Dict format type).
self._log_config_write(f"{parent_name}.{key}", value)
self._log_config_write(f"{parent_name}.{key}", value, **kwargs)
elif isinstance(element, Sequence) and not isinstance(element, str):
# Since str are sequence we negate it to be logged in the else.
for idx, value in enumerate(element):
self._log_config_write(f"{parent_name}.{idx}", value)
self._log_config_write(f"{parent_name}.{idx}", value, **kwargs)
else:
self.log_param(parent_name, element)
self.log_param(parent_name, element, **kwargs)

def _on_train_batch_end_write(self, batch_number: int, logs: Dict) -> None:
"""
Expand All @@ -168,8 +223,6 @@ def on_train_end(self, logs: Dict):
"""
self._on_train_end_write(logs)
self._status = "FINISHED"

mlflow.end_run()
self._status_handling()

def _on_train_end_write(self, logs) -> None:
Expand All @@ -191,25 +244,31 @@ def on_test_end(self, logs: Dict):
self._on_test_end_write(logs)
self._status = "FINISHED"

mlflow.end_run()
self._status_handling()

def _on_test_end_write(self, logs: Dict) -> None:
for key, value in logs.items():
self.log_metric(key, value)

def _status_handling(self):
# We set_terminated the run to get the finishing status (FINISHED or FAILED)
self.ml_flow_client.set_terminated(self.run_id, status=self._status)
if self.terminate_on_end:
# We set_terminated the run to get the finishing status (FINISHED or FAILED)
self.ml_flow_client.set_terminated(self.run_id, status=self._status)

def _handle_experiment_id(self, experiment_name):
def _handle_experiment_id(self, experiment_name, experiment_id):
"""
Handle the existing experiment name to grab the id and append a new experiment to it.
"""
if experiment_name is not None and experiment_id is not None:
raise ValueError("Either provide the experiment name or experiment id, not both.")

if experiment_id is not None:
return experiment_id

try:
self.experiment_id = self.ml_flow_client.create_experiment(experiment_name, self.tracking)
return self.ml_flow_client.create_experiment(experiment_name, self.tracking)
except MlflowException:
self.experiment_id = self.ml_flow_client.get_experiment_by_name(experiment_name).experiment_id
return self.ml_flow_client.get_experiment_by_name(experiment_name).experiment_id

def _log_git_version(self):
"""
Expand Down
55 changes: 41 additions & 14 deletions tests/framework/callbacks/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@

import os
from typing import Dict, List, Mapping, Sequence
from unittest import TestCase
from unittest import TestCase, skipIf
from unittest.mock import MagicMock, call, patch

import git
import torch
import torch.nn as nn
from mlflow.exceptions import MlflowException
from omegaconf import DictConfig

try:
from mlflow.exceptions import MlflowException
from omegaconf import DictConfig

mlflow_available = True
except ImportError:
mlflow_available = False

from poutyne import Model
from poutyne.framework.callbacks.mlflow_logger import MLFlowLogger, _get_git_commit
Expand All @@ -36,6 +42,7 @@
mlflow_default_git_commit_tag = "mlflow.source.git.commit"


@skipIf(not mlflow_available, "imports for MLFlowLogger not available")
class MLFlowLoggerTest(TestCase):
def setUp(self) -> None:
self.a_experiment_name = "a_name"
Expand Down Expand Up @@ -118,10 +125,6 @@ def test_whenCorrectSettings_givenAMLFlowInstantiation_thenMLFlowClientIsProperl
]
ml_flow_client_patch.assert_has_calls(settings_calls)

actual_experiment_id = mlflow_logger.experiment_id
expected_experiment_id = self.a_experiment_id
self.assertEqual(expected_experiment_id, actual_experiment_id)

actual_run_id = mlflow_logger.run_id
expected_run_id = self.a_run_id
self.assertEqual(expected_run_id, actual_run_id)
Expand Down Expand Up @@ -149,7 +152,21 @@ def test_whenLogMetric_givenAMLFlowCallback_thenLogMetric(self):
ml_flow_client_calls = []
for key, value in self.a_log.items():
mlflow_logger.log_metric(key, value)
ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key=key, value=value, step=None))
ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key=key, value=value))
ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)

@patch("poutyne.framework.mlflow_logger._get_git_commit", MagicMock())
def test_whenLogMetrics_givenAMLFlowCallback_thenLogEachMetric(self):
with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
ml_flow_client_patch.return_value.create_run = self.run_mock

mlflow_logger = MLFlowLogger(self.a_experiment_name)
mlflow_logger.log_metrics(self.a_log)

ml_flow_client_calls = []
for key, value in self.a_log.items():
ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key=key, value=value))
ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)

@patch("poutyne.framework.mlflow_logger._get_git_commit", MagicMock())
Expand Down Expand Up @@ -204,6 +221,20 @@ def test_whenLogConfigParamsAConfigDictWithSequence_givenAMLFlowCallback_thenLog
ml_flow_client_calls = self._populate_calls_from_dict(self.settings_in_dict_config_with_sequence)
ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)

@patch("poutyne.framework.mlflow_logger._get_git_commit", MagicMock())
def test_whenLogParams_givenAMLFlowCallback_thenLogEachParam(self):
with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
ml_flow_client_patch.return_value.create_experiment = self.experiment_mock
ml_flow_client_patch.return_value.create_run = self.run_mock

mlflow_logger = MLFlowLogger(self.a_experiment_name)
mlflow_logger.log_params(self.settings_in_dict)

ml_flow_client_calls = []
for key, value in self.settings_in_dict.items():
ml_flow_client_calls.append(call().log_param(run_id=self.a_run_id, key=key, value=value))
ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)

@patch("poutyne.framework.mlflow_logger._get_git_commit", MagicMock())
def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenLogLastEpochNumber(self):
with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch:
Expand All @@ -214,9 +245,7 @@ def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenLogLastEpochNumber(self)
mlflow_logger.set_params({"epochs": self.num_epochs})
mlflow_logger.on_train_end(self.a_log)

ml_flow_client_calls = [
call().log_metric(run_id=self.a_run_id, key='last-epoch', value=self.num_epochs, step=None)
]
ml_flow_client_calls = [call().log_metric(run_id=self.a_run_id, key='last-epoch', value=self.num_epochs)]
ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)

@patch("poutyne.framework.mlflow_logger._get_git_commit", MagicMock())
Expand Down Expand Up @@ -339,9 +368,7 @@ def _populate_calls_from_logs(self, logs: Dict) -> List:
ml_flow_client_calls.append(
call().log_metric(run_id=self.a_run_id, key=key, value=value, step=epoch_num + 1)
) # +1 for enumerate
ml_flow_client_calls.append(
call().log_metric(run_id=self.a_run_id, key="last-epoch", value=self.num_epochs, step=None)
)
ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key="last-epoch", value=self.num_epochs))
ml_flow_client_calls.append(call().set_terminated(self.a_run_id, status='FINISHED'))
return ml_flow_client_calls

Expand Down
9 changes: 7 additions & 2 deletions tests/framework/callbacks/test_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import warnings
from tempfile import TemporaryDirectory, TemporaryFile
from unittest import TestCase, main
from unittest import TestCase, main, skipIf
from unittest.mock import MagicMock, call, patch

import torch
import torch.nn as nn
import wandb

try:
import wandb
except ImportError:
wandb = None

from poutyne import Callback, Model, ModelCheckpoint, WandBLogger
from tests.framework.tools import some_data_generator
Expand All @@ -25,6 +29,7 @@ def on_train_begin(self, logs):
self.history = []


@skipIf(wandb is None, "imports for WandBLogger not available")
class WandBLoggerTest(TestCase):
def setUp(self):
torch.manual_seed(42)
Expand Down

0 comments on commit c56df09

Please sign in to comment.