From d2cc22c69827174e09ec1870cb7d10b5879177b9 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Tue, 29 Aug 2023 02:06:47 +0200 Subject: [PATCH] Disable unproven paragraph flattening in `split_text()` --- autogpt/processing/text.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/autogpt/processing/text.py b/autogpt/processing/text.py index f0a47e53bc4d..7f90c2c85444 100644 --- a/autogpt/processing/text.py +++ b/autogpt/processing/text.py @@ -1,7 +1,7 @@ """Text processing functions""" import logging from math import ceil -from typing import Iterator, Optional, Sequence +from typing import Iterator, Optional, Sequence, TypeVar import spacy import tiktoken @@ -13,14 +13,18 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") -def batch(iterable: Sequence, max_batch_length: int, overlap: int = 0): + +def batch( + sequence: Sequence[T], max_batch_length: int, overlap: int = 0 +) -> Iterator[Sequence[T]]: """Batch data from iterable into slices of length N. The last batch may be shorter.""" # batched('ABCDEFG', 3) --> ABC DEF G if max_batch_length < 1: raise ValueError("n must be at least one") - for i in range(0, len(iterable), max_batch_length - overlap): - yield iterable[i : i + max_batch_length] + for i in range(0, len(sequence), max_batch_length - overlap): + yield sequence[i : i + max_batch_length] def _max_chunk_length(model: str, max: Optional[int] = None) -> int: @@ -42,7 +46,7 @@ def chunk_content( content: str, for_model: str, max_chunk_length: Optional[int] = None, - with_overlap=True, + with_overlap: bool = True, ) -> Iterator[tuple[str, int]]: """Split content into chunks of approximately equal token length.""" @@ -155,7 +159,7 @@ def split_text( text: str, for_model: str, config: Config, - with_overlap=True, + with_overlap: bool = True, max_chunk_length: Optional[int] = None, ) -> Iterator[tuple[str, int]]: """Split text into chunks of sentences, with each chunk not exceeding the maximum length @@ -176,8 +180,6 @@ def split_text( max_length = _max_chunk_length(for_model, max_chunk_length) - # flatten paragraphs to improve performance - text = text.replace("\n", " ") text_length = count_string_tokens(text, for_model) if text_length < max_length: