diff --git a/aicirt/src/bintokens.rs b/aicirt/src/bintokens.rs index 53661409..677e00f0 100644 --- a/aicirt/src/bintokens.rs +++ b/aicirt/src/bintokens.rs @@ -3,7 +3,7 @@ use aici_abi::bytes::TokRxInfo; use anyhow::{anyhow, bail, Result}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -use tokenizers::{FromPretrainedParameters, Tokenizer}; +use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer}; #[derive(Serialize, Deserialize)] pub struct ByteTokenizer { @@ -190,11 +190,28 @@ fn from_hex(hex_str: &str) -> Result> { } impl ByteTokenizer { - pub fn from_tokenizer(hft: Tokenizer) -> Result { + pub fn from_tokenizer(mut hft: Tokenizer) -> Result { let mut is_byte_level = false; let mut is_byte_fallback = false; let mut space_ch = ' '; + // remove the "Prepend space" + if let Some(n) = hft.get_normalizer() { + let n = match n { + NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new( + x.get_normalizers() + .iter() + .filter_map(|n| match n { + NormalizerWrapper::Prepend(_) => None, + _ => Some(n.clone()), + }) + .collect(), + )), + _ => n.clone(), + }; + hft.with_normalizer(n); + } + if let Some(d) = hft.get_decoder() { let v = serde_json::to_value(d).unwrap(); if v["type"].as_str() == Some("ByteLevel") {