Skip to content

Commit

Permalink
Convert entirely to encoded offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt committed Dec 16, 2023
1 parent 5ab3f93 commit fb103d3
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 129 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
auto_enums = "0.8.3"
derive_more = { version = "0.99.17", default-features = false, features = ["deref", "from"] }
either = "1.9.0"
itertools = "0.12.0"
once_cell = "1.18.0"
Expand Down
23 changes: 6 additions & 17 deletions src/characters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,11 @@ impl ChunkSizer for Characters {
/// Return offsets for each unit of text used to calculate chunk size.
/// Should return an exclusive byte range for each element counted.
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
EncodedOffsets::new(
chunk
.char_indices()
.map(|(i, c)| i..(i + c.len_utf8()))
.collect(),
)
}

/// Determine the size of a given chunk to use for validation.
///
/// ```
/// use text_splitter::{Characters, ChunkSizer};
///
/// assert_eq!(Characters.chunk_size("hello"), 5);
fn chunk_size(&self, chunk: &str) -> usize {
chunk.chars().count()
chunk
.char_indices()
.map(|(i, c)| i..(i + c.len_utf8()))
.collect::<Vec<_>>()
.into()
}
}

Expand All @@ -41,6 +30,6 @@ mod tests {
#[test]
fn returns_offsets() {
let offsets = Characters.encoded_offsets("eé");
assert_eq!(offsets, EncodedOffsets::new(vec![0..1, 1..3]));
assert_eq!(offsets, vec![0..1, 1..3].into());
}
}
32 changes: 3 additions & 29 deletions src/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@ impl ChunkSizer for Tokenizer {
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
chunk_size(self, chunk)
}
}

impl ChunkSizer for &Tokenizer {
Expand All @@ -26,16 +16,6 @@ impl ChunkSizer for &Tokenizer {
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
///
/// # Panics
///
/// Will panic if you don't have a byte-level tokenizer and the splitter
/// encounters text it can't tokenize.
fn chunk_size(&self, chunk: &str) -> usize {
chunk_size(self, chunk)
}
}

fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets {
Expand Down Expand Up @@ -64,14 +44,8 @@ fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets {
range.end -= 1;
}
}
EncodedOffsets::new(offsets)
}

fn chunk_size(tokenizer: &Tokenizer, chunk: &str) -> usize {
tokenizer
.encode(chunk, false)
.map(|enc| enc.len())
.expect("Unable to tokenize the following string {chunk}")
offsets.into()
}

#[cfg(test)]
Expand All @@ -82,13 +56,13 @@ mod tests {
fn returns_offsets() {
let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = tokenizer.encoded_offsets(" An apple a");
assert_eq!(offsets, EncodedOffsets::new(vec![0..3, 3..9, 9..11]));
assert_eq!(offsets, vec![0..3, 3..9, 9..11].into());
}

#[test]
fn returns_offsets_handles_prefix() {
let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let offsets = tokenizer.encoded_offsets("An apple a");
assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10]));
assert_eq!(offsets, vec![0..2, 2..8, 8..10].into());
}
}
107 changes: 48 additions & 59 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ use core::{
};

use auto_enums::auto_enum;
use derive_more::{Deref, From};
use either::Either;
use itertools::Itertools;
use once_cell::sync::Lazy;
Expand All @@ -156,27 +157,14 @@ pub use characters::Characters;
/// a [`ChunkSizer`].
///
/// Each offset is an exclusive range.
#[derive(Debug, Eq, PartialEq)]
pub struct EncodedOffsets {
offsets: Vec<Range<usize>>,
}

impl EncodedOffsets {
/// Generate new encoded offsets with the offsets for the given encoding implementation.
#[must_use]
pub fn new(offsets: Vec<Range<usize>>) -> Self {
Self { offsets }
}
}
#[derive(Debug, Default, Deref, Eq, From, PartialEq)]
pub struct EncodedOffsets(Vec<Range<usize>>);

/// Determines the size of a given chunk.
pub trait ChunkSizer {
/// Return offsets for each unit of text used to calculate chunk size.
/// Should return an exclusive byte range for each element counted.
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets;

/// Determine the size of a given chunk to use for validation
fn chunk_size(&self, chunk: &str) -> usize;
}

/// Describes the largest valid chunk size(s) that can be generated.
Expand Down Expand Up @@ -207,8 +195,9 @@ pub trait ChunkCapacity {
/// - `Ordering::Less` indicates more could be added
/// - `Ordering::Equal` indicates the chunk is within the capacity range
/// - `Ordering::Greater` indicates the chunk is larger than the capacity
fn fits(&self, chunk_size: usize) -> Ordering {
fn fits(&self, offsets: &EncodedOffsets) -> Ordering {
let end = self.end();
let chunk_size = offsets.len();

match self.start() {
Some(start) => {
Expand Down Expand Up @@ -578,13 +567,14 @@ where
}

/// Is the given text within the chunk size?
fn check_capacity(&self, chunk: &str) -> (usize, Ordering) {
let chunk_size = self.chunk_sizer.chunk_size(if self.trim_chunks {
fn check_capacity(&self, chunk: &str) -> (EncodedOffsets, Ordering) {
let offsets = self.chunk_sizer.encoded_offsets(if self.trim_chunks {
chunk.trim()
} else {
chunk
});
(chunk_size, self.chunk_capacity.fits(chunk_size))
let fits = self.chunk_capacity.fits(&offsets);
(offsets, fits)
}

/// Generate the next chunk, applying trimming settings.
Expand All @@ -599,20 +589,20 @@ where

let mut end = self.cursor;
// Track change in chunk size
let (mut chunk_size, mut fits) = (0, Ordering::Less);
let (mut encoded_offsets, mut fits) = (EncodedOffsets::default(), Ordering::Less);
// Consume as many as we can fit
for (offset, str) in self.next_section()?.split() {
let chunk = self.text.get(start..offset + str.len())?;
// Cache prev chunk size before replacing
let (prev_chunk_size, prev_fits) = (chunk_size, fits);
(chunk_size, fits) = self.check_capacity(chunk);
let (prev_encoded_offsets, prev_fits) = (encoded_offsets, fits);
(encoded_offsets, fits) = self.check_capacity(chunk);

// If we are now beyond the first item, and it is too large, end here.
if start != end
&& (fits.is_gt()
// For tokenizers, it is possible that the next string still may be the same amount of tokens.
// Check if both are equal, but we added to the chunk size, which we don't want for ranges.
|| (fits.is_eq() && prev_fits.is_eq() && chunk_size > prev_chunk_size))
|| (fits.is_eq() && prev_fits.is_eq() && encoded_offsets.len() > prev_encoded_offsets.len()))
{
break;
}
Expand Down Expand Up @@ -853,19 +843,14 @@ mod tests {
struct Str;

impl ChunkSizer for Str {
fn chunk_size(&self, chunk: &str) -> usize {
chunk.len()
}

fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
EncodedOffsets::new(
chunk
.as_bytes()
.iter()
.enumerate()
.map(|(i, _)| (i..i))
.collect(),
)
chunk
.as_bytes()
.iter()
.enumerate()
.map(|(i, _)| (i..i))
.collect::<Vec<_>>()
.into()
}
}

Expand Down Expand Up @@ -1028,66 +1013,70 @@ mod tests {
#[test]
fn check_chunk_capacity() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!(4.fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!(5.fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!(6.fits(Characters.chunk_size(chunk)), Ordering::Less);
assert_eq!(4.fits(&offsets), Ordering::Greater);
assert_eq!(5.fits(&offsets), Ordering::Equal);
assert_eq!(6.fits(&offsets), Ordering::Less);
}

#[test]
fn check_chunk_capacity_for_range() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!((0..0).fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!((0..5).fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!((5..6).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((6..100).fits(Characters.chunk_size(chunk)), Ordering::Less);
assert_eq!((0..0).fits(&offsets), Ordering::Greater);
assert_eq!((0..5).fits(&offsets), Ordering::Greater);
assert_eq!((5..6).fits(&offsets), Ordering::Equal);
assert_eq!((6..100).fits(&offsets), Ordering::Less);
}

#[test]
fn check_chunk_capacity_for_range_from() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!((0..).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((5..).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((6..).fits(Characters.chunk_size(chunk)), Ordering::Less);
assert_eq!((0..).fits(&offsets), Ordering::Equal);
assert_eq!((5..).fits(&offsets), Ordering::Equal);
assert_eq!((6..).fits(&offsets), Ordering::Less);
}

#[test]
fn check_chunk_capacity_for_range_full() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!((..).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((..).fits(&offsets), Ordering::Equal);
}

#[test]
fn check_chunk_capacity_for_range_inclusive() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!(
(0..=4).fits(Characters.chunk_size(chunk)),
Ordering::Greater
);
assert_eq!((5..=6).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((4..=5).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((6..=100).fits(Characters.chunk_size(chunk)), Ordering::Less);
assert_eq!((0..=4).fits(&offsets), Ordering::Greater);
assert_eq!((5..=6).fits(&offsets), Ordering::Equal);
assert_eq!((4..=5).fits(&offsets), Ordering::Equal);
assert_eq!((6..=100).fits(&offsets), Ordering::Less);
}

#[test]
fn check_chunk_capacity_for_range_to() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!((..0).fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!((..5).fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!((..6).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((..0).fits(&offsets), Ordering::Greater);
assert_eq!((..5).fits(&offsets), Ordering::Greater);
assert_eq!((..6).fits(&offsets), Ordering::Equal);
}

#[test]
fn check_chunk_capacity_for_range_to_inclusive() {
let chunk = "12345";
let offsets = Characters.encoded_offsets(chunk);

assert_eq!((..=4).fits(Characters.chunk_size(chunk)), Ordering::Greater);
assert_eq!((..=5).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((..=6).fits(Characters.chunk_size(chunk)), Ordering::Equal);
assert_eq!((..=4).fits(&offsets), Ordering::Greater);
assert_eq!((..=5).fits(&offsets), Ordering::Equal);
assert_eq!((..=6).fits(&offsets), Ordering::Equal);
}
}
18 changes: 2 additions & 16 deletions src/tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ impl ChunkSizer for CoreBPE {
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
fn chunk_size(&self, text: &str) -> usize {
chunk_size(self, text)
}
}

impl ChunkSizer for &CoreBPE {
Expand All @@ -21,11 +16,6 @@ impl ChunkSizer for &CoreBPE {
fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets {
encoded_offsets(self, chunk)
}

/// Returns the number of tokens in a given text after tokenization.
fn chunk_size(&self, text: &str) -> usize {
chunk_size(self, text)
}
}

fn encoded_offsets(bpe: &CoreBPE, chunk: &str) -> EncodedOffsets {
Expand All @@ -38,11 +28,7 @@ fn encoded_offsets(bpe: &CoreBPE, chunk: &str) -> EncodedOffsets {
*offset = end;
Some(item)
});
EncodedOffsets::new(decoded.collect())
}

fn chunk_size(bpe: &CoreBPE, text: &str) -> usize {
bpe.encode_ordinary(text).len()
decoded.collect::<Vec<_>>().into()
}

#[cfg(test)]
Expand All @@ -54,6 +40,6 @@ mod tests {
#[test]
fn returns_offsets() {
let offsets = cl100k_base().unwrap().encoded_offsets("An apple a");
assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10]));
assert_eq!(offsets, vec![0..2, 2..8, 8..10].into());
}
}
Loading

0 comments on commit fb103d3

Please sign in to comment.