Skip to content

Commit

Permalink
Merge pull request #17 from aurelio-labs/simonas/regex-chunker
Browse files Browse the repository at this point in the history
feat: regex chunker
  • Loading branch information
simjak authored Jul 19, 2024
2 parents 21e8571 + 259440e commit 6f9f4c6
Show file tree
Hide file tree
Showing 16 changed files with 441 additions and 292 deletions.
148 changes: 140 additions & 8 deletions docs/00-chunkers-intro.ipynb

Large diffs are not rendered by default.

256 changes: 40 additions & 216 deletions docs/02-chunkers-async.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-chunkers"
version = "0.0.8"
version = "0.0.9"
description = "Super advanced chunking methods for AI"
authors = ["Aurelio AI <[email protected]>"]
readme = "README.md"
Expand Down
6 changes: 4 additions & 2 deletions semantic_chunkers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BaseChunker,
ConsecutiveChunker,
CumulativeChunker,
RegexChunker,
StatisticalChunker,
)
from semantic_chunkers.splitters import BaseSplitter, RegexSplitter
Expand All @@ -11,8 +12,9 @@
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"BaseSplitter",
"RegexSplitter",
"BaseSplitter",
"RegexChunker",
]

__version__ = "0.0.8"
__version__ = "0.0.9"
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.chunkers.consecutive import ConsecutiveChunker
from semantic_chunkers.chunkers.cumulative import CumulativeChunker
from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.chunkers.statistical import StatisticalChunker

__all__ = [
"BaseChunker",
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"RegexChunker",
]
4 changes: 2 additions & 2 deletions semantic_chunkers/chunkers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Optional

from colorama import Fore, Style
from pydantic.v1 import BaseModel, Extra
Expand All @@ -10,7 +10,7 @@

class BaseChunker(BaseModel):
name: str
encoder: BaseEncoder
encoder: Optional[BaseEncoder]
splitter: BaseSplitter

class Config:
Expand Down
4 changes: 3 additions & 1 deletion semantic_chunkers/chunkers/consecutive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter


class ConsecutiveChunker(BaseChunker):
"""
Called "consecutive sim chunker" because we check the similarities of consecutive document embeddings (compare ith to i+1th document embedding).
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
4 changes: 3 additions & 1 deletion semantic_chunkers/chunkers/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter


class CumulativeChunker(BaseChunker):
Expand All @@ -16,6 +16,8 @@ class CumulativeChunker(BaseChunker):
embeddings of cumulative concatenated documents with the next document.
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
58 changes: 58 additions & 0 deletions semantic_chunkers/chunkers/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio
from typing import List, Union

import regex

from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters import RegexSplitter
from semantic_chunkers.utils import text


class RegexChunker(BaseChunker):
def __init__(
self,
splitter: RegexSplitter = RegexSplitter(),
max_chunk_tokens: int = 300,
delimiters: List[Union[str, regex.Pattern]] = [],
):
super().__init__(name="regex_chunker", encoder=None, splitter=splitter)
self.splitter: RegexSplitter = splitter
self.max_chunk_tokens = max_chunk_tokens
self.delimiters = delimiters

def __call__(self, docs: list[str]) -> List[List[Chunk]]:
chunks = []
current_chunk = Chunk(
splits=[],
metadata={},
)
current_chunk.token_count = 0

for doc in docs:
sentences = self.splitter(doc, delimiters=self.delimiters)
for sentence in sentences:
sentence_token_count = text.tiktoken_length(sentence)
if (
current_chunk.token_count + sentence_token_count
> self.max_chunk_tokens
):
if current_chunk.splits:
chunks.append(current_chunk)
current_chunk = Chunk(splits=[])
current_chunk.token_count = 0

current_chunk.splits.append(sentence)
if current_chunk.token_count is None:
current_chunk.token_count = 0
current_chunk.token_count += sentence_token_count

# Last chunk
if current_chunk.splits:
chunks.append(current_chunk)

return [chunks]

async def acall(self, docs: list[str]) -> List[List[Chunk]]:
chunks = await asyncio.to_thread(self.__call__, docs)
return chunks
8 changes: 5 additions & 3 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Optional

import numpy as np
from semantic_router.encoders.base import BaseEncoder
Expand All @@ -9,7 +9,7 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter
from semantic_chunkers.utils.logger import logger
from semantic_chunkers.utils.text import (
async_retry_with_timeout,
Expand Down Expand Up @@ -44,6 +44,8 @@ def __str__(self):


class StatisticalChunker(BaseChunker):
encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down Expand Up @@ -104,7 +106,7 @@ def _chunk(
splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_chunk: Chunk | None = None
last_chunk: Optional[Chunk] = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_chunk is not None:
Expand Down
2 changes: 1 addition & 1 deletion semantic_chunkers/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter

__all__ = [
"BaseSplitter",
Expand Down
79 changes: 79 additions & 0 deletions semantic_chunkers/splitters/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import List, Union

import regex

from semantic_chunkers.splitters.base import BaseSplitter


class RegexSplitter(BaseSplitter):
"""
Enhanced regex pattern to split a given text into sentences more accurately.
"""

regex_pattern = r"""
# Negative lookbehind for word boundary, word char, dot, word char
(?<!\b\w\.\w.)
# Negative lookbehind for single uppercase initials like "A."
(?<!\b[A-Z][a-z]\.)
# Negative lookbehind for abbreviations like "U.S."
(?<!\b[A-Z]\.)
# Negative lookbehind for abbreviations with uppercase letters and dots
(?<!\b\p{Lu}\.\p{Lu}.)
# Negative lookbehind for numbers, to avoid splitting decimals
(?<!\b\p{N}\.)
# Positive lookbehind for punctuation followed by whitespace
(?<=\.|\?|!|:|\.\.\.)\s+
# Positive lookahead for uppercase letter or opening quote at word boundary
(?="?(?=[A-Z])|"\b)
# OR
|
# Splits after punctuation that follows closing punctuation, followed by
# whitespace
(?<=[\"\'\]\)\}][\.!?])\s+(?=[\"\'\(A-Z])
# OR
|
# Splits after punctuation if not preceded by a period
(?<=[^\.][\.!?])\s+(?=[A-Z])
# OR
|
# Handles splitting after ellipses
(?<=\.\.\.)\s+(?=[A-Z])
# OR
|
# Matches and removes control characters and format characters
[\p{Cc}\p{Cf}]+
# OR
|
# Splits after punctuation marks followed by another punctuation mark
(?<=[\.!?])(?=[\.!?])
# OR
|
# Splits after exclamation or question marks followed by whitespace or end of string
(?<=[!?])(?=\s|$)
"""

def __call__(
self, doc: str, delimiters: List[Union[str, regex.Pattern]] = []
) -> List[str]:
if not delimiters:
compiled_pattern = regex.compile(self.regex_pattern)
delimiters.append(compiled_pattern)
sentences = [doc]
for delimiter in delimiters:
sentences_for_next_delimiter = []
for sentence in sentences:
if isinstance(delimiter, regex.Pattern):
sub_sentences = regex.split(
self.regex_pattern, doc, flags=regex.VERBOSE
)
split_char = "" # No single character to append for regex pattern
else:
sub_sentences = sentence.split(delimiter)
split_char = delimiter
for i, sub_sentence in enumerate(sub_sentences):
if i < len(sub_sentences) - 1:
sub_sentence += split_char # Append delimiter to sub_sentence
if sub_sentence.strip():
sentences_for_next_delimiter.append(sub_sentence.strip())
sentences = sentences_for_next_delimiter
return sentences
57 changes: 0 additions & 57 deletions semantic_chunkers/splitters/sentence.py

This file was deleted.

File renamed without changes.
48 changes: 48 additions & 0 deletions tests/unit/test_regex_chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio
import unittest

from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.utils import text


class TestRegexChunker(unittest.TestCase):
def setUp(self):
self.chunker = RegexChunker(max_chunk_tokens=10)

def test_call(self):
docs = ["This is a test. This is only a test."]
chunks_list = self.chunker(docs)
chunks = chunks_list[0]

self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

def test_acall(self):
docs = ["This is a test. This is only a test."]

async def run_test():
chunks_list = await self.chunker.acall(docs)
chunks = chunks_list[0]
self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

asyncio.run(run_test())


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 6f9f4c6

Please sign in to comment.