From e1bf7acff6ec9946ad9f581413199495b5740838 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Sat, 2 Mar 2024 09:35:02 +0100 Subject: [PATCH] Add markdown benchmark code (#107) --- .github/workflows/ci.yml | 14 +++++++ benches/chunk_size.rs | 79 +++++++++++++++++++++++++++++++++------- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 38c2251b..cf2f61ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,6 +52,20 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true + benchmarks: + name: Test that benchmarks run + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - uses: Swatinem/rust-cache@v2 + + - name: Run cargo test + run: cargo test --benches --all-features + lints: name: Lints runs-on: ubuntu-latest diff --git a/benches/chunk_size.rs b/benches/chunk_size.rs index 01f7d694..bdbb5d03 100644 --- a/benches/chunk_size.rs +++ b/benches/chunk_size.rs @@ -3,23 +3,23 @@ use std::fs; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use text_splitter::{Characters, TextSplitter}; +use text_splitter::{unstable_markdown::MarkdownSplitter, Characters, TextSplitter}; use tiktoken_rs::{cl100k_base, CoreBPE}; use tokenizers::Tokenizer; #[allow(clippy::large_enum_variant)] -enum Splitter { +enum TextSplitterImpl { Characters(TextSplitter), Huggingface(TextSplitter), Tiktoken(TextSplitter), } -impl Splitter { +impl TextSplitterImpl { fn name(&self) -> &str { match self { - Splitter::Characters(_) => "Characters", - Splitter::Huggingface(_) => "Huggingface", - Splitter::Tiktoken(_) => "Tiktoken", + TextSplitterImpl::Characters(_) => "Characters", + TextSplitterImpl::Huggingface(_) => "Huggingface", + TextSplitterImpl::Tiktoken(_) => "Tiktoken", } } @@ -35,20 +35,73 @@ impl Splitter { fn chunks<'text>(&self, text: &'text str, chunk_size: usize) -> Vec<&'text str> { match self { - Splitter::Characters(splitter) => splitter.chunks(text, chunk_size).collect(), - Splitter::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(), - Splitter::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(), + Self::Characters(splitter) => splitter.chunks(text, chunk_size).collect(), + Self::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(), + Self::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(), } } } -fn criterion_benchmark(c: &mut Criterion) { +#[allow(clippy::large_enum_variant)] +enum MarkdownSplitterImpl { + Characters(MarkdownSplitter), + Huggingface(MarkdownSplitter), + Tiktoken(MarkdownSplitter), +} + +impl MarkdownSplitterImpl { + fn name(&self) -> &str { + match self { + MarkdownSplitterImpl::Characters(_) => "Characters", + MarkdownSplitterImpl::Huggingface(_) => "Huggingface", + MarkdownSplitterImpl::Tiktoken(_) => "Tiktoken", + } + } + + fn iter() -> [Self; 3] { + [ + Self::Characters(MarkdownSplitter::default()), + Self::Huggingface(MarkdownSplitter::new( + Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), + )), + Self::Tiktoken(MarkdownSplitter::new(cl100k_base().unwrap())), + ] + } + + fn chunks<'text>(&self, text: &'text str, chunk_size: usize) -> Vec<&'text str> { + match self { + Self::Characters(splitter) => splitter.chunks(text, chunk_size).collect(), + Self::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(), + Self::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(), + } + } +} + +fn text_benchmark(c: &mut Criterion) { for filename in ["romeo_and_juliet", "room_with_a_view"] { let mut group = c.benchmark_group(filename); let text = fs::read_to_string(format!("tests/inputs/text/{filename}.txt")).unwrap(); - for splitter in Splitter::iter() { - for chunk_size in (5..17).map(|n| 2usize.pow(n)) { + for splitter in TextSplitterImpl::iter() { + for chunk_size in (2..9).map(|n| 4usize.pow(n)) { + group.bench_with_input( + BenchmarkId::new(splitter.name(), chunk_size), + &chunk_size, + |b, &chunk_size| b.iter(|| splitter.chunks(&text, chunk_size)), + ); + } + } + group.finish(); + } +} + +fn markdown_benchmark(c: &mut Criterion) { + for filename in ["commonmark_spec"] { + let mut group = c.benchmark_group(filename); + let text = fs::read_to_string(format!("tests/inputs/markdown/{filename}.md")).unwrap(); + + for splitter in MarkdownSplitterImpl::iter() { + for chunk_size in (2..9).map(|n| 4usize.pow(n)) { group.bench_with_input( BenchmarkId::new(splitter.name(), chunk_size), &chunk_size, @@ -60,5 +113,5 @@ fn criterion_benchmark(c: &mut Criterion) { } } -criterion_group!(benches, criterion_benchmark); +criterion_group!(benches, text_benchmark, markdown_benchmark); criterion_main!(benches);