Skip to content

Commit

Permalink
Adding test for split by function
Browse files Browse the repository at this point in the history
  • Loading branch information
Giovanni-Alzetta authored and GivAlz committed Sep 10, 2024
1 parent 1cc238e commit 387330e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
10 changes: 5 additions & 5 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Dict, List, Literal, Tuple, Optional, Callable
from typing import Callable, Dict, List, Literal, Optional, Tuple

from more_itertools import windowed

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
split_length: int = 200,
split_overlap: int = 0,
split_threshold: int = 0,
splitting_function: Optional[Callable[[str], List[str]]] = None
splitting_function: Optional[Callable[[str], List[str]]] = None,
):
"""
Initialize DocumentSplitter.
Expand Down Expand Up @@ -81,7 +81,6 @@ def __init__(
self.split_threshold = split_threshold
self.splitting_function = splitting_function


@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand Down Expand Up @@ -122,8 +121,9 @@ def run(self, documents: List[Document]):
)
return {"documents": split_docs}

def _split_into_units(self, text: str,
split_by: Literal["function", "page", "passage", "sentence", "word"]) -> List[str]:
def _split_into_units(
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word"]
) -> List[str]:
if split_by == "page":
self.split_at = "\f"
elif split_by == "passage":
Expand Down
23 changes: 23 additions & 0 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import re

import pytest

from haystack import Document
Expand Down Expand Up @@ -165,6 +167,27 @@ def test_split_by_page(self):
assert docs[2].meta["split_idx_start"] == text.index(docs[2].content)
assert docs[2].meta["page_number"] == 3

def test_split_by_function(self):
splitting_function = lambda input_str: input_str.split(".")
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
text = "This.Is.A.Test"
result = splitter.run(documents=[Document(content=text)])
docs = result["documents"]

word_list = ["This", "Is", "A", "Test"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target

splitting_function = lambda input_str: re.split("[\s]{2,}", input_str)
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
text = "This Is\n A Test"
result = splitter.run(documents=[Document(content=text)])
docs = result["documents"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target

def test_split_by_word_with_overlap(self):
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2)
text = "This is a text with some words. There is a second sentence. And there is a third sentence."
Expand Down

0 comments on commit 387330e

Please sign in to comment.