diff --git a/Cargo.toml b/Cargo.toml index 0e1856fb..f22f01e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] auto_enums = "0.8.3" +derive_more = { version = "0.99.17", default-features = false, features = ["deref", "from"] } either = "1.9.0" itertools = "0.12.0" once_cell = "1.18.0" diff --git a/src/characters.rs b/src/characters.rs index 2541a3f1..f009fd54 100644 --- a/src/characters.rs +++ b/src/characters.rs @@ -15,22 +15,11 @@ impl ChunkSizer for Characters { /// Return offsets for each unit of text used to calculate chunk size. /// Should return an exclusive byte range for each element counted. fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { - EncodedOffsets::new( - chunk - .char_indices() - .map(|(i, c)| i..(i + c.len_utf8())) - .collect(), - ) - } - - /// Determine the size of a given chunk to use for validation. - /// - /// ``` - /// use text_splitter::{Characters, ChunkSizer}; - /// - /// assert_eq!(Characters.chunk_size("hello"), 5); - fn chunk_size(&self, chunk: &str) -> usize { - chunk.chars().count() + chunk + .char_indices() + .map(|(i, c)| i..(i + c.len_utf8())) + .collect::>() + .into() } } @@ -41,6 +30,6 @@ mod tests { #[test] fn returns_offsets() { let offsets = Characters.encoded_offsets("eé"); - assert_eq!(offsets, EncodedOffsets::new(vec![0..1, 1..3])); + assert_eq!(offsets, vec![0..1, 1..3].into()); } } diff --git a/src/huggingface.rs b/src/huggingface.rs index 1663e601..f4c92412 100644 --- a/src/huggingface.rs +++ b/src/huggingface.rs @@ -8,16 +8,6 @@ impl ChunkSizer for Tokenizer { fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { encoded_offsets(self, chunk) } - - /// Returns the number of tokens in a given text after tokenization. - /// - /// # Panics - /// - /// Will panic if you don't have a byte-level tokenizer and the splitter - /// encounters text it can't tokenize. - fn chunk_size(&self, chunk: &str) -> usize { - chunk_size(self, chunk) - } } impl ChunkSizer for &Tokenizer { @@ -26,16 +16,6 @@ impl ChunkSizer for &Tokenizer { fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { encoded_offsets(self, chunk) } - - /// Returns the number of tokens in a given text after tokenization. - /// - /// # Panics - /// - /// Will panic if you don't have a byte-level tokenizer and the splitter - /// encounters text it can't tokenize. - fn chunk_size(&self, chunk: &str) -> usize { - chunk_size(self, chunk) - } } fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets { @@ -64,14 +44,8 @@ fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets { range.end -= 1; } } - EncodedOffsets::new(offsets) -} -fn chunk_size(tokenizer: &Tokenizer, chunk: &str) -> usize { - tokenizer - .encode(chunk, false) - .map(|enc| enc.len()) - .expect("Unable to tokenize the following string {chunk}") + offsets.into() } #[cfg(test)] @@ -82,13 +56,13 @@ mod tests { fn returns_offsets() { let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap(); let offsets = tokenizer.encoded_offsets(" An apple a"); - assert_eq!(offsets, EncodedOffsets::new(vec![0..3, 3..9, 9..11])); + assert_eq!(offsets, vec![0..3, 3..9, 9..11].into()); } #[test] fn returns_offsets_handles_prefix() { let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap(); let offsets = tokenizer.encoded_offsets("An apple a"); - assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10])); + assert_eq!(offsets, vec![0..2, 2..8, 8..10].into()); } } diff --git a/src/lib.rs b/src/lib.rs index 3c142245..b3ffc511 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,6 +138,7 @@ use core::{ }; use auto_enums::auto_enum; +use derive_more::{Deref, From}; use either::Either; use itertools::Itertools; use once_cell::sync::Lazy; @@ -156,27 +157,14 @@ pub use characters::Characters; /// a [`ChunkSizer`]. /// /// Each offset is an exclusive range. -#[derive(Debug, Eq, PartialEq)] -pub struct EncodedOffsets { - offsets: Vec>, -} - -impl EncodedOffsets { - /// Generate new encoded offsets with the offsets for the given encoding implementation. - #[must_use] - pub fn new(offsets: Vec>) -> Self { - Self { offsets } - } -} +#[derive(Debug, Default, Deref, Eq, From, PartialEq)] +pub struct EncodedOffsets(Vec>); /// Determines the size of a given chunk. pub trait ChunkSizer { /// Return offsets for each unit of text used to calculate chunk size. /// Should return an exclusive byte range for each element counted. fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets; - - /// Determine the size of a given chunk to use for validation - fn chunk_size(&self, chunk: &str) -> usize; } /// Describes the largest valid chunk size(s) that can be generated. @@ -207,8 +195,9 @@ pub trait ChunkCapacity { /// - `Ordering::Less` indicates more could be added /// - `Ordering::Equal` indicates the chunk is within the capacity range /// - `Ordering::Greater` indicates the chunk is larger than the capacity - fn fits(&self, chunk_size: usize) -> Ordering { + fn fits(&self, offsets: &EncodedOffsets) -> Ordering { let end = self.end(); + let chunk_size = offsets.len(); match self.start() { Some(start) => { @@ -578,13 +567,14 @@ where } /// Is the given text within the chunk size? - fn check_capacity(&self, chunk: &str) -> (usize, Ordering) { - let chunk_size = self.chunk_sizer.chunk_size(if self.trim_chunks { + fn check_capacity(&self, chunk: &str) -> (EncodedOffsets, Ordering) { + let offsets = self.chunk_sizer.encoded_offsets(if self.trim_chunks { chunk.trim() } else { chunk }); - (chunk_size, self.chunk_capacity.fits(chunk_size)) + let fits = self.chunk_capacity.fits(&offsets); + (offsets, fits) } /// Generate the next chunk, applying trimming settings. @@ -599,20 +589,20 @@ where let mut end = self.cursor; // Track change in chunk size - let (mut chunk_size, mut fits) = (0, Ordering::Less); + let (mut encoded_offsets, mut fits) = (EncodedOffsets::default(), Ordering::Less); // Consume as many as we can fit for (offset, str) in self.next_section()?.split() { let chunk = self.text.get(start..offset + str.len())?; // Cache prev chunk size before replacing - let (prev_chunk_size, prev_fits) = (chunk_size, fits); - (chunk_size, fits) = self.check_capacity(chunk); + let (prev_encoded_offsets, prev_fits) = (encoded_offsets, fits); + (encoded_offsets, fits) = self.check_capacity(chunk); // If we are now beyond the first item, and it is too large, end here. if start != end && (fits.is_gt() // For tokenizers, it is possible that the next string still may be the same amount of tokens. // Check if both are equal, but we added to the chunk size, which we don't want for ranges. - || (fits.is_eq() && prev_fits.is_eq() && chunk_size > prev_chunk_size)) + || (fits.is_eq() && prev_fits.is_eq() && encoded_offsets.len() > prev_encoded_offsets.len())) { break; } @@ -853,19 +843,14 @@ mod tests { struct Str; impl ChunkSizer for Str { - fn chunk_size(&self, chunk: &str) -> usize { - chunk.len() - } - fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { - EncodedOffsets::new( - chunk - .as_bytes() - .iter() - .enumerate() - .map(|(i, _)| (i..i)) - .collect(), - ) + chunk + .as_bytes() + .iter() + .enumerate() + .map(|(i, _)| (i..i)) + .collect::>() + .into() } } @@ -1028,66 +1013,70 @@ mod tests { #[test] fn check_chunk_capacity() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!(4.fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!(5.fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!(6.fits(Characters.chunk_size(chunk)), Ordering::Less); + assert_eq!(4.fits(&offsets), Ordering::Greater); + assert_eq!(5.fits(&offsets), Ordering::Equal); + assert_eq!(6.fits(&offsets), Ordering::Less); } #[test] fn check_chunk_capacity_for_range() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!((0..0).fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!((0..5).fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!((5..6).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((6..100).fits(Characters.chunk_size(chunk)), Ordering::Less); + assert_eq!((0..0).fits(&offsets), Ordering::Greater); + assert_eq!((0..5).fits(&offsets), Ordering::Greater); + assert_eq!((5..6).fits(&offsets), Ordering::Equal); + assert_eq!((6..100).fits(&offsets), Ordering::Less); } #[test] fn check_chunk_capacity_for_range_from() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!((0..).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((5..).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((6..).fits(Characters.chunk_size(chunk)), Ordering::Less); + assert_eq!((0..).fits(&offsets), Ordering::Equal); + assert_eq!((5..).fits(&offsets), Ordering::Equal); + assert_eq!((6..).fits(&offsets), Ordering::Less); } #[test] fn check_chunk_capacity_for_range_full() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!((..).fits(Characters.chunk_size(chunk)), Ordering::Equal); + assert_eq!((..).fits(&offsets), Ordering::Equal); } #[test] fn check_chunk_capacity_for_range_inclusive() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!( - (0..=4).fits(Characters.chunk_size(chunk)), - Ordering::Greater - ); - assert_eq!((5..=6).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((4..=5).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((6..=100).fits(Characters.chunk_size(chunk)), Ordering::Less); + assert_eq!((0..=4).fits(&offsets), Ordering::Greater); + assert_eq!((5..=6).fits(&offsets), Ordering::Equal); + assert_eq!((4..=5).fits(&offsets), Ordering::Equal); + assert_eq!((6..=100).fits(&offsets), Ordering::Less); } #[test] fn check_chunk_capacity_for_range_to() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!((..0).fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!((..5).fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!((..6).fits(Characters.chunk_size(chunk)), Ordering::Equal); + assert_eq!((..0).fits(&offsets), Ordering::Greater); + assert_eq!((..5).fits(&offsets), Ordering::Greater); + assert_eq!((..6).fits(&offsets), Ordering::Equal); } #[test] fn check_chunk_capacity_for_range_to_inclusive() { let chunk = "12345"; + let offsets = Characters.encoded_offsets(chunk); - assert_eq!((..=4).fits(Characters.chunk_size(chunk)), Ordering::Greater); - assert_eq!((..=5).fits(Characters.chunk_size(chunk)), Ordering::Equal); - assert_eq!((..=6).fits(Characters.chunk_size(chunk)), Ordering::Equal); + assert_eq!((..=4).fits(&offsets), Ordering::Greater); + assert_eq!((..=5).fits(&offsets), Ordering::Equal); + assert_eq!((..=6).fits(&offsets), Ordering::Equal); } } diff --git a/src/tiktoken.rs b/src/tiktoken.rs index 32de733e..7e708b8a 100644 --- a/src/tiktoken.rs +++ b/src/tiktoken.rs @@ -8,11 +8,6 @@ impl ChunkSizer for CoreBPE { fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { encoded_offsets(self, chunk) } - - /// Returns the number of tokens in a given text after tokenization. - fn chunk_size(&self, text: &str) -> usize { - chunk_size(self, text) - } } impl ChunkSizer for &CoreBPE { @@ -21,11 +16,6 @@ impl ChunkSizer for &CoreBPE { fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { encoded_offsets(self, chunk) } - - /// Returns the number of tokens in a given text after tokenization. - fn chunk_size(&self, text: &str) -> usize { - chunk_size(self, text) - } } fn encoded_offsets(bpe: &CoreBPE, chunk: &str) -> EncodedOffsets { @@ -38,11 +28,7 @@ fn encoded_offsets(bpe: &CoreBPE, chunk: &str) -> EncodedOffsets { *offset = end; Some(item) }); - EncodedOffsets::new(decoded.collect()) -} - -fn chunk_size(bpe: &CoreBPE, text: &str) -> usize { - bpe.encode_ordinary(text).len() + decoded.collect::>().into() } #[cfg(test)] @@ -54,6 +40,6 @@ mod tests { #[test] fn returns_offsets() { let offsets = cl100k_base().unwrap().encoded_offsets("An apple a"); - assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10])); + assert_eq!(offsets, vec![0..2, 2..8, 8..10].into()); } } diff --git a/tests/text_splitter_snapshots.rs b/tests/text_splitter_snapshots.rs index 8d8e2eb7..5f1d4ca1 100644 --- a/tests/text_splitter_snapshots.rs +++ b/tests/text_splitter_snapshots.rs @@ -17,7 +17,7 @@ fn characters_default() { assert_eq!(chunks.join(""), text); for chunk in chunks.iter() { - assert_le!(Characters.chunk_size(chunk), chunk_size); + assert_le!(Characters.encoded_offsets(chunk).len(), chunk_size); } insta::assert_yaml_snapshot!(chunks); } @@ -34,7 +34,7 @@ fn characters_trim() { let chunks = splitter.chunks(&text, chunk_size).collect::>(); for chunk in chunks.iter() { - assert_le!(Characters.chunk_size(chunk), chunk_size); + assert_le!(Characters.encoded_offsets(chunk).len(), chunk_size); } insta::assert_yaml_snapshot!(chunks); } @@ -52,7 +52,7 @@ fn characters_range() { assert_eq!(chunks.join(""), text); for chunk in chunks.iter() { - assert_le!(Characters.chunk_size(chunk), range.end()); + assert_le!(Characters.encoded_offsets(chunk).len(), range.end()); } insta::assert_yaml_snapshot!(chunks); } @@ -69,7 +69,7 @@ fn characters_range_trim() { let chunks = splitter.chunks(&text, range.clone()).collect::>(); for chunk in chunks.iter() { - assert_le!(Characters.chunk_size(chunk), range.end()); + assert_le!(Characters.encoded_offsets(chunk).len(), range.end()); } insta::assert_yaml_snapshot!(chunks); } @@ -90,7 +90,10 @@ fn huggingface_default() { assert_eq!(chunks.join(""), text); for chunk in chunks.iter() { - assert_le!(HUGGINGFACE_TOKENIZER.chunk_size(chunk), chunk_size); + assert_le!( + HUGGINGFACE_TOKENIZER.encoded_offsets(chunk).len(), + chunk_size + ); } insta::assert_yaml_snapshot!(chunks); } @@ -107,7 +110,10 @@ fn huggingface_trim() { let chunks = splitter.chunks(&text, chunk_size).collect::>(); for chunk in chunks.iter() { - assert_le!(HUGGINGFACE_TOKENIZER.chunk_size(chunk), chunk_size); + assert_le!( + HUGGINGFACE_TOKENIZER.encoded_offsets(chunk).len(), + chunk_size + ); } insta::assert_yaml_snapshot!(chunks); } @@ -127,7 +133,7 @@ fn tiktoken_default() { assert_eq!(chunks.join(""), text); for chunk in chunks.iter() { - assert_le!(TIKTOKEN_TOKENIZER.chunk_size(chunk), chunk_size); + assert_le!(TIKTOKEN_TOKENIZER.encoded_offsets(chunk).len(), chunk_size); } insta::assert_yaml_snapshot!(chunks); } @@ -144,7 +150,7 @@ fn tiktoken_trim() { let chunks = splitter.chunks(&text, chunk_size).collect::>(); for chunk in chunks.iter() { - assert_le!(TIKTOKEN_TOKENIZER.chunk_size(chunk), chunk_size); + assert_le!(TIKTOKEN_TOKENIZER.encoded_offsets(chunk).len(), chunk_size); } insta::assert_yaml_snapshot!(chunks); }