Skip to content

Commit

Permalink
remove space-prepending from tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 3, 2024
1 parent ec4c172 commit e7af50a
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions aicirt/src/bintokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use aici_abi::bytes::TokRxInfo;
use anyhow::{anyhow, bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokenizers::{normalizers::Sequence, FromPretrainedParameters, NormalizerWrapper, Tokenizer};

#[derive(Serialize, Deserialize)]
pub struct ByteTokenizer {
Expand Down Expand Up @@ -190,11 +190,28 @@ fn from_hex(hex_str: &str) -> Result<Vec<u8>> {
}

impl ByteTokenizer {
pub fn from_tokenizer(hft: Tokenizer) -> Result<ByteTokenizer> {
pub fn from_tokenizer(mut hft: Tokenizer) -> Result<ByteTokenizer> {
let mut is_byte_level = false;
let mut is_byte_fallback = false;
let mut space_ch = ' ';

// remove the "Prepend space"
if let Some(n) = hft.get_normalizer() {
let n = match n {
NormalizerWrapper::Sequence(x) => NormalizerWrapper::Sequence(Sequence::new(
x.get_normalizers()
.iter()
.filter_map(|n| match n {
NormalizerWrapper::Prepend(_) => None,
_ => Some(n.clone()),
})
.collect(),
)),
_ => n.clone(),
};
hft.with_normalizer(n);
}

if let Some(d) = hft.get_decoder() {
let v = serde_json::to_value(d).unwrap();
if v["type"].as_str() == Some("ByteLevel") {
Expand Down

0 comments on commit e7af50a

Please sign in to comment.