Skip to content

Commit

Permalink
Use iterator based approach for chunk size to avoid extra allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt committed Dec 27, 2023
1 parent 12c6cac commit 294c1bd
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 185 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
auto_enums = "0.8.3"
derive_more = { version = "0.99.17", default-features = false, features = ["deref", "deref_mut", "from"] }
either = "1.9.0"
itertools = "0.12.0"
once_cell = "1.18.0"
Expand Down
31 changes: 12 additions & 19 deletions src/characters.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{ChunkSizer, EncodedOffsets};
use std::ops::Range;

use crate::{ChunkCapacity, ChunkSize, ChunkSizer};

/// Used for splitting a piece of text into chunks based on the number of
/// characters in each chunk.
Expand All @@ -11,25 +13,16 @@ use crate::{ChunkSizer, EncodedOffsets};
#[derive(Debug)]
pub struct Characters;

impl ChunkSizer for Characters {
/// Return offsets for each unit of text used to calculate chunk size.
/// Should return an exclusive byte range for each element counted.
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
chunk
.char_indices()
.map(|(i, c)| i..(i + c.len_utf8()))
.collect::<Vec<_>>()
.into()
impl Characters {
fn encoded_offsets(chunk: &str) -> impl Iterator<Item = Range<usize>> + '_ {
chunk.char_indices().map(|(i, c)| i..(i + c.len_utf8()))
}
}

impl ChunkSizer for Characters {
/// Determine the size of a given chunk to use for validation.
///
/// ```
/// use text_splitter::{Characters, ChunkSizer};
///
/// assert_eq!(Characters.chunk_size("hello"), 5);
fn chunk_size(&self, chunk: &str) -> usize {
chunk.chars().count()
fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize {
ChunkSize::from_offsets(Self::encoded_offsets(chunk), capacity)
}
}

Expand All @@ -39,7 +32,7 @@ mod tests {

#[test]
fn returns_offsets() {
let offsets = Characters.encoded_offsets("eé");
assert_eq!(offsets, vec![0..1, 1..3].into());
let offsets = Characters::encoded_offsets("eé").collect::<Vec<_>>();
assert_eq!(offsets, vec![0..1, 1..3]);
}
}
50 changes: 18 additions & 32 deletions src/huggingface.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,37 @@
use std::ops::Range;

use tokenizers::Tokenizer;

use crate::{ChunkSizer, EncodedOffsets};
use crate::{ChunkCapacity, ChunkSize, ChunkSizer};

impl ChunkSizer for Tokenizer {
/// Return offsets for each unit of text used to calculate chunk size.
/// Should return an exclusive byte range for each element counted.
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
chunk_size(self, chunk)
fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize {
ChunkSize::from_offsets(encoded_offsets(self, chunk), capacity)
}
}

impl ChunkSizer for &Tokenizer {
/// Return offsets for each unit of text used to calculate chunk size.
/// Should return an exclusive byte range for each element counted.
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
chunk_size(self, chunk)
fn chunk_size(&self, chunk: &str, capacity: &impl ChunkCapacity) -> ChunkSize {
ChunkSize::from_offsets(encoded_offsets(self, chunk), capacity)
}
}

fn chunk_size(tokenizer: &Tokenizer, chunk: &str) -> usize {
tokenizer
.encode(chunk, false)
.map(|enc| enc.len())
.expect("Unable to tokenize the following string {str}")
}

fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets {
fn encoded_offsets<'text>(
tokenizer: &Tokenizer,
chunk: &'text str,
) -> impl Iterator<Item = Range<usize>> + 'text {
let encoding = tokenizer
.encode(chunk, false)
.expect("Unable to tokenize the following string {chunk}");
Expand All @@ -72,7 +58,7 @@ fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets {
}
}

offsets.into()
offsets.into_iter()
}

#[cfg(test)]
Expand All @@ -81,15 +67,15 @@ mod tests {

#[test]
fn returns_offsets() {
let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = tokenizer.encoded_offsets(" An apple a");
assert_eq!(offsets, vec![0..3, 3..9, 9..11].into());
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = encoded_offsets(&tokenizer, " An apple a").collect::<Vec<_>>();
assert_eq!(offsets, vec![0..3, 3..9, 9..11]);
}

#[test]
fn returns_offsets_handles_prefix() {
let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = tokenizer.encoded_offsets("An apple a");
assert_eq!(offsets, vec![0..2, 2..8, 8..10].into());
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = encoded_offsets(&tokenizer, "An apple a").collect::<Vec<_>>();
assert_eq!(offsets, vec![0..2, 2..8, 8..10]);
}
}
Loading

0 comments on commit 294c1bd

Please sign in to comment.