diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index 4eb15aeec2..556878a965 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy -from typing import Callable, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple from more_itertools import windowed from haystack import Document, component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.utils import deserialize_callable, serialize_callable @component @@ -243,3 +245,31 @@ def _add_split_overlap_information( # add split overlap information to previous Document regarding this Document overlapping_range = (0, overlapping_range[1] - overlapping_range[0]) previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range}) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + """ + serialized = default_to_dict( + self, + split_by=self.split_by, + split_length=self.split_length, + split_overlap=self.split_overlap, + split_threshold=self.split_threshold, + ) + if self.splitting_function: + serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function) + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter": + """ + Deserializes the component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + + splitting_function = init_params.get("splitting_function", None) + if splitting_function: + init_params["splitting_function"] = deserialize_callable(splitting_function) + + return default_from_dict(cls, data) diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index edad6ddb72..7c942ab4cc 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -7,6 +7,12 @@ from haystack import Document from haystack.components.preprocessors import DocumentSplitter +from haystack.utils import deserialize_callable, serialize_callable + + +# custom split function for testing +def custom_split(text): + return text.split(".") def merge_documents(documents): @@ -352,3 +358,90 @@ def test_add_split_overlap_information(self): # reconstruct the original document content from the split documents assert doc.content == merge_documents(docs) + + def test_to_dict(self): + """ + Test the to_dict method of the DocumentSplitter class. + """ + splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5) + serialized = splitter.to_dict() + + assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter" + assert serialized["init_parameters"]["split_by"] == "word" + assert serialized["init_parameters"]["split_length"] == 10 + assert serialized["init_parameters"]["split_overlap"] == 2 + assert serialized["init_parameters"]["split_threshold"] == 5 + assert "splitting_function" not in serialized["init_parameters"] + + def test_to_dict_with_splitting_function(self): + """ + Test the to_dict method of the DocumentSplitter class when a custom splitting function is provided. + """ + + splitter = DocumentSplitter(split_by="function", splitting_function=custom_split) + serialized = splitter.to_dict() + + assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter" + assert serialized["init_parameters"]["split_by"] == "function" + assert "splitting_function" in serialized["init_parameters"] + assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"])) + + def test_from_dict(self): + """ + Test the from_dict class method of the DocumentSplitter class. + """ + data = { + "type": "haystack.components.preprocessors.document_splitter.DocumentSplitter", + "init_parameters": {"split_by": "word", "split_length": 10, "split_overlap": 2, "split_threshold": 5}, + } + splitter = DocumentSplitter.from_dict(data) + + assert splitter.split_by == "word" + assert splitter.split_length == 10 + assert splitter.split_overlap == 2 + assert splitter.split_threshold == 5 + assert splitter.splitting_function is None + + def test_from_dict_with_splitting_function(self): + """ + Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided. + """ + + def custom_split(text): + return text.split(".") + + data = { + "type": "haystack.components.preprocessors.document_splitter.DocumentSplitter", + "init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)}, + } + splitter = DocumentSplitter.from_dict(data) + + assert splitter.split_by == "function" + assert callable(splitter.splitting_function) + assert splitter.splitting_function("a.b.c") == ["a", "b", "c"] + + def test_roundtrip_serialization(self): + """ + Test the round-trip serialization of the DocumentSplitter class. + """ + original_splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5) + serialized = original_splitter.to_dict() + deserialized_splitter = DocumentSplitter.from_dict(serialized) + + assert original_splitter.split_by == deserialized_splitter.split_by + assert original_splitter.split_length == deserialized_splitter.split_length + assert original_splitter.split_overlap == deserialized_splitter.split_overlap + assert original_splitter.split_threshold == deserialized_splitter.split_threshold + + def test_roundtrip_serialization_with_splitting_function(self): + """ + Test the round-trip serialization of the DocumentSplitter class when a custom splitting function is provided. + """ + + original_splitter = DocumentSplitter(split_by="function", splitting_function=custom_split) + serialized = original_splitter.to_dict() + deserialized_splitter = DocumentSplitter.from_dict(serialized) + + assert original_splitter.split_by == deserialized_splitter.split_by + assert callable(deserialized_splitter.splitting_function) + assert deserialized_splitter.splitting_function("a.b.c") == ["a", "b", "c"]