Skip to content

Commit

Permalink
Disable unproven paragraph flattening in split_text()
Browse files Browse the repository at this point in the history
  • Loading branch information
Pwuts committed Aug 29, 2023
1 parent 6fac238 commit d2cc22c
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions autogpt/processing/text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Text processing functions"""
import logging
from math import ceil
from typing import Iterator, Optional, Sequence
from typing import Iterator, Optional, Sequence, TypeVar

import spacy
import tiktoken
Expand All @@ -13,14 +13,18 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")

def batch(iterable: Sequence, max_batch_length: int, overlap: int = 0):

def batch(
sequence: Sequence[T], max_batch_length: int, overlap: int = 0
) -> Iterator[Sequence[T]]:
"""Batch data from iterable into slices of length N. The last batch may be shorter."""
# batched('ABCDEFG', 3) --> ABC DEF G
if max_batch_length < 1:
raise ValueError("n must be at least one")
for i in range(0, len(iterable), max_batch_length - overlap):
yield iterable[i : i + max_batch_length]
for i in range(0, len(sequence), max_batch_length - overlap):
yield sequence[i : i + max_batch_length]


def _max_chunk_length(model: str, max: Optional[int] = None) -> int:
Expand All @@ -42,7 +46,7 @@ def chunk_content(
content: str,
for_model: str,
max_chunk_length: Optional[int] = None,
with_overlap=True,
with_overlap: bool = True,
) -> Iterator[tuple[str, int]]:
"""Split content into chunks of approximately equal token length."""

Expand Down Expand Up @@ -155,7 +159,7 @@ def split_text(
text: str,
for_model: str,
config: Config,
with_overlap=True,
with_overlap: bool = True,
max_chunk_length: Optional[int] = None,
) -> Iterator[tuple[str, int]]:
"""Split text into chunks of sentences, with each chunk not exceeding the maximum length
Expand All @@ -176,8 +180,6 @@ def split_text(

max_length = _max_chunk_length(for_model, max_chunk_length)

# flatten paragraphs to improve performance
text = text.replace("\n", " ")
text_length = count_string_tokens(text, for_model)

if text_length < max_length:
Expand Down

0 comments on commit d2cc22c

Please sign in to comment.