Skip to content

Commit

Permalink
[fix] Make multiprocessing optional for inherited chunkers
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavnicksm committed Jan 6, 2025
1 parent e0c67f9 commit ba7caae
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 16 deletions.
47 changes: 41 additions & 6 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __init__(
self.tokenizer = tokenizer_or_token_counter
self._tokenizer_backend = self._get_tokenizer_backend()
self.token_counter = self._get_tokenizer_counter()

# Set whether to use multiprocessing or not
self._use_multiprocessing = True

def _get_tokenizer_backend(self):
"""Return the backend tokenizer object."""
Expand Down Expand Up @@ -235,33 +238,65 @@ def _determine_optimal_workers(self) -> int:
f"Error determining optimal workers: {e}. Using single process."
)
return 1


def _process_batch_sequential(self,
texts: List[str],
show_progress_bar: bool = True) -> List[List[Chunk]]:
"""Process a batch of texts sequentially."""
return [
self.chunk(t) for t in tqdm(
texts,
desc="πŸ¦› CHONKING",
disable=not show_progress_bar,
unit="texts",
bar_format="{desc}: |{bar:20}| {percentage:3.0f}% β€’ {n_fmt}/{total_fmt} texts chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱",
ascii=' β–β–Žβ–β–Œβ–‹β–Šβ–‰'
)
]

def _process_batch_multiprocessing(self,
texts: List[str],
show_progress_bar: bool = True) -> List[List[Chunk]]:
"""Process a batch of texts using multiprocessing."""
num_workers = self._determine_optimal_workers()
with Pool(processes=num_workers) as pool:
return list(tqdm(pool.imap(self.chunk, texts),
desc="πŸ¦› CHONKING",
disable=not show_progress_bar,
unit="texts",
bar_format="{desc}: |{bar:20}| {percentage:3.0f}% β€’ {n_fmt}/{total_fmt} texts chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱",
ascii=' β–β–Žβ–β–Œβ–‹β–Šβ–‰'))

def chunk_batch(
self,
text: List[str],
texts: List[str],
show_progress_bar: bool = True,
) -> List[List[Chunk]]:
"""Split a List of texts into their respective chunks.
By default, this method uses multiprocessing to parallelize the chunking process.
Args:
text: List of input texts to be chunked.
texts: List of input texts to be chunked.
show_progress_bar: Whether to show a progress bar.
Returns:
List of lists of Chunk objects containing the chunked text and metadata
"""
return [self.chunk(t) for t in tqdm(text, desc="Chunking Texts", disable=not show_progress_bar)]
if self._use_multiprocessing:
return self._process_batch_multiprocessing(texts, show_progress_bar)
else:
return self._process_batch_sequential(texts, show_progress_bar)

def __call__(
self, text: Union[str, List[str]]
self, text: Union[str, List[str]], show_progress_bar: bool = True
) -> Union[List[Chunk], List[List[Chunk]]]:
"""Make the chunker callable directly.
Args:
text: Input text or list of texts to be chunked
show_progress_bar: Whether to show a progress bar (for batch chunking)
Returns:
List of Chunk objects or list of lists of Chunk
Expand All @@ -270,7 +305,7 @@ def __call__(
if isinstance(text, str):
return self.chunk(text)
elif isinstance(text, list):
return self.chunk_batch(text)
return self.chunk_batch(text, show_progress_bar)
else:
raise ValueError("Input must be a string or a list of strings.")

Expand Down
3 changes: 3 additions & 0 deletions src/chonkie/chunker/late.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(self,
# for the semantic meaning to be calculated properly
super().__init__(self.embedding_model.get_tokenizer_or_token_counter())

# Remove the multiprocessing flag from the base class
self._use_multiprocessing = False

def _create_token_chunks(self,
chunk_texts: List[str],
token_counts: List[int],
Expand Down
3 changes: 3 additions & 0 deletions src/chonkie/chunker/sdpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(
)
self.skip_window = skip_window

# Remove the multiprocessing flag from the base class
self._use_multiprocessing = False

def _merge_groups(self, groups: List[List[Sentence]]) -> List[Sentence]:
"""Merge the groups together."""
merged_group = []
Expand Down
3 changes: 3 additions & 0 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def __init__(
# for the group semantic meaning to be calculated properly
super().__init__(self.embedding_model.get_tokenizer_or_token_counter())

# Remove the multiprocessing flag from the base class
self._use_multiprocessing = False

def _split_sentences(
self,
text: str,
Expand Down
51 changes: 41 additions & 10 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .base import BaseChunker


from tqdm import trange
class TokenChunker(BaseChunker):
"""Chunker that splits text into chunks of a specified token size.
Expand Down Expand Up @@ -48,6 +48,8 @@ def __init__(
if isinstance(chunk_overlap, int)
else int(chunk_overlap * chunk_size)
)

self._use_multiprocessing = False

def _create_chunks(
self,
Expand Down Expand Up @@ -169,27 +171,56 @@ def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]:
return result

def chunk_batch(
self, texts: List[str], batch_size: Union[int, None] = None
self,
texts: List[str],
batch_size: int = 1,
show_progress_bar: bool = True
) -> List[List[Chunk]]:
"""Split a batch of texts into their respective chunks.
Args:
texts: List of input texts to be chunked
batch_size: Number of texts to process in a single batch
show_progress_bar: Whether to show a progress bar
Returns:
List of lists of Chunk objects containing the chunked text and metadata
"""
# if batch_size is not None, we process the texts in mini-batches to avoid memory issues
if batch_size is not None:
chunks = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : min(i + batch_size, len(texts))]
chunks.extend(self._process_text_batch(batch_texts))
return chunks
chunks = []
for i in trange(0,
len(texts),
batch_size,
desc="πŸ¦› CHONKING",
disable=not show_progress_bar,
unit="batches",
ascii=" β–β–Žβ–β–Œβ–‹β–Šβ–‰",
bar_format="{desc}: |{bar:20}| {percentage:3.0f}% β€’ {n_fmt}/{total_fmt} batches chunked [{elapsed}<{remaining}, {rate_fmt}] 🌱"):
batch_texts = texts[i : min(i + batch_size, len(texts))]
chunks.extend(self._process_text_batch(batch_texts))
return chunks

def __call__(self,
text: Union[str, List[str]],
batch_size: int = 1,
show_progress_bar: bool = True) -> Union[List[Chunk], List[List[Chunk]]]:
"""Make the TokenChunker callable directly.
Args:
text: Input text or list of texts to be chunked
batch_size: Number of texts to process in a single batch
show_progress_bar: Whether to show a progress bar (for batch chunking)
Returns:
List of Chunk objects or list of lists of Chunk
"""
if isinstance(text, str):
return self.chunk(text)
elif isinstance(text, list) and isinstance(text[0], str):
return self.chunk_batch(text, batch_size, show_progress_bar)
else:
return self._process_text_batch(texts)
raise ValueError("Invalid input type. Expected a string or a list of strings.")

def __repr__(self) -> str:
"""Return a string representation of the TokenChunker."""
Expand Down

0 comments on commit ba7caae

Please sign in to comment.