Skip to content

Commit

Permalink
Add markdown benchmark code (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt authored Mar 2, 2024
1 parent 9e427fa commit e1bf7ac
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 13 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

benchmarks:
name: Test that benchmarks run
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2

- name: Run cargo test
run: cargo test --benches --all-features

lints:
name: Lints
runs-on: ubuntu-latest
Expand Down
79 changes: 66 additions & 13 deletions benches/chunk_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
use std::fs;

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use text_splitter::{Characters, TextSplitter};
use text_splitter::{unstable_markdown::MarkdownSplitter, Characters, TextSplitter};
use tiktoken_rs::{cl100k_base, CoreBPE};
use tokenizers::Tokenizer;

#[allow(clippy::large_enum_variant)]
enum Splitter {
enum TextSplitterImpl {
Characters(TextSplitter<Characters>),
Huggingface(TextSplitter<Tokenizer>),
Tiktoken(TextSplitter<CoreBPE>),
}

impl Splitter {
impl TextSplitterImpl {
fn name(&self) -> &str {
match self {
Splitter::Characters(_) => "Characters",
Splitter::Huggingface(_) => "Huggingface",
Splitter::Tiktoken(_) => "Tiktoken",
TextSplitterImpl::Characters(_) => "Characters",
TextSplitterImpl::Huggingface(_) => "Huggingface",
TextSplitterImpl::Tiktoken(_) => "Tiktoken",
}
}

Expand All @@ -35,20 +35,73 @@ impl Splitter {

fn chunks<'text>(&self, text: &'text str, chunk_size: usize) -> Vec<&'text str> {
match self {
Splitter::Characters(splitter) => splitter.chunks(text, chunk_size).collect(),
Splitter::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(),
Splitter::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(),
Self::Characters(splitter) => splitter.chunks(text, chunk_size).collect(),
Self::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(),
Self::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(),
}
}
}

fn criterion_benchmark(c: &mut Criterion) {
#[allow(clippy::large_enum_variant)]
enum MarkdownSplitterImpl {
Characters(MarkdownSplitter<Characters>),
Huggingface(MarkdownSplitter<Tokenizer>),
Tiktoken(MarkdownSplitter<CoreBPE>),
}

impl MarkdownSplitterImpl {
fn name(&self) -> &str {
match self {
MarkdownSplitterImpl::Characters(_) => "Characters",
MarkdownSplitterImpl::Huggingface(_) => "Huggingface",
MarkdownSplitterImpl::Tiktoken(_) => "Tiktoken",
}
}

fn iter() -> [Self; 3] {
[
Self::Characters(MarkdownSplitter::default()),
Self::Huggingface(MarkdownSplitter::new(
Tokenizer::from_pretrained("bert-base-cased", None).unwrap(),
)),
Self::Tiktoken(MarkdownSplitter::new(cl100k_base().unwrap())),
]
}

fn chunks<'text>(&self, text: &'text str, chunk_size: usize) -> Vec<&'text str> {
match self {
Self::Characters(splitter) => splitter.chunks(text, chunk_size).collect(),
Self::Huggingface(splitter) => splitter.chunks(text, chunk_size).collect(),
Self::Tiktoken(splitter) => splitter.chunks(text, chunk_size).collect(),
}
}
}

fn text_benchmark(c: &mut Criterion) {
for filename in ["romeo_and_juliet", "room_with_a_view"] {
let mut group = c.benchmark_group(filename);
let text = fs::read_to_string(format!("tests/inputs/text/{filename}.txt")).unwrap();

for splitter in Splitter::iter() {
for chunk_size in (5..17).map(|n| 2usize.pow(n)) {
for splitter in TextSplitterImpl::iter() {
for chunk_size in (2..9).map(|n| 4usize.pow(n)) {
group.bench_with_input(
BenchmarkId::new(splitter.name(), chunk_size),
&chunk_size,
|b, &chunk_size| b.iter(|| splitter.chunks(&text, chunk_size)),
);
}
}
group.finish();
}
}

fn markdown_benchmark(c: &mut Criterion) {
for filename in ["commonmark_spec"] {
let mut group = c.benchmark_group(filename);
let text = fs::read_to_string(format!("tests/inputs/markdown/{filename}.md")).unwrap();

for splitter in MarkdownSplitterImpl::iter() {
for chunk_size in (2..9).map(|n| 4usize.pow(n)) {
group.bench_with_input(
BenchmarkId::new(splitter.name(), chunk_size),
&chunk_size,
Expand All @@ -60,5 +113,5 @@ fn criterion_benchmark(c: &mut Criterion) {
}
}

criterion_group!(benches, criterion_benchmark);
criterion_group!(benches, text_benchmark, markdown_benchmark);
criterion_main!(benches);

0 comments on commit e1bf7ac

Please sign in to comment.