diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index 644cb4f867..e6f41fec34 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -71,9 +71,11 @@ import inspect import sys from collections.abc import Callable +from contextlib import contextmanager +from contextvars import ContextVar from copy import deepcopy from types import new_class -from typing import Any, Dict, Optional, Protocol, runtime_checkable +from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable from haystack import logging from haystack.core.errors import ComponentError @@ -84,6 +86,30 @@ logger = logging.getLogger(__name__) +# Callback inputs: component class and init parameters (as keyword arguments). +_COMPONENT_PRE_INIT_CALLBACK: ContextVar[Optional[Callable[[Type, Dict[str, Any]], None]]] = ContextVar( + "component_pre_init_callback", default=None +) + + +@contextmanager +def _hook_component_init(callback: Callable[[Type, Dict[str, Any]], None]): + """ + Context manager to set a callback that will be invoked + before a component's constructor is called. The callback + receives the component class and the init parameters (as keyword + arguments) and can modify the init parameters in place. + + :param callback: + Callback function to invoke. + """ + token = _COMPONENT_PRE_INIT_CALLBACK.set(callback) + try: + yield + finally: + _COMPONENT_PRE_INIT_CALLBACK.reset(token) + + @runtime_checkable class Component(Protocol): """ @@ -123,13 +149,39 @@ def run(self, **kwargs): class ComponentMeta(type): + @staticmethod + def positional_to_kwargs(cls_type, args) -> Dict[str, Any]: + init_signature = inspect.signature(cls_type.__init__) + init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"} + + out = {} + for arg, (name, info) in zip(args, init_params.items()): + if info.kind == inspect.Parameter.VAR_POSITIONAL: + raise ComponentError( + "Pre-init hooks do not support components with variadic positional args in their init method" + ) + + assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY) + out[name] = arg + return out + def __call__(cls, *args, **kwargs): """ This method is called when clients instantiate a Component and runs before __new__ and __init__. """ # This will call __new__ then __init__, giving us back the Component instance - instance = super().__call__(*args, **kwargs) + pre_init_hook = _COMPONENT_PRE_INIT_CALLBACK.get() + if pre_init_hook is None: + instance = super().__call__(*args, **kwargs) + else: + named_positional_args = ComponentMeta.positional_to_kwargs(cls, args) + assert ( + set(named_positional_args.keys()).intersection(kwargs.keys()) == set() + ), "positional and keyword arguments overlap" + kwargs.update(named_positional_args) + pre_init_hook(cls, kwargs) + instance = super().__call__(**kwargs) # Before returning, we have the chance to modify the newly created # Component instance, so we take the chance and set up the I/O sockets diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 498daa5c90..5939f9b7cb 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -22,7 +22,7 @@ PipelineUnmarshalError, PipelineValidationError, ) -from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict from haystack.core.type_utils import _type_name, _types_are_compatible from haystack.marshal import Marshaller, YamlMarshaller from haystack.telemetry import pipeline_running @@ -130,12 +130,16 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T: + def from_dict( + cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs + ) -> T: """ Deserializes the pipeline from a dictionary. :param data: Dictionary to deserialize from. + :param callbacks: + Callbacks to invoke during deserialization. :param kwargs: `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones. :returns: @@ -171,7 +175,7 @@ def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T: # Create a new one component_class = component.registry[component_data["type"]] - instance = component_from_dict(component_class, component_data) + instance = component_from_dict(component_class, component_data, name, callbacks) pipe.add_component(name=name, instance=instance) for connection in data.get("connections", []): @@ -208,7 +212,12 @@ def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER): fp.write(marshaller.marshal(self.to_dict())) @classmethod - def loads(cls, data: Union[str, bytes, bytearray], marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline": + def loads( + cls, + data: Union[str, bytes, bytearray], + marshaller: Marshaller = DEFAULT_MARSHALLER, + callbacks: Optional[DeserializationCallbacks] = None, + ) -> "Pipeline": """ Creates a `Pipeline` object from the string representation passed in the `data` argument. @@ -216,13 +225,20 @@ def loads(cls, data: Union[str, bytes, bytearray], marshaller: Marshaller = DEFA The string representation of the pipeline, can be `str`, `bytes` or `bytearray`. :param marshaller: The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. + :param callbacks: + Callbacks to invoke during deserialization. :returns: A `Pipeline` object. """ - return cls.from_dict(marshaller.unmarshal(data)) + return cls.from_dict(marshaller.unmarshal(data), callbacks) @classmethod - def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline": + def load( + cls, + fp: TextIO, + marshaller: Marshaller = DEFAULT_MARSHALLER, + callbacks: Optional[DeserializationCallbacks] = None, + ) -> "Pipeline": """ Creates a `Pipeline` object from the string representation read from the file-like object passed in the `fp` argument. @@ -233,10 +249,12 @@ def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipel A file-like object ready to be read from. :param marshaller: The Marshaller used to create the string representation. Defaults to `YamlMarshaller`. + :param callbacks: + Callbacks to invoke during deserialization. :returns: A `Pipeline` object. """ - return cls.from_dict(marshaller.unmarshal(fp.read())) + return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks) def add_component(self, name: str, instance: Component) -> None: """ diff --git a/haystack/core/serialization.py b/haystack/core/serialization.py index 8ed7fe5dfe..08f2b50cef 100644 --- a/haystack/core/serialization.py +++ b/haystack/core/serialization.py @@ -2,11 +2,32 @@ # # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Any, Dict, Type +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Type +from haystack.core.component.component import _hook_component_init from haystack.core.errors import DeserializationError, SerializationError +@dataclass(frozen=True) +class DeserializationCallbacks: + """ + Callback functions that are invoked in specific + stages of the pipeline deserialization process. + + :param component_pre_init: + Invoked just before a component instance is + initialized. Receives the following inputs: + `component_name`, `component_class`, `init_params`. + + The callback is allowed to modify the `init_params` + dictionary, which contains all the parameters that + are passed to the component's constructor. + """ + + component_pre_init: Optional[Callable[[str, Type, Dict[str, Any]], None]] = None + + def component_to_dict(obj: Any) -> Dict[str, Any]: """ Converts a component instance into a dictionary. If a `to_dict` method is present in the @@ -59,7 +80,9 @@ def generate_qualified_class_name(cls: Type[object]) -> str: return f"{cls.__module__}.{cls.__name__}" -def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any: +def component_from_dict( + cls: Type[object], data: Dict[str, Any], name: str, callbacks: Optional[DeserializationCallbacks] = None +) -> Any: """ Creates a component instance from a dictionary. If a `from_dict` method is present in the component class, that will be used instead of the default method. @@ -68,13 +91,30 @@ def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any: The class to be used for deserialization. :param data: The serialized data. + :param name: + The name of the component. + :param callbacks: + Callbacks to invoke during deserialization. :returns: The deserialized component. """ - if hasattr(cls, "from_dict"): - return cls.from_dict(data) - return default_from_dict(cls, data) + def component_pre_init_callback(component_cls, init_params): + assert callbacks is not None + assert callbacks.component_pre_init is not None + callbacks.component_pre_init(name, component_cls, init_params) + + def do_from_dict(): + if hasattr(cls, "from_dict"): + return cls.from_dict(data) + + return default_from_dict(cls, data) + + if callbacks is None or callbacks.component_pre_init is None: + return do_from_dict() + + with _hook_component_init(component_pre_init_callback): + return do_from_dict() def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]: diff --git a/releasenotes/notes/pipeline-deserialization-callbacks-0642248725918684.yaml b/releasenotes/notes/pipeline-deserialization-callbacks-0642248725918684.yaml new file mode 100644 index 0000000000..68d1368750 --- /dev/null +++ b/releasenotes/notes/pipeline-deserialization-callbacks-0642248725918684.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add support for callbacks during pipeline deserialization. Currently supports a pre-init hook for components that can be used to inspect and modify the initialization parameters + before the invocation of the component's `__init__` method. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 13843d6de2..a4305c09df 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -1,9 +1,11 @@ import logging +from functools import partial from typing import Any import pytest from haystack.core.component import Component, InputSocket, OutputSocket, component +from haystack.core.component.component import _hook_component_init from haystack.core.component.types import Variadic from haystack.core.errors import ComponentError from haystack.core.pipeline import Pipeline @@ -271,3 +273,120 @@ def run(self, value: int): "Component 'MockComponent' has no variadic input, but it's marked as greedy." " This is not supported and can lead to unexpected behavior.\n" in caplog.text ) + + +def test_pre_init_hooking(): + @component + class MockComponent: + def __init__(self, pos_arg1, pos_arg2, pos_arg3=None, *, kwarg1=1, kwarg2="string"): + self.pos_arg1 = pos_arg1 + self.pos_arg2 = pos_arg2 + self.pos_arg3 = pos_arg3 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + @component.output_types(output_value=int) + def run(self, input_value: int): + return {"output_value": input_value} + + def pre_init_hook(component_class, init_params, expected_params): + assert component_class == MockComponent + assert init_params == expected_params + + def pre_init_hook_modify(component_class, init_params, expected_params): + assert component_class == MockComponent + assert init_params == expected_params + + init_params["pos_arg1"] = 2 + init_params["pos_arg2"] = 0 + init_params["pos_arg3"] = "modified" + init_params["kwarg2"] = "modified string" + + with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "kwarg1": None})): + _ = MockComponent(1, 2, kwarg1=None) + + with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "pos_arg3": 0.01})): + _ = MockComponent(pos_arg1=1, pos_arg2=2, pos_arg3=0.01) + + with _hook_component_init( + partial(pre_init_hook_modify, expected_params={"pos_arg1": 0, "pos_arg2": 1, "pos_arg3": 0.01, "kwarg1": 0}) + ): + c = MockComponent(0, 1, pos_arg3=0.01, kwarg1=0) + + assert c.pos_arg1 == 2 + assert c.pos_arg2 == 0 + assert c.pos_arg3 == "modified" + assert c.kwarg1 == 0 + assert c.kwarg2 == "modified string" + + +def test_pre_init_hooking_variadic_positional_args(): + @component + class MockComponent: + def __init__(self, *args, kwarg1=1, kwarg2="string"): + self.args = args + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + @component.output_types(output_value=int) + def run(self, input_value: int): + return {"output_value": input_value} + + def pre_init_hook(component_class, init_params, expected_params): + assert component_class == MockComponent + assert init_params == expected_params + + c = MockComponent(1, 2, 3, kwarg1=None) + assert c.args == (1, 2, 3) + assert c.kwarg1 is None + assert c.kwarg2 == "string" + + with pytest.raises(ComponentError), _hook_component_init( + partial(pre_init_hook, expected_params={"args": (1, 2), "kwarg1": None}) + ): + _ = MockComponent(1, 2, kwarg1=None) + + +def test_pre_init_hooking_variadic_kwargs(): + @component + class MockComponent: + def __init__(self, pos_arg1, pos_arg2=None, **kwargs): + self.pos_arg1 = pos_arg1 + self.pos_arg2 = pos_arg2 + self.kwargs = kwargs + + @component.output_types(output_value=int) + def run(self, input_value: int): + return {"output_value": input_value} + + def pre_init_hook(component_class, init_params, expected_params): + assert component_class == MockComponent + assert init_params == expected_params + + with _hook_component_init( + partial(pre_init_hook, expected_params={"pos_arg1": 1, "kwarg1": None, "kwarg2": 10, "kwarg3": "string"}) + ): + c = MockComponent(1, kwarg1=None, kwarg2=10, kwarg3="string") + assert c.pos_arg1 == 1 + assert c.pos_arg2 is None + assert c.kwargs == {"kwarg1": None, "kwarg2": 10, "kwarg3": "string"} + + def pre_init_hook_modify(component_class, init_params, expected_params): + assert component_class == MockComponent + assert init_params == expected_params + + init_params["pos_arg1"] = 2 + init_params["pos_arg2"] = 0 + init_params["some_kwarg"] = "modified string" + + with _hook_component_init( + partial( + pre_init_hook_modify, + expected_params={"pos_arg1": 0, "pos_arg2": 1, "kwarg1": 999, "some_kwarg": "some_value"}, + ) + ): + c = MockComponent(0, 1, kwarg1=999, some_kwarg="some_value") + + assert c.pos_arg1 == 2 + assert c.pos_arg2 == 0 + assert c.kwargs == {"kwarg1": 999, "some_kwarg": "modified string"} diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 18bfe63b04..95dae81518 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -16,9 +16,10 @@ from haystack.core.component.types import InputSocket, OutputSocket from haystack.core.errors import PipelineDrawingError, PipelineError, PipelineMaxLoops, PipelineRuntimeError from haystack.core.pipeline import Pipeline, PredefinedPipeline +from haystack.core.serialization import DeserializationCallbacks from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.testing.factory import component_class -from haystack.testing.sample_components import AddFixedValue, Double +from haystack.testing.sample_components import AddFixedValue, Double, Greet logging.basicConfig(level=logging.DEBUG) @@ -446,6 +447,77 @@ def test_from_dict(): ) +def test_from_dict_with_callbacks(): + data = { + "metadata": {"test": "test"}, + "max_loops_allowed": 101, + "components": { + "add_two": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 2}, + }, + "add_default": { + "type": "haystack.testing.sample_components.add_value.AddFixedValue", + "init_parameters": {"add": 1}, + }, + "double": {"type": "haystack.testing.sample_components.double.Double", "init_parameters": {}}, + "greet": {"type": "haystack.testing.sample_components.greet.Greet", "init_parameters": {"message": "test"}}, + }, + "connections": [ + {"sender": "add_two.result", "receiver": "double.value"}, + {"sender": "double.value", "receiver": "add_default.value"}, + ], + } + + components_seen_in_callback = [] + + def component_pre_init_callback(name, component_cls, init_params): + assert name in ["add_two", "add_default", "double", "greet"] + assert component_cls in [AddFixedValue, Double, Greet] + + if name == "add_two": + assert init_params == {"add": 2} + elif name == "add_default": + assert init_params == {"add": 1} + elif name == "greet": + assert init_params == {"message": "test"} + + components_seen_in_callback.append(name) + + pipe = Pipeline.from_dict(data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback)) + assert components_seen_in_callback == ["add_two", "add_default", "double", "greet"] + add_two = pipe.graph.nodes["add_two"]["instance"] + assert add_two.add == 2 + add_default = pipe.graph.nodes["add_default"]["instance"] + assert add_default.add == 1 + greet = pipe.graph.nodes["greet"]["instance"] + assert greet.message == "test" + assert greet.log_level == "INFO" + + def component_pre_init_callback_modify(name, component_cls, init_params): + assert name in ["add_two", "add_default", "double", "greet"] + assert component_cls in [AddFixedValue, Double, Greet] + + if name == "add_two": + init_params["add"] = 3 + elif name == "add_default": + init_params["add"] = 0 + elif name == "greet": + init_params["message"] = "modified test" + init_params["log_level"] = "DEBUG" + + pipe = Pipeline.from_dict( + data, callbacks=DeserializationCallbacks(component_pre_init=component_pre_init_callback_modify) + ) + add_two = pipe.graph.nodes["add_two"]["instance"] + assert add_two.add == 3 + add_default = pipe.graph.nodes["add_default"]["instance"] + assert add_default.add == 0 + greet = pipe.graph.nodes["greet"]["instance"] + assert greet.message == "modified test" + assert greet.log_level == "DEBUG" + + def test_from_dict_with_empty_dict(): assert Pipeline() == Pipeline.from_dict({})