From 6fc30d4e31830fd407057c0a3ca00e3366b76fb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yolan=20Honor=C3=A9-Roug=C3=A9?= Date: Fri, 22 Nov 2024 23:07:48 +0100 Subject: [PATCH] specify signature at runtime with parameters --- kedro_mlflow/framework/hooks/mlflow_hook.py | 13 ++++++- kedro_mlflow/mlflow/kedro_pipeline_model.py | 43 +++++++++------------ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/kedro_mlflow/framework/hooks/mlflow_hook.py b/kedro_mlflow/framework/hooks/mlflow_hook.py index 7aed8ccb..a6d66639 100644 --- a/kedro_mlflow/framework/hooks/mlflow_hook.py +++ b/kedro_mlflow/framework/hooks/mlflow_hook.py @@ -389,7 +389,18 @@ def after_pipeline_run( if isinstance(model_signature, str): if model_signature == "auto": input_data = catalog.load(pipeline.input_name) - model_signature = infer_signature(model_input=input_data) + + # all pipeline params will be overridable at predict time: https://mlflow.org/docs/latest/model/signatures.html#model-signatures-with-inference-params + # I add the special "runner" parameter to be able to choose it at runtime + pipeline_params = { + ds_name[7:]: catalog.load(ds_name) + for ds_name in pipeline.inputs() + if ds_name.startswith("params:") + } | {"runner": "SequentialRunner"} + model_signature = infer_signature( + model_input=input_data, + params=pipeline_params, + ) mlflow.pyfunc.log_model( python_model=kedro_pipeline_model, diff --git a/kedro_mlflow/mlflow/kedro_pipeline_model.py b/kedro_mlflow/mlflow/kedro_pipeline_model.py index f7e2c760..0358bfc6 100644 --- a/kedro_mlflow/mlflow/kedro_pipeline_model.py +++ b/kedro_mlflow/mlflow/kedro_pipeline_model.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from kedro.framework.hooks import _create_hook_manager from kedro.io import DataCatalog, MemoryDataset @@ -8,10 +8,17 @@ from kedro.runner import AbstractRunner, SequentialRunner from kedro_datasets.pickle import PickleDataset from mlflow.pyfunc import PythonModel +from pydantic import BaseModel from kedro_mlflow.pipeline.pipeline_ml import PipelineML +class PredictParamsSchema(BaseModel): + parameters: Optional[dict[str, Any]] = {} + # runner: AbstractRunner + # hooks: Iterable[Any] # cf. _register_hooks + + class KedroPipelineModel(PythonModel): def __init__( self, @@ -207,25 +214,20 @@ def predict(self, context, model_input, params=None): # TODO hooks # TODO runner - hook_manager = _create_hook_manager() - # _register_hooks(hook_manager, params.hooks) + params = params or {} + runner_class = params.pop("runner", "SequentialRunner") + runner = ( + self.runner + ) # runner="build it dynamically from runner class" or self.runner - runner = self.runner # params.runner or self.runner + hook_manager = _create_hook_manager() + # _register_hooks(hook_manager, predict_params.hooks) - for name, value in params.parameters.items(): + for name, value in params.items(): + # no need to check if params are ni the catalog, because mlflow already checks that the params mathc the signature param = f"params:{name}" - if param in self.loaded_catalog._datasets: - self._logger.info(f"Use {param}={value}") - self.loaded_catalog.save(name=param, data=value, replace=True) - else: - params_set = { - ds[7:] - for ds in self.loaded_catalog._datasets - if ds.startswith("params:") - } - self._logger.info( - f"{name} is not a valid parameter. Use one of '{','.join(params_set)}'. " - ) + self._logger.info(f"Using {param}={value} for the prediction") + self.loaded_catalog.save(name=param, data=value) self.loaded_catalog.save( name=self.input_name, @@ -247,10 +249,3 @@ def predict(self, context, model_input, params=None): class KedroPipelineModelError(Exception): """Error raised when the KedroPipelineModel construction fails""" - - -# from pydantic import BaseModel -# class PredictParamsSchema(BaseModel): -# parameters: dict[str, Any] -# runner: AbstractRunner -# hooks: Iterable[Any] # cf. _register_hooks