diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 2a7f866a..78b407c9 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -4,7 +4,7 @@ use std::io::BufRead; use anyhow::{anyhow, Result}; use byteorder::{LittleEndian, ReadBytesExt}; -use crate::model::Model; +use crate::model::{DictWeight, Model}; struct KyteaConfig { _model_tag: String, @@ -430,13 +430,13 @@ impl TryFrom for Model { if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1; - let mut weights = [0i32; 3]; + let mut weights = DictWeight::default(); for j in 0..kytea_dict.n_dicts as usize { if data.in_dict >> j & 1 == 1 { let offset = 3 * config.dict_n as usize * j + 3 * word_len; - weights[0] += feature_lookup.dict_vec[offset] as i32; - weights[1] += feature_lookup.dict_vec[offset + 1] as i32; - weights[2] += feature_lookup.dict_vec[offset + 2] as i32; + weights.right += feature_lookup.dict_vec[offset] as i32; + weights.inner += feature_lookup.dict_vec[offset + 1] as i32; + weights.left += feature_lookup.dict_vec[offset + 2] as i32; } } dict_weights.push(weights); diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 357a85af..b2465346 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -23,6 +23,13 @@ pub type ScoreValue = f64; #[cfg(feature = "model-quantize")] pub type ScoreValue = i32; +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +pub struct DictWeight { + pub right: ScoreValue, + pub inner: ScoreValue, + pub left: ScoreValue, +} + /// Model data. #[derive(Serialize, Deserialize)] pub struct Model { @@ -32,7 +39,7 @@ pub struct Model { pub(crate) word_weights: Vec>, pub(crate) type_weights: Vec>, - pub(crate) dict_weights: Vec<[ScoreValue; 3]>, + pub(crate) dict_weights: Vec, #[cfg(feature = "model-quantize")] pub(crate) quantize_multiplier: f64, @@ -102,9 +109,7 @@ impl Model { let mut types = vec![]; let mut word_weights = vec![]; let mut type_weights = vec![]; - let mut dict_weights: Vec<[_; 3]> = (0..dict_word_max_size) - .map(|_| [ScoreValue::default(); 3]) - .collect(); + let mut dict_weights = vec![DictWeight::default(); dict_word_max_size]; let mut word_ids = StringIdManager::new(); let mut type_ids = StringIdManager::new(); @@ -155,9 +160,12 @@ impl Model { } type_weights[id][feature.rel_position] = weight as WeightValue; } - FeatureContent::DictionaryWord(size) => { - dict_weights[size - 1][feature.rel_position] = weight as ScoreValue; - } + FeatureContent::DictionaryWord(size) => match feature.rel_position { + 0 => dict_weights[size - 1].right = weight as ScoreValue, + 1 => dict_weights[size - 1].inner = weight as ScoreValue, + 2 => dict_weights[size - 1].left = weight as ScoreValue, + _ => panic!("Invalid rel_position"), + }, }; } Self { diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index b8ea4f1a..a2a6f245 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -11,7 +11,7 @@ use std::thread; #[cfg(feature = "multithreading")] use crossbeam_channel::{Receiver, Sender}; -use crate::model::{Model, ScoreValue, WeightValue}; +use crate::model::{DictWeight, Model, ScoreValue}; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -22,7 +22,7 @@ pub struct Predictor { word_pma: DoubleArrayAhoCorasick, dict_pma: DoubleArrayAhoCorasick, word_weights: Vec>, - dict_weights: Vec<[ScoreValue; 3]>, + dict_weights: Vec, dict_word_wise: bool, bias: ScoreValue, char_window_size: usize, @@ -46,21 +46,40 @@ impl Predictor { /// A new predictor. pub fn new(model: Model) -> Self { let bias = model.bias; - let word_weights = Self::merge_weights(&model.words, &model.word_weights); - let type_weights = Self::merge_weights(&model.types, &model.type_weights); + + let words = model.words; + let dict = model.dict; let dict_weights = model.dict_weights; + let mut word_weights: Vec<_> = model + .word_weights + .into_iter() + .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) + .collect(); + let type_weights: Vec<_> = model + .type_weights + .into_iter() + .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) + .collect(); + + let (dict, dict_weights) = Self::merge_dict_weights( + dict, + dict_weights, + &words, + &mut word_weights, + model.char_window_size, + model.dict_word_wise, + ); + + let word_weights = Self::merge_weights(&words, &word_weights); + let type_weights = Self::merge_weights(&model.types, &type_weights); + #[cfg(feature = "model-quantize")] let bias = bias as i32; - #[cfg(feature = "model-quantize")] - let dict_weights = dict_weights - .iter() - .map(|ws| [ws[0] as i32, ws[1] as i32, ws[2] as i32]) - .collect(); - let word_pma = DoubleArrayAhoCorasick::new(model.words).unwrap(); + let word_pma = DoubleArrayAhoCorasick::new(words).unwrap(); let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap(); - let dict_pma = DoubleArrayAhoCorasick::new(model.dict).unwrap(); + let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); @@ -81,7 +100,63 @@ impl Predictor { } } - fn merge_weights(words: &[Vec], weights: &[Vec]) -> Vec> { + fn merge_dict_weights( + dict: Vec>, + dict_weights: Vec, + words: &[Vec], + word_weights: &mut Vec>, + char_window_size: usize, + dict_word_wise: bool, + ) -> (Vec>, Vec) { + let mut word_map = HashMap::new(); + for (i, word) in words.iter().cloned().enumerate() { + word_map.insert(word, i); + } + let mut new_dict = vec![]; + if dict_word_wise { + let mut new_dict_weights = vec![]; + for (word, weight) in dict.into_iter().zip(dict_weights) { + let word_size = std::str::from_utf8(&word).unwrap().chars().count(); + match word_map.get(&word) { + Some(&idx) if char_window_size >= word_size => { + let start = char_window_size - word_size; + let end = start + word_size; + word_weights[idx][start] += weight.right; + for i in start + 1..end { + word_weights[idx][i] += weight.inner; + } + word_weights[idx][end] += weight.left; + } + _ => { + new_dict.push(word); + new_dict_weights.push(weight); + } + } + } + (new_dict, new_dict_weights) + } else { + for word in dict { + let word_size = std::str::from_utf8(&word).unwrap().chars().count(); + match word_map.get(&word) { + Some(&idx) if char_window_size >= word_size => { + let start = char_window_size - word_size; + let end = start + word_size; + let word_size_idx = std::cmp::min(word_size, dict_weights.len()) - 1; + let weight = &dict_weights[word_size_idx]; + word_weights[idx][start] += weight.right; + for i in start + 1..end { + word_weights[idx][i] += weight.inner; + } + word_weights[idx][end] += weight.left; + } + _ => new_dict.push(word), + } + } + (new_dict, dict_weights) + } + } + + fn merge_weights(words: &[Vec], weights: &[Vec]) -> Vec> { let mut result = vec![]; let word_ids = words .iter() @@ -95,11 +170,10 @@ impl Predictor { if let Some(&idx) = word_ids.get(&seq[st..]) { if let Some(new_weights) = new_weights.as_mut() { for (w_new, w) in new_weights.iter_mut().zip(&weights[idx]) { - *w_new += *w as ScoreValue; + *w_new += *w; } } else { - new_weights - .replace(weights[idx].iter().map(|&w| w as ScoreValue).collect()); + new_weights.replace(weights[idx].clone()); } } } @@ -160,19 +234,19 @@ impl Predictor { } else { std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1 }; - let [w_right, w_center, w_left] = self.dict_weights[idx]; + let dict_weight = self.dict_weights[idx]; if m_start >= padding && m_start < padding + ys.len() { - ys[m_start - padding] += w_right; + ys[m_start - padding] += dict_weight.right; } let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); if range_start < range_end { for y in &mut ys[range_start as usize..range_end as usize] { - *y += w_center; + *y += dict_weight.inner; } } if m_end >= padding && m_end < ys.len() + padding { - ys[m_end - padding] += w_left; + ys[m_end - padding] += dict_weight.left; } } } @@ -553,9 +627,31 @@ mod tests { vec![37, 38, 39], ], #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![[20.0, 20.5, 21.0], [21.5, 22.0, 22.5]], + dict_weights: vec![ + DictWeight { + right: 20.0, + inner: 20.5, + left: 21.0, + }, + DictWeight { + right: 21.5, + inner: 22.0, + left: 22.5, + }, + ], #[cfg(feature = "model-quantize")] - dict_weights: vec![[40, 41, 42], [43, 44, 45]], + dict_weights: vec![ + DictWeight { + right: 40, + inner: 41, + left: 42, + }, + DictWeight { + right: 43, + inner: 44, + left: 45, + }, + ], #[cfg(feature = "model-quantize")] quantize_multiplier: 0.5, dict_word_wise: false, @@ -640,9 +736,41 @@ mod tests { vec![33, 34, 35, 36, 37], ], #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![[9.5, 9.75, 10.0], [10.25, 10.5, 10.75], [11.0, 11.25, 11.5]], + dict_weights: vec![ + DictWeight { + right: 9.5, + inner: 9.75, + left: 10.0, + }, + DictWeight { + right: 10.25, + inner: 10.5, + left: 10.75, + }, + DictWeight { + right: 11.0, + inner: 11.25, + left: 11.5, + }, + ], #[cfg(feature = "model-quantize")] - dict_weights: vec![[38, 39, 40], [41, 42, 43], [44, 45, 46]], + dict_weights: vec![ + DictWeight { + right: 38, + inner: 39, + left: 40, + }, + DictWeight { + right: 41, + inner: 42, + left: 43, + }, + DictWeight { + right: 44, + inner: 45, + left: 46, + }, + ], #[cfg(feature = "model-quantize")] quantize_multiplier: 0.25, dict_word_wise: false, @@ -727,9 +855,41 @@ mod tests { vec![33, 34, 35, 36, 37], ], #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![[9.5, 9.75, 11.0], [10.25, 10.5, 10.75], [11.0, 11.25, 11.5]], + dict_weights: vec![ + DictWeight { + right: 9.5, + inner: 9.75, + left: 11.0, + }, + DictWeight { + right: 10.25, + inner: 10.5, + left: 10.75, + }, + DictWeight { + right: 11.0, + inner: 11.25, + left: 11.5, + }, + ], #[cfg(feature = "model-quantize")] - dict_weights: vec![[38, 39, 40], [41, 42, 43], [44, 45, 46]], + dict_weights: vec![ + DictWeight { + right: 38, + inner: 39, + left: 40, + }, + DictWeight { + right: 41, + inner: 42, + left: 43, + }, + DictWeight { + right: 44, + inner: 45, + left: 46, + }, + ], #[cfg(feature = "model-quantize")] quantize_multiplier: 0.25, dict_word_wise: true, diff --git a/vaporetto/src/utils.rs b/vaporetto/src/utils.rs index b3392ccb..47b51b80 100644 --- a/vaporetto/src/utils.rs +++ b/vaporetto/src/utils.rs @@ -56,6 +56,7 @@ impl StringIdManager { } #[cfg(test)] +#[allow(unused_macros)] macro_rules! ct2u8 { ( $( $v:path ),* ) => { ct2u8!( $( $v, )* )