Skip to content

Commit

Permalink
Add custom serde
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Sep 12, 2024
1 parent 3bd3e3b commit 6a59250
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 1 deletion.
32 changes: 31 additions & 1 deletion haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from more_itertools import windowed

from haystack import Document, component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.utils import deserialize_callable, serialize_callable


@component
Expand Down Expand Up @@ -243,3 +245,31 @@ def _add_split_overlap_information(
# add split overlap information to previous Document regarding this Document
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""
serialized = default_to_dict(
self,
split_by=self.split_by,
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
)
if self.splitting_function:
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter":
"""
Deserializes the component from a dictionary.
"""
init_params = data.get("init_parameters", {})

splitting_function = init_params.get("splitting_function", None)
if splitting_function:
init_params["splitting_function"] = deserialize_callable(splitting_function)

return default_from_dict(cls, data)
93 changes: 93 additions & 0 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from haystack import Document
from haystack.components.preprocessors import DocumentSplitter
from haystack.utils import deserialize_callable, serialize_callable


# custom split function for testing
def custom_split(text):
return text.split(".")


def merge_documents(documents):
Expand Down Expand Up @@ -352,3 +358,90 @@ def test_add_split_overlap_information(self):

# reconstruct the original document content from the split documents
assert doc.content == merge_documents(docs)

def test_to_dict(self):
"""
Test the to_dict method of the DocumentSplitter class.
"""
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "word"
assert serialized["init_parameters"]["split_length"] == 10
assert serialized["init_parameters"]["split_overlap"] == 2
assert serialized["init_parameters"]["split_threshold"] == 5
assert "splitting_function" not in serialized["init_parameters"]

def test_to_dict_with_splitting_function(self):
"""
Test the to_dict method of the DocumentSplitter class when a custom splitting function is provided.
"""

splitter = DocumentSplitter(split_by="function", splitting_function=custom_split)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "function"
assert "splitting_function" in serialized["init_parameters"]
assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"]))

def test_from_dict(self):
"""
Test the from_dict class method of the DocumentSplitter class.
"""
data = {
"type": "haystack.components.preprocessors.document_splitter.DocumentSplitter",
"init_parameters": {"split_by": "word", "split_length": 10, "split_overlap": 2, "split_threshold": 5},
}
splitter = DocumentSplitter.from_dict(data)

assert splitter.split_by == "word"
assert splitter.split_length == 10
assert splitter.split_overlap == 2
assert splitter.split_threshold == 5
assert splitter.splitting_function is None

def test_from_dict_with_splitting_function(self):
"""
Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided.
"""

def custom_split(text):
return text.split(".")

data = {
"type": "haystack.components.preprocessors.document_splitter.DocumentSplitter",
"init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)},
}
splitter = DocumentSplitter.from_dict(data)

assert splitter.split_by == "function"
assert callable(splitter.splitting_function)
assert splitter.splitting_function("a.b.c") == ["a", "b", "c"]

def test_roundtrip_serialization(self):
"""
Test the round-trip serialization of the DocumentSplitter class.
"""
original_splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
serialized = original_splitter.to_dict()
deserialized_splitter = DocumentSplitter.from_dict(serialized)

assert original_splitter.split_by == deserialized_splitter.split_by
assert original_splitter.split_length == deserialized_splitter.split_length
assert original_splitter.split_overlap == deserialized_splitter.split_overlap
assert original_splitter.split_threshold == deserialized_splitter.split_threshold

def test_roundtrip_serialization_with_splitting_function(self):
"""
Test the round-trip serialization of the DocumentSplitter class when a custom splitting function is provided.
"""

original_splitter = DocumentSplitter(split_by="function", splitting_function=custom_split)
serialized = original_splitter.to_dict()
deserialized_splitter = DocumentSplitter.from_dict(serialized)

assert original_splitter.split_by == deserialized_splitter.split_by
assert callable(deserialized_splitter.splitting_function)
assert deserialized_splitter.splitting_function("a.b.c") == ["a", "b", "c"]

0 comments on commit 6a59250

Please sign in to comment.