Skip to content

Commit

Permalink
Merge branch 'main' into paplorinc/cl100k-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
l0rinc authored Apr 6, 2024
2 parents 91be802 + 1b9faf2 commit 92a320c
Showing 1 changed file with 74 additions and 86 deletions.
160 changes: 74 additions & 86 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,110 +8,85 @@ use std::thread;
use fancy_regex::Regex;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::pyclass;
use pyo3::PyResult;
use pyo3::types::{PyBytes, PyList, PyTuple};
use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

fn _byte_pair_merge<T>(
piece: &[u8],
ranks: &HashMap<Vec<u8>, Rank>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);

// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.copied()
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
None
Rank::MAX
}
}
};

// We look up the ranks once in the beginning and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank(&parts, i, 0) {
Some(rank) => {
// Rank::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != Rank::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);

// Rank::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}

if min_rank.0 != Rank::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
}

parts.remove(i + 1);
} else {
break;
}
}
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
for i in 0..parts.len() - 1 {
out.push(f(parts[i].0..parts[i + 1].0));
}
out
parts
}

pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect()
}

pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
assert!(piece.len() > 1);
_byte_pair_merge(&ranks, &piece)
.windows(2)
.map(|part| &piece[part[0].0..part[1].0])
.collect()
}

// Various performance notes:
Expand Down Expand Up @@ -162,10 +137,10 @@ fn hash_current_thread() -> usize {
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}
Expand Down Expand Up @@ -214,11 +189,10 @@ impl CoreBPE {
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
ret.extend(&byte_pair_encode(piece, &self.encoder));
}
ret
}
Expand Down Expand Up @@ -516,7 +490,10 @@ impl CoreBPE {
unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]);

tokens.truncate(tokens.len() - last_piece_token_len);
tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder));
match self.encoder.get(&unstable_bytes) {
Some(token) => tokens.push(*token),
None => tokens.extend(&byte_pair_encode(&unstable_bytes, &self.encoder)),
}
}
tokens
}
Expand Down Expand Up @@ -597,15 +574,26 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
mod tests {
use rustc_hash::FxHashMap as HashMap;

use crate::byte_pair_split;
use crate::{byte_pair_split, Rank};

#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);
fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
HashMap::from_iter([
(b"ab".to_vec(), 0),
(b"cd".to_vec(), 1),
])
}

#[test]
fn test_simple_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}

#[test]
fn test_repeated_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}
}

0 comments on commit 92a320c

Please sign in to comment.