Skip to content

Commit

Permalink
working on dynamic loading of tokenizers from hf
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 3, 2024
1 parent abd701c commit d374332
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 10 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"logprobs",
"lstrip",
"maxtol",
"mixtral",
"mmap",
"nheads",
"NOSYS",
Expand Down
204 changes: 196 additions & 8 deletions aicirt/src/bintokens.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::HashMap;
use aici_abi::bytes::TokRxInfo;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::BTreeMap;
use std::{collections::BTreeMap, fmt::Debug};
use tokenizers::Tokenizer;

#[derive(Serialize, Deserialize)]
pub struct TokenInfo {
Expand All @@ -14,7 +16,7 @@ pub struct TokenInfo {
pub text: BTreeMap<String, u32>,
}

pub struct Tokenizer {
pub struct BinTokenizer {
pub name: String,
pub description: String,
info: Option<TokenInfo>,
Expand All @@ -27,7 +29,7 @@ pub struct Tokenizer {

macro_rules! tok {
($name:literal, $desc:literal, $models:literal) => {
Tokenizer {
BinTokenizer {
name: $name.into(),
description: $desc.into(),
base_tokenizer: None,
Expand All @@ -39,7 +41,7 @@ macro_rules! tok {
}
};
($username:literal, $name:literal, $desc:literal, $models:literal, $add:expr) => {
Tokenizer {
BinTokenizer {
name: $username.into(),
description: $desc.into(),
base_tokenizer: Some($name),
Expand All @@ -52,7 +54,7 @@ macro_rules! tok {
};
}

pub fn tokenizers() -> Vec<Tokenizer> {
pub fn tokenizers() -> Vec<BinTokenizer> {
vec![
tok!("gpt4", "cl100k_base, used by GPT-4 and GPT-3.5", "gpt-4"),
tok!("llama", "used by Llama, CodeLlama, etc.", ""),
Expand Down Expand Up @@ -99,6 +101,192 @@ pub fn tokenizers() -> Vec<Tokenizer> {
]
}

fn is_self_mapped(c: char) -> bool {
match c {
'!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}' => true,
_ => false,
}
}

fn build_char_map() -> HashMap<char, u8> {
let mut res = HashMap::default();
let mut k = 0x100u32;
for byte in 0..=255u8 {
let c = byte as char;
if is_self_mapped(c) {
res.insert(c, byte);
} else {
res.insert(char::from_u32(k).unwrap(), byte);
k += 1;
}
}
res
}

fn build_bintok(hft: Tokenizer) -> Result<TokenInfo> {
let mut is_byte_level = false;
let mut is_byte_fallback = false;
let mut space_ch = ' ';

if let Some(d) = hft.get_decoder() {
let v = serde_json::to_value(d).unwrap();
if v["type"].as_str() == Some("ByteLevel") {
is_byte_level = true;
} else if v["type"].as_str() == Some("Sequence") {
if let Some(decoders) = v["decoders"].as_array() {
for decoder in decoders {
if decoder["type"].as_str() == Some("ByteFallback") {
is_byte_fallback = true;
} else if decoder["type"].as_str() == Some("Replace")
&& decoder["content"].as_str() == Some(" ")
{
if let Some(s) = decoder["pattern"]["String"].as_str() {
let s: Vec<char> = s.chars().collect();
if s.len() == 1 {
space_ch = s[0];
}
}
}
}
}
}
}

if !is_byte_fallback && !is_byte_level {
bail!("can't determine decoder type: {:?}", hft.get_decoder());
}

let vocab_size = hft.get_vocab_size(true) as u32;
let mut res = TokenInfo {
hf_model: "foobar".to_string(),
eos_token: 0,
vocab_size: Some(vocab_size),
special: BTreeMap::new(),
binary: BTreeMap::new(),
text: BTreeMap::new(),
};

for (id, info) in hft.get_added_tokens_decoder().iter() {
if info.special {
match info.content.as_str() {
"</s>" | "<|endoftext|>" => res.eos_token = *id,
_ => {}
}
res.special.insert(info.content.clone(), *id);
} else {
res.text.insert(info.content.clone(), *id);
}
}

let added = hft.get_added_tokens_decoder();
let char_map = build_char_map();

for tok_id in 0..vocab_size {
if added.contains_key(&tok_id) {
continue;
}
if let Some(tok_name) = hft.id_to_token(tok_id) {
if is_byte_fallback {
if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") {
// parse hex number from tok_name
let hex_str = &tok_name[3..5];
let byte = u8::from_str_radix(hex_str, 16).unwrap();
if byte >= 0x80 {
let s = format!("{:02x}", byte);
res.binary.insert(s, tok_id);
} else {
let s = format!("{}", byte as char);
res.text.insert(s, tok_id);
}
} else {
assert!(!tok_name.starts_with("<0x"));
let tok_name = tok_name.replace(space_ch, " ");
res.text.insert(tok_name, tok_id);
}
} else if is_byte_level {
let bytes: Result<Vec<u8>> = tok_name
.chars()
.map(|c| {
char_map
.get(&c)
.map(|c| *c)
.ok_or_else(|| anyhow!("missing char: {}", c))
})
.collect();
let bytes = match bytes {
Ok(b) => b,
Err(e) => {
println!("error: {} for {:?}", e, tok_name);
continue;
}
};

if let Ok(s) = String::from_utf8(bytes.clone()) {
res.text.insert(s, tok_id);
} else {
let hexstr = String::from_iter(bytes.iter().map(|b| format!("{:02x}", b)));
res.binary.insert(hexstr, tok_id);
}
} else {
panic!();
}
} else {
println!("missing token: {}", tok_id);
}
}

Ok(res)
}

fn cmp_maps<T: Ord + Eq + Debug, U: Eq + Debug>(a: &BTreeMap<T, U>, b: &BTreeMap<T, U>) {
for (k, v) in a {
if b.get(k) != Some(v) {
println!("{:?}: {:?} != {:?}", k, v, b.get(k));
}
}
for (k, v) in b {
if a.get(k) == None {
println!("{:?}: None != {:?}", k, v);
}
}
}

pub fn convert_tokenizers() {
for bintok in tokenizers() {
if bintok.base_tokenizer.is_some() {
continue;
}

if bintok.name != "llama" && bintok.name != "phi" {
//continue;
}

//

let mut info = serde_json::from_slice::<TokenInfo>(bintok.get_info_bytes()).unwrap();
let max = vec![
info.binary.values().max(),
info.special.values().max(),
info.text.values().max(),
]
.iter()
.filter_map(|x| *x)
.max()
.unwrap();
info.vocab_size = Some(max + 1);

println!("{}: {}", bintok.name, max + 1);
let hft = Tokenizer::from_bytes(bintok.hf_bytes.unwrap()).unwrap();
let info2 = build_bintok(hft).unwrap();

assert!(info.eos_token == info2.eos_token);
assert!(info.vocab_size == info2.vocab_size);
cmp_maps(&info.special, &info2.special);
cmp_maps(&info.binary, &info2.binary);
cmp_maps(&info.text, &info2.text);
}
}

pub fn list_tokenizers() -> String {
format!(
"Available tokenizers for -t or --tokenizer:\n{}",
Expand All @@ -125,7 +313,7 @@ pub fn guess_tokenizer(model_name: &str) -> Option<String> {
.map(|t| t.name.clone())
}

pub fn find_tokenizer(name: &str) -> Result<Tokenizer> {
pub fn find_tokenizer(name: &str) -> Result<BinTokenizer> {
for mut t in tokenizers() {
if t.name == name {
t.load();
Expand All @@ -146,7 +334,7 @@ fn from_hex(hex_str: &str) -> Result<Vec<u8>> {
Ok(bytes)
}

impl Tokenizer {
impl BinTokenizer {
fn load(&mut self) {
if self.info.is_none() {
let mut info = serde_json::from_slice::<TokenInfo>(self.get_info_bytes()).unwrap();
Expand Down
8 changes: 8 additions & 0 deletions aicirt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ struct Cli {
#[arg(short, long, default_value = "llama")]
tokenizer: String,

#[arg(long, default_value = "false")]
convert_tokenizers: bool,

/// Save the --tokenizer=... to specified file
#[arg(long)]
save_tokenizer: Option<String>,
Expand Down Expand Up @@ -1014,6 +1017,11 @@ fn main() -> () {

let cli = Cli::parse();

if cli.convert_tokenizers {
bintokens::convert_tokenizers();
return;
}

if !cli.name.starts_with("/") {
eprintln!("--name must start with /");
std::process::exit(1);
Expand Down
4 changes: 2 additions & 2 deletions aicirt/src/moduleinstance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use aici_abi::{
use aicirt::{
api::{AiciMidProcessResultInner, AiciPostProcessResultInner, SequenceResult},
bail_user,
bintokens::Tokenizer,
bintokens::BinTokenizer,
user_error,
};
use anyhow::{anyhow, bail, ensure, Result};
Expand All @@ -36,7 +36,7 @@ impl WasmContext {
unsafe { wasmtime::Module::deserialize_file(&self.engine, path) }
}

pub fn new(limits: AiciLimits, tokenizer: Tokenizer) -> Result<Self> {
pub fn new(limits: AiciLimits, tokenizer: BinTokenizer) -> Result<Self> {
let mut cfg = wasmtime::Config::default();
// these are defaults as of 13.0.0, but we specify them anyways for stability
cfg.debug_info(false)
Expand Down

0 comments on commit d374332

Please sign in to comment.