Skip to content

Commit

Permalink
feat: Add chunk overlap setting
Browse files Browse the repository at this point in the history
Allows for overlapping chunks. Will still use the semantic range sections to determine a good splitting point for the overlap as well.
  • Loading branch information
benbrandt committed Apr 26, 2024
1 parent 0d6b722 commit 49015e9
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 13 deletions.
65 changes: 63 additions & 2 deletions src/chunk_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ impl ChunkCapacity {
/// # Errors
///
/// If the `max` size is less than the `desired` size, an error is returned.
pub fn with_max(self, max: usize) -> Result<Self, ChunkCapacityError> {
pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
if max < self.desired {
Err(ChunkCapacityError(
ChunkCapacityErrorRepr::MaxLessThanDesired,
))
} else {
Ok(Self { max, ..self })
self.max = max;
Ok(self)
}
}

Expand Down Expand Up @@ -256,6 +257,20 @@ pub trait ChunkSizer {
fn chunk_size(&self, chunk: &str, capacity: &ChunkCapacity) -> ChunkSize;
}

/// Indicates there was an error with the chunk configuration.
/// The `Display` implementation will provide a human-readable error message to
/// help debug the issue that caused the error.
#[derive(Error, Debug)]
#[error(transparent)]
pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);

/// Private error and free to change across minor version of the crate.
#[derive(Error, Debug)]
enum ChunkConfigErrorRepr {
#[error("The overlap is larger than or equal to the chunk capacity")]
OverlapLargerThanCapacity,
}

/// Configuration for how chunks should be created
#[derive(Debug)]
pub struct ChunkConfig<Sizer>
Expand All @@ -264,6 +279,8 @@ where
{
/// The chunk capacity to use for filling chunks
capacity: ChunkCapacity,
/// The amount of overlap between chunks. Defaults to 0.
overlap: usize,
/// The chunk sizer to use for determining the size of each chunk
sizer: Sizer,
/// Whether whitespace will be trimmed from the beginning and end of each chunk
Expand All @@ -282,6 +299,7 @@ impl ChunkConfig<Characters> {
pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
Self {
capacity: capacity.into(),
overlap: 0,
sizer: Characters,
trim: true,
}
Expand All @@ -297,6 +315,27 @@ where
&self.capacity
}

/// Retrieve the amount of overlap between chunks.
pub fn overlap(&self) -> usize {
self.overlap
}

/// Set the amount of overlap between chunks.
///
/// # Errors
///
/// Will return an error if the overlap is larger than or equal to the chunk capacity.
pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
if overlap >= self.capacity.max {
Err(ChunkConfigError(
ChunkConfigErrorRepr::OverlapLargerThanCapacity,
))
} else {
self.overlap = overlap;
Ok(self)
}
}

/// Retrieve a reference to the chunk sizer for this configuration.
pub fn sizer(&self) -> &Sizer {
&self.sizer
Expand All @@ -313,6 +352,7 @@ where
pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
ChunkConfig {
capacity: self.capacity,
overlap: self.overlap,
sizer,
trim: self.trim,
}
Expand Down Expand Up @@ -340,6 +380,11 @@ where
self.trim = trim;
self
}

/// Generate a memoized chunk sizer from this config to cache the size of chunks.
pub(crate) fn memoized_sizer(&self) -> MemoizedChunkSizer<'_, Sizer> {
MemoizedChunkSizer::new(&self.capacity, &self.sizer)
}
}

impl<T> From<T> for ChunkConfig<Characters>
Expand Down Expand Up @@ -704,4 +749,20 @@ mod tests {
assert_eq!(capacity.desired(), 10);
assert_eq!(capacity.max(), 10);
}

#[test]
fn set_chunk_overlap() {
let config = ChunkConfig::new(10).with_overlap(5).unwrap();
assert_eq!(config.overlap(), 5);
}

#[test]
fn cant_set_overlap_larger_than_capacity() {
let chunk_config = ChunkConfig::new(5);
let err = chunk_config.with_overlap(10).unwrap_err();
assert_eq!(
err.to_string(),
"The overlap is larger than or equal to the chunk capacity"
);
}
}
85 changes: 74 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,20 @@ where
Sizer: ChunkSizer,
Level: SemanticLevel,
{
/// Chunk configuration for this iterator
chunk_config: &'sizer ChunkConfig<Sizer>,
/// How to validate chunk sizes
chunk_sizer: MemoizedChunkSizer<'sizer, Sizer>,
/// Current byte offset in the `text`
cursor: usize,
/// Reusable container for next sections to avoid extra allocations
next_sections: Vec<(usize, &'text str)>,
/// Previous item's byte range
prev_item_range: Option<Range<usize>>,
/// Splitter used for determining semantic levels.
semantic_split: SemanticSplitRanges<Level>,
/// Original text to iterate over and generate chunks from
text: &'text str,
/// Whether or not chunks should be trimmed
trim_chunks: bool,
}

impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level>
Expand All @@ -173,18 +175,19 @@ where
/// Starts with an offset of 0
fn new(chunk_config: &'sizer ChunkConfig<Sizer>, text: &'text str) -> Self {
Self {
chunk_config,
chunk_sizer: chunk_config.memoized_sizer(),
cursor: 0,
chunk_sizer: MemoizedChunkSizer::new(chunk_config.capacity(), chunk_config.sizer()),
next_sections: Vec::new(),
prev_item_range: None,
semantic_split: SemanticSplitRanges::new(Level::offsets(text).collect()),
text,
trim_chunks: chunk_config.trim(),
}
}

/// If trim chunks is on, trim the str and adjust the offset
fn trim_chunk(&self, offset: usize, chunk: &'text str) -> (usize, &'text str) {
if self.trim_chunks {
if self.chunk_config.trim() {
Level::trim_chunk(offset, chunk)
} else {
(offset, chunk)
Expand All @@ -200,6 +203,18 @@ where
self.semantic_split.update_ranges(self.cursor);
self.update_next_sections();

let (start, end) = self.binary_search_next_chunk()?;

// Optionally move cursor back if overlap is desired
self.update_cursor(end);

let chunk = self.text.get(start..end)?;
// Trim whitespace if user requested it
Some(self.trim_chunk(start, chunk))
}

/// Use binary search to find the next chunk that fits within the chunk size
fn binary_search_next_chunk(&mut self) -> Option<(usize, usize)> {
let start = self.cursor;
let mut end = self.cursor;
let mut equals_found = false;
Expand Down Expand Up @@ -256,7 +271,6 @@ where
}
}

// Sometimes with tokenization, we can get a bigger chunk for the same amount of tokens.
if let (Some(successful_index), Some(chunk_size)) =
(successful_index, successful_chunk_size)
{
Expand All @@ -281,12 +295,53 @@ where
}
}

self.cursor = end;
Some((start, end))
}

/// Use binary search to find the sections that fit within the overlap size.
/// If no overlap deisired, return end.
fn update_cursor(&mut self, end: usize) {
if self.chunk_config.overlap() == 0 {
self.cursor = end;
return;
}

let chunk = self.text.get(start..self.cursor)?;
// Binary search for overlap
let mut start = end;
let mut low = 0;
// Find closest index that would work
let binary_search_by_key = dbg!(self
.next_sections
.binary_search_by_key(&end, |(offset, str)| offset + str.len()));
let mut high = match binary_search_by_key {
Ok(i) | Err(i) => i,
};
dbg!(&self.next_sections);

// Trim whitespace if user requested it
Some(self.trim_chunk(start, chunk))
while low <= high {
let mid = low + (high - low) / 2;
let (offset, _) = self.next_sections[mid];
let (_, chunk) =
self.trim_chunk(offset, self.text.get(offset..end).expect("Invalid range"));
let chunk_size = self
.chunk_config
.sizer()
.chunk_size(chunk, &self.chunk_config.overlap().into());

// We got further than the last one, so update start
if chunk_size.fits().is_le() && offset < start && offset > self.cursor {
start = offset;
}

// Adjust search area
if chunk_size.fits().is_lt() && mid > 0 {
high = mid - 1;
} else {
low = mid + 1;
}
}

self.cursor = start;
}

/// Find the ideal next sections, breaking it up until we find the largest chunk.
Expand Down Expand Up @@ -367,7 +422,15 @@ where
// Make sure we didn't get an empty chunk. Should only happen in
// cases where we trim.
(_, "") => continue,
c => return Some(c),
c => {
let item_range = Some(c.0..c.0 + c.1.len());
// Skip because we've emitted a duplicate chunk
if item_range == self.prev_item_range {
continue;
}
self.prev_item_range = item_range;
return Some(c);
}
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions tests/text_splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,38 @@ fn huggingface_tokenizer_with_padding() {
]
);
}

#[test]
fn chunk_overlap_characters() {
let splitter = TextSplitter::new(ChunkConfig::new(4).with_overlap(2).unwrap());
let text = "1234567890";

let chunks = splitter.chunks(text).collect::<Vec<_>>();

assert_eq!(chunks, ["1234", "3456", "5678", "7890"]);
}

#[test]
fn chunk_overlap_words() {
let splitter = TextSplitter::new(
ChunkConfig::new(4)
.with_overlap(3)
.unwrap()
.with_trim(false),
);
let text = "An apple a day";

let chunks = splitter.chunks(text).collect::<Vec<_>>();

assert_eq!(chunks, ["An ", " ", "appl", "pple", " a ", "a ", " day"]);
}

#[test]
fn chunk_overlap_words_trim() {
let splitter = TextSplitter::new(ChunkConfig::new(4).with_overlap(3).unwrap());
let text = "An apple a day";

let chunks = splitter.chunks(text).collect::<Vec<_>>();

assert_eq!(chunks, ["An", "appl", "pple", "a", "day"]);
}

0 comments on commit 49015e9

Please sign in to comment.