Skip to content

Commit

Permalink
Fixing malformed rust tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jun 27, 2024
1 parent b53b21c commit aa87939
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
10 changes: 4 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub enum Config {
Baichuan,
Paligemma(Paligemma),
Gemma,
Gemma2,
Cohere,
Drbx,
Falcon,
Expand Down
3 changes: 3 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>,
pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
}

impl HubTokenizerConfig {
Expand Down
35 changes: 32 additions & 3 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
Expand Down Expand Up @@ -268,8 +268,6 @@ async fn main() -> Result<(), RouterError> {
)
}
};
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
Expand Down Expand Up @@ -299,6 +297,37 @@ async fn main() -> Result<(), RouterError> {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer{
if let Some(class) = &tokenizer_config.tokenizer_class{
if class == "LlamaTokenizer"{
tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
let mut single = vec![];
let mut special_tokens = vec![];
if let Some(true) = &tokenizer_config.add_bos_token{
if let Some(bos_token) = &tokenizer_config.bos_token{
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
special_tokens.push((bos_token.clone(), bos_token_id));
single.push(bos_token.to_string());
}
}
single.push("$0".to_string());
if let Some(true) = &tokenizer_config.add_eos_token{
if let Some(eos_token) = &tokenizer_config.eos_token{
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
special_tokens.push((eos_token.clone(), eos_token_id));
single.push(eos_token.to_string());
}
}
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
tokenizer.with_post_processor(post_processor);
}}
}
tokenizer

});

let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
Expand Down

0 comments on commit aa87939

Please sign in to comment.