diff --git a/src/characters.rs b/src/characters.rs index 15efdb61..2541a3f1 100644 --- a/src/characters.rs +++ b/src/characters.rs @@ -1,4 +1,4 @@ -use crate::ChunkSizer; +use crate::{ChunkSizer, EncodedOffsets}; /// Used for splitting a piece of text into chunks based on the number of /// characters in each chunk. @@ -12,6 +12,17 @@ use crate::ChunkSizer; pub struct Characters; impl ChunkSizer for Characters { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + EncodedOffsets::new( + chunk + .char_indices() + .map(|(i, c)| i..(i + c.len_utf8())) + .collect(), + ) + } + /// Determine the size of a given chunk to use for validation. /// /// ``` @@ -22,3 +33,14 @@ impl ChunkSizer for Characters { chunk.chars().count() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_offsets() { + let offsets = Characters.encoded_offsets("eé"); + assert_eq!(offsets, EncodedOffsets::new(vec![0..1, 1..3])); + } +} diff --git a/src/huggingface.rs b/src/huggingface.rs index 6e06f906..1663e601 100644 --- a/src/huggingface.rs +++ b/src/huggingface.rs @@ -1,8 +1,14 @@ use tokenizers::Tokenizer; -use crate::ChunkSizer; +use crate::{ChunkSizer, EncodedOffsets}; impl ChunkSizer for Tokenizer { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + encoded_offsets(self, chunk) + } + /// Returns the number of tokens in a given text after tokenization. /// /// # Panics @@ -15,6 +21,12 @@ impl ChunkSizer for Tokenizer { } impl ChunkSizer for &Tokenizer { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + encoded_offsets(self, chunk) + } + /// Returns the number of tokens in a given text after tokenization. /// /// # Panics @@ -26,9 +38,57 @@ impl ChunkSizer for &Tokenizer { } } +fn encoded_offsets(tokenizer: &Tokenizer, chunk: &str) -> EncodedOffsets { + let encoding = tokenizer + .encode(chunk, false) + .expect("Unable to tokenize the following string {chunk}"); + let mut offsets = encoding + .get_offsets() + .iter() + .map(|(start, end)| { + let end = *end + 1; + *start..end + }) + .collect::>(); + // Sometimes the offsets are off by one because of whitespace prefixing + let prefixed = offsets + .last() + .map(|r| r.end != chunk.len()) + .unwrap_or_default(); + + if prefixed { + for range in &mut offsets { + if range.start != 0 { + range.start -= 1; + } + range.end -= 1; + } + } + EncodedOffsets::new(offsets) +} + fn chunk_size(tokenizer: &Tokenizer, chunk: &str) -> usize { tokenizer .encode(chunk, false) .map(|enc| enc.len()) - .expect("Unable to tokenize the following string {str}") + .expect("Unable to tokenize the following string {chunk}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_offsets() { + let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap(); + let offsets = tokenizer.encoded_offsets(" An apple a"); + assert_eq!(offsets, EncodedOffsets::new(vec![0..3, 3..9, 9..11])); + } + + #[test] + fn returns_offsets_handles_prefix() { + let tokenizer = &Tokenizer::from_pretrained("bert-base-cased", None).unwrap(); + let offsets = tokenizer.encoded_offsets("An apple a"); + assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10])); + } } diff --git a/src/lib.rs b/src/lib.rs index aa8d55cc..3c142245 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -152,8 +152,29 @@ mod tiktoken; pub use characters::Characters; +/// Contains start and end byte offsets for an encoded unit of text in a chunk, calculated by +/// a [`ChunkSizer`]. +/// +/// Each offset is an exclusive range. +#[derive(Debug, Eq, PartialEq)] +pub struct EncodedOffsets { + offsets: Vec>, +} + +impl EncodedOffsets { + /// Generate new encoded offsets with the offsets for the given encoding implementation. + #[must_use] + pub fn new(offsets: Vec>) -> Self { + Self { offsets } + } +} + /// Determines the size of a given chunk. pub trait ChunkSizer { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets; + /// Determine the size of a given chunk to use for validation fn chunk_size(&self, chunk: &str) -> usize; } @@ -835,6 +856,17 @@ mod tests { fn chunk_size(&self, chunk: &str) -> usize { chunk.len() } + + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + EncodedOffsets::new( + chunk + .as_bytes() + .iter() + .enumerate() + .map(|(i, _)| (i..i)) + .collect(), + ) + } } #[test] diff --git a/src/tiktoken.rs b/src/tiktoken.rs index cf90170b..32de733e 100644 --- a/src/tiktoken.rs +++ b/src/tiktoken.rs @@ -1,31 +1,59 @@ use tiktoken_rs::CoreBPE; -use crate::ChunkSizer; +use crate::{ChunkSizer, EncodedOffsets}; impl ChunkSizer for CoreBPE { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + encoded_offsets(self, chunk) + } + /// Returns the number of tokens in a given text after tokenization. - /// - /// # Panics - /// - /// Will panic if you don't have a byte-level tokenizer and the splitter - /// encounters text it can't tokenize. fn chunk_size(&self, text: &str) -> usize { chunk_size(self, text) } } impl ChunkSizer for &CoreBPE { + /// Return offsets for each unit of text used to calculate chunk size. + /// Should return an exclusive byte range for each element counted. + fn encoded_offsets(&self, chunk: &str) -> EncodedOffsets { + encoded_offsets(self, chunk) + } + /// Returns the number of tokens in a given text after tokenization. - /// - /// # Panics - /// - /// Will panic if you don't have a byte-level tokenizer and the splitter - /// encounters text it can't tokenize. fn chunk_size(&self, text: &str) -> usize { chunk_size(self, text) } } +fn encoded_offsets(bpe: &CoreBPE, chunk: &str) -> EncodedOffsets { + let tokens = bpe.encode_ordinary(chunk); + let decoded = bpe + ._decode_native_and_split(tokens) + .scan(0usize, |offset, bytes| { + let end = *offset + bytes.len(); + let item = *offset..end; + *offset = end; + Some(item) + }); + EncodedOffsets::new(decoded.collect()) +} + fn chunk_size(bpe: &CoreBPE, text: &str) -> usize { bpe.encode_ordinary(text).len() } + +#[cfg(test)] +mod tests { + use super::*; + + use tiktoken_rs::cl100k_base; + + #[test] + fn returns_offsets() { + let offsets = cl100k_base().unwrap().encoded_offsets("An apple a"); + assert_eq!(offsets, EncodedOffsets::new(vec![0..2, 2..8, 8..10])); + } +}