Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6612-support-mlflow-data-tracking #6616

Merged
merged 36 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
093a254
add data tracking to MLFlowHandler
binliunls Jun 14, 2023
e22a8d0
fix format
binliunls Jun 14, 2023
10765e7
Merge branch 'Project-MONAI:dev' into 6612-support-mlflow-data-tracking
binliunls Jun 19, 2023
8075f34
log the dataset to the mlflow ui
binliunls Jun 20, 2023
5f6e158
fix the import order issue
binliunls Jun 20, 2023
6a5e395
remove the dataset log from the complete method.
binliunls Jun 20, 2023
20868d0
Merge branch 'Project-MONAI:dev' into 6612-support-mlflow-data-tracking
binliunls Jun 25, 2023
def7cc9
fix the unit test issue
binliunls Jun 25, 2023
b2a7482
fix the type hint issue
binliunls Jun 25, 2023
5219a6c
fix the format issue
binliunls Jun 25, 2023
a2d857c
add the type hint for the input parameter
binliunls Jun 25, 2023
a5a3801
fix the format issue
binliunls Jun 25, 2023
8a3186e
fix the codeformat issue
binliunls Jun 25, 2023
d0ec3bb
fix the fl monai algo unit test issue
binliunls Jun 26, 2023
dc0d24f
split the warning info to shorter strings
binliunls Jun 26, 2023
903e442
Merge branch 'dev' into 6612-support-mlflow-data-tracking
binliunls Jun 27, 2023
acd75f5
Update the comments and warning messages in the code.
binliunls Jun 28, 2023
f16d1ef
Add the link to PandasDataset.
binliunls Jun 28, 2023
93777e6
add the unit test case
binliunls Jun 29, 2023
657cb0f
fix the format issue in the unit test
binliunls Jun 29, 2023
d514256
add the mlflow version requirement.
binliunls Jun 29, 2023
41f4c78
Merge branch 'dev' into 6612-support-mlflow-data-tracking
binliunls Jun 29, 2023
55e69f6
Merge branch 'dev' into 6612-support-mlflow-data-tracking
wyli Jul 4, 2023
58d975b
add the skip_if_downloading_fails context manager to the test case.
binliunls Jul 5, 2023
08e31c8
update mlflow handler according to the reviewers opinions.
binliunls Jul 7, 2023
87edbb8
update the test case
binliunls Jul 7, 2023
0ffd97c
fix the format issue
binliunls Jul 7, 2023
556618b
make tqdm as optional import
binliunls Jul 7, 2023
9709303
fix the mypy issue
binliunls Jul 7, 2023
856f377
set the dataset name type to string
binliunls Jul 10, 2023
4d5ca1a
Merge branch 'dev' into 6612-support-mlflow-data-tracking
binliunls Jul 10, 2023
83d8d6f
Remove the dataset_log parameter
binliunls Jul 10, 2023
cf83a7e
Remove the dataset log settings from the default MLFlow setting.
binliunls Jul 10, 2023
83ca9ff
fix the format issue
binliunls Jul 10, 2023
616c3ba
Merge branch 'dev' into 6612-support-mlflow-data-tracking
binliunls Jul 18, 2023
e854563
Merge branch 'dev' into 6612-support-mlflow-data-tracking
wyli Jul 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 113 additions & 2 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@

import os
import time
from collections.abc import Callable, Sequence
import warnings
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
from torch.utils.data import Dataset

from monai.config import IgniteInfo
from monai.utils import ensure_tuple, min_version, optional_import
from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
mlflow.entities, _ = optional_import(
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
)
pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")

if TYPE_CHECKING:
from ignite.engine import Engine
Expand Down Expand Up @@ -72,6 +76,14 @@ class MLFlowHandler:
Must accept parameter "engine", use default logger if None.
iteration_logger: customized callable logger for iteration level logging with MLFlow.
Must accept parameter "engine", use default logger if None.
dataset_logger: customized callable logger to log the dataset information with MLFlow.
Must accept parameter "dataset_dict", use default logger if None.
dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
dataset, that needs to be recorded. This arg is only useful when MLFlow version >= 2.4.0.
For more details about how to log data with MLFlow, please go to the website:
https://mlflow.org/docs/latest/python_api/mlflow.data.html.
dataset_keys: a key or a collection of keys to indicate contents in the dataset that
need to be stored by MLFlow.
binliunls marked this conversation as resolved.
Show resolved Hide resolved
output_transform: a callable that is used to transform the
``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}.
By default this value logging happens when every iteration completed.
Expand Down Expand Up @@ -111,6 +123,9 @@ def __init__(
epoch_log: bool | Callable[[Engine, int], bool] = True,
epoch_logger: Callable[[Engine], Any] | None = None,
iteration_logger: Callable[[Engine], Any] | None = None,
dataset_logger: Callable[[Mapping[str, Dataset]], Any] | None = None,
dataset_dict: Mapping[str, Dataset] | None = None,
dataset_keys: str = CommonKeys.IMAGE,
output_transform: Callable = lambda x: x[0],
global_epoch_transform: Callable = lambda x: x,
state_attributes: Sequence[str] | None = None,
Expand All @@ -126,6 +141,7 @@ def __init__(
self.epoch_log = epoch_log
self.epoch_logger = epoch_logger
self.iteration_logger = iteration_logger
self.dataset_logger = dataset_logger
self.output_transform = output_transform
self.global_epoch_transform = global_epoch_transform
self.state_attributes = state_attributes
Expand All @@ -140,6 +156,8 @@ def __init__(
self.close_on_complete = close_on_complete
self.experiment = None
self.cur_run = None
self.dataset_dict = dataset_dict
self.dataset_keys = ensure_tuple(dataset_keys)

def _delete_exist_param_in_dict(self, param_dict: dict) -> None:
"""
Expand Down Expand Up @@ -210,6 +228,11 @@ def start(self, engine: Engine) -> None:
self._delete_exist_param_in_dict(attrs)
self._log_params(attrs)

if self.dataset_logger:
self.dataset_logger(self.dataset_dict)
else:
self._default_dataset_log(self.dataset_dict)

def _set_experiment(self):
experiment = self.experiment
if not experiment:
Expand All @@ -222,6 +245,36 @@ def _set_experiment(self):
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
self.experiment = experiment

@staticmethod
wyli marked this conversation as resolved.
Show resolved Hide resolved
def _get_pandas_dataset_info(pandas_dataset):
dataset_name = pandas_dataset.name
return {
f"{dataset_name}_digest": pandas_dataset.digest,
f"{dataset_name}_samples": pandas_dataset.profile["num_rows"],
}

def _log_dataset(self, sample_dict: dict[str, Any], context: str = "train") -> None:
if not self.cur_run:
raise ValueError("Current Run is not Active to log the dataset")

# Need to update the self.cur_run to sync the dataset log, otherwise the `inputs` info will be out-of-date.
self.cur_run = self.client.get_run(self.cur_run.info.run_id)
logged_set = [x for x in self.cur_run.inputs.dataset_inputs if x.dataset.name.startswith(context)]
# In case there are datasets with the same name.
dataset_count = str(len(logged_set))
dataset_name = f"{context}_dataset_{dataset_count}"
sample_df = pandas.DataFrame(sample_dict)
dataset = mlflow.data.from_pandas(sample_df, name=dataset_name)
exist_dataset_list = list(
filter(lambda x: x.dataset.digest == dataset.digest, self.cur_run.inputs.dataset_inputs)
)

if not len(exist_dataset_list):
datasets = [mlflow.entities.DatasetInput(dataset._to_mlflow_entity())]
self.client.log_inputs(run_id=self.cur_run.info.run_id, datasets=datasets)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
dataset_info = MLFlowHandler._get_pandas_dataset_info(dataset)
self._log_params(dataset_info)

def _log_params(self, params: dict[str, Any]) -> None:
if not self.cur_run:
raise ValueError("Current Run is not Active to log params")
Expand Down Expand Up @@ -352,3 +405,61 @@ def _default_iteration_log(self, engine: Engine) -> None:
for i, param_group in enumerate(cur_optimizer.param_groups)
}
self._log_metrics(params, step=engine.state.iteration)

def _default_dataset_log(self, dataset_dict: Mapping[str, Dataset] | None) -> None:
"""
Execute dataset log operation based on the input dataset_dict. The dataset_dict should have a format
like:
{
"dataset_name0": dataset0,
"dataset_name1": dataset1,
......
}
The keys stand for names of datasets, which will be logged as prefixes of dataset names in MLFlow.
The values are PyTorch datasets from which sample names are abstracted to build a Pandas DataFrame.
If the input dataset_dict is None, this function will directly return and do nothing.

To use this function, every sample in the input datasets must contain keys specified by the `dataset_keys`
parameter.
This function will log a PandasDataset to MLFlow inputs, generated from the Pandas DataFrame.
For more details about PandasDataset, please refer to this link:
https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.pandas_dataset.PandasDataset

Please note that it may take a while to record the dataset if it has too many samples.

Args:
dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch
dataset, that needs to be recorded.

"""

if dataset_dict is None:
return
elif len(dataset_dict) == 0:
warnings.warn("There is no dataset to log!")

# Log datasets to MLFlow one by one.
for dataset_type, dataset in dataset_dict.items():
if dataset is None:
raise AttributeError(f"The {dataset_type} dataset of is None. Cannot record it by MLFlow.")

sample_dict: dict[str, list[str]] = {}
dataset_samples = getattr(dataset, "data", [])
for sample in tqdm(dataset_samples, f"Recording the {dataset_type} dataset"):
for key in self.dataset_keys:
if key not in sample_dict:
sample_dict[key] = []

if key in sample:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
value_to_log = sample[key]
else:
raise KeyError(f"Unexpect key '{key}' in the sample.")

if not isinstance(value_to_log, str):
warnings.warn(
f"Expected type string, got type {type(value_to_log)} of the {key} name."
"May log an empty dataset in MLFlow"
)
else:
sample_dict[key].append(value_to_log)
self._log_dataset(sample_dict, dataset_type)
56 changes: 55 additions & 1 deletion tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
from ignite.engine import Engine, Events
from parameterized import parameterized

from monai.apps import download_and_extract
from monai.bundle import ConfigWorkflow, download
from monai.handlers import MLFlowHandler
from monai.utils import path_to_uri
from monai.utils import optional_import, path_to_uri
from tests.utils import skip_if_downloading_fails, skip_if_quick

_, has_dataset_tracking = optional_import("mlflow", "2.4.0")


def get_event_filter(e):
Expand Down Expand Up @@ -230,6 +235,55 @@ def test_multi_thread(self):
self.tmpdir_list.append(res)
self.assertTrue(len(glob.glob(res)) > 0)

@skip_if_quick
@unittest.skipUnless(has_dataset_tracking, reason="Requires mlflow version >= 2.4.0.")
def test_dataset_tracking(self):
test_bundle_name = "endoscopic_tool_segmentation"
with tempfile.TemporaryDirectory() as tempdir:
binliunls marked this conversation as resolved.
Show resolved Hide resolved
resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/endoscopic_tool_dataset.zip"
md5 = "f82da47259c0a617202fb54624798a55"
compressed_file = os.path.join(tempdir, "endoscopic_tool_segmentation.zip")
data_dir = os.path.join(tempdir, "endoscopic_tool_dataset")
with skip_if_downloading_fails():
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, tempdir, md5)

download(test_bundle_name, bundle_dir=tempdir)

bundle_root = os.path.join(tempdir, test_bundle_name)
config_file = os.path.join(bundle_root, "configs/inference.json")
meta_file = os.path.join(bundle_root, "configs/metadata.json")
logging_file = os.path.join(bundle_root, "configs/logging.conf")
workflow = ConfigWorkflow(
workflow="infer",
config_file=config_file,
meta_file=meta_file,
logging_file=logging_file,
init_id="initialize",
run_id="run",
final_id="finalize",
)

tracking_path = os.path.join(bundle_root, "eval")
workflow.bundle_root = bundle_root
workflow.dataset_dir = data_dir
workflow.initialize()
infer_dataset = workflow.dataset
mlflow_handler = MLFlowHandler(
iteration_log=False,
epoch_log=False,
dataset_dict={"test": infer_dataset},
tracking_uri=path_to_uri(tracking_path),
)
mlflow_handler.attach(workflow.evaluator)
workflow.run()
workflow.finalize()

cur_run = mlflow_handler.client.get_run(mlflow_handler.cur_run.info.run_id)
logged_nontrain_set = [x for x in cur_run.inputs.dataset_inputs if x.dataset.name.startswith("test")]
self.assertEqual(len(logged_nontrain_set), 1)
mlflow_handler.close()


if __name__ == "__main__":
unittest.main()