Skip to content

Commit

Permalink
specify signature at runtime with parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Nov 22, 2024
1 parent bda1e6e commit 6fc30d4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
13 changes: 12 additions & 1 deletion kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,18 @@ def after_pipeline_run(
if isinstance(model_signature, str):
if model_signature == "auto":
input_data = catalog.load(pipeline.input_name)
model_signature = infer_signature(model_input=input_data)

# all pipeline params will be overridable at predict time: https://mlflow.org/docs/latest/model/signatures.html#model-signatures-with-inference-params
# I add the special "runner" parameter to be able to choose it at runtime
pipeline_params = {
ds_name[7:]: catalog.load(ds_name)
for ds_name in pipeline.inputs()
if ds_name.startswith("params:")
} | {"runner": "SequentialRunner"}
model_signature = infer_signature(
model_input=input_data,
params=pipeline_params,
)

mlflow.pyfunc.log_model(
python_model=kedro_pipeline_model,
Expand Down
43 changes: 19 additions & 24 deletions kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import logging
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

from kedro.framework.hooks import _create_hook_manager
from kedro.io import DataCatalog, MemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner import AbstractRunner, SequentialRunner
from kedro_datasets.pickle import PickleDataset
from mlflow.pyfunc import PythonModel
from pydantic import BaseModel

from kedro_mlflow.pipeline.pipeline_ml import PipelineML


class PredictParamsSchema(BaseModel):
parameters: Optional[dict[str, Any]] = {}
# runner: AbstractRunner
# hooks: Iterable[Any] # cf. _register_hooks


class KedroPipelineModel(PythonModel):
def __init__(
self,
Expand Down Expand Up @@ -207,25 +214,20 @@ def predict(self, context, model_input, params=None):
# TODO hooks
# TODO runner

hook_manager = _create_hook_manager()
# _register_hooks(hook_manager, params.hooks)
params = params or {}
runner_class = params.pop("runner", "SequentialRunner")
runner = (
self.runner
) # runner="build it dynamically from runner class" or self.runner

runner = self.runner # params.runner or self.runner
hook_manager = _create_hook_manager()
# _register_hooks(hook_manager, predict_params.hooks)

for name, value in params.parameters.items():
for name, value in params.items():
# no need to check if params are ni the catalog, because mlflow already checks that the params mathc the signature
param = f"params:{name}"
if param in self.loaded_catalog._datasets:
self._logger.info(f"Use {param}={value}")
self.loaded_catalog.save(name=param, data=value, replace=True)
else:
params_set = {
ds[7:]
for ds in self.loaded_catalog._datasets
if ds.startswith("params:")
}
self._logger.info(
f"{name} is not a valid parameter. Use one of '{','.join(params_set)}'. "
)
self._logger.info(f"Using {param}={value} for the prediction")
self.loaded_catalog.save(name=param, data=value)

self.loaded_catalog.save(
name=self.input_name,
Expand All @@ -247,10 +249,3 @@ def predict(self, context, model_input, params=None):

class KedroPipelineModelError(Exception):
"""Error raised when the KedroPipelineModel construction fails"""


# from pydantic import BaseModel
# class PredictParamsSchema(BaseModel):
# parameters: dict[str, Any]
# runner: AbstractRunner
# hooks: Iterable[Any] # cf. _register_hooks

0 comments on commit 6fc30d4

Please sign in to comment.