Skip to content

Commit

Permalink
feat : DocumentSplitter, adding the option to split_by function (#8336)
Browse files Browse the repository at this point in the history
* Adding splitting function

* Adding test for split by function

* Adding release note for feat adding split by function

* Fixing release note for split_by_function

* Fixing issue with splitting_function non callable

* nit: fixing value error in documentsplitter for split_by

* Add custom serde

---------

Co-authored-by: Giovanni Alzetta <[email protected]>
Co-authored-by: Vladimir Blagojevic <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2024
1 parent 7e9f153 commit 4106e7e
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 5 deletions.
51 changes: 46 additions & 5 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Dict, List, Literal, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from more_itertools import windowed

from haystack import Document, component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.utils import deserialize_callable, serialize_callable


@component
Expand Down Expand Up @@ -46,10 +48,11 @@ class DocumentSplitter:

def __init__(
self,
split_by: Literal["word", "sentence", "page", "passage"] = "word",
split_by: Literal["function", "page", "passage", "sentence", "word"] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_threshold: int = 0,
splitting_function: Optional[Callable[[str], List[str]]] = None,
):
"""
Initialize DocumentSplitter.
Expand All @@ -61,18 +64,24 @@ def __init__(
:param split_overlap: The number of overlapping units for each split.
:param split_threshold: The minimum number of units per split. If a split has fewer units
than the threshold, it's attached to the previous split.
:param splitting_function: Necessary when `split_by` is set to "function".
This is a function which must accept a single `str` as input and return a `list` of `str` as output,
representing the chunks after splitting.
"""

self.split_by = split_by
if split_by not in ["word", "sentence", "page", "passage"]:
if split_by not in ["function", "page", "passage", "sentence", "word"]:
raise ValueError("split_by must be one of 'word', 'sentence', 'page' or 'passage'.")
if split_by == "function" and splitting_function is None:
raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.")
if split_length <= 0:
raise ValueError("split_length must be greater than 0.")
self.split_length = split_length
if split_overlap < 0:
raise ValueError("split_overlap must be greater than or equal to 0.")
self.split_overlap = split_overlap
self.split_threshold = split_threshold
self.splitting_function = splitting_function

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
Expand Down Expand Up @@ -114,7 +123,9 @@ def run(self, documents: List[Document]):
)
return {"documents": split_docs}

def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage", "page"]) -> List[str]:
def _split_into_units(
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word"]
) -> List[str]:
if split_by == "page":
self.split_at = "\f"
elif split_by == "passage":
Expand All @@ -123,9 +134,11 @@ def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "pa
self.split_at = "."
elif split_by == "word":
self.split_at = " "
elif split_by == "function" and self.splitting_function is not None:
return self.splitting_function(text)
else:
raise NotImplementedError(
"DocumentSplitter only supports 'word', 'sentence', 'page' or 'passage' split_by options."
"DocumentSplitter only supports 'function', 'page', 'passage', 'sentence' or 'word' split_by options."
)
units = text.split(self.split_at)
# Add the delimiter back to all units except the last one
Expand Down Expand Up @@ -232,3 +245,31 @@ def _add_split_overlap_information(
# add split overlap information to previous Document regarding this Document
overlapping_range = (0, overlapping_range[1] - overlapping_range[0])
previous_doc.meta["_split_overlap"].append({"doc_id": current_doc.id, "range": overlapping_range})

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""
serialized = default_to_dict(
self,
split_by=self.split_by,
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
)
if self.splitting_function:
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter":
"""
Deserializes the component from a dictionary.
"""
init_params = data.get("init_parameters", {})

splitting_function = init_params.get("splitting_function", None)
if splitting_function:
init_params["splitting_function"] = deserialize_callable(splitting_function)

return default_from_dict(cls, data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added the option to use a custom splitting function in DocumentSplitter. The function must accept a string as
input and return a list of strings, representing the split units. To use the feature initialise `DocumentSplitter`
with `split_by="function"` providing the custom splitting function as `splitting_function=custom_function`.
116 changes: 116 additions & 0 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import re

import pytest

from haystack import Document
from haystack.components.preprocessors import DocumentSplitter
from haystack.utils import deserialize_callable, serialize_callable


# custom split function for testing
def custom_split(text):
return text.split(".")


def merge_documents(documents):
Expand Down Expand Up @@ -165,6 +173,27 @@ def test_split_by_page(self):
assert docs[2].meta["split_idx_start"] == text.index(docs[2].content)
assert docs[2].meta["page_number"] == 3

def test_split_by_function(self):
splitting_function = lambda input_str: input_str.split(".")
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
text = "This.Is.A.Test"
result = splitter.run(documents=[Document(content=text)])
docs = result["documents"]

word_list = ["This", "Is", "A", "Test"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target

splitting_function = lambda input_str: re.split("[\s]{2,}", input_str)
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
text = "This Is\n A Test"
result = splitter.run(documents=[Document(content=text)])
docs = result["documents"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target

def test_split_by_word_with_overlap(self):
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2)
text = "This is a text with some words. There is a second sentence. And there is a third sentence."
Expand Down Expand Up @@ -329,3 +358,90 @@ def test_add_split_overlap_information(self):

# reconstruct the original document content from the split documents
assert doc.content == merge_documents(docs)

def test_to_dict(self):
"""
Test the to_dict method of the DocumentSplitter class.
"""
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "word"
assert serialized["init_parameters"]["split_length"] == 10
assert serialized["init_parameters"]["split_overlap"] == 2
assert serialized["init_parameters"]["split_threshold"] == 5
assert "splitting_function" not in serialized["init_parameters"]

def test_to_dict_with_splitting_function(self):
"""
Test the to_dict method of the DocumentSplitter class when a custom splitting function is provided.
"""

splitter = DocumentSplitter(split_by="function", splitting_function=custom_split)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.document_splitter.DocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "function"
assert "splitting_function" in serialized["init_parameters"]
assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"]))

def test_from_dict(self):
"""
Test the from_dict class method of the DocumentSplitter class.
"""
data = {
"type": "haystack.components.preprocessors.document_splitter.DocumentSplitter",
"init_parameters": {"split_by": "word", "split_length": 10, "split_overlap": 2, "split_threshold": 5},
}
splitter = DocumentSplitter.from_dict(data)

assert splitter.split_by == "word"
assert splitter.split_length == 10
assert splitter.split_overlap == 2
assert splitter.split_threshold == 5
assert splitter.splitting_function is None

def test_from_dict_with_splitting_function(self):
"""
Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided.
"""

def custom_split(text):
return text.split(".")

data = {
"type": "haystack.components.preprocessors.document_splitter.DocumentSplitter",
"init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)},
}
splitter = DocumentSplitter.from_dict(data)

assert splitter.split_by == "function"
assert callable(splitter.splitting_function)
assert splitter.splitting_function("a.b.c") == ["a", "b", "c"]

def test_roundtrip_serialization(self):
"""
Test the round-trip serialization of the DocumentSplitter class.
"""
original_splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
serialized = original_splitter.to_dict()
deserialized_splitter = DocumentSplitter.from_dict(serialized)

assert original_splitter.split_by == deserialized_splitter.split_by
assert original_splitter.split_length == deserialized_splitter.split_length
assert original_splitter.split_overlap == deserialized_splitter.split_overlap
assert original_splitter.split_threshold == deserialized_splitter.split_threshold

def test_roundtrip_serialization_with_splitting_function(self):
"""
Test the round-trip serialization of the DocumentSplitter class when a custom splitting function is provided.
"""

original_splitter = DocumentSplitter(split_by="function", splitting_function=custom_split)
serialized = original_splitter.to_dict()
deserialized_splitter = DocumentSplitter.from_dict(serialized)

assert original_splitter.split_by == deserialized_splitter.split_by
assert callable(deserialized_splitter.splitting_function)
assert deserialized_splitter.splitting_function("a.b.c") == ["a", "b", "c"]

0 comments on commit 4106e7e

Please sign in to comment.