Skip to content

Commit

Permalink
feat: Add support for deserialization callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed Apr 9, 2024
1 parent 988c360 commit a592bf2
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 15 deletions.
56 changes: 54 additions & 2 deletions haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@
import inspect
import sys
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable

from haystack import logging
from haystack.core.errors import ComponentError
Expand All @@ -84,6 +86,30 @@
logger = logging.getLogger(__name__)


# Callback inputs: component class and init parameters (as keyword arguments).
_COMPONENT_PRE_INIT_CALLBACK: ContextVar[Optional[Callable[[Type, Dict[str, Any]], None]]] = ContextVar(
"component_pre_init_callback", default=None
)


@contextmanager
def _hook_component_init(callback: Callable[[Type, Dict[str, Any]], None]):
"""
Context manager to set a callback that will be invoked
before a component's constructor is called. The callback
receives the component class and the init parameters (as keyword
arguments) and can modify the init parameters in place.
:param callback:
Callback function to invoke.
"""
token = _COMPONENT_PRE_INIT_CALLBACK.set(callback)
try:
yield
finally:
_COMPONENT_PRE_INIT_CALLBACK.reset(token)


@runtime_checkable
class Component(Protocol):
"""
Expand Down Expand Up @@ -123,13 +149,39 @@ def run(self, **kwargs):


class ComponentMeta(type):
@staticmethod
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
init_signature = inspect.signature(cls_type.__init__)
init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}

out = {}
for arg, (name, info) in zip(args, init_params.items()):
if info.kind == inspect.Parameter.VAR_POSITIONAL:
raise ComponentError(
"Pre-init hooks do not support components with variadic positional args in their init method"
)

assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
out[name] = arg
return out

def __call__(cls, *args, **kwargs):
"""
This method is called when clients instantiate a Component and
runs before __new__ and __init__.
"""
# This will call __new__ then __init__, giving us back the Component instance
instance = super().__call__(*args, **kwargs)
pre_init_hook = _COMPONENT_PRE_INIT_CALLBACK.get()
if pre_init_hook is None:
instance = super().__call__(*args, **kwargs)
else:
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook(cls, kwargs)
instance = super().__call__(**kwargs)

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
Expand Down
32 changes: 25 additions & 7 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PipelineUnmarshalError,
PipelineValidationError,
)
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.marshal import Marshaller, YamlMarshaller
from haystack.telemetry import pipeline_running
Expand Down Expand Up @@ -130,12 +130,16 @@ def to_dict(self) -> Dict[str, Any]:
}

@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T:
def from_dict(
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
) -> T:
"""
Deserializes the pipeline from a dictionary.
:param data:
Dictionary to deserialize from.
:param callbacks:
Callbacks to invoke during deserialization.
:param kwargs:
`components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
:returns:
Expand Down Expand Up @@ -171,7 +175,7 @@ def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T:

# Create a new one
component_class = component.registry[component_data["type"]]
instance = component_from_dict(component_class, component_data)
instance = component_from_dict(component_class, component_data, name, callbacks)
pipe.add_component(name=name, instance=instance)

for connection in data.get("connections", []):
Expand Down Expand Up @@ -208,21 +212,33 @@ def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
fp.write(marshaller.marshal(self.to_dict()))

@classmethod
def loads(cls, data: Union[str, bytes, bytearray], marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline":
def loads(
cls,
data: Union[str, bytes, bytearray],
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> "Pipeline":
"""
Creates a `Pipeline` object from the string representation passed in the `data` argument.
:param data:
The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(data))
return cls.from_dict(marshaller.unmarshal(data), callbacks)

@classmethod
def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline":
def load(
cls,
fp: TextIO,
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> "Pipeline":
"""
Creates a `Pipeline` object from the string representation read from the file-like
object passed in the `fp` argument.
Expand All @@ -233,10 +249,12 @@ def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipel
A file-like object ready to be read from.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(fp.read()))
return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks)

def add_component(self, name: str, instance: Component) -> None:
"""
Expand Down
50 changes: 45 additions & 5 deletions haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,32 @@
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Dict, Type
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Type

from haystack.core.component.component import _hook_component_init
from haystack.core.errors import DeserializationError, SerializationError


@dataclass(frozen=True)
class DeserializationCallbacks:
"""
Callback functions that are invoked in specific
stages of the pipeline deserialization process.
:param component_pre_init:
Invoked just before a component instance is
initialized. Receives the following inputs:
`component_name`, `component_class`, `init_params`.
The callback is allowed to modify the `init_params`
dictionary, which contains all the parameters that
are passed to the component's constructor.
"""

component_pre_init: Optional[Callable[[str, Type, Dict[str, Any]], None]] = None


def component_to_dict(obj: Any) -> Dict[str, Any]:
"""
Converts a component instance into a dictionary. If a `to_dict` method is present in the
Expand Down Expand Up @@ -59,7 +80,9 @@ def generate_qualified_class_name(cls: Type[object]) -> str:
return f"{cls.__module__}.{cls.__name__}"


def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
def component_from_dict(
cls: Type[object], data: Dict[str, Any], name: str, callbacks: Optional[DeserializationCallbacks] = None
) -> Any:
"""
Creates a component instance from a dictionary. If a `from_dict` method is present in the
component class, that will be used instead of the default method.
Expand All @@ -68,13 +91,30 @@ def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
The class to be used for deserialization.
:param data:
The serialized data.
:param name:
The name of the component.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
The deserialized component.
"""
if hasattr(cls, "from_dict"):
return cls.from_dict(data)

return default_from_dict(cls, data)
def component_pre_init_callback(component_cls, init_params):
assert callbacks is not None
assert callbacks.component_pre_init is not None
callbacks.component_pre_init(name, component_cls, init_params)

def do_from_dict():
if hasattr(cls, "from_dict"):
return cls.from_dict(data)

return default_from_dict(cls, data)

if callbacks is None or callbacks.component_pre_init is None:
return do_from_dict()

with _hook_component_init(component_pre_init_callback):
return do_from_dict()


def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add support for callbacks during pipeline deserialization. Currently supports a pre-init hook for components that can be used to inspect and modify the initialization parameters
before the invocation of the component's `__init__` method.
119 changes: 119 additions & 0 deletions test/core/component/test_component.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from functools import partial
from typing import Any

import pytest

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.component.component import _hook_component_init
from haystack.core.component.types import Variadic
from haystack.core.errors import ComponentError
from haystack.core.pipeline import Pipeline
Expand Down Expand Up @@ -271,3 +273,120 @@ def run(self, value: int):
"Component 'MockComponent' has no variadic input, but it's marked as greedy."
" This is not supported and can lead to unexpected behavior.\n" in caplog.text
)


def test_pre_init_hooking():
@component
class MockComponent:
def __init__(self, pos_arg1, pos_arg2, pos_arg3=None, *, kwarg1=1, kwarg2="string"):
self.pos_arg1 = pos_arg1
self.pos_arg2 = pos_arg2
self.pos_arg3 = pos_arg3
self.kwarg1 = kwarg1
self.kwarg2 = kwarg2

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

def pre_init_hook_modify(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

init_params["pos_arg1"] = 2
init_params["pos_arg2"] = 0
init_params["pos_arg3"] = "modified"
init_params["kwarg2"] = "modified string"

with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "kwarg1": None})):
_ = MockComponent(1, 2, kwarg1=None)

with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "pos_arg3": 0.01})):
_ = MockComponent(pos_arg1=1, pos_arg2=2, pos_arg3=0.01)

with _hook_component_init(
partial(pre_init_hook_modify, expected_params={"pos_arg1": 0, "pos_arg2": 1, "pos_arg3": 0.01, "kwarg1": 0})
):
c = MockComponent(0, 1, pos_arg3=0.01, kwarg1=0)

assert c.pos_arg1 == 2
assert c.pos_arg2 == 0
assert c.pos_arg3 == "modified"
assert c.kwarg1 == 0
assert c.kwarg2 == "modified string"


def test_pre_init_hooking_variadic_positional_args():
@component
class MockComponent:
def __init__(self, *args, kwarg1=1, kwarg2="string"):
self.args = args
self.kwarg1 = kwarg1
self.kwarg2 = kwarg2

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

c = MockComponent(1, 2, 3, kwarg1=None)
assert c.args == (1, 2, 3)
assert c.kwarg1 is None
assert c.kwarg2 == "string"

with pytest.raises(ComponentError), _hook_component_init(
partial(pre_init_hook, expected_params={"args": (1, 2), "kwarg1": None})
):
_ = MockComponent(1, 2, kwarg1=None)


def test_pre_init_hooking_variadic_kwargs():
@component
class MockComponent:
def __init__(self, pos_arg1, pos_arg2=None, **kwargs):
self.pos_arg1 = pos_arg1
self.pos_arg2 = pos_arg2
self.kwargs = kwargs

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

with _hook_component_init(
partial(pre_init_hook, expected_params={"pos_arg1": 1, "kwarg1": None, "kwarg2": 10, "kwarg3": "string"})
):
c = MockComponent(1, kwarg1=None, kwarg2=10, kwarg3="string")
assert c.pos_arg1 == 1
assert c.pos_arg2 is None
assert c.kwargs == {"kwarg1": None, "kwarg2": 10, "kwarg3": "string"}

def pre_init_hook_modify(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

init_params["pos_arg1"] = 2
init_params["pos_arg2"] = 0
init_params["some_kwarg"] = "modified string"

with _hook_component_init(
partial(
pre_init_hook_modify,
expected_params={"pos_arg1": 0, "pos_arg2": 1, "kwarg1": 999, "some_kwarg": "some_value"},
)
):
c = MockComponent(0, 1, kwarg1=999, some_kwarg="some_value")

assert c.pos_arg1 == 2
assert c.pos_arg2 == 0
assert c.kwargs == {"kwarg1": 999, "some_kwarg": "modified string"}
Loading

0 comments on commit a592bf2

Please sign in to comment.