diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 20add9b11f..a2bd345dc6 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -13,20 +13,24 @@ import os import time -from collections.abc import Callable, Sequence +import warnings +from collections.abc import Callable, Mapping, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any import torch +from torch.utils.data import Dataset from monai.config import IgniteInfo -from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.") mlflow.entities, _ = optional_import( "mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler." ) +pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.") +tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm") if TYPE_CHECKING: from ignite.engine import Engine @@ -72,6 +76,14 @@ class MLFlowHandler: Must accept parameter "engine", use default logger if None. iteration_logger: customized callable logger for iteration level logging with MLFlow. Must accept parameter "engine", use default logger if None. + dataset_logger: customized callable logger to log the dataset information with MLFlow. + Must accept parameter "dataset_dict", use default logger if None. + dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch + dataset, that needs to be recorded. This arg is only useful when MLFlow version >= 2.4.0. + For more details about how to log data with MLFlow, please go to the website: + https://mlflow.org/docs/latest/python_api/mlflow.data.html. + dataset_keys: a key or a collection of keys to indicate contents in the dataset that + need to be stored by MLFlow. output_transform: a callable that is used to transform the ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}. By default this value logging happens when every iteration completed. @@ -111,6 +123,9 @@ def __init__( epoch_log: bool | Callable[[Engine, int], bool] = True, epoch_logger: Callable[[Engine], Any] | None = None, iteration_logger: Callable[[Engine], Any] | None = None, + dataset_logger: Callable[[Mapping[str, Dataset]], Any] | None = None, + dataset_dict: Mapping[str, Dataset] | None = None, + dataset_keys: str = CommonKeys.IMAGE, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, state_attributes: Sequence[str] | None = None, @@ -126,6 +141,7 @@ def __init__( self.epoch_log = epoch_log self.epoch_logger = epoch_logger self.iteration_logger = iteration_logger + self.dataset_logger = dataset_logger self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform self.state_attributes = state_attributes @@ -140,6 +156,8 @@ def __init__( self.close_on_complete = close_on_complete self.experiment = None self.cur_run = None + self.dataset_dict = dataset_dict + self.dataset_keys = ensure_tuple(dataset_keys) def _delete_exist_param_in_dict(self, param_dict: dict) -> None: """ @@ -210,6 +228,11 @@ def start(self, engine: Engine) -> None: self._delete_exist_param_in_dict(attrs) self._log_params(attrs) + if self.dataset_logger: + self.dataset_logger(self.dataset_dict) + else: + self._default_dataset_log(self.dataset_dict) + def _set_experiment(self): experiment = self.experiment if not experiment: @@ -222,6 +245,36 @@ def _set_experiment(self): raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment") self.experiment = experiment + @staticmethod + def _get_pandas_dataset_info(pandas_dataset): + dataset_name = pandas_dataset.name + return { + f"{dataset_name}_digest": pandas_dataset.digest, + f"{dataset_name}_samples": pandas_dataset.profile["num_rows"], + } + + def _log_dataset(self, sample_dict: dict[str, Any], context: str = "train") -> None: + if not self.cur_run: + raise ValueError("Current Run is not Active to log the dataset") + + # Need to update the self.cur_run to sync the dataset log, otherwise the `inputs` info will be out-of-date. + self.cur_run = self.client.get_run(self.cur_run.info.run_id) + logged_set = [x for x in self.cur_run.inputs.dataset_inputs if x.dataset.name.startswith(context)] + # In case there are datasets with the same name. + dataset_count = str(len(logged_set)) + dataset_name = f"{context}_dataset_{dataset_count}" + sample_df = pandas.DataFrame(sample_dict) + dataset = mlflow.data.from_pandas(sample_df, name=dataset_name) + exist_dataset_list = list( + filter(lambda x: x.dataset.digest == dataset.digest, self.cur_run.inputs.dataset_inputs) + ) + + if not len(exist_dataset_list): + datasets = [mlflow.entities.DatasetInput(dataset._to_mlflow_entity())] + self.client.log_inputs(run_id=self.cur_run.info.run_id, datasets=datasets) + dataset_info = MLFlowHandler._get_pandas_dataset_info(dataset) + self._log_params(dataset_info) + def _log_params(self, params: dict[str, Any]) -> None: if not self.cur_run: raise ValueError("Current Run is not Active to log params") @@ -352,3 +405,61 @@ def _default_iteration_log(self, engine: Engine) -> None: for i, param_group in enumerate(cur_optimizer.param_groups) } self._log_metrics(params, step=engine.state.iteration) + + def _default_dataset_log(self, dataset_dict: Mapping[str, Dataset] | None) -> None: + """ + Execute dataset log operation based on the input dataset_dict. The dataset_dict should have a format + like: + { + "dataset_name0": dataset0, + "dataset_name1": dataset1, + ...... + } + The keys stand for names of datasets, which will be logged as prefixes of dataset names in MLFlow. + The values are PyTorch datasets from which sample names are abstracted to build a Pandas DataFrame. + If the input dataset_dict is None, this function will directly return and do nothing. + + To use this function, every sample in the input datasets must contain keys specified by the `dataset_keys` + parameter. + This function will log a PandasDataset to MLFlow inputs, generated from the Pandas DataFrame. + For more details about PandasDataset, please refer to this link: + https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.pandas_dataset.PandasDataset + + Please note that it may take a while to record the dataset if it has too many samples. + + Args: + dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch + dataset, that needs to be recorded. + + """ + + if dataset_dict is None: + return + elif len(dataset_dict) == 0: + warnings.warn("There is no dataset to log!") + + # Log datasets to MLFlow one by one. + for dataset_type, dataset in dataset_dict.items(): + if dataset is None: + raise AttributeError(f"The {dataset_type} dataset of is None. Cannot record it by MLFlow.") + + sample_dict: dict[str, list[str]] = {} + dataset_samples = getattr(dataset, "data", []) + for sample in tqdm(dataset_samples, f"Recording the {dataset_type} dataset"): + for key in self.dataset_keys: + if key not in sample_dict: + sample_dict[key] = [] + + if key in sample: + value_to_log = sample[key] + else: + raise KeyError(f"Unexpect key '{key}' in the sample.") + + if not isinstance(value_to_log, str): + warnings.warn( + f"Expected type string, got type {type(value_to_log)} of the {key} name." + "May log an empty dataset in MLFlow" + ) + else: + sample_dict[key].append(value_to_log) + self._log_dataset(sample_dict, dataset_type) diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index f09f9b93d5..d5578c01bc 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -23,8 +23,13 @@ from ignite.engine import Engine, Events from parameterized import parameterized +from monai.apps import download_and_extract +from monai.bundle import ConfigWorkflow, download from monai.handlers import MLFlowHandler -from monai.utils import path_to_uri +from monai.utils import optional_import, path_to_uri +from tests.utils import skip_if_downloading_fails, skip_if_quick + +_, has_dataset_tracking = optional_import("mlflow", "2.4.0") def get_event_filter(e): @@ -230,6 +235,55 @@ def test_multi_thread(self): self.tmpdir_list.append(res) self.assertTrue(len(glob.glob(res)) > 0) + @skip_if_quick + @unittest.skipUnless(has_dataset_tracking, reason="Requires mlflow version >= 2.4.0.") + def test_dataset_tracking(self): + test_bundle_name = "endoscopic_tool_segmentation" + with tempfile.TemporaryDirectory() as tempdir: + resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/endoscopic_tool_dataset.zip" + md5 = "f82da47259c0a617202fb54624798a55" + compressed_file = os.path.join(tempdir, "endoscopic_tool_segmentation.zip") + data_dir = os.path.join(tempdir, "endoscopic_tool_dataset") + with skip_if_downloading_fails(): + if not os.path.exists(data_dir): + download_and_extract(resource, compressed_file, tempdir, md5) + + download(test_bundle_name, bundle_dir=tempdir) + + bundle_root = os.path.join(tempdir, test_bundle_name) + config_file = os.path.join(bundle_root, "configs/inference.json") + meta_file = os.path.join(bundle_root, "configs/metadata.json") + logging_file = os.path.join(bundle_root, "configs/logging.conf") + workflow = ConfigWorkflow( + workflow="infer", + config_file=config_file, + meta_file=meta_file, + logging_file=logging_file, + init_id="initialize", + run_id="run", + final_id="finalize", + ) + + tracking_path = os.path.join(bundle_root, "eval") + workflow.bundle_root = bundle_root + workflow.dataset_dir = data_dir + workflow.initialize() + infer_dataset = workflow.dataset + mlflow_handler = MLFlowHandler( + iteration_log=False, + epoch_log=False, + dataset_dict={"test": infer_dataset}, + tracking_uri=path_to_uri(tracking_path), + ) + mlflow_handler.attach(workflow.evaluator) + workflow.run() + workflow.finalize() + + cur_run = mlflow_handler.client.get_run(mlflow_handler.cur_run.info.run_id) + logged_nontrain_set = [x for x in cur_run.inputs.dataset_inputs if x.dataset.name.startswith("test")] + self.assertEqual(len(logged_nontrain_set), 1) + mlflow_handler.close() + if __name__ == "__main__": unittest.main()