From 3509c8a4b2403b94437d27bdec7958bf05cb5129 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 19 Nov 2021 19:36:58 +0900 Subject: [PATCH 01/60] Use Option for dict_pma --- vaporetto/src/predictor.rs | 72 +++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index a2a6f245..c50cf54e 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -20,7 +20,7 @@ use daachorse::DoubleArrayAhoCorasick; /// Predictor. pub struct Predictor { word_pma: DoubleArrayAhoCorasick, - dict_pma: DoubleArrayAhoCorasick, + dict_pma: Option, word_weights: Vec>, dict_weights: Vec, dict_word_wise: bool, @@ -79,7 +79,11 @@ impl Predictor { let word_pma = DoubleArrayAhoCorasick::new(words).unwrap(); let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap(); - let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); + let dict_pma = if dict.is_empty() { + None + } else { + Some(DoubleArrayAhoCorasick::new(dict).unwrap()) + }; let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); @@ -213,40 +217,42 @@ impl Predictor { } fn add_dict_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.dict_window_size { - start + 1 - self.dict_window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.dict_window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.dict_pma.find_overlapping_iter(&text) { - let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let idx = if self.dict_word_wise { - m.pattern() + if let Some(dict_pma) = self.dict_pma.as_ref() { + let char_start = if start >= self.dict_window_size { + start + 1 - self.dict_window_size } else { - std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1 + 0 }; - let dict_weight = self.dict_weights[idx]; - if m_start >= padding && m_start < padding + ys.len() { - ys[m_start - padding] += dict_weight.right; - } - let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); - let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); - if range_start < range_end { - for y in &mut ys[range_start as usize..range_end as usize] { - *y += dict_weight.inner; + let text_start = sentence.char_to_str_pos[char_start]; + let char_end = std::cmp::min( + start + ys.len() + self.dict_window_size, + sentence.char_to_str_pos.len() - 1, + ); + let text_end = sentence.char_to_str_pos[char_end]; + let text = &sentence.text[text_start..text_end]; + let padding = start - char_start + 1; + for m in dict_pma.find_overlapping_iter(&text) { + let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; + let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; + let idx = if self.dict_word_wise { + m.pattern() + } else { + std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1 + }; + let dict_weight = self.dict_weights[idx]; + if m_start >= padding && m_start < padding + ys.len() { + ys[m_start - padding] += dict_weight.right; + } + let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); + let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); + if range_start < range_end { + for y in &mut ys[range_start as usize..range_end as usize] { + *y += dict_weight.inner; + } + } + if m_end >= padding && m_end < ys.len() + padding { + ys[m_end - padding] += dict_weight.left; } - } - if m_end >= padding && m_end < ys.len() + padding { - ys[m_end - padding] += dict_weight.left; } } } From d7439b05412147f21452ececeaca0ef368d75058 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 25 Nov 2021 10:57:36 +0900 Subject: [PATCH 02/60] Remove multi-threading feature --- vaporetto/Cargo.toml | 2 - vaporetto/src/lib.rs | 3 - vaporetto/src/predictor.rs | 182 ------------------------------------- 3 files changed, 187 deletions(-) diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 18f1baf0..ca1531c6 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -19,14 +19,12 @@ daachorse = "0.2.0" # MIT or Apache-2.0 serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 byteorder = { version = "1.4", optional = true } # Unlicense or MIT -crossbeam-channel = { version = "0.5", optional = true } # MIT or Apache-2.0 liblinear = { version = "1", optional = true } # MIT [features] default = ["model-quantize"] kytea = ["byteorder"] model-quantize = [] -multithreading = ["crossbeam-channel"] train = ["liblinear"] [package.metadata.docs.rs] diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index d7107e66..662f7b8c 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -46,9 +46,6 @@ pub use model::Model; pub use predictor::Predictor; pub use sentence::{BoundaryType, CharacterType, Sentence}; -#[cfg(feature = "multithreading")] -pub use predictor::MultithreadPredictor; - #[cfg(feature = "train")] pub use trainer::{Dataset, SolverType, Trainer}; diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index a2a6f245..91aadd6a 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,16 +1,6 @@ use std::collections::HashMap; use std::ops::Range; -#[cfg(feature = "multithreading")] -use std::cell::RefCell; -#[cfg(feature = "multithreading")] -use std::sync::Arc; -#[cfg(feature = "multithreading")] -use std::thread; - -#[cfg(feature = "multithreading")] -use crossbeam_channel::{Receiver, Sender}; - use crate::model::{DictWeight, Model, ScoreValue}; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -377,178 +367,6 @@ impl Predictor { self.dict_window_size = std::cmp::max(size, 1); self } - - /// Creates a multithreading predictor. This function is the alias of - /// [`MultithreadPredictor::new()`]. - /// - /// # Arguments - /// - /// * `n_threads` - The number of threads. - /// * `chunk_size` - The chunk size of each thread. - /// - /// # Returns - /// - /// A multithread predictor. - #[cfg(feature = "multithreading")] - #[cfg_attr(docsrs, doc(cfg(feature = "multithreading")))] - pub fn multithreading(self, n_threads: usize, chunk_size: usize) -> MultithreadPredictor { - MultithreadPredictor::new(self, n_threads, chunk_size) - } -} - -/// Predictor for multithreading. -#[cfg(feature = "multithreading")] -#[cfg_attr(docsrs, doc(cfg(feature = "multithreading")))] -pub struct MultithreadPredictor { - task_tx: Sender<(Arc, Range, Vec)>, - result_rx: Receiver<(Vec, Range)>, - chunk_size: usize, - ys_pool: RefCell>>, - - #[cfg(feature = "model-quantize")] - quantize_multiplier: f64, -} - -#[cfg(feature = "multithreading")] -impl MultithreadPredictor { - /// Creates a multithreading predictor. - /// - /// # Arguments - /// - /// * `predictor` - A normal predictor. - /// * `n_threads` - The number of threads. - /// * `chunk_size` - The chunk size of each thread. - /// - /// # Returns - /// - /// A multithread predictor. - pub fn new(predictor: Predictor, n_threads: usize, chunk_size: usize) -> Self { - let predictor = Arc::new(predictor); - - let (result_tx, result_rx) = crossbeam_channel::unbounded(); - let (task_tx, task_rx) = - crossbeam_channel::unbounded::<(Arc, Range, Vec)>(); - for _ in 0..n_threads { - let predictor = Arc::clone(&predictor); - let result_tx = result_tx.clone(); - let task_rx = task_rx.clone(); - thread::spawn(move || { - for (sentence, range, mut ys) in task_rx { - predictor.predict_partial_impl( - &sentence, - range.clone(), - &mut ys[..range.len()], - ); - std::mem::drop(sentence); - result_tx.send((ys, range)).unwrap(); - } - }); - } - - Self { - task_tx, - result_rx, - chunk_size, - ys_pool: RefCell::new(vec![]), - - #[cfg(feature = "model-quantize")] - quantize_multiplier: predictor.quantize_multiplier, - } - } - - /// Predicts word boundaries. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict(&self, sentence: Sentence) -> Sentence { - let sentence = Arc::new(sentence); - - let mut n_chunks = 0; - let mut ys_pool = self.ys_pool.borrow_mut(); - for start in (0..sentence.boundaries.len()).step_by(self.chunk_size) { - let ys = ys_pool - .pop() - .unwrap_or_else(|| vec![ScoreValue::default(); self.chunk_size]); - let sentence = Arc::clone(&sentence); - let end = std::cmp::min(start + self.chunk_size, sentence.boundaries.len()); - self.task_tx.send((sentence, start..end, ys)).unwrap(); - n_chunks += 1; - } - let mut boundaries = vec![BoundaryType::Unknown; sentence.boundaries.len()]; - for _ in 0..n_chunks { - let (ys, range) = self.result_rx.recv().unwrap(); - for (&y, b) in ys.iter().zip(&mut boundaries[range]) { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - } - ys_pool.push(ys); - } - - let mut sentence = Arc::try_unwrap(sentence).unwrap(); - sentence.boundaries = boundaries; - sentence - } - - /// Predicts word boundaries. This function inserts scores. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_with_score(&self, mut sentence: Sentence) -> Sentence { - let mut scores = sentence - .boundary_scores - .take() - .unwrap_or_else(|| vec![0.; sentence.boundaries.len()]); - let sentence = Arc::new(sentence); - let mut n_chunks = 0; - let mut ys_pool = self.ys_pool.borrow_mut(); - for start in (0..sentence.boundaries.len()).step_by(self.chunk_size) { - let ys = ys_pool - .pop() - .unwrap_or_else(|| vec![ScoreValue::default(); self.chunk_size]); - let sentence = Arc::clone(&sentence); - let end = std::cmp::min(start + self.chunk_size, sentence.boundaries.len()); - self.task_tx.send((sentence, start..end, ys)).unwrap(); - n_chunks += 1; - } - let mut boundaries = vec![BoundaryType::Unknown; sentence.boundaries.len()]; - for _ in 0..n_chunks { - let (ys, range) = self.result_rx.recv().unwrap(); - for (&y, (b, s)) in ys - .iter() - .zip(boundaries[range.clone()].iter_mut().zip(&mut scores[range])) - { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; - } - ys_pool.push(ys); - } - - let mut sentence = Arc::try_unwrap(sentence).unwrap(); - sentence.boundaries = boundaries; - sentence.boundary_scores.replace(scores); - sentence - } } #[cfg(test)] From d4f06e71ce67ee96f7f42d5ce5c32077141786f6 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 25 Nov 2021 13:03:17 +0900 Subject: [PATCH 03/60] Add scorer modules for char_ngrams and dict --- vaporetto/src/char_scorer.rs | 73 +++++++++++++++++++++ vaporetto/src/dict_scorer.rs | 94 ++++++++++++++++++++++++++ vaporetto/src/lib.rs | 2 + vaporetto/src/predictor.rs | 123 +++++++---------------------------- 4 files changed, 194 insertions(+), 98 deletions(-) create mode 100644 vaporetto/src/char_scorer.rs create mode 100644 vaporetto/src/dict_scorer.rs diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs new file mode 100644 index 00000000..4256fa34 --- /dev/null +++ b/vaporetto/src/char_scorer.rs @@ -0,0 +1,73 @@ +use crate::model::ScoreValue; +use crate::sentence::Sentence; +use daachorse::DoubleArrayAhoCorasick; + +pub enum CharScorer { + Pma(CharScorerPma), +} + +impl CharScorer { + pub fn new( + pma: DoubleArrayAhoCorasick, + weights: Vec>, + window_size: usize, + ) -> Self { + Self::Pma(CharScorerPma::new(pma, weights, window_size)) + } + + pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + match self { + CharScorer::Pma(pma) => pma.add_scores(sentence, start, ys), + } + } +} + +pub struct CharScorerPma { + pma: DoubleArrayAhoCorasick, + weights: Vec>, + window_size: usize, +} + +impl CharScorerPma { + pub fn new( + pma: DoubleArrayAhoCorasick, + weights: Vec>, + window_size: usize, + ) -> Self { + Self { + pma, + weights, + window_size, + } + } + + pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + let char_start = if start >= self.window_size { + start + 1 - self.window_size + } else { + 0 + }; + let text_start = sentence.char_to_str_pos[char_start]; + let char_end = std::cmp::min( + start + ys.len() + self.window_size, + sentence.char_to_str_pos.len() - 1, + ); + let text_end = sentence.char_to_str_pos[char_end]; + let text = &sentence.text[text_start..text_end]; + let padding = start - char_start + 1; + for m in self.pma.find_overlapping_no_suffix_iter(&text) { + let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; + let offset = m_end as isize - self.window_size as isize - padding as isize; + let weights = &self.weights[m.pattern()]; + if offset >= 0 { + for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { + *y += w; + } + } else { + for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) { + *y += w; + } + } + } + } +} diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs new file mode 100644 index 00000000..40dae9cb --- /dev/null +++ b/vaporetto/src/dict_scorer.rs @@ -0,0 +1,94 @@ +use crate::model::{DictWeight, ScoreValue}; +use crate::sentence::Sentence; +use daachorse::DoubleArrayAhoCorasick; + +pub enum DictScorer { + Pma(DictScorerPma), +} + +impl DictScorer { + pub fn new( + pma: DoubleArrayAhoCorasick, + weights: Vec, + word_wise_score: bool, + ) -> Self { + Self::Pma(DictScorerPma::new(pma, weights, word_wise_score)) + } + + pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + match self { + DictScorer::Pma(pma) => pma.add_scores(sentence, start, ys), + } + } + + pub fn window_size(&mut self, size: usize) { + match self { + DictScorer::Pma(pma) => pma.window_size(size), + } + } +} + +pub struct DictScorerPma { + pma: DoubleArrayAhoCorasick, + weights: Vec, + window_size: usize, + word_wise_score: bool, +} + +impl DictScorerPma { + pub fn new( + pma: DoubleArrayAhoCorasick, + weights: Vec, + word_wise_score: bool, + ) -> Self { + Self { + pma, + weights, + window_size: 1, + word_wise_score, + } + } + + pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + let char_start = if start >= self.window_size { + start + 1 - self.window_size + } else { + 0 + }; + let text_start = sentence.char_to_str_pos[char_start]; + let char_end = std::cmp::min( + start + ys.len() + self.window_size, + sentence.char_to_str_pos.len() - 1, + ); + let text_end = sentence.char_to_str_pos[char_end]; + let text = &sentence.text[text_start..text_end]; + let padding = start - char_start + 1; + for m in self.pma.find_overlapping_iter(&text) { + let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; + let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; + let idx = if self.word_wise_score { + m.pattern() + } else { + std::cmp::min(m_end - m_start, self.weights.len()) - 1 + }; + let dict_weight = self.weights[idx]; + if m_start >= padding && m_start < padding + ys.len() { + ys[m_start - padding] += dict_weight.right; + } + let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); + let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); + if range_start < range_end { + for y in &mut ys[range_start as usize..range_end as usize] { + *y += dict_weight.inner; + } + } + if m_end >= padding && m_end < ys.len() + padding { + ys[m_end - padding] += dict_weight.left; + } + } + } + + pub fn window_size(&mut self, size: usize) { + self.window_size = std::cmp::max(size, 1); + } +} diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index d7107e66..7b545359 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -33,6 +33,8 @@ mod model; mod predictor; mod sentence; mod type_scorer; +mod char_scorer; +mod dict_scorer; #[cfg(feature = "train")] mod feature; diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index c50cf54e..0480c888 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -13,22 +13,19 @@ use crossbeam_channel::{Receiver, Sender}; use crate::model::{DictWeight, Model, ScoreValue}; use crate::sentence::{BoundaryType, Sentence}; +use crate::char_scorer::CharScorer; use crate::type_scorer::TypeScorer; +use crate::dict_scorer::DictScorer; use daachorse::DoubleArrayAhoCorasick; /// Predictor. pub struct Predictor { - word_pma: DoubleArrayAhoCorasick, - dict_pma: Option, - word_weights: Vec>, - dict_weights: Vec, - dict_word_wise: bool, bias: ScoreValue, - char_window_size: usize, - dict_window_size: usize, + char_scorer: CharScorer, type_scorer: TypeScorer, + dict_scorer: Option, #[cfg(feature = "model-quantize")] quantize_multiplier: f64, @@ -47,11 +44,11 @@ impl Predictor { pub fn new(model: Model) -> Self { let bias = model.bias; - let words = model.words; + let chars = model.words; let dict = model.dict; let dict_weights = model.dict_weights; - let mut word_weights: Vec<_> = model + let mut char_weights: Vec<_> = model .word_weights .into_iter() .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) @@ -65,39 +62,36 @@ impl Predictor { let (dict, dict_weights) = Self::merge_dict_weights( dict, dict_weights, - &words, - &mut word_weights, + &chars, + &mut char_weights, model.char_window_size, model.dict_word_wise, ); - let word_weights = Self::merge_weights(&words, &word_weights); + let char_weights = Self::merge_weights(&chars, &char_weights); let type_weights = Self::merge_weights(&model.types, &type_weights); #[cfg(feature = "model-quantize")] let bias = bias as i32; - let word_pma = DoubleArrayAhoCorasick::new(words).unwrap(); + let char_pma = DoubleArrayAhoCorasick::new(chars).unwrap(); let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap(); - let dict_pma = if dict.is_empty() { + + let char_scorer = CharScorer::new(char_pma, char_weights, model.char_window_size); + let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); + let dict_scorer = if dict.is_empty() { None } else { - Some(DoubleArrayAhoCorasick::new(dict).unwrap()) + let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); + Some(DictScorer::new(dict_pma, dict_weights, model.dict_word_wise)) }; - let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); - Self { - word_pma, - dict_pma, - word_weights, - dict_weights, - dict_word_wise: model.dict_word_wise, bias, - char_window_size: model.char_window_size, - dict_window_size: 1, + char_scorer, type_scorer, + dict_scorer, #[cfg(feature = "model-quantize")] quantize_multiplier: model.quantize_multiplier, @@ -186,77 +180,6 @@ impl Predictor { result } - fn add_word_ngram_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.char_window_size { - start + 1 - self.char_window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.char_window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.word_pma.find_overlapping_no_suffix_iter(&text) { - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let offset = m_end as isize - self.char_window_size as isize - padding as isize; - let weights = &self.word_weights[m.pattern()]; - if offset >= 0 { - for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { - *y += w; - } - } else { - for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) { - *y += w; - } - } - } - } - - fn add_dict_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - if let Some(dict_pma) = self.dict_pma.as_ref() { - let char_start = if start >= self.dict_window_size { - start + 1 - self.dict_window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.dict_window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in dict_pma.find_overlapping_iter(&text) { - let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let idx = if self.dict_word_wise { - m.pattern() - } else { - std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1 - }; - let dict_weight = self.dict_weights[idx]; - if m_start >= padding && m_start < padding + ys.len() { - ys[m_start - padding] += dict_weight.right; - } - let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); - let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); - if range_start < range_end { - for y in &mut ys[range_start as usize..range_end as usize] { - *y += dict_weight.inner; - } - } - if m_end >= padding && m_end < ys.len() + padding { - ys[m_end - padding] += dict_weight.left; - } - } - } - } - fn predict_partial_impl( &self, sentence: &Sentence, @@ -264,9 +187,11 @@ impl Predictor { ys: &mut [ScoreValue], ) { ys.fill(self.bias); - self.add_word_ngram_scores(sentence, range.start, ys); + self.char_scorer.add_scores(sentence, range.start, ys); self.type_scorer.add_scores(sentence, range.start, ys); - self.add_dict_scores(sentence, range.start, ys); + if let Some(dict_scorer) = self.dict_scorer.as_ref() { + dict_scorer.add_scores(sentence, range.start, ys); + } } /// Predicts word boundaries of the specified range of a sentence. @@ -380,7 +305,9 @@ impl Predictor { /// /// A predictor with the specified window size. pub fn dict_window_size(mut self, size: usize) -> Self { - self.dict_window_size = std::cmp::max(size, 1); + if let Some(dict_scorer) = self.dict_scorer.as_mut() { + dict_scorer.window_size(size); + } self } From 1c4370bb6078c83dd5b3bad5801b3a06b8b9f5c8 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 25 Nov 2021 13:05:07 +0900 Subject: [PATCH 04/60] Rename some variables --- vaporetto/src/kytea_model.rs | 44 +++++++++++++++++---------- vaporetto/src/lib.rs | 4 +-- vaporetto/src/model.rs | 58 +++++++++++++++++++----------------- vaporetto/src/predictor.rs | 36 ++++++++++++---------- 4 files changed, 80 insertions(+), 62 deletions(-) diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 78b407c9..60585291 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -409,20 +409,32 @@ impl TryFrom for Model { .type_dict .ok_or_else(|| anyhow!("no type dictionary."))?; - let mut words: Vec> = vec![]; - let mut word_weights = vec![]; - for (word, v) in char_dict.dump_items() { - let weight_size = config.char_w as usize * 2 - word.len() + 1; - words.push(word.into_iter().collect::().as_bytes().to_vec()); - word_weights.push(v[..weight_size].to_vec()); + let mut char_ngrams: Vec> = vec![]; + let mut char_ngram_weights = vec![]; + for (char_ngram, v) in char_dict.dump_items() { + let weight_size = config.char_w as usize * 2 - char_ngram.len() + 1; + char_ngrams.push( + char_ngram + .into_iter() + .collect::() + .as_bytes() + .to_vec(), + ); + char_ngram_weights.push(v[..weight_size].to_vec()); } - let mut types: Vec> = vec![]; - let mut type_weights = vec![]; - for (word, v) in type_dict.dump_items() { - let weight_size = config.type_w as usize * 2 - word.len() + 1; - types.push(word.into_iter().collect::().as_bytes().to_vec()); - type_weights.push(v[..weight_size].to_vec()); + let mut type_ngrams: Vec> = vec![]; + let mut type_ngram_weights = vec![]; + for (type_ngram, v) in type_dict.dump_items() { + let weight_size = config.type_w as usize * 2 - type_ngram.len() + 1; + type_ngrams.push( + type_ngram + .into_iter() + .collect::() + .as_bytes() + .to_vec(), + ); + type_ngram_weights.push(v[..weight_size].to_vec()); } let mut dict: Vec> = vec![]; @@ -445,15 +457,15 @@ impl TryFrom for Model { } Ok(Self { - words, - types, + char_ngrams, + type_ngrams, dict, #[cfg(feature = "model-quantize")] quantize_multiplier, - word_weights, - type_weights, + char_ngram_weights, + type_ngram_weights, dict_weights, dict_word_wise: true, bias, diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index 7b545359..8daa5986 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -29,12 +29,12 @@ #[macro_use] mod utils; +mod char_scorer; +mod dict_scorer; mod model; mod predictor; mod sentence; mod type_scorer; -mod char_scorer; -mod dict_scorer; #[cfg(feature = "train")] mod feature; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index b2465346..35747b0e 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -33,12 +33,12 @@ pub struct DictWeight { /// Model data. #[derive(Serialize, Deserialize)] pub struct Model { - pub(crate) words: Vec>, - pub(crate) types: Vec>, + pub(crate) char_ngrams: Vec>, + pub(crate) type_ngrams: Vec>, pub(crate) dict: Vec>, - pub(crate) word_weights: Vec>, - pub(crate) type_weights: Vec>, + pub(crate) char_ngram_weights: Vec>, + pub(crate) type_ngram_weights: Vec>, pub(crate) dict_weights: Vec, #[cfg(feature = "model-quantize")] @@ -105,13 +105,13 @@ impl Model { .unwrap() as i32; let bias = model.label_bias(wb_idx); - let mut words = vec![]; - let mut types = vec![]; - let mut word_weights = vec![]; - let mut type_weights = vec![]; + let mut char_ngrams = vec![]; + let mut type_ngrams = vec![]; + let mut char_ngram_weights = vec![]; + let mut type_ngram_weights = vec![]; let mut dict_weights = vec![DictWeight::default(); dict_word_max_size]; - let mut word_ids = StringIdManager::new(); - let mut type_ids = StringIdManager::new(); + let mut char_ngram_ids = StringIdManager::new(); + let mut type_ngram_ids = StringIdManager::new(); #[cfg(feature = "model-quantize")] let quantize_multiplier = { @@ -138,27 +138,29 @@ impl Model { let weight = weight / quantize_multiplier; match feature.feature { - FeatureContent::CharacterNgram(word) => { - let id = word_ids.get_id(word.as_bytes()); - if id == word_weights.len() { - words.push(word.as_bytes().to_vec()); - word_weights.push(vec![ + FeatureContent::CharacterNgram(char_ngram) => { + let id = char_ngram_ids.get_id(char_ngram.as_bytes()); + if id == char_ngram_weights.len() { + char_ngrams.push(char_ngram.as_bytes().to_vec()); + char_ngram_weights.push(vec![ WeightValue::default(); - char_window_size * 2 - word.chars().count() + 1 + char_window_size * 2 + - char_ngram.chars().count() + + 1 ]); } - word_weights[id][feature.rel_position] = weight as WeightValue; + char_ngram_weights[id][feature.rel_position] = weight as WeightValue; } - FeatureContent::CharacterTypeNgram(word) => { - let id = type_ids.get_id(word) as usize; - if id == type_weights.len() { - types.push(word.to_vec()); - type_weights.push(vec![ + FeatureContent::CharacterTypeNgram(type_ngram) => { + let id = type_ngram_ids.get_id(type_ngram) as usize; + if id == type_ngram_weights.len() { + type_ngrams.push(type_ngram.to_vec()); + type_ngram_weights.push(vec![ WeightValue::default(); - type_window_size * 2 - word.len() + 1 + type_window_size * 2 - type_ngram.len() + 1 ]); } - type_weights[id][feature.rel_position] = weight as WeightValue; + type_ngram_weights[id][feature.rel_position] = weight as WeightValue; } FeatureContent::DictionaryWord(size) => match feature.rel_position { 0 => dict_weights[size - 1].right = weight as ScoreValue, @@ -169,15 +171,15 @@ impl Model { }; } Self { - words, - types, + char_ngrams, + type_ngrams, dict, #[cfg(feature = "model-quantize")] quantize_multiplier, - word_weights, - type_weights, + char_ngram_weights, + type_ngram_weights, dict_weights, dict_word_wise: false, bias, diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 0480c888..61169c0c 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -11,11 +11,11 @@ use std::thread; #[cfg(feature = "multithreading")] use crossbeam_channel::{Receiver, Sender}; +use crate::char_scorer::CharScorer; +use crate::dict_scorer::DictScorer; use crate::model::{DictWeight, Model, ScoreValue}; use crate::sentence::{BoundaryType, Sentence}; -use crate::char_scorer::CharScorer; use crate::type_scorer::TypeScorer; -use crate::dict_scorer::DictScorer; use daachorse::DoubleArrayAhoCorasick; @@ -44,17 +44,17 @@ impl Predictor { pub fn new(model: Model) -> Self { let bias = model.bias; - let chars = model.words; + let char_ngrams = model.char_ngrams; let dict = model.dict; let dict_weights = model.dict_weights; - let mut char_weights: Vec<_> = model - .word_weights + let mut char_ngram_weights: Vec<_> = model + .char_ngram_weights .into_iter() .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) .collect(); - let type_weights: Vec<_> = model - .type_weights + let type_ngram_weights: Vec<_> = model + .type_ngram_weights .into_iter() .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) .collect(); @@ -62,28 +62,32 @@ impl Predictor { let (dict, dict_weights) = Self::merge_dict_weights( dict, dict_weights, - &chars, - &mut char_weights, + &char_ngrams, + &mut char_ngram_weights, model.char_window_size, model.dict_word_wise, ); - let char_weights = Self::merge_weights(&chars, &char_weights); - let type_weights = Self::merge_weights(&model.types, &type_weights); + let char_ngram_weights = Self::merge_weights(&char_ngrams, &char_ngram_weights); + let type_ngram_weights = Self::merge_weights(&model.type_ngrams, &type_ngram_weights); #[cfg(feature = "model-quantize")] let bias = bias as i32; - let char_pma = DoubleArrayAhoCorasick::new(chars).unwrap(); - let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap(); + let char_pma = DoubleArrayAhoCorasick::new(char_ngrams).unwrap(); + let type_pma = DoubleArrayAhoCorasick::new(model.type_ngrams).unwrap(); - let char_scorer = CharScorer::new(char_pma, char_weights, model.char_window_size); - let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size); + let char_scorer = CharScorer::new(char_pma, char_ngram_weights, model.char_window_size); + let type_scorer = TypeScorer::new(type_pma, type_ngram_weights, model.type_window_size); let dict_scorer = if dict.is_empty() { None } else { let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); - Some(DictScorer::new(dict_pma, dict_weights, model.dict_word_wise)) + Some(DictScorer::new( + dict_pma, + dict_weights, + model.dict_word_wise, + )) }; Self { From e8b664d46fbbe681be1aab323523940e20ff2f6e Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 25 Nov 2021 14:37:00 +0900 Subject: [PATCH 05/60] Fix var names --- vaporetto/src/predictor.rs | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 61169c0c..3d6b5bc0 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -520,21 +520,21 @@ mod tests { /// 世: 40 42 fn generate_model_1() -> Model { Model { - words: vec![ + char_ngrams: vec![ "我ら".as_bytes().to_vec(), "全世界".as_bytes().to_vec(), "国民".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "界".as_bytes().to_vec(), ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec![ "全世界".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "世".as_bytes().to_vec(), ], #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ + char_ngram_weights: vec![ vec![0.5, 1.0, 1.5, 2.0, 2.5], vec![3.0, 3.5, 4.0, 4.5], vec![5.0, 5.5, 6.0, 6.5, 7.0], @@ -542,7 +542,7 @@ mod tests { vec![10.0, 10.5, 11.0, 11.5, 12.0, 12.5], ], #[cfg(feature = "model-quantize")] - word_weights: vec![ + char_ngram_weights: vec![ vec![1, 2, 3, 4, 5], vec![6, 7, 8, 9], vec![10, 11, 12, 13, 14], @@ -550,14 +550,14 @@ mod tests { vec![20, 21, 22, 23, 24, 25], ], #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ + type_ngram_weights: vec![ vec![13.0, 13.5, 14.0, 14.5], vec![15.0, 15.5, 16.0, 16.5], vec![17.0, 17.5, 18.0], vec![18.5, 19.0, 19.5], ], #[cfg(feature = "model-quantize")] - type_weights: vec![ + type_ngram_weights: vec![ vec![26, 27, 28, 29], vec![30, 31, 32, 33], vec![34, 35, 36], @@ -629,21 +629,21 @@ mod tests { /// 世: 38 40 fn generate_model_2() -> Model { Model { - words: vec![ + char_ngrams: vec![ "我ら".as_bytes().to_vec(), "全世界".as_bytes().to_vec(), "国民".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "界".as_bytes().to_vec(), ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec![ "全世界".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "世".as_bytes().to_vec(), ], #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ + char_ngram_weights: vec![ vec![0.25, 0.5, 0.75], vec![1.0, 1.25], vec![1.5, 1.75, 2.0], @@ -651,7 +651,7 @@ mod tests { vec![3.0, 3.25, 3.5, 3.75], ], #[cfg(feature = "model-quantize")] - word_weights: vec![ + char_ngram_weights: vec![ vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], @@ -659,14 +659,14 @@ mod tests { vec![12, 13, 14, 15], ], #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ + type_ngram_weights: vec![ vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], vec![7.0, 7.25, 7.5, 7.75, 8.0], vec![8.25, 8.5, 8.75, 9.0, 9.25], ], #[cfg(feature = "model-quantize")] - type_weights: vec![ + type_ngram_weights: vec![ vec![16, 17, 18, 19, 20, 21], vec![22, 23, 24, 25, 26, 27], vec![28, 29, 30, 31, 32], @@ -748,21 +748,21 @@ mod tests { /// 世: 44 46 fn generate_model_3() -> Model { Model { - words: vec![ + char_ngrams: vec![ "我ら".as_bytes().to_vec(), "全世界".as_bytes().to_vec(), "国民".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "界".as_bytes().to_vec(), ], - types: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec![ "国民".as_bytes().to_vec(), "世界".as_bytes().to_vec(), "世".as_bytes().to_vec(), ], #[cfg(not(feature = "model-quantize"))] - word_weights: vec![ + char_ngram_weights: vec![ vec![0.25, 0.5, 0.75], vec![1.0, 1.25], vec![1.5, 1.75, 2.0], @@ -770,7 +770,7 @@ mod tests { vec![3.0, 3.25, 3.5, 3.75], ], #[cfg(feature = "model-quantize")] - word_weights: vec![ + char_ngram_weights: vec![ vec![1, 2, 3], vec![4, 5], vec![6, 7, 8], @@ -778,14 +778,14 @@ mod tests { vec![12, 13, 14, 15], ], #[cfg(not(feature = "model-quantize"))] - type_weights: vec![ + type_ngram_weights: vec![ vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], vec![7.0, 7.25, 7.5, 7.75, 8.0], vec![8.25, 8.5, 8.75, 9.0, 9.25], ], #[cfg(feature = "model-quantize")] - type_weights: vec![ + type_ngram_weights: vec![ vec![16, 17, 18, 19, 20, 21], vec![22, 23, 24, 25, 26, 27], vec![28, 29, 30, 31, 32], From da9b5bfbf58825360b7ce1f5b8f773624586eba3 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 25 Nov 2021 19:15:08 +0900 Subject: [PATCH 06/60] Expand enum --- vaporetto/src/char_scorer.rs | 24 ++---------------------- vaporetto/src/dict_scorer.rs | 30 ++---------------------------- 2 files changed, 4 insertions(+), 50 deletions(-) diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 4256fa34..31cc0911 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -2,33 +2,13 @@ use crate::model::ScoreValue; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; -pub enum CharScorer { - Pma(CharScorerPma), -} - -impl CharScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { - Self::Pma(CharScorerPma::new(pma, weights, window_size)) - } - - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - match self { - CharScorer::Pma(pma) => pma.add_scores(sentence, start, ys), - } - } -} - -pub struct CharScorerPma { +pub struct CharScorer { pma: DoubleArrayAhoCorasick, weights: Vec>, window_size: usize, } -impl CharScorerPma { +impl CharScorer { pub fn new( pma: DoubleArrayAhoCorasick, weights: Vec>, diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index 40dae9cb..2c9e9326 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -2,40 +2,14 @@ use crate::model::{DictWeight, ScoreValue}; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; -pub enum DictScorer { - Pma(DictScorerPma), -} - -impl DictScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec, - word_wise_score: bool, - ) -> Self { - Self::Pma(DictScorerPma::new(pma, weights, word_wise_score)) - } - - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - match self { - DictScorer::Pma(pma) => pma.add_scores(sentence, start, ys), - } - } - - pub fn window_size(&mut self, size: usize) { - match self { - DictScorer::Pma(pma) => pma.window_size(size), - } - } -} - -pub struct DictScorerPma { +pub struct DictScorer { pma: DoubleArrayAhoCorasick, weights: Vec, window_size: usize, word_wise_score: bool, } -impl DictScorerPma { +impl DictScorer { pub fn new( pma: DoubleArrayAhoCorasick, weights: Vec, From 058fa971856e2ceccf7fa2a335560b676dd4c054 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 26 Nov 2021 10:34:59 +0900 Subject: [PATCH 07/60] Remove bench directory (#4) * Remove bench directory * Update README --- .gitmodules | 12 -- README.md | 2 + bench/README.md | 16 --- bench/compile_all.sh | 54 --------- bench/download_resources.sh | 23 ---- bench/elapsed_time.patch | 114 ------------------ bench/kuromoji/pom.xml | 72 ----------- .../src/main/java/kuromoji_bench/App.java | 28 ----- bench/kytea | 1 - bench/lindera | 1 - bench/mecab | 1 - bench/run_all.sh | 31 ----- bench/stats.py | 46 ------- bench/sudachi.rs | 1 - bench/sudachi/pom.xml | 72 ----------- .../src/main/java/sudachi_bench/App.java | 36 ------ bench/sudachi/sudachi.json | 25 ---- 17 files changed, 2 insertions(+), 533 deletions(-) delete mode 100644 .gitmodules delete mode 100644 bench/README.md delete mode 100755 bench/compile_all.sh delete mode 100755 bench/download_resources.sh delete mode 100644 bench/elapsed_time.patch delete mode 100644 bench/kuromoji/pom.xml delete mode 100644 bench/kuromoji/src/main/java/kuromoji_bench/App.java delete mode 160000 bench/kytea delete mode 160000 bench/lindera delete mode 160000 bench/mecab delete mode 100755 bench/run_all.sh delete mode 100755 bench/stats.py delete mode 160000 bench/sudachi.rs delete mode 100644 bench/sudachi/pom.xml delete mode 100644 bench/sudachi/src/main/java/sudachi_bench/App.java delete mode 100644 bench/sudachi/sudachi.json diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index c2e05dab..00000000 --- a/.gitmodules +++ /dev/null @@ -1,12 +0,0 @@ -[submodule "bench/kytea"] - path = bench/kytea - url = https://github.com/neubig/kytea.git -[submodule "bench/lindera"] - path = bench/lindera - url = https://github.com/lindera-morphology/lindera.git -[submodule "bench/mecab"] - path = bench/mecab - url = https://github.com/taku910/mecab.git -[submodule "bench/sudachi.rs"] - path = bench/sudachi.rs - url = https://github.com/WorksApplications/sudachi.rs.git diff --git a/README.md b/README.md index c4f49838..97a925d1 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,8 @@ You can specify all arguments above multiple times. ## Speed Comparison of Various Tokenizers +You can find the comparison script at [here](https://github.com/legalforce-research/tokenizer-speed-bench). + ### Experimental Setup * Document: Japanese training data of Kyoto Free Translation Task diff --git a/bench/README.md b/bench/README.md deleted file mode 100644 index f66549b2..00000000 --- a/bench/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Benchmarking of various tokenizers - -## Preparation - -``` -% git submodule update --init -% ./download_resources.sh -% ./compile_all.sh -``` - -## Measurement - -``` -% ./run_all.sh 2>&1 | tee ./results -% ./stats.py < ./results -``` diff --git a/bench/compile_all.sh b/bench/compile_all.sh deleted file mode 100755 index d9e3fabe..00000000 --- a/bench/compile_all.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -set -eux - -which patch -which cargo -which autoreconf -which libtool -which make -which mvn - -set +e - -patch -p1 -N < ./elapsed_time.patch - -set -e - -pushd .. -cargo build --release -./target/release/convert_kytea_model --model-in "./bench/kytea/jp-0.4.7-6.mod" --model-out "./jp-0.4.7-6.tokenize.mod" -popd - -pushd ./kytea -autoreconf -i -./configure -make -popd - -pushd ./mecab/mecab -./configure --prefix=$(cd .. && pwd)/tmpusr -make -make install -popd -pushd ./mecab/mecab-ipadic -./configure --with-charset=utf8 --prefix=$(cd .. && pwd)/tmpusr --with-mecab-config=../mecab/mecab-config -make -make install -popd - -pushd ./kuromoji -mvn compile -popd - -pushd ./lindera -cargo build --release -popd - -pushd ./sudachi -mvn compile -popd - -pushd ./sudachi.rs -cargo build --release -popd diff --git a/bench/download_resources.sh b/bench/download_resources.sh deleted file mode 100755 index 4f5e91df..00000000 --- a/bench/download_resources.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -set -eux - -which wget -which gunzip -which unzip -which tar - -pushd ./kytea -wget "http://www.phontron.com/kytea/download/model/jp-0.4.7-6.mod.gz" -gunzip "./jp-0.4.7-6.mod.gz" -popd -pushd ./sudachi -wget "http://sudachi.s3-website-ap-northeast-1.amazonaws.com/sudachidict/sudachi-dictionary-20210802-core.zip" -unzip "./sudachi-dictionary-20210802-core.zip" -popd -pushd ./sudachi.rs -./fetch_dictionary.sh -popd - -wget "http://www.phontron.com/kftt/download/kftt-data-1.0.tar.gz" -tar xf "./kftt-data-1.0.tar.gz" diff --git a/bench/elapsed_time.patch b/bench/elapsed_time.patch deleted file mode 100644 index 9f5211a1..00000000 --- a/bench/elapsed_time.patch +++ /dev/null @@ -1,114 +0,0 @@ ---- a/kytea/src/lib/kytea.cpp -+++ b/kytea/src/lib/kytea.cpp -@@ -19,6 +19,7 @@ - #include - #include - #include -+#include - #include - #include - #include -@@ -1206,6 +1207,8 @@ void Kytea::analyze() { - for(int i = 0; i < config_->getNumTags(); i++) - out->setDoTag(i,config_->getDoTag(i)); - -+ chrono::steady_clock::time_point begin = chrono::steady_clock::now(); -+ - KyteaSentence* next; - while((next = in->readSentence()) != 0) { - if(config_->getDoWS()) -@@ -1218,6 +1221,9 @@ void Kytea::analyze() { - delete next; - } - -+ chrono::steady_clock::time_point end = chrono::steady_clock::now(); -+ cerr << "Elapsed-kytea: " << (double) chrono::duration_cast(end - begin).count() / 1000 << " [sec]" << endl; -+ - delete in; - delete out; - if(inStr) delete inStr; ---- a/mecab/mecab/src/tagger.cpp -+++ b/mecab/mecab/src/tagger.cpp -@@ -6,6 +6,7 @@ - #include - #include - #include -+#include - #include "common.h" - #include "connector.h" - #include "mecab.h" -@@ -1229,6 +1230,8 @@ int mecab_do(int argc, char **argv) { - WHAT_ERROR("cannot create tagger"); - } - -+ std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); -+ - for (size_t i = 0; i < rest.size(); ++i) { - MeCab::istream_wrapper ifs(rest[i].c_str()); - if (!*ifs) { -@@ -1255,6 +1258,8 @@ int mecab_do(int argc, char **argv) { - std::strncpy(ibuf, sentence.c_str(), ibufsize); - } - if (ifs->eof() && !ibuf[0]) { -+ std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); -+ std::cerr << "Elapsed-mecab: " << (double) std::chrono::duration_cast(end - begin).count() / 1000 << " [sec]" << std::endl; - return false; - } - if (ifs->fail()) { ---- a/lindera/lindera-cli/src/main.rs -+++ b/lindera/lindera-cli/src/main.rs -@@ -2,6 +2,7 @@ use std::fs; - use std::io; - use std::io::{BufRead, BufReader}; - use std::path::Path; -+use std::time::Instant; - - use clap::{crate_authors, crate_description, crate_version, App, AppSettings, Arg}; - -@@ -123,6 +124,8 @@ fn main() -> LinderaResult<()> { - Box::new(BufReader::new(io::stdin())) - }; - -+ let start = Instant::now(); -+ - loop { - // read the text to be tokenized from stdin - let mut text = String::new(); -@@ -145,5 +148,8 @@ fn main() -> LinderaResult<()> { - }; - } - -+ let duration = start.elapsed(); -+ eprintln!("Elapsed-lindera: {} [sec]", duration.as_secs_f64()); -+ - Ok(()) - } ---- a/sudachi.rs/sudachi-cli/src/main.rs -+++ b/sudachi.rs/sudachi-cli/src/main.rs -@@ -20,6 +20,7 @@ use std::fs::File; - use std::io::{self, BufRead, BufReader, BufWriter, Write}; - use std::path::PathBuf; - use std::process; -+use std::time::Instant; - - use structopt::StructOpt; - -@@ -132,6 +133,8 @@ fn main() { - - let format = make_output::<&JapaneseDictionary>(&args); - -+ let start = Instant::now(); -+ - // tokenize and output results - for line in reader.lines() { - let input = line.expect("Failed to read line"); -@@ -157,6 +160,9 @@ fn main() { - } - // it is recommended to call write before dropping BufWriter - writer.flush().expect("flush failed"); -+ -+ let duration = start.elapsed(); -+ eprintln!("Elapsed-sudachi.rs: {} [sec]", duration.as_secs_f64()); - } - - fn make_output(cli: &Cli) -> Box> { diff --git a/bench/kuromoji/pom.xml b/bench/kuromoji/pom.xml deleted file mode 100644 index 5f88e1a3..00000000 --- a/bench/kuromoji/pom.xml +++ /dev/null @@ -1,72 +0,0 @@ - - - - 4.0.0 - - kuromoji_bench - kuromoji_bench - 1.0-SNAPSHOT - - kuromoji_bench - - - UTF-8 - 1.7 - 1.7 - - - - - com.atilika.kuromoji - kuromoji-ipadic - 0.9.0 - - - - - - - - - maven-clean-plugin - 3.1.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.8.0 - - - maven-surefire-plugin - 2.22.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - maven-site-plugin - 3.7.1 - - - maven-project-info-reports-plugin - 3.0.0 - - - - - diff --git a/bench/kuromoji/src/main/java/kuromoji_bench/App.java b/bench/kuromoji/src/main/java/kuromoji_bench/App.java deleted file mode 100644 index 7f347d38..00000000 --- a/bench/kuromoji/src/main/java/kuromoji_bench/App.java +++ /dev/null @@ -1,28 +0,0 @@ -package kuromoji_bench; - -import com.atilika.kuromoji.ipadic.Token; -import com.atilika.kuromoji.ipadic.Tokenizer; -import java.util.List; -import java.util.ArrayList; -import java.util.Scanner; -import java.time.Instant; -import java.time.Duration; - -public class App { - public static void main(String[] args) { - Tokenizer tokenizer = new Tokenizer(); - Scanner input = new Scanner(System.in); - Instant start = Instant.now(); - while (input.hasNext()) { - List tokens = tokenizer.tokenize(input.nextLine()); - List words = new ArrayList(); - for (Token token : tokens) { - words.add(token.getSurface()); - } - System.out.println(String.join(" ", words)); - } - Instant finish = Instant.now(); - double timeElapsed = (double) Duration.between(start, finish).toMillis() / 1000; - System.err.println("Elapsed-kuromoji: " + timeElapsed + " [sec]"); - } -} diff --git a/bench/kytea b/bench/kytea deleted file mode 160000 index 73a94c4a..00000000 --- a/bench/kytea +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 73a94c4a3045087a7e90f27700f3b870a72625e7 diff --git a/bench/lindera b/bench/lindera deleted file mode 160000 index 0f500336..00000000 --- a/bench/lindera +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0f50033653631261a290ae4ac94cc16bfe63f3bb diff --git a/bench/mecab b/bench/mecab deleted file mode 160000 index 046fa78b..00000000 --- a/bench/mecab +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 046fa78b2ed56fbd4fac312040f6d62fc1bc31e3 diff --git a/bench/run_all.sh b/bench/run_all.sh deleted file mode 100755 index 6b6f2365..00000000 --- a/bench/run_all.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -set -eux - -INPUT_DATA="./kftt-data-1.0/data/orig/kyoto-train.ja" - -for i in 0 1 2 3 4 5 6 7 8 9 -do - for j in 0 1 2 3 4 5 6 7 8 9 - do - echo "iter" $i $j - - ./kytea/src/bin/kytea -model "./kytea/jp-0.4.7-6.mod" -notags < $INPUT_DATA > /dev/null - - ../target/release/predict --model "../jp-0.4.7-6.tokenize.mod" < $INPUT_DATA > /dev/null - - ./mecab/tmpusr/bin/mecab -Owakati < $INPUT_DATA > /dev/null - - pushd ./kuromoji - mvn exec:java -Dexec.mainClass=kuromoji_bench.App < ../$INPUT_DATA > /dev/null - popd - - ./lindera/target/release/lindera -O wakati < $INPUT_DATA > /dev/null - - pushd ./sudachi - mvn exec:java -Dexec.mainClass=sudachi_bench.App < ../$INPUT_DATA > /dev/null - popd - - ./sudachi.rs/target/release/sudachi -w -m C < $INPUT_DATA > /dev/null - done -done diff --git a/bench/stats.py b/bench/stats.py deleted file mode 100755 index 9004493f..00000000 --- a/bench/stats.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import collections -import math -import re -import sys - - -RE_DICT = [ - ('kytea', re.compile(r'Elapsed-kytea: ([0-9\.]+) \[sec\]')), - ('vaporetto', re.compile(r'Elapsed: ([0-9\.]+) \[sec\]')), - ('mecab', re.compile(r'Elapsed-mecab: ([0-9\.]+) \[sec\]')), - ('kuromoji', re.compile(r'Elapsed-kuromoji: ([0-9\.]+) \[sec\]')), - ('lindera', re.compile(r'Elapsed-lindera: ([0-9\.]+) \[sec\]')), - ('sudachi', re.compile(r'Elapsed-sudachi: ([0-9\.]+) \[sec\]')), - ('sudachi.rs', re.compile(r'Elapsed-sudachi.rs: ([0-9\.]+) \[sec\]')), -] - -N_CHARS = 16318893 - - -def mean_std(times: list[float]) -> (float, float): - speeds = [N_CHARS / time for time in times] - mean = sum(speeds) / len(speeds) - dist = sum((speed - mean) ** 2 for speed in speeds) / len(speeds) - return mean, math.sqrt(dist) - - -def _main(): - times = collections.defaultdict(list) - for line in sys.stdin: - for name, r in RE_DICT: - m = r.match(line) - if m is not None: - times[name].append(float(m.group(1))) - break - - for name, _ in RE_DICT: - mean, std = mean_std(times[name]) - print(f'{name} {mean} {std}') - - -if __name__ == '__main__': - _main() diff --git a/bench/sudachi.rs b/bench/sudachi.rs deleted file mode 160000 index 1cf62ec2..00000000 --- a/bench/sudachi.rs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1cf62ec2d6949db76e5aa2625c9b76f747960ac1 diff --git a/bench/sudachi/pom.xml b/bench/sudachi/pom.xml deleted file mode 100644 index 26d7f26d..00000000 --- a/bench/sudachi/pom.xml +++ /dev/null @@ -1,72 +0,0 @@ - - - - 4.0.0 - - sudachi_bench - sudachi_bench - 1.0-SNAPSHOT - - sudachi_bench - - - UTF-8 - 1.7 - 1.7 - - - - - com.worksap.nlp - sudachi - 0.5.2 - - - - - - - - - maven-clean-plugin - 3.1.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.8.0 - - - maven-surefire-plugin - 2.22.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - maven-site-plugin - 3.7.1 - - - maven-project-info-reports-plugin - 3.0.0 - - - - - diff --git a/bench/sudachi/src/main/java/sudachi_bench/App.java b/bench/sudachi/src/main/java/sudachi_bench/App.java deleted file mode 100644 index ac249c98..00000000 --- a/bench/sudachi/src/main/java/sudachi_bench/App.java +++ /dev/null @@ -1,36 +0,0 @@ -package sudachi_bench; - -import java.io.IOException; -import com.worksap.nlp.sudachi.Tokenizer; -import com.worksap.nlp.sudachi.Dictionary; -import com.worksap.nlp.sudachi.DictionaryFactory; -import com.worksap.nlp.sudachi.Morpheme; -import java.util.List; -import java.util.ArrayList; -import java.util.Scanner; -import java.time.Instant; -import java.time.Duration; -import java.nio.file.Paths; -import java.nio.file.Files; - -public class App { - public static void main(String[] args) throws IOException { - String settings = Files.readString(Paths.get("sudachi.json")); - Scanner input = new Scanner(System.in); - try (Dictionary dict = new DictionaryFactory().create(settings)) { - Tokenizer tokenizer = dict.create(); - Instant start = Instant.now(); - while (input.hasNext()) { - List tokens = tokenizer.tokenize(Tokenizer.SplitMode.C, input.nextLine()); - List words = new ArrayList(); - for (Morpheme token : tokens) { - words.add(token.surface()); - } - System.out.println(String.join(" ", words)); - } - Instant finish = Instant.now(); - double timeElapsed = (double) Duration.between(start, finish).toMillis() / 1000; - System.err.println("Elapsed-sudachi: " + timeElapsed + " [sec]"); - } - } -} diff --git a/bench/sudachi/sudachi.json b/bench/sudachi/sudachi.json deleted file mode 100644 index 9a94c67c..00000000 --- a/bench/sudachi/sudachi.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "systemDict" : "sudachi-dictionary-20210802/system_core.dic", - "inputTextPlugin" : [ - { "class" : "com.worksap.nlp.sudachi.DefaultInputTextPlugin" }, - { "class" : "com.worksap.nlp.sudachi.ProlongedSoundMarkInputTextPlugin", - "prolongedSoundMarks": ["ー", "-", "⁓", "〜", "〰"], - "replacementSymbol": "ー"} - ], - "oovProviderPlugin" : [ - { "class" : "com.worksap.nlp.sudachi.MeCabOovProviderPlugin" }, - { "class" : "com.worksap.nlp.sudachi.SimpleOovProviderPlugin", - "oovPOS" : [ "補助記号", "一般", "*", "*", "*", "*" ], - "leftId" : 5968, - "rightId" : 5968, - "cost" : 3857 } - ], - "pathRewritePlugin" : [ - { "class" : "com.worksap.nlp.sudachi.JoinNumericPlugin", - "joinKanjiNumeric" : true }, - { "class" : "com.worksap.nlp.sudachi.JoinKatakanaOovPlugin", - "oovPOS" : [ "名詞", "普通名詞", "一般", "*", "*", "*" ], - "minLength" : 3 - } - ] -} From 1fcf6c04b6871177695d450e86ae8b6b981472ad Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 26 Nov 2021 10:56:02 +0900 Subject: [PATCH 08/60] Add JS file generator and simplify example script (#5) * Add JS file generator and simplify example script * Update README.md * Update build_portable_js.sh * Update build_portable_js.sh --- Cargo.toml | 3 + model/model.zstd | Bin 258 -> 0 bytes vaporetto_wasm/Cargo.toml | 6 + vaporetto_wasm/README.md | 38 +++++-- vaporetto_wasm/build_portable_js.sh | 16 +++ vaporetto_wasm/src/lib.rs | 109 ++++++++++++++---- vaporetto_wasm/www/index.html | 8 +- vaporetto_wasm/www/index.js | 170 ++++------------------------ 8 files changed, 163 insertions(+), 187 deletions(-) delete mode 100644 model/model.zstd create mode 100755 vaporetto_wasm/build_portable_js.sh diff --git a/Cargo.toml b/Cargo.toml index a3ba8aa0..09c8e826 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,5 +7,8 @@ members = [ "train", "evaluate", "convert_kytea_model", +] + +exclude = [ "vaporetto_wasm", ] diff --git a/model/model.zstd b/model/model.zstd deleted file mode 100644 index 8d409665268e8889a50d295df613be9091e92485..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 258 zcmdPcs{fZE<19PF6m|xF=EqGt*&a8yHSK)d+|jrm$eR4Psi$%MWX{Kp>wAFWJ)A%> zAZIcsqlbq_NJxkW6Jv-6gWYjK=@*~0E7+V@S*sq1+`p26m36T*8v_uqumUmXvv*2D z_8Pm@>n{ARZ(w9=U}R-sSR?|H2LfiB1rx4x2E7RAy0GFbqpnclgAR-9=L7niS(z?Z zOf@;d`;xQy_A2XlD}|=_9~y2fbZX}NFvsP)<@vQY4;Dmi z()L(ls1kSoD(jTLkGfSDJR^-O9vs@twe4N>m1hsEA8q=kr}4o5pfYdQu^+o ``` -4. Open http://localhost:8000/www +3. You can use the generated JS file like the follwing code: + ```html + + + + + + + + + + ``` diff --git a/vaporetto_wasm/build_portable_js.sh b/vaporetto_wasm/build_portable_js.sh new file mode 100755 index 00000000..d4584f11 --- /dev/null +++ b/vaporetto_wasm/build_portable_js.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -eu + +DIRNAME="$(dirname $0)" +MODEL="$(realpath $1)" +IDENT="$2" +OUTPUT="$3" +pushd "$DIRNAME" +VAPORETTO_MODEL_PATH="$MODEL" wasm-pack build --release --target no-modules +popd +encoded_wasm=$(base64 < "${DIRNAME}/pkg/vaporetto_wasm_bg.wasm") +cat \ + <(sed "s/wasm_bindgen/__vaporetto_${IDENT}_wbg/g" < "${DIRNAME}/pkg/vaporetto_wasm.js") \ + <(echo "async function vaporetto_${IDENT}(){await __vaporetto_${IDENT}_wbg(fetch('data:application/wasm;base64,${encoded_wasm}'));return __vaporetto_${IDENT}_wbg.Vaporetto;}") \ + > "$OUTPUT" diff --git a/vaporetto_wasm/src/lib.rs b/vaporetto_wasm/src/lib.rs index e7a9189d..3d75ef72 100644 --- a/vaporetto_wasm/src/lib.rs +++ b/vaporetto_wasm/src/lib.rs @@ -4,68 +4,127 @@ use js_sys::{Array, Object}; use vaporetto::{BoundaryType, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, - SentenceFilter, + string_filters::KyteaFullwidthFilter, + SentenceFilter, StringFilter, }; use wasm_bindgen::{prelude::*, JsValue}; +#[global_allocator] +static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT; + #[wasm_bindgen] pub struct Vaporetto { predictor: Predictor, + fullwidth_filter: KyteaFullwidthFilter, post_filters: Vec>, } #[wasm_bindgen] impl Vaporetto { #[wasm_bindgen] - pub fn new() -> Self { - let mut f = Cursor::new(include_bytes!("../../model/model.zstd")); + pub fn new(filters: &str) -> Self { + let mut f = Cursor::new(include_bytes!(env!("VAPORETTO_MODEL_PATH"))); let mut decoder = ruzstd::StreamingDecoder::new(&mut f).unwrap(); let mut buff = vec![]; decoder.read_to_end(&mut buff).unwrap(); let model = Model::read(&mut buff.as_slice()).unwrap(); let predictor = Predictor::new(model); - let post_filters: Vec> = vec![ - Box::new(ConcatGraphemeClustersFilter::new()), - Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), - ]; + let post_filters: Vec<_> = filters + .chars() + .map(|c| { + let b: Box = match c { + 'D' => Box::new(KyteaWsConstFilter::new(CharacterType::Digit)), + 'R' => Box::new(KyteaWsConstFilter::new(CharacterType::Roman)), + 'H' => Box::new(KyteaWsConstFilter::new(CharacterType::Hiragana)), + 'T' => Box::new(KyteaWsConstFilter::new(CharacterType::Katakana)), + 'K' => Box::new(KyteaWsConstFilter::new(CharacterType::Kanji)), + 'O' => Box::new(KyteaWsConstFilter::new(CharacterType::Other)), + 'G' => Box::new(ConcatGraphemeClustersFilter::new()), + _ => panic!("invalid filter: {}", c), + }; + b + }) + .collect(); Self { predictor, + fullwidth_filter: KyteaFullwidthFilter::new(), post_filters, } } #[wasm_bindgen] - pub fn predict_partial(&self, text: &str, start: usize, end: usize) -> Object { - let s = if let Ok(s) = Sentence::from_raw(text) { + pub fn tokenize(&self, text: &str) -> Object { + let result = Array::new(); + let mut s = if let Ok(s) = Sentence::from_raw(text) { s } else { - return JsValue::NULL.into(); + return result.into(); }; - if start >= end { - return JsValue::NULL.into(); - } - let s = self.predictor.predict_partial_with_score(s, start..end); + let norm = self.fullwidth_filter.filter(text); + let s_norm = if let Ok(s) = Sentence::from_raw(norm) { + s + } else { + return result.into(); + }; + let s_norm = self.predictor.predict(s_norm); + s.boundaries_mut().clone_from_slice(s_norm.boundaries()); let s = self .post_filters .iter() .fold(s, |s, filter| filter.filter(s)); + if let Ok(words) = s.to_tokenized_vec() { + for word in words { + result.push(&JsValue::from_str(word)); + } + } + result.into() + } + + #[wasm_bindgen] + pub fn predict(&self, text: &str) -> Object { let result = Array::new(); - for (&score, &b) in s.boundary_scores().unwrap()[start..end] + let text = self.fullwidth_filter.filter(text); + let s = if let Ok(s) = Sentence::from_raw(text) { + s + } else { + return result.into(); + }; + let s = self.predictor.predict(s); + let s = self + .post_filters .iter() - .zip(&s.boundaries()[start..end]) - { - let boundary = Array::new(); - boundary.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); - boundary.push(&JsValue::from_f64(score)); - result.push(&boundary); + .fold(s, |s, filter| filter.filter(s)); + + for &b in s.boundaries() { + result.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); } result.into() } -} -impl Default for Vaporetto { - fn default() -> Self { - Self::new() + #[wasm_bindgen] + pub fn predict_with_score(&self, text: &str) -> Object { + let result = Array::new(); + let text = self.fullwidth_filter.filter(text); + let s = if let Ok(s) = Sentence::from_raw(text) { + s + } else { + return result.into(); + }; + let s = self.predictor.predict_with_score(s); + let s = self + .post_filters + .iter() + .fold(s, |s, filter| filter.filter(s)); + + if let Some(boundaries) = s.boundary_scores() { + for (&score, &b) in boundaries.iter().zip(s.boundaries()) { + let boundary = Array::new(); + boundary.push(&JsValue::from_bool(b == BoundaryType::WordBoundary)); + boundary.push(&JsValue::from_f64(score)); + result.push(&boundary); + } + } + result.into() } } diff --git a/vaporetto_wasm/www/index.html b/vaporetto_wasm/www/index.html index a6a43b8c..8434232c 100644 --- a/vaporetto_wasm/www/index.html +++ b/vaporetto_wasm/www/index.html @@ -2,9 +2,10 @@ - Vaporetto Real-time Tokenization + Vaporetto Demo - + +
@@ -17,9 +18,8 @@
Output:
-
+

             
-
Loading...
diff --git a/vaporetto_wasm/www/index.js b/vaporetto_wasm/www/index.js index e7f0deb2..212e5985 100644 --- a/vaporetto_wasm/www/index.js +++ b/vaporetto_wasm/www/index.js @@ -1,152 +1,30 @@ -import init from '../pkg/vaporetto_wasm.js'; -import * as wasm from '../pkg/vaporetto_wasm.js'; - -const loading = document.getElementById("loading"); -loading.style.display = "block"; - -function run() { - const predictor = wasm.Vaporetto.new(); - - loading.style.display = "none"; - - function createTextSpan(text) { - const span = document.createElement("span"); - const textnode = document.createTextNode(text); - span.appendChild(textnode); - return span; +function createTextSpan(text, isBoundary, score) { + const span = document.createElement("span"); + const textnode = document.createTextNode(text); + span.appendChild(textnode); + if (isBoundary) { + span.style.borderLeft = "5pt solid rgba(0, 0, 0, " + Math.atan(score / 2) + ")"; } + return span; +} - function replace_text(elem, prev_text, text, range_from, range_to, boundaries, window_size) { - const prev_boundary_start = Math.max(range_from[0] - window_size, 0); - const prev_boundary_end = Math.min(range_from[1] + window_size - 1, prev_text.length - 1); - const node_end_idx = prev_boundary_end + 1; - let node_end = elem.childNodes[0]; - if (prev_text.length != 0) { - node_end = elem.childNodes[node_end_idx]; - if (range_from[0] == 0) { - node_end.previousSibling.remove(); - } - for (let i = prev_boundary_end - prev_boundary_start; i > 0; --i) { - node_end.previousSibling.remove(); - } - } - const next_boundary_start = Math.max(range_to[0] - window_size, 0); - const next_boundary_end = Math.min(range_to[1] + window_size - 1, text.length - 1); - if (text.length != 0) { - if (range_to[0] == 0) { - node_end.before(createTextSpan(text[next_boundary_start])); - } - for (let i = 0; i < next_boundary_end - next_boundary_start; ++i) { - const elem = createTextSpan(text[next_boundary_start + i + 1]); - if (boundaries[i][0]) { - elem.style.borderLeft = '5pt solid rgba(0, 0, 0, ' + Math.atan(boundaries[i][1] / 2) + ')'; - } - node_end.before(elem); - } - } - } - - const input_text = document.getElementById('input_text'); - input_text.value = ""; - - const window_size = 3; - - let input_data = null; - let prev_range = [0, 0]; - let prev_chars = []; - let chars_pos_map = [0]; - - let composition_start = null; - input_text.addEventListener('compositionstart', function (e) { - composition_start = chars_pos_map[e.target.selectionStart]; - }); - - input_text.addEventListener('compositionend', function (e) { - composition_start = null; - }); - - input_text.addEventListener('beforeinput', function (e) { - input_data = e.data; - if (composition_start != null) { - prev_range = [composition_start, chars_pos_map[e.target.selectionEnd]]; - } else { - prev_range = [chars_pos_map[e.target.selectionStart], chars_pos_map[e.target.selectionEnd]]; - } - }); - - input_text.addEventListener('input', function (e) { - const t0 = performance.now(); +vaporetto_bccwj_suw_small().then((Vaporetto) => { + const vaporetto_suw = Vaporetto.new("DG"); - const cur_text = e.target.value; - const cur_chars = Array.from(cur_text); - chars_pos_map = new Array(cur_text.length); - let utf16_pos = 0; - for (let i = 0; i < cur_chars.length; ++i) { - chars_pos_map[utf16_pos] = i; - utf16_pos += cur_chars[i].length; + input_text.addEventListener("input", (e) => { + const text = input_text.value; + const scores = vaporetto_suw.predict_with_score(text); + let i = -1; + while (tokenized.firstChild) { + tokenized.removeChild(tokenized.firstChild); } - chars_pos_map.push(cur_chars.length); - - let range_from = null; - let range_to = null; - switch (e.inputType) { - case 'insertText': - case 'insertLineBreak': - case 'insertParagraph': - case 'insertFromPaste': - case 'insertCompositionText': - range_from = prev_range; - range_to = [prev_range[0], prev_range[1] + cur_chars.length - prev_chars.length]; - break; - case 'deleteWordBackward': - case 'deleteWordForward': - case 'deleteSoftLineBackward': - case 'deleteSoftLineForward': - case 'deleteEntireSoftLine': - case 'deleteHardLineBackward': - case 'deleteHardLineForward': - case 'deleteByCut': - case 'deleteContent': - case 'deleteContentBackward': - case 'deleteContentForward': - const start = chars_pos_map[e.target.selectionStart]; - const right_length = cur_chars.length - start; - const prev_end = prev_chars.length - right_length; - range_from = [start, prev_end]; - range_to = [start, start]; - break; - default: - range_from = [0, prev_chars.length]; - range_to = [0, cur_chars.length]; + for (let c of text) { + if (i >= 0) { + tokenized.appendChild(createTextSpan(c, scores[i][0], scores[i][1])); + } else { + tokenized.appendChild(createTextSpan(c, false, 0)); + } + ++i; } - - const tokenized = document.getElementById("tokenized"); - - const predict_chars_start = Math.max(range_to[0] - window_size * 2 + 1, 0); - const predict_chars_end = Math.min(range_to[1] + window_size * 2 - 1, cur_chars.length); - const predict_chars = cur_chars.slice(predict_chars_start, predict_chars_end); - - const boundary_start = Math.max(range_to[0] - window_size, 0); - const boundary_end = Math.min(range_to[1] + window_size - 1, cur_chars.length - 1); - - const predict_boundary_start = boundary_start - predict_chars_start; - const predict_boundary_end = boundary_end - predict_chars_start; - - const boundaries = predictor.predict_partial(predict_chars.join(""), predict_boundary_start, predict_boundary_end); - - console.log("input with window:", predict_chars); - console.log("prediction range:", [predict_boundary_start, predict_boundary_end]); - console.log("boundaries:", boundaries); - - replace_text(tokenized, prev_chars, cur_chars, range_from, range_to, boundaries, window_size); - - const t1 = performance.now(); - - console.log("Elapsed:", t1 - t0, "[ms]"); - console.log("-----"); - - prev_chars = cur_chars; }); -} - -init().then(run); +}); From fb79a8081bc2c4618e33592d80e5a2b3fa18c6ea Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 26 Nov 2021 16:15:04 +0900 Subject: [PATCH 09/60] Remove Predictor::predict_partial() (#6) * Remove Predictor::predict_partial() * Format * Remove unnecessary dict_window_size() * Update vaporetto/src/dict_scorer.rs Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- vaporetto/src/char_scorer.rs | 21 +-- vaporetto/src/dict_scorer.rs | 43 ++---- vaporetto/src/predictor.rs | 280 +++++------------------------------ vaporetto/src/type_scorer.rs | 45 ++---- 4 files changed, 68 insertions(+), 321 deletions(-) diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 31cc0911..da28bb1a 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -21,23 +21,10 @@ impl CharScorer { } } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.pma.find_overlapping_no_suffix_iter(&text) { - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; - let offset = m_end as isize - self.window_size as isize - padding as isize; + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { + let m_end = sentence.str_to_char_pos[m.end()]; + let offset = m_end as isize - self.window_size as isize - 1; let weights = &self.weights[m.pattern()]; if offset >= 0 { for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index 2c9e9326..668a9d30 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -5,7 +5,6 @@ use daachorse::DoubleArrayAhoCorasick; pub struct DictScorer { pma: DoubleArrayAhoCorasick, weights: Vec, - window_size: usize, word_wise_score: bool, } @@ -18,51 +17,29 @@ impl DictScorer { Self { pma, weights, - window_size: 1, word_wise_score, } } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let char_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let text_start = sentence.char_to_str_pos[char_start]; - let char_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_to_str_pos.len() - 1, - ); - let text_end = sentence.char_to_str_pos[char_end]; - let text = &sentence.text[text_start..text_end]; - let padding = start - char_start + 1; - for m in self.pma.find_overlapping_iter(&text) { - let m_start = sentence.str_to_char_pos[m.start() + text_start] - char_start; - let m_end = sentence.str_to_char_pos[m.end() + text_start] - char_start; + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + for m in self.pma.find_overlapping_iter(&sentence.text) { + let m_start = sentence.str_to_char_pos[m.start()]; + let m_end = sentence.str_to_char_pos[m.end()]; let idx = if self.word_wise_score { m.pattern() } else { std::cmp::min(m_end - m_start, self.weights.len()) - 1 }; let dict_weight = self.weights[idx]; - if m_start >= padding && m_start < padding + ys.len() { - ys[m_start - padding] += dict_weight.right; + if m_start != 0 { + ys[m_start - 1] += dict_weight.right; } - let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1); - let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize); - if range_start < range_end { - for y in &mut ys[range_start as usize..range_end as usize] { - *y += dict_weight.inner; - } + for y in &mut ys[m_start..m_end - 1] { + *y += dict_weight.inner; } - if m_end >= padding && m_end < ys.len() + padding { - ys[m_end - padding] += dict_weight.left; + if m_end <= ys.len() { + ys[m_end - 1] += dict_weight.left; } } } - - pub fn window_size(&mut self, size: usize) { - self.window_size = std::cmp::max(size, 1); - } } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 8829140f..ee4d7af8 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::ops::Range; use crate::char_scorer::CharScorer; use crate::dict_scorer::DictScorer; @@ -174,85 +173,15 @@ impl Predictor { result } - fn predict_partial_impl( - &self, - sentence: &Sentence, - range: Range, - ys: &mut [ScoreValue], - ) { + fn predict_impl(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { ys.fill(self.bias); - self.char_scorer.add_scores(sentence, range.start, ys); - self.type_scorer.add_scores(sentence, range.start, ys); + self.char_scorer.add_scores(sentence, ys); + self.type_scorer.add_scores(sentence, ys); if let Some(dict_scorer) = self.dict_scorer.as_ref() { - dict_scorer.add_scores(sentence, range.start, ys); + dict_scorer.add_scores(sentence, ys); } } - /// Predicts word boundaries of the specified range of a sentence. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// * `range` - The range of the sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_partial(&self, mut sentence: Sentence, range: Range) -> Sentence { - let mut ys = vec![ScoreValue::default(); range.len()]; - self.predict_partial_impl(&sentence, range.clone(), &mut ys); - for (y, b) in ys.into_iter().zip(sentence.boundaries[range].iter_mut()) { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - } - sentence - } - - /// Predicts word boundaries of the specified range of a sentence. This function inserts - /// scores. - /// - /// # Arguments - /// - /// * `sentence` - A sentence. - /// * `range` - The range of the sentence. - /// - /// # Returns - /// - /// A sentence with predicted boundary information. - pub fn predict_partial_with_score( - &self, - mut sentence: Sentence, - range: Range, - ) -> Sentence { - let mut ys = vec![ScoreValue::default(); range.len()]; - self.predict_partial_impl(&sentence, range.clone(), &mut ys); - let mut scores = sentence - .boundary_scores - .take() - .unwrap_or_else(|| vec![0.; sentence.boundaries.len()]); - for (y, (b, s)) in ys.into_iter().zip( - sentence.boundaries[range.clone()] - .iter_mut() - .zip(scores[range].iter_mut()), - ) { - *b = if y >= ScoreValue::default() { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }; - - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; - } - sentence.boundary_scores.replace(scores); - sentence - } - /// Predicts word boundaries. /// /// # Arguments @@ -262,13 +191,20 @@ impl Predictor { /// # Returns /// /// A sentence with predicted boundary information. - pub fn predict(&self, sentence: Sentence) -> Sentence { + pub fn predict(&self, mut sentence: Sentence) -> Sentence { let boundaries_size = sentence.boundaries.len(); - if boundaries_size == 0 { - sentence - } else { - self.predict_partial(sentence, 0..boundaries_size) + if boundaries_size != 0 { + let mut ys = vec![ScoreValue::default(); boundaries_size]; + self.predict_impl(&sentence, &mut ys); + for (y, b) in ys.into_iter().zip(sentence.boundaries.iter_mut()) { + *b = if y >= ScoreValue::default() { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }; + } } + sentence } /// Predicts word boundaries. This function inserts scores. @@ -280,29 +216,33 @@ impl Predictor { /// # Returns /// /// A sentence with predicted boundary information. - pub fn predict_with_score(&self, sentence: Sentence) -> Sentence { + pub fn predict_with_score(&self, mut sentence: Sentence) -> Sentence { let boundaries_size = sentence.boundaries.len(); - if boundaries_size == 0 { - sentence - } else { - self.predict_partial_with_score(sentence, 0..boundaries_size) - } - } - - /// Sets the window size of words in the dictionary. - /// - /// # Arguments - /// - /// * `size` - The window size. - /// - /// # Returns - /// - /// A predictor with the specified window size. - pub fn dict_window_size(mut self, size: usize) -> Self { - if let Some(dict_scorer) = self.dict_scorer.as_mut() { - dict_scorer.window_size(size); + if boundaries_size != 0 { + let mut ys = vec![ScoreValue::default(); boundaries_size]; + self.predict_impl(&sentence, &mut ys); + let mut scores = sentence + .boundary_scores + .take() + .unwrap_or_else(|| vec![0.; boundaries_size]); + for (y, (b, s)) in ys + .into_iter() + .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) + { + *b = if y >= ScoreValue::default() { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }; + + #[cfg(feature = "model-quantize")] + let y = y as f64 * self.quantize_multiplier; + + *s = y; + } + sentence.boundary_scores.replace(scores); } - self + sentence } } @@ -794,142 +734,4 @@ mod tests { s.boundary_scores().unwrap(), ); } - - #[test] - fn test_predict_partial_1() { - let model = generate_model_1(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 1..5); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_partial_2() { - let model = generate_model_2(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 2..7); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_partial_3() { - let model = generate_model_3(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial(s, 2..6); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - } - - #[test] - fn test_predict_partial_with_score_1() { - let model = generate_model_1(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 1..5); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - assert_eq!( - &[0.0, -2.5, 22.5, 66.0, 66.5, 0.0, 0.0, 0.0], - s.boundary_scores().unwrap(), - ); - } - - #[test] - fn test_predict_partial_with_score_2() { - let model = generate_model_2(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 2..7); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - assert_eq!( - &[0.0, 0.0, -9.75, 14.25, 26.0, 8.5, -19.75, 0.0], - s.boundary_scores().unwrap(), - ); - } - - #[test] - fn test_predict_partial_with_score_3() { - let model = generate_model_3(); - let p = Predictor::new(model); - let s = Sentence::from_raw("我らは全世界の国民").unwrap(); - let s = p.predict_partial_with_score(s, 2..6); - assert_eq!( - &[ - BoundaryType::Unknown, - BoundaryType::Unknown, - BoundaryType::NotWordBoundary, - BoundaryType::WordBoundary, - BoundaryType::WordBoundary, - BoundaryType::NotWordBoundary, - BoundaryType::Unknown, - BoundaryType::Unknown, - ], - s.boundaries(), - ); - assert_eq!( - &[0.0, 0.0, -20.75, 4.5, 16.25, -3.0, 0.0, 0.0], - s.boundary_scores().unwrap(), - ); - } } diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 0254d663..2e4b81dd 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -20,10 +20,10 @@ impl TypeScorer { } } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { match self { - TypeScorer::Pma(pma) => pma.add_scores(sentence, start, ys), - TypeScorer::Cache(cache) => cache.add_scores(sentence, start, ys), + TypeScorer::Pma(pma) => pma.add_scores(sentence, ys), + TypeScorer::Cache(cache) => cache.add_scores(sentence, ys), } } } @@ -47,20 +47,12 @@ impl TypeScorerPma { } } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let type_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let type_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_type.len(), - ); - let char_type = &sentence.char_type[type_start..type_end]; - let padding = start - type_start + 1; - for m in self.pma.find_overlapping_no_suffix_iter(&char_type) { - let offset = m.end() as isize - self.window_size as isize - padding as isize; + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + for m in self + .pma + .find_overlapping_no_suffix_iter(&sentence.char_type) + { + let offset = m.end() as isize - self.window_size as isize - 1; let weights = &self.weights[m.pattern()]; if offset >= 0 { for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { @@ -111,28 +103,17 @@ impl TypeScorerCache { } } - pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) { - let type_start = if start >= self.window_size { - start + 1 - self.window_size - } else { - 0 - }; - let type_end = std::cmp::min( - start + ys.len() + self.window_size, - sentence.char_type.len(), - ); - let char_type = &sentence.char_type[type_start..type_end]; - let offset = self.window_size + start; + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { let mut seqid = 0; - for i in 0..offset { - if let Some(ct) = char_type.get(i) { + for i in 0..self.window_size { + if let Some(ct) = sentence.char_type.get(i) { seqid = self.increment_seqid(seqid, *ct); } else { seqid = self.increment_seqid_without_char(seqid); }; } for (i, y) in ys.iter_mut().enumerate() { - if let Some(ct) = char_type.get(i + offset) { + if let Some(ct) = sentence.char_type.get(i + self.window_size) { seqid = self.increment_seqid(seqid, *ct); } else { seqid = self.increment_seqid_without_char(seqid); From 23a9a7df8afea3ac96eceaed5b93856c4161676e Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 29 Nov 2021 00:49:35 +0900 Subject: [PATCH 10/60] Update rust.yml (#2) --- .github/workflows/rust.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 0e735a18..79e78881 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,8 +1,8 @@ on: push: - branches: [ main ] + branches: [ main, develop ] pull_request: - branches: [ main ] + branches: [ main, develop ] name: build From 3ce7d7211683b1a721d4334ca68d16cea3635e84 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 29 Nov 2021 16:04:53 +0900 Subject: [PATCH 11/60] Support SIMD (#1) * Add simd feature * Use cfg_attr * Disable simd in stable Rust * Use std::simd * CharScorerVector -> CharScorerNaive * Fix var name * Remove unnecessary checking * Fix job name of CI * Add simd_len() function to CharScorerSimd * Remove unnecessary switch --- .github/workflows/rust.yml | 4 +- vaporetto/Cargo.toml | 1 + vaporetto/src/char_scorer.rs | 85 +++++++++++++++++++++++++++++++++++- vaporetto/src/lib.rs | 1 + vaporetto/src/predictor.rs | 71 +++++++++++++++++++++++++++--- 5 files changed, 152 insertions(+), 10 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 79e78881..3e0dbc1b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -48,11 +48,11 @@ jobs: command: test args: --release -p vaporetto --no-default-features - - name: Run cargo test (vaporetto / all-features) + - name: Run cargo test (vaporetto / features kytea+train) uses: actions-rs/cargo@v1 with: command: test - args: --release -p vaporetto --all-features + args: --release -p vaporetto --features kytea,train nightly: name: Nightly diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index ca1531c6..ab056251 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -26,6 +26,7 @@ default = ["model-quantize"] kytea = ["byteorder"] model-quantize = [] train = ["liblinear"] +simd = [] [package.metadata.docs.rs] all-features = true diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index da28bb1a..987ecc60 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -2,13 +2,52 @@ use crate::model::ScoreValue; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; -pub struct CharScorer { +#[cfg(feature = "simd")] +use std::simd::i32x8; + +pub enum CharScorer { + Naive(CharScorerNaive), + + #[cfg(feature = "simd")] + Simd(CharScorerSimd), +} + +impl CharScorer { + pub fn new( + pma: DoubleArrayAhoCorasick, + weights: Vec>, + window_size: usize, + ) -> Self { + #[cfg(not(feature = "simd"))] + { + Self::Naive(CharScorerNaive::new(pma, weights, window_size)) + } + + #[cfg(feature = "simd")] + if window_size <= 4 { + Self::Simd(CharScorerSimd::new(pma, weights, window_size)) + } else { + Self::Naive(CharScorerNaive::new(pma, weights, window_size)) + } + } + + pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { + match self { + CharScorer::Naive(naive) => naive.add_scores(sentence, &mut ys[padding..]), + + #[cfg(feature = "simd")] + CharScorer::Simd(simd) => simd.add_scores(sentence, padding, ys), + } + } +} + +pub struct CharScorerNaive { pma: DoubleArrayAhoCorasick, weights: Vec>, window_size: usize, } -impl CharScorer { +impl CharScorerNaive { pub fn new( pma: DoubleArrayAhoCorasick, weights: Vec>, @@ -38,3 +77,45 @@ impl CharScorer { } } } + +#[cfg(feature = "simd")] +pub struct CharScorerSimd { + pma: DoubleArrayAhoCorasick, + weights: Vec, + window_size: usize, +} + +#[cfg(feature = "simd")] +impl CharScorerSimd { + pub fn new(pma: DoubleArrayAhoCorasick, weights: Vec>, window_size: usize) -> Self { + let weights: Vec<_> = weights + .iter() + .map(|w| { + let mut s = [0i32; 8]; + s[..w.len()].copy_from_slice(&w); + i32x8::from_array(s) + }) + .collect(); + Self { + pma, + weights, + window_size, + } + } + + pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { + for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { + let m_end = sentence.str_to_char_pos[m.end()]; + let offset = padding as isize + m_end as isize - self.window_size as isize - 1; + let weights = &self.weights[m.pattern()]; + let ys_slice = &mut ys[offset as usize..offset as usize + 8]; + let mut target = i32x8::from_slice(ys_slice); + target += weights; + ys_slice.copy_from_slice(target.as_array()); + } + } + + pub const fn simd_len() -> usize { + 8 + } +} diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index aead06a6..b59cd2aa 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -1,4 +1,5 @@ #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] //! # Vaporetto //! diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index ee4d7af8..b83d6f16 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -6,6 +6,9 @@ use crate::model::{DictWeight, Model, ScoreValue}; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; +#[cfg(feature = "simd")] +use crate::char_scorer::CharScorerSimd; + use daachorse::DoubleArrayAhoCorasick; /// Predictor. @@ -18,6 +21,9 @@ pub struct Predictor { #[cfg(feature = "model-quantize")] quantize_multiplier: f64, + + #[cfg(feature = "simd")] + padding: usize, } impl Predictor { @@ -88,6 +94,9 @@ impl Predictor { #[cfg(feature = "model-quantize")] quantize_multiplier: model.quantize_multiplier, + + #[cfg(feature = "simd")] + padding: model.char_window_size.max(model.type_window_size), } } @@ -173,12 +182,12 @@ impl Predictor { result } - fn predict_impl(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { ys.fill(self.bias); - self.char_scorer.add_scores(sentence, ys); - self.type_scorer.add_scores(sentence, ys); + self.char_scorer.add_scores(sentence, padding, ys); + self.type_scorer.add_scores(sentence, &mut ys[padding..]); if let Some(dict_scorer) = self.dict_scorer.as_ref() { - dict_scorer.add_scores(sentence, ys); + dict_scorer.add_scores(sentence, &mut ys[padding..]); } } @@ -193,9 +202,11 @@ impl Predictor { /// A sentence with predicted boundary information. pub fn predict(&self, mut sentence: Sentence) -> Sentence { let boundaries_size = sentence.boundaries.len(); + + #[cfg(not(feature = "simd"))] if boundaries_size != 0 { let mut ys = vec![ScoreValue::default(); boundaries_size]; - self.predict_impl(&sentence, &mut ys); + self.predict_impl(&sentence, 0, &mut ys); for (y, b) in ys.into_iter().zip(sentence.boundaries.iter_mut()) { *b = if y >= ScoreValue::default() { BoundaryType::WordBoundary @@ -204,6 +215,24 @@ impl Predictor { }; } } + + #[cfg(feature = "simd")] + if boundaries_size != 0 { + let ys_size = boundaries_size + self.padding + CharScorerSimd::simd_len() - 1; + let mut ys = vec![ScoreValue::default(); ys_size]; + self.predict_impl(&sentence, self.padding, &mut ys); + for (&y, b) in ys[self.padding..] + .into_iter() + .zip(sentence.boundaries.iter_mut()) + { + *b = if y >= ScoreValue::default() { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }; + } + } + sentence } @@ -218,9 +247,11 @@ impl Predictor { /// A sentence with predicted boundary information. pub fn predict_with_score(&self, mut sentence: Sentence) -> Sentence { let boundaries_size = sentence.boundaries.len(); + + #[cfg(not(feature = "simd"))] if boundaries_size != 0 { let mut ys = vec![ScoreValue::default(); boundaries_size]; - self.predict_impl(&sentence, &mut ys); + self.predict_impl(&sentence, 0, &mut ys); let mut scores = sentence .boundary_scores .take() @@ -242,6 +273,34 @@ impl Predictor { } sentence.boundary_scores.replace(scores); } + + #[cfg(feature = "simd")] + if boundaries_size != 0 { + let ys_size = boundaries_size + self.padding + CharScorerSimd::simd_len() - 1; + let mut ys = vec![ScoreValue::default(); ys_size]; + self.predict_impl(&sentence, self.padding, &mut ys); + let mut scores = sentence + .boundary_scores + .take() + .unwrap_or_else(|| vec![0.; boundaries_size]); + for (&y, (b, s)) in ys[self.padding..] + .into_iter() + .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) + { + *b = if y >= ScoreValue::default() { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }; + + #[cfg(feature = "model-quantize")] + let y = y as f64 * self.quantize_multiplier; + + *s = y; + } + sentence.boundary_scores.replace(scores); + } + sentence } } From 6d329163f46e9fce585f0c3bd1968dd7ed3834be Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 29 Nov 2021 20:50:13 +0900 Subject: [PATCH 12/60] Add descriptions of features (#3) * Update README.md * Update lib.rs * Update README.md --- vaporetto/README.md | 17 +++++++++++------ vaporetto/src/lib.rs | 10 ++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vaporetto/README.md b/vaporetto/README.md index 9e309661..6b774112 100644 --- a/vaporetto/README.md +++ b/vaporetto/README.md @@ -14,14 +14,19 @@ let mut f = BufReader::new(File::open("model.raw").unwrap()); let model = Model::read(&mut f).unwrap(); let predictor = Predictor::new(model); -for line in stdin().lock().lines() { - let s = Sentence::from_raw(line.unwrap()).unwrap(); - let s = predictor.predict(s); - let toks = s.to_tokenized_string().unwrap(); - println!("{}", toks); -} +let s = Sentence::from_raw("火星猫の生態").unwrap(); +let s = predictor.predict(s); + +println!("{:?}", s.to_tokenized_vec().unwrap()); +// ["火星", "猫", "の", "生態"] ``` +## Feature flags + +* `kytea` - Enables the reader for models generated by KyTea. +* `train` - Enables the trainer. +* `simd` - Use the SIMD operations for prediction. (Nightly version of Rust is required.) + ## License Licensed under either of diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index b59cd2aa..2d537fc8 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -17,12 +17,10 @@ //! let model = Model::read(&mut f).unwrap(); //! let predictor = Predictor::new(model); //! -//! for line in stdin().lock().lines() { -//! let s = Sentence::from_raw(line.unwrap()).unwrap(); -//! let s = predictor.predict(s); -//! let toks = s.to_tokenized_string().unwrap(); -//! println!("{}", toks); -//! } +//! let s = Sentence::from_raw("火星猫の生態").unwrap(); +//! let s = predictor.predict(s); +//! +//! println!("{:?}", s.to_tokenized_vec().unwrap()); //! ``` //! //! Training requires **crate feature** `train`. For more details, see [`Trainer`]. From 5eaa397faaa81325559d5e90df6ac121282cbe4f Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 30 Nov 2021 13:30:39 +0900 Subject: [PATCH 13/60] Validate length of patterns and weights (#4) * Validate length of patterns and weights * Format * Fix bugs * Fix format * Add comments --- vaporetto/src/char_scorer.rs | 39 +++++++++------- vaporetto/src/dict_scorer.rs | 14 +++--- vaporetto/src/kytea_model.rs | 14 ++---- vaporetto/src/model.rs | 10 ++--- vaporetto/src/predictor.rs | 87 +++++++++++++++--------------------- vaporetto/src/trainer.rs | 4 +- vaporetto/src/type_scorer.rs | 41 +++++++++-------- vaporetto/src/utils.rs | 6 ++- 8 files changed, 106 insertions(+), 109 deletions(-) diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 987ecc60..07c6652e 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -13,21 +13,20 @@ pub enum CharScorer { } impl CharScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { #[cfg(not(feature = "simd"))] { - Self::Naive(CharScorerNaive::new(pma, weights, window_size)) + Self::Naive(CharScorerNaive::new(ngrams, weights, window_size)) } #[cfg(feature = "simd")] if window_size <= 4 { - Self::Simd(CharScorerSimd::new(pma, weights, window_size)) + Self::Simd(CharScorerSimd::new(ngrams, weights, window_size)) } else { - Self::Naive(CharScorerNaive::new(pma, weights, window_size)) + Self::Naive(CharScorerNaive::new(ngrams, weights, window_size)) } } @@ -48,13 +47,15 @@ pub struct CharScorerNaive { } impl CharScorerNaive { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { + if ngrams.len() != weights.len() { + panic!("ngrams.len() != weights.len()"); + } Self { - pma, + pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), weights, window_size, } @@ -87,7 +88,13 @@ pub struct CharScorerSimd { #[cfg(feature = "simd")] impl CharScorerSimd { - pub fn new(pma: DoubleArrayAhoCorasick, weights: Vec>, window_size: usize) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { + if ngrams.len() != weights.len() { + panic!("ngrams.len() != weights.len()"); + } let weights: Vec<_> = weights .iter() .map(|w| { @@ -97,7 +104,7 @@ impl CharScorerSimd { }) .collect(); Self { - pma, + pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), weights, window_size, } diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index 668a9d30..f8ae1604 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -9,13 +9,15 @@ pub struct DictScorer { } impl DictScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec, - word_wise_score: bool, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(words: &[String], weights: Vec, word_wise_score: bool) -> Self { + if word_wise_score && words.len() != weights.len() { + panic!("word_wise_score == true && words.len() != weights.len()"); + } Self { - pma, + pma: DoubleArrayAhoCorasick::new(words).unwrap(), weights, word_wise_score, } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 60585291..bc721177 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -409,17 +409,11 @@ impl TryFrom for Model { .type_dict .ok_or_else(|| anyhow!("no type dictionary."))?; - let mut char_ngrams: Vec> = vec![]; + let mut char_ngrams: Vec = vec![]; let mut char_ngram_weights = vec![]; for (char_ngram, v) in char_dict.dump_items() { let weight_size = config.char_w as usize * 2 - char_ngram.len() + 1; - char_ngrams.push( - char_ngram - .into_iter() - .collect::() - .as_bytes() - .to_vec(), - ); + char_ngrams.push(char_ngram.into_iter().collect::()); char_ngram_weights.push(v[..weight_size].to_vec()); } @@ -437,7 +431,7 @@ impl TryFrom for Model { type_ngram_weights.push(v[..weight_size].to_vec()); } - let mut dict: Vec> = vec![]; + let mut dict: Vec = vec![]; let mut dict_weights = vec![]; if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { @@ -452,7 +446,7 @@ impl TryFrom for Model { } } dict_weights.push(weights); - dict.push(w.into_iter().collect::().as_bytes().to_vec()); + dict.push(w.into_iter().collect()); } } diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 35747b0e..6a500ee9 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -33,9 +33,9 @@ pub struct DictWeight { /// Model data. #[derive(Serialize, Deserialize)] pub struct Model { - pub(crate) char_ngrams: Vec>, + pub(crate) char_ngrams: Vec, pub(crate) type_ngrams: Vec>, - pub(crate) dict: Vec>, + pub(crate) dict: Vec, pub(crate) char_ngram_weights: Vec>, pub(crate) type_ngram_weights: Vec>, @@ -93,7 +93,7 @@ impl Model { pub(crate) fn from_liblinear_model( model: impl LibLinearModel, fid_manager: FeatureIDManager, - dict: Vec>, + dict: Vec, char_window_size: usize, type_window_size: usize, dict_word_max_size: usize, @@ -139,9 +139,9 @@ impl Model { match feature.feature { FeatureContent::CharacterNgram(char_ngram) => { - let id = char_ngram_ids.get_id(char_ngram.as_bytes()); + let id = char_ngram_ids.get_id(&char_ngram); if id == char_ngram_weights.len() { - char_ngrams.push(char_ngram.as_bytes().to_vec()); + char_ngrams.push(char_ngram.to_string()); char_ngram_weights.push(vec![ WeightValue::default(); char_window_size * 2 diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index b83d6f16..89e5088e 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -9,8 +9,6 @@ use crate::type_scorer::TypeScorer; #[cfg(feature = "simd")] use crate::char_scorer::CharScorerSimd; -use daachorse::DoubleArrayAhoCorasick; - /// Predictor. pub struct Predictor { bias: ScoreValue, @@ -69,20 +67,16 @@ impl Predictor { #[cfg(feature = "model-quantize")] let bias = bias as i32; - let char_pma = DoubleArrayAhoCorasick::new(char_ngrams).unwrap(); - let type_pma = DoubleArrayAhoCorasick::new(model.type_ngrams).unwrap(); - - let char_scorer = CharScorer::new(char_pma, char_ngram_weights, model.char_window_size); - let type_scorer = TypeScorer::new(type_pma, type_ngram_weights, model.type_window_size); + let char_scorer = CharScorer::new(&char_ngrams, char_ngram_weights, model.char_window_size); + let type_scorer = TypeScorer::new( + &model.type_ngrams, + type_ngram_weights, + model.type_window_size, + ); let dict_scorer = if dict.is_empty() { None } else { - let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap(); - Some(DictScorer::new( - dict_pma, - dict_weights, - model.dict_word_wise, - )) + Some(DictScorer::new(&dict, dict_weights, model.dict_word_wise)) }; Self { @@ -101,13 +95,13 @@ impl Predictor { } fn merge_dict_weights( - dict: Vec>, + dict: Vec, dict_weights: Vec, - words: &[Vec], + words: &[String], word_weights: &mut Vec>, char_window_size: usize, dict_word_wise: bool, - ) -> (Vec>, Vec) { + ) -> (Vec, Vec) { let mut word_map = HashMap::new(); for (i, word) in words.iter().cloned().enumerate() { word_map.insert(word, i); @@ -116,7 +110,7 @@ impl Predictor { if dict_word_wise { let mut new_dict_weights = vec![]; for (word, weight) in dict.into_iter().zip(dict_weights) { - let word_size = std::str::from_utf8(&word).unwrap().chars().count(); + let word_size = word.chars().count(); match word_map.get(&word) { Some(&idx) if char_window_size >= word_size => { let start = char_window_size - word_size; @@ -136,7 +130,7 @@ impl Predictor { (new_dict, new_dict_weights) } else { for word in dict { - let word_size = std::str::from_utf8(&word).unwrap().chars().count(); + let word_size = word.chars().count(); match word_map.get(&word) { Some(&idx) if char_window_size >= word_size => { let start = char_window_size - word_size; @@ -156,15 +150,18 @@ impl Predictor { } } - fn merge_weights(words: &[Vec], weights: &[Vec]) -> Vec> { + fn merge_weights

(words: &[P], weights: &[Vec]) -> Vec> + where + P: AsRef<[u8]>, + { let mut result = vec![]; let word_ids = words .iter() - .cloned() .enumerate() - .map(|(i, w)| (w, i)) + .map(|(i, w)| (w.as_ref().to_vec(), i)) .collect::, usize>>(); for seq in words { + let seq = seq.as_ref(); let mut new_weights: Option> = None; for st in (0..seq.len()).rev() { if let Some(&idx) = word_ids.get(&seq[st..]) { @@ -338,18 +335,14 @@ mod tests { fn generate_model_1() -> Model { Model { char_ngrams: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), + "我ら".to_string(), + "全世界".to_string(), + "国民".to_string(), + "世界".to_string(), + "界".to_string(), ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "全世界".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], + dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], #[cfg(not(feature = "model-quantize"))] char_ngram_weights: vec![ vec![0.5, 1.0, 1.5, 2.0, 2.5], @@ -447,18 +440,14 @@ mod tests { fn generate_model_2() -> Model { Model { char_ngrams: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), + "我ら".to_string(), + "全世界".to_string(), + "国民".to_string(), + "世界".to_string(), + "界".to_string(), ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "全世界".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], + dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], #[cfg(not(feature = "model-quantize"))] char_ngram_weights: vec![ vec![0.25, 0.5, 0.75], @@ -566,18 +555,14 @@ mod tests { fn generate_model_3() -> Model { Model { char_ngrams: vec![ - "我ら".as_bytes().to_vec(), - "全世界".as_bytes().to_vec(), - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "界".as_bytes().to_vec(), + "我ら".to_string(), + "全世界".to_string(), + "国民".to_string(), + "世界".to_string(), + "界".to_string(), ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], - dict: vec![ - "国民".as_bytes().to_vec(), - "世界".as_bytes().to_vec(), - "世".as_bytes().to_vec(), - ], + dict: vec!["国民".to_string(), "世界".to_string(), "世".to_string()], #[cfg(not(feature = "model-quantize"))] char_ngram_weights: vec![ vec![0.25, 0.5, 0.75], diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index b81de851..50a5f6a9 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -72,7 +72,7 @@ impl From for liblinear::SolverType { /// Dataset manager. #[cfg_attr(docsrs, doc(cfg(feature = "train")))] pub struct Dataset<'a> { - dictionary: Vec>, + dictionary: Vec, feature_extractor: FeatureExtractor, example_generator: ExampleGenerator, char_window_size: usize, @@ -118,7 +118,7 @@ impl<'a> Dataset<'a> { dictionary: dictionary .as_ref() .iter() - .map(|word| (word.as_ref() as &[u8]).to_vec()) + .map(|word| (word.as_ref() as &str).to_string()) .collect(), feature_extractor: FeatureExtractor::new( char_ngram_size, diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 2e4b81dd..f2d2b6da 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -8,15 +8,14 @@ pub enum TypeScorer { } impl TypeScorer { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { if window_size <= 3 { - Self::Cache(TypeScorerCache::new(pma, weights, window_size)) + Self::Cache(TypeScorerCache::new(ngrams, weights, window_size)) } else { - Self::Pma(TypeScorerPma::new(pma, weights, window_size)) + Self::Pma(TypeScorerPma::new(ngrams, weights, window_size)) } } @@ -35,13 +34,15 @@ pub struct TypeScorerPma { } impl TypeScorerPma { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + if ngrams.len() != weights.len() { + panic!("ngrams.len() != weights.len()"); + } Self { - pma, + pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), weights, window_size, } @@ -74,11 +75,15 @@ pub struct TypeScorerCache { } impl TypeScorerCache { - pub fn new( - pma: DoubleArrayAhoCorasick, - weights: Vec>, - window_size: usize, - ) -> Self { + /// # Panics + /// + /// `ngrams` and `weights` must have same number of entries. + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + if ngrams.len() != weights.len() { + panic!("ngrams.len() != weights.len()"); + } + let pma = DoubleArrayAhoCorasick::new(ngrams).unwrap(); + let sequence_size = window_size * 2; let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32); diff --git a/vaporetto/src/utils.rs b/vaporetto/src/utils.rs index 47b51b80..c8f3c422 100644 --- a/vaporetto/src/utils.rs +++ b/vaporetto/src/utils.rs @@ -46,7 +46,11 @@ impl StringIdManager { } } - pub fn get_id(&mut self, key: &[u8]) -> usize { + pub fn get_id(&mut self, key: K) -> usize + where + K: AsRef<[u8]>, + { + let key = key.as_ref(); self.map.get(key).copied().unwrap_or_else(|| { let new_id = self.map.len(); self.map.insert(key.into(), new_id); From 9d3ca6b09280bdc43c2cced229632d00ad627516 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 1 Dec 2021 00:29:52 +0900 Subject: [PATCH 14/60] Remove FP support and use i32 in model file (#5) * Use i32 for holding quantized weights * Remove model-quantize feature and remove supporting FP numbers * Use 24bit for quantization * 24bit -> 16bit * Fix a bug * Add a comment * Rename BIT_DEPTH -> QUANTIZE_BIT_DEPTH --- vaporetto/Cargo.toml | 3 +- vaporetto/src/char_scorer.rs | 13 ++- vaporetto/src/dict_scorer.rs | 4 +- vaporetto/src/kytea_model.rs | 7 +- vaporetto/src/model.rs | 65 ++++++-------- vaporetto/src/predictor.rs | 168 ++++------------------------------- vaporetto/src/type_scorer.rs | 23 +++-- 7 files changed, 66 insertions(+), 217 deletions(-) diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index ab056251..44161ed3 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -22,9 +22,8 @@ byteorder = { version = "1.4", optional = true } # Unlicense or MIT liblinear = { version = "1", optional = true } # MIT [features] -default = ["model-quantize"] +default = [] kytea = ["byteorder"] -model-quantize = [] train = ["liblinear"] simd = [] diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 07c6652e..205cbe64 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -1,4 +1,3 @@ -use crate::model::ScoreValue; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; @@ -16,7 +15,7 @@ impl CharScorer { /// # Panics /// /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { + pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { #[cfg(not(feature = "simd"))] { Self::Naive(CharScorerNaive::new(ngrams, weights, window_size)) @@ -30,7 +29,7 @@ impl CharScorer { } } - pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { match self { CharScorer::Naive(naive) => naive.add_scores(sentence, &mut ys[padding..]), @@ -42,7 +41,7 @@ impl CharScorer { pub struct CharScorerNaive { pma: DoubleArrayAhoCorasick, - weights: Vec>, + weights: Vec>, window_size: usize, } @@ -50,7 +49,7 @@ impl CharScorerNaive { /// # Panics /// /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { + pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { if ngrams.len() != weights.len() { panic!("ngrams.len() != weights.len()"); } @@ -61,7 +60,7 @@ impl CharScorerNaive { } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { let m_end = sentence.str_to_char_pos[m.end()]; let offset = m_end as isize - self.window_size as isize - 1; @@ -110,7 +109,7 @@ impl CharScorerSimd { } } - pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { let m_end = sentence.str_to_char_pos[m.end()]; let offset = padding as isize + m_end as isize - self.window_size as isize - 1; diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index f8ae1604..5afc62f5 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -1,4 +1,4 @@ -use crate::model::{DictWeight, ScoreValue}; +use crate::model::DictWeight; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; @@ -23,7 +23,7 @@ impl DictScorer { } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { for m in self.pma.find_overlapping_iter(&sentence.text) { let m_start = sentence.str_to_char_pos[m.start()]; let m_end = sentence.str_to_char_pos[m.end()]; diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index bc721177..a486e0d9 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -401,7 +401,7 @@ impl TryFrom for Model { let feature_lookup = wordseg_model .feature_lookup .ok_or_else(|| anyhow!("no lookup data."))?; - let bias = feature_lookup.biases[0]; + let bias = feature_lookup.biases[0] as i32; let char_dict = feature_lookup .char_dict .ok_or_else(|| anyhow!("no character dictionary."))?; @@ -414,7 +414,7 @@ impl TryFrom for Model { for (char_ngram, v) in char_dict.dump_items() { let weight_size = config.char_w as usize * 2 - char_ngram.len() + 1; char_ngrams.push(char_ngram.into_iter().collect::()); - char_ngram_weights.push(v[..weight_size].to_vec()); + char_ngram_weights.push(v[..weight_size].iter().map(|&w| w as i32).collect()); } let mut type_ngrams: Vec> = vec![]; @@ -428,7 +428,7 @@ impl TryFrom for Model { .as_bytes() .to_vec(), ); - type_ngram_weights.push(v[..weight_size].to_vec()); + type_ngram_weights.push(v[..weight_size].iter().map(|&w| w as i32).collect()); } let mut dict: Vec = vec![]; @@ -455,7 +455,6 @@ impl TryFrom for Model { type_ngrams, dict, - #[cfg(feature = "model-quantize")] quantize_multiplier, char_ngram_weights, diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 6a500ee9..211eafd3 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -11,23 +11,19 @@ use crate::sentence::BoundaryType; use crate::utils::{FeatureIDManager, StringIdManager}; #[cfg(feature = "train")] use liblinear::LibLinearModel; + #[cfg(feature = "train")] const EPSILON: f64 = 1e-6; -#[cfg(not(feature = "model-quantize"))] -pub type WeightValue = f64; -#[cfg(feature = "model-quantize")] -pub type WeightValue = i16; -#[cfg(not(feature = "model-quantize"))] -pub type ScoreValue = f64; -#[cfg(feature = "model-quantize")] -pub type ScoreValue = i32; +// Bit depth for weight quantization. +#[cfg(feature = "train")] +const QUANTIZE_BIT_DEPTH: u8 = 16; #[derive(Clone, Copy, Default, Serialize, Deserialize)] pub struct DictWeight { - pub right: ScoreValue, - pub inner: ScoreValue, - pub left: ScoreValue, + pub right: i32, + pub inner: i32, + pub left: i32, } /// Model data. @@ -37,16 +33,15 @@ pub struct Model { pub(crate) type_ngrams: Vec>, pub(crate) dict: Vec, - pub(crate) char_ngram_weights: Vec>, - pub(crate) type_ngram_weights: Vec>, + pub(crate) char_ngram_weights: Vec>, + pub(crate) type_ngram_weights: Vec>, pub(crate) dict_weights: Vec, - #[cfg(feature = "model-quantize")] pub(crate) quantize_multiplier: f64, pub(crate) dict_word_wise: bool, - pub(crate) bias: WeightValue, + pub(crate) bias: i32, pub(crate) char_window_size: usize, pub(crate) type_window_size: usize, } @@ -113,20 +108,16 @@ impl Model { let mut char_ngram_ids = StringIdManager::new(); let mut type_ngram_ids = StringIdManager::new(); - #[cfg(feature = "model-quantize")] - let quantize_multiplier = { - let mut weight_max = bias.abs(); - for fid in 0..model.num_features() { - let weight = model.feature_coefficient(fid as i32, wb_idx).abs(); - if weight > weight_max { - weight_max = weight; - } + let mut weight_max = bias.abs(); + for fid in 0..model.num_features() { + let weight = model.feature_coefficient(fid as i32, wb_idx).abs(); + if weight > weight_max { + weight_max = weight; } - weight_max / 32767. - }; + } + let quantize_multiplier = weight_max / ((1 << (QUANTIZE_BIT_DEPTH - 1)) - 1) as f64; - #[cfg(feature = "model-quantize")] - let bias = (bias / quantize_multiplier) as i16; + let bias = (bias / quantize_multiplier) as i32; for (feature, fid) in fid_manager.map { let weight = model.feature_coefficient(fid as i32 + 1, wb_idx); @@ -134,7 +125,6 @@ impl Model { continue; } - #[cfg(feature = "model-quantize")] let weight = weight / quantize_multiplier; match feature.feature { @@ -143,29 +133,27 @@ impl Model { if id == char_ngram_weights.len() { char_ngrams.push(char_ngram.to_string()); char_ngram_weights.push(vec![ - WeightValue::default(); + 0; char_window_size * 2 - char_ngram.chars().count() + 1 ]); } - char_ngram_weights[id][feature.rel_position] = weight as WeightValue; + char_ngram_weights[id][feature.rel_position] = weight as i32; } FeatureContent::CharacterTypeNgram(type_ngram) => { let id = type_ngram_ids.get_id(type_ngram) as usize; if id == type_ngram_weights.len() { type_ngrams.push(type_ngram.to_vec()); - type_ngram_weights.push(vec![ - WeightValue::default(); - type_window_size * 2 - type_ngram.len() + 1 - ]); + type_ngram_weights + .push(vec![0; type_window_size * 2 - type_ngram.len() + 1]); } - type_ngram_weights[id][feature.rel_position] = weight as WeightValue; + type_ngram_weights[id][feature.rel_position] = weight as i32; } FeatureContent::DictionaryWord(size) => match feature.rel_position { - 0 => dict_weights[size - 1].right = weight as ScoreValue, - 1 => dict_weights[size - 1].inner = weight as ScoreValue, - 2 => dict_weights[size - 1].left = weight as ScoreValue, + 0 => dict_weights[size - 1].right = weight as i32, + 1 => dict_weights[size - 1].inner = weight as i32, + 2 => dict_weights[size - 1].left = weight as i32, _ => panic!("Invalid rel_position"), }, }; @@ -175,7 +163,6 @@ impl Model { type_ngrams, dict, - #[cfg(feature = "model-quantize")] quantize_multiplier, char_ngram_weights, diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 89e5088e..e9c04e19 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use crate::char_scorer::CharScorer; use crate::dict_scorer::DictScorer; -use crate::model::{DictWeight, Model, ScoreValue}; +use crate::model::{DictWeight, Model}; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -11,13 +11,12 @@ use crate::char_scorer::CharScorerSimd; /// Predictor. pub struct Predictor { - bias: ScoreValue, + bias: i32, char_scorer: CharScorer, type_scorer: TypeScorer, dict_scorer: Option, - #[cfg(feature = "model-quantize")] quantize_multiplier: f64, #[cfg(feature = "simd")] @@ -41,16 +40,8 @@ impl Predictor { let dict = model.dict; let dict_weights = model.dict_weights; - let mut char_ngram_weights: Vec<_> = model - .char_ngram_weights - .into_iter() - .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) - .collect(); - let type_ngram_weights: Vec<_> = model - .type_ngram_weights - .into_iter() - .map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect()) - .collect(); + let mut char_ngram_weights = model.char_ngram_weights; + let type_ngram_weights = model.type_ngram_weights; let (dict, dict_weights) = Self::merge_dict_weights( dict, @@ -64,9 +55,6 @@ impl Predictor { let char_ngram_weights = Self::merge_weights(&char_ngrams, &char_ngram_weights); let type_ngram_weights = Self::merge_weights(&model.type_ngrams, &type_ngram_weights); - #[cfg(feature = "model-quantize")] - let bias = bias as i32; - let char_scorer = CharScorer::new(&char_ngrams, char_ngram_weights, model.char_window_size); let type_scorer = TypeScorer::new( &model.type_ngrams, @@ -86,7 +74,6 @@ impl Predictor { type_scorer, dict_scorer, - #[cfg(feature = "model-quantize")] quantize_multiplier: model.quantize_multiplier, #[cfg(feature = "simd")] @@ -98,7 +85,7 @@ impl Predictor { dict: Vec, dict_weights: Vec, words: &[String], - word_weights: &mut Vec>, + word_weights: &mut Vec>, char_window_size: usize, dict_word_wise: bool, ) -> (Vec, Vec) { @@ -150,7 +137,7 @@ impl Predictor { } } - fn merge_weights

(words: &[P], weights: &[Vec]) -> Vec> + fn merge_weights

(words: &[P], weights: &[Vec]) -> Vec> where P: AsRef<[u8]>, { @@ -179,7 +166,7 @@ impl Predictor { result } - fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [ScoreValue]) { + fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { ys.fill(self.bias); self.char_scorer.add_scores(sentence, padding, ys); self.type_scorer.add_scores(sentence, &mut ys[padding..]); @@ -202,10 +189,10 @@ impl Predictor { #[cfg(not(feature = "simd"))] if boundaries_size != 0 { - let mut ys = vec![ScoreValue::default(); boundaries_size]; + let mut ys = vec![0; boundaries_size]; self.predict_impl(&sentence, 0, &mut ys); for (y, b) in ys.into_iter().zip(sentence.boundaries.iter_mut()) { - *b = if y >= ScoreValue::default() { + *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary @@ -216,13 +203,13 @@ impl Predictor { #[cfg(feature = "simd")] if boundaries_size != 0 { let ys_size = boundaries_size + self.padding + CharScorerSimd::simd_len() - 1; - let mut ys = vec![ScoreValue::default(); ys_size]; + let mut ys = vec![0; ys_size]; self.predict_impl(&sentence, self.padding, &mut ys); for (&y, b) in ys[self.padding..] .into_iter() .zip(sentence.boundaries.iter_mut()) { - *b = if y >= ScoreValue::default() { + *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary @@ -247,7 +234,7 @@ impl Predictor { #[cfg(not(feature = "simd"))] if boundaries_size != 0 { - let mut ys = vec![ScoreValue::default(); boundaries_size]; + let mut ys = vec![0; boundaries_size]; self.predict_impl(&sentence, 0, &mut ys); let mut scores = sentence .boundary_scores @@ -257,16 +244,13 @@ impl Predictor { .into_iter() .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) { - *b = if y >= ScoreValue::default() { + *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary }; - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; + *s = y as f64 * self.quantize_multiplier; } sentence.boundary_scores.replace(scores); } @@ -274,7 +258,7 @@ impl Predictor { #[cfg(feature = "simd")] if boundaries_size != 0 { let ys_size = boundaries_size + self.padding + CharScorerSimd::simd_len() - 1; - let mut ys = vec![ScoreValue::default(); ys_size]; + let mut ys = vec![0; ys_size]; self.predict_impl(&sentence, self.padding, &mut ys); let mut scores = sentence .boundary_scores @@ -284,16 +268,13 @@ impl Predictor { .into_iter() .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) { - *b = if y >= ScoreValue::default() { + *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary }; - #[cfg(feature = "model-quantize")] - let y = y as f64 * self.quantize_multiplier; - - *s = y; + *s = y as f64 * self.quantize_multiplier; } sentence.boundary_scores.replace(scores); } @@ -343,15 +324,6 @@ mod tests { ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - #[cfg(not(feature = "model-quantize"))] - char_ngram_weights: vec![ - vec![0.5, 1.0, 1.5, 2.0, 2.5], - vec![3.0, 3.5, 4.0, 4.5], - vec![5.0, 5.5, 6.0, 6.5, 7.0], - vec![7.5, 8.0, 8.5, 9.0, 9.5], - vec![10.0, 10.5, 11.0, 11.5, 12.0, 12.5], - ], - #[cfg(feature = "model-quantize")] char_ngram_weights: vec![ vec![1, 2, 3, 4, 5], vec![6, 7, 8, 9], @@ -359,34 +331,12 @@ mod tests { vec![15, 16, 17, 18, 19], vec![20, 21, 22, 23, 24, 25], ], - #[cfg(not(feature = "model-quantize"))] - type_ngram_weights: vec![ - vec![13.0, 13.5, 14.0, 14.5], - vec![15.0, 15.5, 16.0, 16.5], - vec![17.0, 17.5, 18.0], - vec![18.5, 19.0, 19.5], - ], - #[cfg(feature = "model-quantize")] type_ngram_weights: vec![ vec![26, 27, 28, 29], vec![30, 31, 32, 33], vec![34, 35, 36], vec![37, 38, 39], ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 20.0, - inner: 20.5, - left: 21.0, - }, - DictWeight { - right: 21.5, - inner: 22.0, - left: 22.5, - }, - ], - #[cfg(feature = "model-quantize")] dict_weights: vec![ DictWeight { right: 40, @@ -399,12 +349,8 @@ mod tests { left: 45, }, ], - #[cfg(feature = "model-quantize")] quantize_multiplier: 0.5, dict_word_wise: false, - #[cfg(not(feature = "model-quantize"))] - bias: -100.0, - #[cfg(feature = "model-quantize")] bias: -200, char_window_size: 3, type_window_size: 2, @@ -448,15 +394,6 @@ mod tests { ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - #[cfg(not(feature = "model-quantize"))] - char_ngram_weights: vec![ - vec![0.25, 0.5, 0.75], - vec![1.0, 1.25], - vec![1.5, 1.75, 2.0], - vec![2.25, 2.5, 2.75], - vec![3.0, 3.25, 3.5, 3.75], - ], - #[cfg(feature = "model-quantize")] char_ngram_weights: vec![ vec![1, 2, 3], vec![4, 5], @@ -464,39 +401,12 @@ mod tests { vec![9, 10, 11], vec![12, 13, 14, 15], ], - #[cfg(not(feature = "model-quantize"))] - type_ngram_weights: vec![ - vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], - vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], - vec![7.0, 7.25, 7.5, 7.75, 8.0], - vec![8.25, 8.5, 8.75, 9.0, 9.25], - ], - #[cfg(feature = "model-quantize")] type_ngram_weights: vec![ vec![16, 17, 18, 19, 20, 21], vec![22, 23, 24, 25, 26, 27], vec![28, 29, 30, 31, 32], vec![33, 34, 35, 36, 37], ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 9.5, - inner: 9.75, - left: 10.0, - }, - DictWeight { - right: 10.25, - inner: 10.5, - left: 10.75, - }, - DictWeight { - right: 11.0, - inner: 11.25, - left: 11.5, - }, - ], - #[cfg(feature = "model-quantize")] dict_weights: vec![ DictWeight { right: 38, @@ -514,12 +424,8 @@ mod tests { left: 46, }, ], - #[cfg(feature = "model-quantize")] quantize_multiplier: 0.25, dict_word_wise: false, - #[cfg(not(feature = "model-quantize"))] - bias: -71.25, - #[cfg(feature = "model-quantize")] bias: -285, char_window_size: 2, type_window_size: 3, @@ -563,15 +469,6 @@ mod tests { ], type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], dict: vec!["国民".to_string(), "世界".to_string(), "世".to_string()], - #[cfg(not(feature = "model-quantize"))] - char_ngram_weights: vec![ - vec![0.25, 0.5, 0.75], - vec![1.0, 1.25], - vec![1.5, 1.75, 2.0], - vec![2.25, 2.5, 2.75], - vec![3.0, 3.25, 3.5, 3.75], - ], - #[cfg(feature = "model-quantize")] char_ngram_weights: vec![ vec![1, 2, 3], vec![4, 5], @@ -579,39 +476,12 @@ mod tests { vec![9, 10, 11], vec![12, 13, 14, 15], ], - #[cfg(not(feature = "model-quantize"))] - type_ngram_weights: vec![ - vec![4.0, 4.25, 4.5, 4.75, 5.0, 5.25], - vec![5.5, 5.75, 6.0, 6.25, 6.5, 6.75], - vec![7.0, 7.25, 7.5, 7.75, 8.0], - vec![8.25, 8.5, 8.75, 9.0, 9.25], - ], - #[cfg(feature = "model-quantize")] type_ngram_weights: vec![ vec![16, 17, 18, 19, 20, 21], vec![22, 23, 24, 25, 26, 27], vec![28, 29, 30, 31, 32], vec![33, 34, 35, 36, 37], ], - #[cfg(not(feature = "model-quantize"))] - dict_weights: vec![ - DictWeight { - right: 9.5, - inner: 9.75, - left: 11.0, - }, - DictWeight { - right: 10.25, - inner: 10.5, - left: 10.75, - }, - DictWeight { - right: 11.0, - inner: 11.25, - left: 11.5, - }, - ], - #[cfg(feature = "model-quantize")] dict_weights: vec![ DictWeight { right: 38, @@ -629,12 +499,8 @@ mod tests { left: 46, }, ], - #[cfg(feature = "model-quantize")] quantize_multiplier: 0.25, dict_word_wise: true, - #[cfg(not(feature = "model-quantize"))] - bias: -71.25, - #[cfg(feature = "model-quantize")] bias: -285, char_window_size: 2, type_window_size: 3, diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index f2d2b6da..696068fe 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -1,4 +1,3 @@ -use crate::model::ScoreValue; use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; @@ -11,7 +10,7 @@ impl TypeScorer { /// # Panics /// /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { if window_size <= 3 { Self::Cache(TypeScorerCache::new(ngrams, weights, window_size)) } else { @@ -19,7 +18,7 @@ impl TypeScorer { } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { match self { TypeScorer::Pma(pma) => pma.add_scores(sentence, ys), TypeScorer::Cache(cache) => cache.add_scores(sentence, ys), @@ -29,7 +28,7 @@ impl TypeScorer { pub struct TypeScorerPma { pma: DoubleArrayAhoCorasick, - weights: Vec>, + weights: Vec>, window_size: usize, } @@ -37,7 +36,7 @@ impl TypeScorerPma { /// # Panics /// /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { if ngrams.len() != weights.len() { panic!("ngrams.len() != weights.len()"); } @@ -48,7 +47,7 @@ impl TypeScorerPma { } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { for m in self .pma .find_overlapping_no_suffix_iter(&sentence.char_type) @@ -69,7 +68,7 @@ impl TypeScorerPma { } pub struct TypeScorerCache { - scores: Vec, + scores: Vec, window_size: usize, sequence_mask: usize, } @@ -78,7 +77,7 @@ impl TypeScorerCache { /// # Panics /// /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { if ngrams.len() != weights.len() { panic!("ngrams.len() != weights.len()"); } @@ -88,13 +87,13 @@ impl TypeScorerCache { let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32); let mut sequence = vec![0u8; sequence_size]; - let mut scores = vec![0 as ScoreValue; all_sequences]; + let mut scores = vec![0; all_sequences]; for (i, score) in scores.iter_mut().enumerate() { if !Self::seqid_to_seq(i, &mut sequence) { continue; } - let mut y = ScoreValue::default(); + let mut y = 0; for m in pma.find_overlapping_no_suffix_iter(&sequence) { y += weights[m.pattern()][sequence_size - m.end()]; } @@ -108,7 +107,7 @@ impl TypeScorerCache { } } - pub fn add_scores(&self, sentence: &Sentence, ys: &mut [ScoreValue]) { + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { let mut seqid = 0; for i in 0..self.window_size { if let Some(ct) = sentence.char_type.get(i) { @@ -141,7 +140,7 @@ impl TypeScorerCache { } #[inline(always)] - fn get_score(&self, seqid: usize) -> ScoreValue { + fn get_score(&self, seqid: usize) -> i32 { self.scores[seqid] } From 06782f612685bb5fd69e60771923cf427a7af2ed Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 1 Dec 2021 11:30:30 +0900 Subject: [PATCH 15/60] Separate n-gram feature management (#6) * Use i32 for holding quantized weights * Remove model-quantize feature and remove supporting FP numbers * Use 24bit for quantization * 24bit -> 16bit * Fix a bug * Add a comment * Rename BIT_DEPTH -> QUANTIZE_BIT_DEPTH * Add NgramModel * Fix a bug --- vaporetto/src/char_scorer.rs | 49 +++---- vaporetto/src/kytea_model.rs | 27 ++-- vaporetto/src/lib.rs | 1 + vaporetto/src/model.rs | 45 +++--- vaporetto/src/ngram_model.rs | 63 +++++++++ vaporetto/src/predictor.rs | 257 +++++++++++++++++++---------------- vaporetto/src/type_scorer.rs | 38 ++---- 7 files changed, 272 insertions(+), 208 deletions(-) create mode 100644 vaporetto/src/ngram_model.rs diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 205cbe64..a69e7f41 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -1,6 +1,8 @@ -use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; +use crate::ngram_model::NgramModel; +use crate::sentence::Sentence; + #[cfg(feature = "simd")] use std::simd::i32x8; @@ -12,20 +14,17 @@ pub enum CharScorer { } impl CharScorer { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { + pub fn new(model: NgramModel, window_size: usize) -> Self { #[cfg(not(feature = "simd"))] { - Self::Naive(CharScorerNaive::new(ngrams, weights, window_size)) + Self::Naive(CharScorerNaive::new(model, window_size)) } #[cfg(feature = "simd")] if window_size <= 4 { - Self::Simd(CharScorerSimd::new(ngrams, weights, window_size)) + Self::Simd(CharScorerSimd::new(model, window_size)) } else { - Self::Naive(CharScorerNaive::new(ngrams, weights, window_size)) + Self::Naive(CharScorerNaive::new(model, window_size)) } } @@ -46,16 +45,11 @@ pub struct CharScorerNaive { } impl CharScorerNaive { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { - if ngrams.len() != weights.len() { - panic!("ngrams.len() != weights.len()"); - } + pub fn new(mut model: NgramModel, window_size: usize) -> Self { + model.merge_weights(); Self { - pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), - weights, + pma: DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(), + weights: model.data.into_iter().map(|d| d.weights).collect(), window_size, } } @@ -87,23 +81,20 @@ pub struct CharScorerSimd { #[cfg(feature = "simd")] impl CharScorerSimd { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[String], weights: Vec>, window_size: usize) -> Self { - if ngrams.len() != weights.len() { - panic!("ngrams.len() != weights.len()"); - } - let weights: Vec<_> = weights - .iter() - .map(|w| { + pub fn new(mut model: NgramModel, window_size: usize) -> Self { + model.merge_weights(); + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(); + let weights = model + .data + .into_iter() + .map(|d| { let mut s = [0i32; 8]; - s[..w.len()].copy_from_slice(&w); + s[..d.weights.len()].copy_from_slice(&d.weights); i32x8::from_array(s) }) .collect(); Self { - pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), + pma, weights, window_size, } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index a486e0d9..13fdaa8a 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use crate::model::{DictWeight, Model}; +use crate::ngram_model::{NgramData, NgramModel}; struct KyteaConfig { _model_tag: String, @@ -409,26 +410,26 @@ impl TryFrom for Model { .type_dict .ok_or_else(|| anyhow!("no type dictionary."))?; - let mut char_ngrams: Vec = vec![]; - let mut char_ngram_weights = vec![]; + let mut char_ngrams = vec![]; for (char_ngram, v) in char_dict.dump_items() { let weight_size = config.char_w as usize * 2 - char_ngram.len() + 1; - char_ngrams.push(char_ngram.into_iter().collect::()); - char_ngram_weights.push(v[..weight_size].iter().map(|&w| w as i32).collect()); + char_ngrams.push(NgramData { + ngram: char_ngram.into_iter().collect(), + weights: v[..weight_size].iter().map(|&w| w as i32).collect(), + }); } - let mut type_ngrams: Vec> = vec![]; - let mut type_ngram_weights = vec![]; + let mut type_ngrams = vec![]; for (type_ngram, v) in type_dict.dump_items() { let weight_size = config.type_w as usize * 2 - type_ngram.len() + 1; - type_ngrams.push( - type_ngram + type_ngrams.push(NgramData { + ngram: type_ngram .into_iter() .collect::() .as_bytes() .to_vec(), - ); - type_ngram_weights.push(v[..weight_size].iter().map(|&w| w as i32).collect()); + weights: v[..weight_size].iter().map(|&w| w as i32).collect(), + }); } let mut dict: Vec = vec![]; @@ -451,14 +452,12 @@ impl TryFrom for Model { } Ok(Self { - char_ngrams, - type_ngrams, + char_ngram_model: NgramModel::new(char_ngrams), + type_ngram_model: NgramModel::new(type_ngrams), dict, quantize_multiplier, - char_ngram_weights, - type_ngram_weights, dict_weights, dict_word_wise: true, bias, diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index 2d537fc8..c81fce84 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -31,6 +31,7 @@ mod utils; mod char_scorer; mod dict_scorer; mod model; +mod ngram_model; mod predictor; mod sentence; mod type_scorer; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 211eafd3..2d038892 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -3,9 +3,13 @@ use std::io::{Read, Write}; use anyhow::Result; use serde::{Deserialize, Serialize}; +use crate::ngram_model::NgramModel; + #[cfg(feature = "train")] use crate::feature::FeatureContent; #[cfg(feature = "train")] +use crate::ngram_model::NgramData; +#[cfg(feature = "train")] use crate::sentence::BoundaryType; #[cfg(feature = "train")] use crate::utils::{FeatureIDManager, StringIdManager}; @@ -29,12 +33,9 @@ pub struct DictWeight { /// Model data. #[derive(Serialize, Deserialize)] pub struct Model { - pub(crate) char_ngrams: Vec, - pub(crate) type_ngrams: Vec>, + pub(crate) char_ngram_model: NgramModel, + pub(crate) type_ngram_model: NgramModel>, pub(crate) dict: Vec, - - pub(crate) char_ngram_weights: Vec>, - pub(crate) type_ngram_weights: Vec>, pub(crate) dict_weights: Vec, pub(crate) quantize_multiplier: f64, @@ -102,8 +103,6 @@ impl Model { let bias = model.label_bias(wb_idx); let mut char_ngrams = vec![]; let mut type_ngrams = vec![]; - let mut char_ngram_weights = vec![]; - let mut type_ngram_weights = vec![]; let mut dict_weights = vec![DictWeight::default(); dict_word_max_size]; let mut char_ngram_ids = StringIdManager::new(); let mut type_ngram_ids = StringIdManager::new(); @@ -130,25 +129,23 @@ impl Model { match feature.feature { FeatureContent::CharacterNgram(char_ngram) => { let id = char_ngram_ids.get_id(&char_ngram); - if id == char_ngram_weights.len() { - char_ngrams.push(char_ngram.to_string()); - char_ngram_weights.push(vec![ - 0; - char_window_size * 2 - - char_ngram.chars().count() - + 1 - ]); + if id == char_ngrams.len() { + char_ngrams.push(NgramData { + ngram: char_ngram.to_string(), + weights: vec![0; char_window_size * 2 - char_ngram.chars().count() + 1], + }); } - char_ngram_weights[id][feature.rel_position] = weight as i32; + char_ngrams[id].weights[feature.rel_position] = weight as i32; } FeatureContent::CharacterTypeNgram(type_ngram) => { let id = type_ngram_ids.get_id(type_ngram) as usize; - if id == type_ngram_weights.len() { - type_ngrams.push(type_ngram.to_vec()); - type_ngram_weights - .push(vec![0; type_window_size * 2 - type_ngram.len() + 1]); + if id == type_ngrams.len() { + type_ngrams.push(NgramData { + ngram: type_ngram.to_vec(), + weights: vec![0; type_window_size * 2 - type_ngram.len() + 1], + }); } - type_ngram_weights[id][feature.rel_position] = weight as i32; + type_ngrams[id].weights[feature.rel_position] = weight as i32; } FeatureContent::DictionaryWord(size) => match feature.rel_position { 0 => dict_weights[size - 1].right = weight as i32, @@ -159,14 +156,12 @@ impl Model { }; } Self { - char_ngrams, - type_ngrams, + char_ngram_model: NgramModel::new(char_ngrams), + type_ngram_model: NgramModel::new(type_ngrams), dict, quantize_multiplier, - char_ngram_weights, - type_ngram_weights, dict_weights, dict_word_wise: false, bias, diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs new file mode 100644 index 00000000..dbff7a91 --- /dev/null +++ b/vaporetto/src/ngram_model.rs @@ -0,0 +1,63 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize)] +pub struct NgramData +where + T: Clone, +{ + pub(crate) ngram: T, + pub(crate) weights: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct NgramModel +where + T: Clone, +{ + pub(crate) data: Vec>, + merged: bool, +} + +impl NgramModel +where + T: AsRef<[u8]> + Clone, +{ + #[cfg(any(feature = "train", test))] + pub fn new(data: Vec>) -> Self { + Self { + data, + merged: false, + } + } + + pub fn merge_weights(&mut self) { + if self.merged { + return; + } + self.merged = true; + let ngrams = self + .data + .iter() + .cloned() + .map(|d| (d.ngram.as_ref().to_vec(), d.weights)) + .collect::>(); + for NgramData { ngram, weights } in &mut self.data { + let ngram = ngram.as_ref(); + let mut new_weights: Option> = None; + for st in (0..ngram.len()).rev() { + if let Some(weights) = ngrams.get(&ngram[st..]) { + if let Some(new_weights) = new_weights.as_mut() { + for (w_new, w) in new_weights.iter_mut().zip(weights) { + *w_new += *w; + } + } else { + new_weights.replace(weights.clone()); + } + } + } + *weights = new_weights.unwrap(); + } + } +} diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index e9c04e19..89d61836 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use crate::char_scorer::CharScorer; use crate::dict_scorer::DictScorer; use crate::model::{DictWeight, Model}; +use crate::ngram_model::NgramModel; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -36,31 +37,21 @@ impl Predictor { pub fn new(model: Model) -> Self { let bias = model.bias; - let char_ngrams = model.char_ngrams; + let mut char_ngram_model = model.char_ngram_model; + let type_ngram_model = model.type_ngram_model; let dict = model.dict; let dict_weights = model.dict_weights; - let mut char_ngram_weights = model.char_ngram_weights; - let type_ngram_weights = model.type_ngram_weights; - let (dict, dict_weights) = Self::merge_dict_weights( dict, dict_weights, - &char_ngrams, - &mut char_ngram_weights, + &mut char_ngram_model, model.char_window_size, model.dict_word_wise, ); - let char_ngram_weights = Self::merge_weights(&char_ngrams, &char_ngram_weights); - let type_ngram_weights = Self::merge_weights(&model.type_ngrams, &type_ngram_weights); - - let char_scorer = CharScorer::new(&char_ngrams, char_ngram_weights, model.char_window_size); - let type_scorer = TypeScorer::new( - &model.type_ngrams, - type_ngram_weights, - model.type_window_size, - ); + let char_scorer = CharScorer::new(char_ngram_model, model.char_window_size); + let type_scorer = TypeScorer::new(type_ngram_model, model.type_window_size); let dict_scorer = if dict.is_empty() { None } else { @@ -84,13 +75,17 @@ impl Predictor { fn merge_dict_weights( dict: Vec, dict_weights: Vec, - words: &[String], - word_weights: &mut Vec>, + char_ngram_model: &mut NgramModel, char_window_size: usize, dict_word_wise: bool, ) -> (Vec, Vec) { let mut word_map = HashMap::new(); - for (i, word) in words.iter().cloned().enumerate() { + for (i, word) in char_ngram_model + .data + .iter() + .map(|d| d.ngram.clone()) + .enumerate() + { word_map.insert(word, i); } let mut new_dict = vec![]; @@ -102,11 +97,11 @@ impl Predictor { Some(&idx) if char_window_size >= word_size => { let start = char_window_size - word_size; let end = start + word_size; - word_weights[idx][start] += weight.right; + char_ngram_model.data[idx].weights[start] += weight.right; for i in start + 1..end { - word_weights[idx][i] += weight.inner; + char_ngram_model.data[idx].weights[i] += weight.inner; } - word_weights[idx][end] += weight.left; + char_ngram_model.data[idx].weights[end] += weight.left; } _ => { new_dict.push(word); @@ -124,11 +119,11 @@ impl Predictor { let end = start + word_size; let word_size_idx = std::cmp::min(word_size, dict_weights.len()) - 1; let weight = &dict_weights[word_size_idx]; - word_weights[idx][start] += weight.right; + char_ngram_model.data[idx].weights[start] += weight.right; for i in start + 1..end { - word_weights[idx][i] += weight.inner; + char_ngram_model.data[idx].weights[i] += weight.inner; } - word_weights[idx][end] += weight.left; + char_ngram_model.data[idx].weights[end] += weight.left; } _ => new_dict.push(word), } @@ -137,35 +132,6 @@ impl Predictor { } } - fn merge_weights

(words: &[P], weights: &[Vec]) -> Vec> - where - P: AsRef<[u8]>, - { - let mut result = vec![]; - let word_ids = words - .iter() - .enumerate() - .map(|(i, w)| (w.as_ref().to_vec(), i)) - .collect::, usize>>(); - for seq in words { - let seq = seq.as_ref(); - let mut new_weights: Option> = None; - for st in (0..seq.len()).rev() { - if let Some(&idx) = word_ids.get(&seq[st..]) { - if let Some(new_weights) = new_weights.as_mut() { - for (w_new, w) in new_weights.iter_mut().zip(&weights[idx]) { - *w_new += *w; - } - } else { - new_weights.replace(weights[idx].clone()); - } - } - } - result.push(new_weights.unwrap()); - } - result - } - fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { ys.fill(self.bias); self.char_scorer.add_scores(sentence, padding, ys); @@ -287,6 +253,8 @@ impl Predictor { mod tests { use super::*; + use crate::ngram_model::NgramData; + /// Input: 我 ら は 全 世 界 の 国 民 /// bias: -200 .. .. .. .. .. .. .. /// words: @@ -315,28 +283,47 @@ mod tests { /// 世: 40 42 fn generate_model_1() -> Model { Model { - char_ngrams: vec![ - "我ら".to_string(), - "全世界".to_string(), - "国民".to_string(), - "世界".to_string(), - "界".to_string(), - ], - type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3, 4, 5], + }, + NgramData { + ngram: "全世界".to_string(), + weights: vec![6, 7, 8, 9], + }, + NgramData { + ngram: "国民".to_string(), + weights: vec![10, 11, 12, 13, 14], + }, + NgramData { + ngram: "世界".to_string(), + weights: vec![15, 16, 17, 18, 19], + }, + NgramData { + ngram: "界".to_string(), + weights: vec![20, 21, 22, 23, 24, 25], + }, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![26, 27, 28, 29], + }, + NgramData { + ngram: b"K".to_vec(), + weights: vec![30, 31, 32, 33], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![34, 35, 36], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![37, 38, 39], + }, + ]), dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - char_ngram_weights: vec![ - vec![1, 2, 3, 4, 5], - vec![6, 7, 8, 9], - vec![10, 11, 12, 13, 14], - vec![15, 16, 17, 18, 19], - vec![20, 21, 22, 23, 24, 25], - ], - type_ngram_weights: vec![ - vec![26, 27, 28, 29], - vec![30, 31, 32, 33], - vec![34, 35, 36], - vec![37, 38, 39], - ], dict_weights: vec![ DictWeight { right: 40, @@ -385,28 +372,47 @@ mod tests { /// 世: 38 40 fn generate_model_2() -> Model { Model { - char_ngrams: vec![ - "我ら".to_string(), - "全世界".to_string(), - "国民".to_string(), - "世界".to_string(), - "界".to_string(), - ], - type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3], + }, + NgramData { + ngram: "全世界".to_string(), + weights: vec![4, 5], + }, + NgramData { + ngram: "国民".to_string(), + weights: vec![6, 7, 8], + }, + NgramData { + ngram: "世界".to_string(), + weights: vec![9, 10, 11], + }, + NgramData { + ngram: "界".to_string(), + weights: vec![12, 13, 14, 15], + }, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![16, 17, 18, 19, 20, 21], + }, + NgramData { + ngram: b"K".to_vec(), + weights: vec![22, 23, 24, 25, 26, 27], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![28, 29, 30, 31, 32], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![33, 34, 35, 36, 37], + }, + ]), dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - char_ngram_weights: vec![ - vec![1, 2, 3], - vec![4, 5], - vec![6, 7, 8], - vec![9, 10, 11], - vec![12, 13, 14, 15], - ], - type_ngram_weights: vec![ - vec![16, 17, 18, 19, 20, 21], - vec![22, 23, 24, 25, 26, 27], - vec![28, 29, 30, 31, 32], - vec![33, 34, 35, 36, 37], - ], dict_weights: vec![ DictWeight { right: 38, @@ -460,28 +466,47 @@ mod tests { /// 世: 44 46 fn generate_model_3() -> Model { Model { - char_ngrams: vec![ - "我ら".to_string(), - "全世界".to_string(), - "国民".to_string(), - "世界".to_string(), - "界".to_string(), - ], - type_ngrams: vec![b"H".to_vec(), b"K".to_vec(), b"KH".to_vec(), b"HK".to_vec()], + char_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: "我ら".to_string(), + weights: vec![1, 2, 3], + }, + NgramData { + ngram: "全世界".to_string(), + weights: vec![4, 5], + }, + NgramData { + ngram: "国民".to_string(), + weights: vec![6, 7, 8], + }, + NgramData { + ngram: "世界".to_string(), + weights: vec![9, 10, 11], + }, + NgramData { + ngram: "界".to_string(), + weights: vec![12, 13, 14, 15], + }, + ]), + type_ngram_model: NgramModel::new(vec![ + NgramData { + ngram: b"H".to_vec(), + weights: vec![16, 17, 18, 19, 20, 21], + }, + NgramData { + ngram: b"K".to_vec(), + weights: vec![22, 23, 24, 25, 26, 27], + }, + NgramData { + ngram: b"KH".to_vec(), + weights: vec![28, 29, 30, 31, 32], + }, + NgramData { + ngram: b"HK".to_vec(), + weights: vec![33, 34, 35, 36, 37], + }, + ]), dict: vec!["国民".to_string(), "世界".to_string(), "世".to_string()], - char_ngram_weights: vec![ - vec![1, 2, 3], - vec![4, 5], - vec![6, 7, 8], - vec![9, 10, 11], - vec![12, 13, 14, 15], - ], - type_ngram_weights: vec![ - vec![16, 17, 18, 19, 20, 21], - vec![22, 23, 24, 25, 26, 27], - vec![28, 29, 30, 31, 32], - vec![33, 34, 35, 36, 37], - ], dict_weights: vec![ DictWeight { right: 38, diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 696068fe..5bc9299d 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -1,20 +1,19 @@ -use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; +use crate::ngram_model::NgramModel; +use crate::sentence::Sentence; + pub enum TypeScorer { Pma(TypeScorerPma), Cache(TypeScorerCache), } impl TypeScorer { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { + pub fn new(model: NgramModel>, window_size: usize) -> Self { if window_size <= 3 { - Self::Cache(TypeScorerCache::new(ngrams, weights, window_size)) + Self::Cache(TypeScorerCache::new(model, window_size)) } else { - Self::Pma(TypeScorerPma::new(ngrams, weights, window_size)) + Self::Pma(TypeScorerPma::new(model, window_size)) } } @@ -33,16 +32,11 @@ pub struct TypeScorerPma { } impl TypeScorerPma { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { - if ngrams.len() != weights.len() { - panic!("ngrams.len() != weights.len()"); - } + pub fn new(mut model: NgramModel>, window_size: usize) -> Self { + model.merge_weights(); Self { - pma: DoubleArrayAhoCorasick::new(ngrams).unwrap(), - weights, + pma: DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(), + weights: model.data.into_iter().map(|d| d.weights).collect(), window_size, } } @@ -74,14 +68,10 @@ pub struct TypeScorerCache { } impl TypeScorerCache { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(ngrams: &[Vec], weights: Vec>, window_size: usize) -> Self { - if ngrams.len() != weights.len() { - panic!("ngrams.len() != weights.len()"); - } - let pma = DoubleArrayAhoCorasick::new(ngrams).unwrap(); + pub fn new(mut model: NgramModel>, window_size: usize) -> Self { + model.merge_weights(); + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(); + let weights: Vec> = model.data.into_iter().map(|d| d.weights).collect(); let sequence_size = window_size * 2; let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32); From ed4849ccbfde12d562c297c3913140eb76317d31 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 1 Dec 2021 16:58:31 +0900 Subject: [PATCH 16/60] Add DictModel (#7) * Add DictModel * Model DictWeight into dict_model * Fix bugs * Fix --- vaporetto/src/dict_model.rs | 137 ++++++++++++++++++++++++ vaporetto/src/dict_scorer.rs | 82 +++++++++++--- vaporetto/src/kytea_model.rs | 16 +-- vaporetto/src/lib.rs | 1 + vaporetto/src/model.rs | 22 ++-- vaporetto/src/ngram_model.rs | 2 +- vaporetto/src/predictor.rs | 202 +++++++++++++---------------------- 7 files changed, 293 insertions(+), 169 deletions(-) create mode 100644 vaporetto/src/dict_model.rs diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs new file mode 100644 index 00000000..b84cd4ac --- /dev/null +++ b/vaporetto/src/dict_model.rs @@ -0,0 +1,137 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::ngram_model::NgramModel; + +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +pub struct DictWeight { + pub right: i32, + pub inner: i32, + pub left: i32, +} + +#[derive(Serialize, Deserialize)] +pub enum DictModel { + Wordwise(DictModelWordwise), + Lengthwise(DictModelLengthwise), +} + +impl DictModel { + pub fn merge_dict_weights( + &mut self, + char_ngram_model: &mut NgramModel, + char_window_size: usize, + ) { + match self { + Self::Wordwise(model) => model.merge_dict_weights(char_ngram_model, char_window_size), + Self::Lengthwise(model) => model.merge_dict_weights(char_ngram_model, char_window_size), + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Wordwise(model) => model.is_empty(), + Self::Lengthwise(model) => model.is_empty(), + } + } +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct WordwiseDictData { + pub(crate) word: String, + pub(crate) weights: DictWeight, +} + +#[derive(Serialize, Deserialize)] +pub struct DictModelWordwise { + pub(crate) data: Vec, +} + +impl DictModelWordwise { + pub fn merge_dict_weights( + &mut self, + char_ngram_model: &mut NgramModel, + char_window_size: usize, + ) { + let mut word_map = HashMap::new(); + for (i, word) in char_ngram_model + .data + .iter() + .map(|d| d.ngram.clone()) + .enumerate() + { + word_map.insert(word, i); + } + let mut new_data = vec![]; + for data in self.data.drain(..) { + let word_size = data.word.chars().count(); + match word_map.get(&data.word) { + Some(&idx) if char_window_size >= word_size => { + let start = char_window_size - word_size; + let end = start + word_size; + char_ngram_model.data[idx].weights[start] += data.weights.right; + for i in start + 1..end { + char_ngram_model.data[idx].weights[i] += data.weights.inner; + } + char_ngram_model.data[idx].weights[end] += data.weights.left; + } + _ => { + new_data.push(data); + } + } + } + self.data = new_data; + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +#[derive(Serialize, Deserialize)] +pub struct DictModelLengthwise { + pub(crate) words: Vec, + pub(crate) weights: Vec, +} + +impl DictModelLengthwise { + pub fn merge_dict_weights( + &mut self, + char_ngram_model: &mut NgramModel, + char_window_size: usize, + ) { + let mut word_map = HashMap::new(); + for (i, word) in char_ngram_model + .data + .iter() + .map(|d| d.ngram.clone()) + .enumerate() + { + word_map.insert(word, i); + } + let mut new_words = vec![]; + for word in self.words.drain(..) { + let word_size = word.chars().count(); + match word_map.get(&word) { + Some(&idx) if char_window_size >= word_size => { + let start = char_window_size - word_size; + let end = start + word_size; + let word_size_idx = word_size.min(self.weights.len()) - 1; + let weight = &self.weights[word_size_idx]; + char_ngram_model.data[idx].weights[start] += weight.right; + for i in start + 1..end { + char_ngram_model.data[idx].weights[i] += weight.inner; + } + char_ngram_model.data[idx].weights[end] += weight.left; + } + _ => new_words.push(word), + } + } + self.words = new_words; + } + + pub fn is_empty(&self) -> bool { + self.words.is_empty() + } +} diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index 5afc62f5..b489f76d 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -1,25 +1,45 @@ -use crate::model::DictWeight; -use crate::sentence::Sentence; use daachorse::DoubleArrayAhoCorasick; -pub struct DictScorer { +use crate::dict_model::{DictModel, DictModelLengthwise, DictModelWordwise, DictWeight}; +use crate::sentence::Sentence; + +pub enum DictScorer { + Wordwise(DictScorerWordwise), + Lengthwise(DictScorerLengthwise), +} + +impl DictScorer { + pub fn new(model: DictModel) -> Self { + match model { + DictModel::Wordwise(model) => Self::Wordwise(DictScorerWordwise::new(model)), + DictModel::Lengthwise(model) => Self::Lengthwise(DictScorerLengthwise::new(model)), + } + } + + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { + match self { + Self::Wordwise(model) => model.add_scores(sentence, ys), + Self::Lengthwise(model) => model.add_scores(sentence, ys), + } + } +} + +pub struct DictScorerWordwise { pma: DoubleArrayAhoCorasick, weights: Vec, - word_wise_score: bool, } -impl DictScorer { - /// # Panics - /// - /// `ngrams` and `weights` must have same number of entries. - pub fn new(words: &[String], weights: Vec, word_wise_score: bool) -> Self { - if word_wise_score && words.len() != weights.len() { - panic!("word_wise_score == true && words.len() != weights.len()"); +impl DictScorerWordwise { + pub fn new(model: DictModelWordwise) -> Self { + let mut words = vec![]; + let mut weights = vec![]; + for pair in model.data { + words.push(pair.word); + weights.push(pair.weights); } Self { pma: DoubleArrayAhoCorasick::new(words).unwrap(), weights, - word_wise_score, } } @@ -27,11 +47,39 @@ impl DictScorer { for m in self.pma.find_overlapping_iter(&sentence.text) { let m_start = sentence.str_to_char_pos[m.start()]; let m_end = sentence.str_to_char_pos[m.end()]; - let idx = if self.word_wise_score { - m.pattern() - } else { - std::cmp::min(m_end - m_start, self.weights.len()) - 1 - }; + let idx = m.pattern(); + let dict_weight = self.weights[idx]; + if m_start != 0 { + ys[m_start - 1] += dict_weight.right; + } + for y in &mut ys[m_start..m_end - 1] { + *y += dict_weight.inner; + } + if m_end <= ys.len() { + ys[m_end - 1] += dict_weight.left; + } + } + } +} + +pub struct DictScorerLengthwise { + pma: DoubleArrayAhoCorasick, + weights: Vec, +} + +impl DictScorerLengthwise { + pub fn new(model: DictModelLengthwise) -> Self { + Self { + pma: DoubleArrayAhoCorasick::new(model.words).unwrap(), + weights: model.weights, + } + } + + pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { + for m in self.pma.find_overlapping_iter(&sentence.text) { + let m_start = sentence.str_to_char_pos[m.start()]; + let m_end = sentence.str_to_char_pos[m.end()]; + let idx = (m_end - m_start).min(self.weights.len()) - 1; let dict_weight = self.weights[idx]; if m_start != 0 { ys[m_start - 1] += dict_weight.right; diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 13fdaa8a..12696ef2 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -4,7 +4,8 @@ use std::io::BufRead; use anyhow::{anyhow, Result}; use byteorder::{LittleEndian, ReadBytesExt}; -use crate::model::{DictWeight, Model}; +use crate::dict_model::{DictModel, DictModelWordwise, DictWeight, WordwiseDictData}; +use crate::model::Model; use crate::ngram_model::{NgramData, NgramModel}; struct KyteaConfig { @@ -432,8 +433,7 @@ impl TryFrom for Model { }); } - let mut dict: Vec = vec![]; - let mut dict_weights = vec![]; + let mut dict_data = vec![]; if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1; @@ -446,20 +446,20 @@ impl TryFrom for Model { weights.left += feature_lookup.dict_vec[offset + 2] as i32; } } - dict_weights.push(weights); - dict.push(w.into_iter().collect()); + dict_data.push(WordwiseDictData { + word: w.into_iter().collect(), + weights, + }); } } Ok(Self { char_ngram_model: NgramModel::new(char_ngrams), type_ngram_model: NgramModel::new(type_ngrams), - dict, + dict_model: DictModel::Wordwise(DictModelWordwise { data: dict_data }), quantize_multiplier, - dict_weights, - dict_word_wise: true, bias, char_window_size: config.char_w as usize, type_window_size: config.type_w as usize, diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index c81fce84..15fecd47 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -29,6 +29,7 @@ mod utils; mod char_scorer; +mod dict_model; mod dict_scorer; mod model; mod ngram_model; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 2d038892..085e8cbd 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -3,8 +3,11 @@ use std::io::{Read, Write}; use anyhow::Result; use serde::{Deserialize, Serialize}; +use crate::dict_model::DictModel; use crate::ngram_model::NgramModel; +#[cfg(feature = "train")] +use crate::dict_model::{DictModelLengthwise, DictWeight}; #[cfg(feature = "train")] use crate::feature::FeatureContent; #[cfg(feature = "train")] @@ -23,25 +26,15 @@ const EPSILON: f64 = 1e-6; #[cfg(feature = "train")] const QUANTIZE_BIT_DEPTH: u8 = 16; -#[derive(Clone, Copy, Default, Serialize, Deserialize)] -pub struct DictWeight { - pub right: i32, - pub inner: i32, - pub left: i32, -} - /// Model data. #[derive(Serialize, Deserialize)] pub struct Model { pub(crate) char_ngram_model: NgramModel, pub(crate) type_ngram_model: NgramModel>, - pub(crate) dict: Vec, - pub(crate) dict_weights: Vec, + pub(crate) dict_model: DictModel, pub(crate) quantize_multiplier: f64, - pub(crate) dict_word_wise: bool, - pub(crate) bias: i32, pub(crate) char_window_size: usize, pub(crate) type_window_size: usize, @@ -158,12 +151,13 @@ impl Model { Self { char_ngram_model: NgramModel::new(char_ngrams), type_ngram_model: NgramModel::new(type_ngrams), - dict, + dict_model: DictModel::Lengthwise(DictModelLengthwise { + words: dict, + weights: dict_weights, + }), quantize_multiplier, - dict_weights, - dict_word_wise: false, bias, char_window_size, type_window_size, diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs index dbff7a91..28ce97e6 100644 --- a/vaporetto/src/ngram_model.rs +++ b/vaporetto/src/ngram_model.rs @@ -24,7 +24,7 @@ impl NgramModel where T: AsRef<[u8]> + Clone, { - #[cfg(any(feature = "train", test))] + #[cfg(any(feature = "train", feature = "kytea", test))] pub fn new(data: Vec>) -> Self { Self { data, diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 89d61836..24e670cb 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,9 +1,6 @@ -use std::collections::HashMap; - use crate::char_scorer::CharScorer; use crate::dict_scorer::DictScorer; -use crate::model::{DictWeight, Model}; -use crate::ngram_model::NgramModel; +use crate::model::Model; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -39,23 +36,16 @@ impl Predictor { let mut char_ngram_model = model.char_ngram_model; let type_ngram_model = model.type_ngram_model; - let dict = model.dict; - let dict_weights = model.dict_weights; - - let (dict, dict_weights) = Self::merge_dict_weights( - dict, - dict_weights, - &mut char_ngram_model, - model.char_window_size, - model.dict_word_wise, - ); + let mut dict_model = model.dict_model; + + dict_model.merge_dict_weights(&mut char_ngram_model, model.char_window_size); let char_scorer = CharScorer::new(char_ngram_model, model.char_window_size); let type_scorer = TypeScorer::new(type_ngram_model, model.type_window_size); - let dict_scorer = if dict.is_empty() { + let dict_scorer = if dict_model.is_empty() { None } else { - Some(DictScorer::new(&dict, dict_weights, model.dict_word_wise)) + Some(DictScorer::new(dict_model)) }; Self { @@ -72,66 +62,6 @@ impl Predictor { } } - fn merge_dict_weights( - dict: Vec, - dict_weights: Vec, - char_ngram_model: &mut NgramModel, - char_window_size: usize, - dict_word_wise: bool, - ) -> (Vec, Vec) { - let mut word_map = HashMap::new(); - for (i, word) in char_ngram_model - .data - .iter() - .map(|d| d.ngram.clone()) - .enumerate() - { - word_map.insert(word, i); - } - let mut new_dict = vec![]; - if dict_word_wise { - let mut new_dict_weights = vec![]; - for (word, weight) in dict.into_iter().zip(dict_weights) { - let word_size = word.chars().count(); - match word_map.get(&word) { - Some(&idx) if char_window_size >= word_size => { - let start = char_window_size - word_size; - let end = start + word_size; - char_ngram_model.data[idx].weights[start] += weight.right; - for i in start + 1..end { - char_ngram_model.data[idx].weights[i] += weight.inner; - } - char_ngram_model.data[idx].weights[end] += weight.left; - } - _ => { - new_dict.push(word); - new_dict_weights.push(weight); - } - } - } - (new_dict, new_dict_weights) - } else { - for word in dict { - let word_size = word.chars().count(); - match word_map.get(&word) { - Some(&idx) if char_window_size >= word_size => { - let start = char_window_size - word_size; - let end = start + word_size; - let word_size_idx = std::cmp::min(word_size, dict_weights.len()) - 1; - let weight = &dict_weights[word_size_idx]; - char_ngram_model.data[idx].weights[start] += weight.right; - for i in start + 1..end { - char_ngram_model.data[idx].weights[i] += weight.inner; - } - char_ngram_model.data[idx].weights[end] += weight.left; - } - _ => new_dict.push(word), - } - } - (new_dict, dict_weights) - } - } - fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { ys.fill(self.bias); self.char_scorer.add_scores(sentence, padding, ys); @@ -253,7 +183,10 @@ impl Predictor { mod tests { use super::*; - use crate::ngram_model::NgramData; + use crate::dict_model::{ + DictModel, DictModelLengthwise, DictModelWordwise, DictWeight, WordwiseDictData, + }; + use crate::ngram_model::{NgramData, NgramModel}; /// Input: 我 ら は 全 世 界 の 国 民 /// bias: -200 .. .. .. .. .. .. .. @@ -323,21 +256,22 @@ mod tests { weights: vec![37, 38, 39], }, ]), - dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - dict_weights: vec![ - DictWeight { - right: 40, - inner: 41, - left: 42, - }, - DictWeight { - right: 43, - inner: 44, - left: 45, - }, - ], + dict_model: DictModel::Lengthwise(DictModelLengthwise { + words: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], + weights: vec![ + DictWeight { + right: 40, + inner: 41, + left: 42, + }, + DictWeight { + right: 43, + inner: 44, + left: 45, + }, + ], + }), quantize_multiplier: 0.5, - dict_word_wise: false, bias: -200, char_window_size: 3, type_window_size: 2, @@ -412,26 +346,27 @@ mod tests { weights: vec![33, 34, 35, 36, 37], }, ]), - dict: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], - dict_weights: vec![ - DictWeight { - right: 38, - inner: 39, - left: 40, - }, - DictWeight { - right: 41, - inner: 42, - left: 43, - }, - DictWeight { - right: 44, - inner: 45, - left: 46, - }, - ], + dict_model: DictModel::Lengthwise(DictModelLengthwise { + words: vec!["全世界".to_string(), "世界".to_string(), "世".to_string()], + weights: vec![ + DictWeight { + right: 38, + inner: 39, + left: 40, + }, + DictWeight { + right: 41, + inner: 42, + left: 43, + }, + DictWeight { + right: 44, + inner: 45, + left: 46, + }, + ], + }), quantize_multiplier: 0.25, - dict_word_wise: false, bias: -285, char_window_size: 2, type_window_size: 3, @@ -506,26 +441,35 @@ mod tests { weights: vec![33, 34, 35, 36, 37], }, ]), - dict: vec!["国民".to_string(), "世界".to_string(), "世".to_string()], - dict_weights: vec![ - DictWeight { - right: 38, - inner: 39, - left: 40, - }, - DictWeight { - right: 41, - inner: 42, - left: 43, - }, - DictWeight { - right: 44, - inner: 45, - left: 46, - }, - ], + dict_model: DictModel::Wordwise(DictModelWordwise { + data: vec![ + WordwiseDictData { + word: "国民".to_string(), + weights: DictWeight { + right: 38, + inner: 39, + left: 40, + }, + }, + WordwiseDictData { + word: "世界".to_string(), + weights: DictWeight { + right: 41, + inner: 42, + left: 43, + }, + }, + WordwiseDictData { + word: "世".to_string(), + weights: DictWeight { + right: 44, + inner: 45, + left: 46, + }, + }, + ], + }), quantize_multiplier: 0.25, - dict_word_wise: true, bias: -285, char_window_size: 2, type_window_size: 3, From 35a6decaa04d3aaf6c89c754f1ffc3aef2360ec3 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 1 Dec 2021 20:03:42 +0900 Subject: [PATCH 17/60] Add errors module and remove anyhow from deps (#8) * Add errors module and remove anyhow from deps * Add a missing file * impl Error * Fix apis * Fix doc of vaporetto_rules --- evaluate/src/main.rs | 2 +- predict/src/main.rs | 2 +- vaporetto/Cargo.toml | 1 - vaporetto/src/dict_scorer.rs | 20 +++--- vaporetto/src/errors.rs | 115 +++++++++++++++++++++++++++++++++++ vaporetto/src/feature.rs | 9 ++- vaporetto/src/kytea_model.rs | 12 ++-- vaporetto/src/lib.rs | 4 +- vaporetto/src/model.rs | 10 ++- vaporetto/src/predictor.rs | 21 ++++--- vaporetto/src/sentence.rs | 78 ++++++++++++++++-------- vaporetto/src/trainer.rs | 9 +-- vaporetto_rules/src/lib.rs | 2 +- vaporetto_wasm/src/lib.rs | 2 +- 14 files changed, 221 insertions(+), 66 deletions(-) create mode 100644 vaporetto/src/errors.rs diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index 8dfeb791..3e96ecf0 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -72,7 +72,7 @@ fn main() -> Result<(), Box> { eprintln!("Loading model file..."); let mut f = zstd::Decoder::new(File::open(opt.model)?)?; let model = Model::read(&mut f)?; - let predictor = Predictor::new(model); + let predictor = Predictor::new(model)?; eprintln!("Start tokenization"); let mut n_true_positive = 0; diff --git a/predict/src/main.rs b/predict/src/main.rs index e6210f11..d8201bce 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -70,7 +70,7 @@ fn main() -> Result<(), Box> { eprintln!("Loading model file..."); let mut f = zstd::Decoder::new(File::open(opt.model)?)?; let model = Model::read(&mut f)?; - let predictor = Predictor::new(model); + let predictor = Predictor::new(model)?; eprintln!("Start tokenization"); let mut n_boundaries = 0; diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index 44161ed3..b25d7528 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -13,7 +13,6 @@ categories = ["text-processing"] autotests = false [dependencies] -anyhow = "1.0" # MIT or Apache-2.0 bincode = "1.3.3" # MIT daachorse = "0.2.0" # MIT or Apache-2.0 serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index b489f76d..8d3ab03d 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -1,6 +1,7 @@ use daachorse::DoubleArrayAhoCorasick; use crate::dict_model::{DictModel, DictModelLengthwise, DictModelWordwise, DictWeight}; +use crate::errors::{Result, VaporettoError}; use crate::sentence::Sentence; pub enum DictScorer { @@ -9,11 +10,11 @@ pub enum DictScorer { } impl DictScorer { - pub fn new(model: DictModel) -> Self { - match model { + pub fn new(model: DictModel) -> Result { + Ok(match model { DictModel::Wordwise(model) => Self::Wordwise(DictScorerWordwise::new(model)), - DictModel::Lengthwise(model) => Self::Lengthwise(DictScorerLengthwise::new(model)), - } + DictModel::Lengthwise(model) => Self::Lengthwise(DictScorerLengthwise::new(model)?), + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { @@ -68,11 +69,16 @@ pub struct DictScorerLengthwise { } impl DictScorerLengthwise { - pub fn new(model: DictModelLengthwise) -> Self { - Self { + pub fn new(model: DictModelLengthwise) -> Result { + if model.weights.is_empty() { + return Err(VaporettoError::invalid_model( + "dict_word_max_size must be >= 1", + )); + } + Ok(Self { pma: DoubleArrayAhoCorasick::new(model.words).unwrap(), weights: model.weights, - } + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { diff --git a/vaporetto/src/errors.rs b/vaporetto/src/errors.rs new file mode 100644 index 00000000..863da6cf --- /dev/null +++ b/vaporetto/src/errors.rs @@ -0,0 +1,115 @@ +//! Definition of errors. + +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub enum VaporettoError { + InvalidModel(InvalidModelError), + InvalidSentence(InvalidSentenceError), + InvalidArgument(InvalidArgumentError), + IOError(std::io::Error), + UTF8Error(std::string::FromUtf8Error), +} + +impl VaporettoError { + pub(crate) fn invalid_model(msg: S) -> Self + where + S: Into, + { + Self::InvalidModel(InvalidModelError { msg: msg.into() }) + } + + pub(crate) fn invalid_sentence(msg: S) -> Self + where + S: Into, + { + Self::InvalidSentence(InvalidSentenceError { msg: msg.into() }) + } + + pub(crate) fn invalid_argument(arg: &'static str, msg: S) -> Self + where + S: Into, + { + Self::InvalidArgument(InvalidArgumentError { + arg, + msg: msg.into(), + }) + } +} + +impl fmt::Display for VaporettoError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::InvalidModel(e) => e.fmt(f), + Self::InvalidSentence(e) => e.fmt(f), + Self::InvalidArgument(e) => e.fmt(f), + Self::IOError(e) => e.fmt(f), + Self::UTF8Error(e) => e.fmt(f), + } + } +} + +impl Error for VaporettoError {} + +pub type Result = std::result::Result; + +/// Error used when the model is invalid. +#[derive(Debug)] +pub struct InvalidModelError { + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidModelError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidModelError: {}", self.msg) + } +} + +impl Error for InvalidModelError {} + +/// Error used when the sentence is invalid. +#[derive(Debug)] +pub struct InvalidSentenceError { + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidSentenceError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidSentenceError: {}", self.msg) + } +} + +impl Error for InvalidSentenceError {} + +/// Error used when the argument is invalid. +#[derive(Debug)] +pub struct InvalidArgumentError { + /// Name of the argument. + pub(crate) arg: &'static str, + + /// Error message. + pub(crate) msg: String, +} + +impl fmt::Display for InvalidArgumentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "InvalidArgumentError: {}: {}", self.arg, self.msg) + } +} + +impl Error for InvalidArgumentError {} + +impl From for VaporettoError { + fn from(error: std::io::Error) -> Self { + Self::IOError(error) + } +} + +impl From for VaporettoError { + fn from(error: std::string::FromUtf8Error) -> Self { + Self::UTF8Error(error) + } +} diff --git a/vaporetto/src/feature.rs b/vaporetto/src/feature.rs index 31a12f9a..8d80ad36 100644 --- a/vaporetto/src/feature.rs +++ b/vaporetto/src/feature.rs @@ -1,6 +1,6 @@ +use crate::errors::{Result, VaporettoError}; use crate::sentence::{BoundaryType, Sentence}; -use anyhow::{anyhow, Result}; use daachorse::DoubleArrayAhoCorasick; #[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] @@ -55,7 +55,10 @@ impl FeatureExtractor { dict_word_max_size, ); if size == 0 { - return Err(anyhow!("`dictionary` contains an empty string")); + return Err(VaporettoError::invalid_argument( + "dictionary", + "contains an empty string", + )); } dict_word_size.push(size); } @@ -224,7 +227,7 @@ mod tests { assert!(fe.is_err()); assert_eq!( - "`dictionary` contains an empty string", + "InvalidArgumentError: dictionary: contains an empty string", &fe.err().unwrap().to_string() ); } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 12696ef2..a6de598b 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -1,10 +1,10 @@ use std::convert::TryFrom; use std::io::BufRead; -use anyhow::{anyhow, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use crate::dict_model::{DictModel, DictModelWordwise, DictWeight, WordwiseDictData}; +use crate::errors::{Result, VaporettoError}; use crate::model::Model; use crate::ngram_model::{NgramData, NgramModel}; @@ -392,24 +392,24 @@ impl KyteaModel { } impl TryFrom for Model { - type Error = anyhow::Error; + type Error = VaporettoError; fn try_from(model: KyteaModel) -> Result { let config = &model.config; let wordseg_model = model .wordseg_model - .ok_or_else(|| anyhow!("no word segmentation model."))?; + .ok_or_else(|| VaporettoError::invalid_model("no word segmentation model."))?; let quantize_multiplier = wordseg_model.multiplier; let feature_lookup = wordseg_model .feature_lookup - .ok_or_else(|| anyhow!("no lookup data."))?; + .ok_or_else(|| VaporettoError::invalid_model("no lookup data."))?; let bias = feature_lookup.biases[0] as i32; let char_dict = feature_lookup .char_dict - .ok_or_else(|| anyhow!("no character dictionary."))?; + .ok_or_else(|| VaporettoError::invalid_model("no character dictionary."))?; let type_dict = feature_lookup .type_dict - .ok_or_else(|| anyhow!("no type dictionary."))?; + .ok_or_else(|| VaporettoError::invalid_model("no type dictionary."))?; let mut char_ngrams = vec![]; for (char_ngram, v) in char_dict.dump_items() { diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index 15fecd47..1705d53c 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -15,7 +15,7 @@ //! //! let mut f = BufReader::new(File::open("model.bin").unwrap()); //! let model = Model::read(&mut f).unwrap(); -//! let predictor = Predictor::new(model); +//! let predictor = Predictor::new(model).unwrap(); //! //! let s = Sentence::from_raw("火星猫の生態").unwrap(); //! let s = predictor.predict(s); @@ -37,6 +37,8 @@ mod predictor; mod sentence; mod type_scorer; +pub mod errors; + #[cfg(feature = "train")] mod feature; #[cfg(feature = "train")] diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 085e8cbd..cbbfa18b 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -1,6 +1,5 @@ use std::io::{Read, Write}; -use anyhow::Result; use serde::{Deserialize, Serialize}; use crate::dict_model::DictModel; @@ -50,12 +49,11 @@ impl Model { /// # Errors /// /// When `wtr` generates an error, it will be returned as is. - pub fn write(&self, wtr: &mut W) -> Result<()> + pub fn write(&self, wtr: &mut W) -> Result<(), bincode::Error> where W: Write, { - bincode::serialize_into(wtr, self)?; - Ok(()) + bincode::serialize_into(wtr, self) } /// Creates a model from a reader. @@ -71,11 +69,11 @@ impl Model { /// # Errors /// /// When `rdr` generates an error, it will be returned as is. - pub fn read(rdr: &mut R) -> Result + pub fn read(rdr: &mut R) -> Result where R: Read, { - Ok(bincode::deserialize_from(rdr)?) + bincode::deserialize_from(rdr) } #[cfg(feature = "train")] diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 24e670cb..66f953a3 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -1,5 +1,6 @@ use crate::char_scorer::CharScorer; use crate::dict_scorer::DictScorer; +use crate::errors::Result; use crate::model::Model; use crate::sentence::{BoundaryType, Sentence}; use crate::type_scorer::TypeScorer; @@ -31,7 +32,7 @@ impl Predictor { /// # Returns /// /// A new predictor. - pub fn new(model: Model) -> Self { + pub fn new(model: Model) -> Result { let bias = model.bias; let mut char_ngram_model = model.char_ngram_model; @@ -45,10 +46,10 @@ impl Predictor { let dict_scorer = if dict_model.is_empty() { None } else { - Some(DictScorer::new(dict_model)) + Some(DictScorer::new(dict_model)?) }; - Self { + Ok(Self { bias, char_scorer, @@ -59,7 +60,7 @@ impl Predictor { #[cfg(feature = "simd")] padding: model.char_window_size.max(model.type_window_size), - } + }) } fn predict_impl(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { @@ -479,7 +480,7 @@ mod tests { #[test] fn test_predict_1() { let model = generate_model_1(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -500,7 +501,7 @@ mod tests { #[test] fn test_predict_2() { let model = generate_model_2(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -521,7 +522,7 @@ mod tests { #[test] fn test_predict_3() { let model = generate_model_3(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict(s); assert_eq!( @@ -542,7 +543,7 @@ mod tests { #[test] fn test_predict_with_score_1() { let model = generate_model_1(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); assert_eq!( @@ -567,7 +568,7 @@ mod tests { #[test] fn test_predict_with_score_2() { let model = generate_model_2(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); assert_eq!( @@ -592,7 +593,7 @@ mod tests { #[test] fn test_predict_with_score_3() { let model = generate_model_3(); - let p = Predictor::new(model); + let p = Predictor::new(model).unwrap(); let s = Sentence::from_raw("我らは全世界の国民").unwrap(); let s = p.predict_with_score(s); assert_eq!( diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index 9f7804b4..fb32ca4d 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use crate::errors::{Result, VaporettoError}; /// Character type. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -140,7 +140,7 @@ impl Sentence { let text = text.into(); if text.is_empty() { - return Err(anyhow!("`text` is empty")); + return Err(VaporettoError::invalid_argument("text", "is empty")); } let chars: Vec = text.chars().collect(); @@ -212,7 +212,10 @@ impl Sentence { let tokenized_text = tokenized_text.as_ref(); if tokenized_text.is_empty() { - return Err(anyhow!("`tokenized_text` is empty")); + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "is empty", + )); } let tokenized_chars: Vec = tokenized_text.chars().collect(); @@ -228,9 +231,15 @@ impl Sentence { } (false, ' ') => { if chars.is_empty() { - return Err(anyhow!("`tokenized_text` starts with a whitespace")); + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "starts with a whitespace", + )); } else if prev_boundary { - return Err(anyhow!("`tokenized_text` contains consecutive whitespaces")); + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "contains consecutive whitespaces", + )); } prev_boundary = true; } @@ -249,7 +258,10 @@ impl Sentence { }; } if prev_boundary { - return Err(anyhow!("`tokenized_text` ends with a whitespace")); + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "ends with a whitespace", + )); } let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); @@ -296,7 +308,9 @@ impl Sentence { } BoundaryType::NotWordBoundary => (), BoundaryType::Unknown => { - return Err(anyhow!("sentence contains an unknown boundary")); + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); } } match c { @@ -344,7 +358,9 @@ impl Sentence { } BoundaryType::NotWordBoundary => (), BoundaryType::Unknown => { - return Err(anyhow!("sentence contains an unknown boundary")); + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); } } } @@ -389,14 +405,14 @@ impl Sentence { let labeled_text = labeled_text.as_ref(); if labeled_text.is_empty() { - return Err(anyhow!("`labeled_text` is empty")); + return Err(VaporettoError::invalid_argument("labeled_text", "is empty")); } let labeled_chars: Vec = labeled_text.chars().collect(); if labeled_chars.len() & 0x01 == 0 { - return Err(anyhow!( - "invalid length for `labeled_text`: {}", - labeled_chars.len() + return Err(VaporettoError::invalid_argument( + "labeled_text", + format!("invalid length: {}", labeled_chars.len()), )); } let mut chars = Vec::with_capacity(labeled_chars.len() / 2 + 1); @@ -407,7 +423,12 @@ impl Sentence { ' ' => BoundaryType::Unknown, '|' => BoundaryType::WordBoundary, '-' => BoundaryType::NotWordBoundary, - _ => return Err(anyhow!("invalid boundary character: '{}'", c)), + _ => { + return Err(VaporettoError::invalid_argument( + "labeled_text", + format!("contains invalid boundary character: '{}'", c), + )) + } }); } for c in labeled_chars.into_iter().step_by(2) { @@ -527,7 +548,7 @@ impl Sentence { } else { match self.str_to_char_pos.get(index) { Some(index) if *index != 0 => Ok(*index), - _ => Err(anyhow!("invalid index")), + _ => Err(VaporettoError::invalid_argument("index", "invalid index")), } } } @@ -556,7 +577,10 @@ mod tests { let s = Sentence::from_raw(""); assert!(s.is_err()); - assert_eq!("`text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: text: is empty", + &s.err().unwrap().to_string() + ); } #[test] @@ -612,7 +636,10 @@ mod tests { let s = Sentence::from_tokenized(""); assert!(s.is_err()); - assert_eq!("`tokenized_text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: tokenized_text: is empty", + &s.err().unwrap().to_string() + ); } #[test] @@ -621,7 +648,7 @@ mod tests { assert!(s.is_err()); assert_eq!( - "`tokenized_text` starts with a whitespace", + "InvalidArgumentError: tokenized_text: starts with a whitespace", &s.err().unwrap().to_string() ); } @@ -632,7 +659,7 @@ mod tests { assert!(s.is_err()); assert_eq!( - "`tokenized_text` ends with a whitespace", + "InvalidArgumentError: tokenized_text: ends with a whitespace", &s.err().unwrap().to_string() ); } @@ -643,7 +670,7 @@ mod tests { assert!(s.is_err()); assert_eq!( - "`tokenized_text` contains consecutive whitespaces", + "InvalidArgumentError: tokenized_text: contains consecutive whitespaces", &s.err().unwrap().to_string() ); } @@ -778,7 +805,7 @@ mod tests { assert!(result.is_err()); assert_eq!( - "sentence contains an unknown boundary", + "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() ); } @@ -810,7 +837,7 @@ mod tests { assert!(result.is_err()); assert_eq!( - "sentence contains an unknown boundary", + "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() ); } @@ -830,7 +857,10 @@ mod tests { let s = Sentence::from_partial_annotation(""); assert!(s.is_err()); - assert_eq!("`labeled_text` is empty", &s.err().unwrap().to_string()); + assert_eq!( + "InvalidArgumentError: labeled_text: is empty", + &s.err().unwrap().to_string() + ); } #[test] @@ -839,7 +869,7 @@ mod tests { assert!(s.is_err()); assert_eq!( - "invalid length for `labeled_text`: 12", + "InvalidArgumentError: labeled_text: invalid length: 12", &s.err().unwrap().to_string() ); } @@ -850,7 +880,7 @@ mod tests { assert!(s.is_err()); assert_eq!( - "invalid boundary character: '?'", + "InvalidArgumentError: labeled_text: contains invalid boundary character: '?'", &s.err().unwrap().to_string() ); } diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index 50a5f6a9..c67d1cf2 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -1,8 +1,7 @@ use std::collections::BTreeMap; use std::str::FromStr; -use anyhow::{anyhow, Result}; - +use crate::errors::{Result, VaporettoError}; use crate::feature::{ExampleGenerator, FeatureExtractor}; use crate::model::Model; use crate::sentence::Sentence; @@ -237,14 +236,16 @@ impl Trainer { let mut builder = liblinear::Builder::new(); let training_input = liblinear::util::TrainingInput::from_sparse_features(dataset.ys, dataset.xs) - .map_err(|e| anyhow!("liblinear error: {:?}", e))?; + .map_err(|e| VaporettoError::invalid_model(format!("liblinear error: {:?}", e)))?; builder.problem().input_data(training_input).bias(self.bias); builder .parameters() .solver_type(solver.into()) .stopping_criterion(self.epsilon) .constraints_violation_cost(self.cost); - let model = builder.build_model().map_err(|e| anyhow!(e.to_string()))?; + let model = builder + .build_model() + .map_err(|e| VaporettoError::invalid_model(e.to_string()))?; Ok(Model::from_liblinear_model( model, diff --git a/vaporetto_rules/src/lib.rs b/vaporetto_rules/src/lib.rs index 8864c586..9ac6969d 100644 --- a/vaporetto_rules/src/lib.rs +++ b/vaporetto_rules/src/lib.rs @@ -17,7 +17,7 @@ //! //! let mut f = BufReader::new(File::open("model.bin").unwrap()); //! let model = Model::read(&mut f).unwrap(); -//! let mut predictor = Predictor::new(model); +//! let mut predictor = Predictor::new(model).unwrap(); //! //! let pre_filters: Vec>> = vec![ //! Box::new(KyteaFullwidthFilter::new()), diff --git a/vaporetto_wasm/src/lib.rs b/vaporetto_wasm/src/lib.rs index 3d75ef72..31ec48aa 100644 --- a/vaporetto_wasm/src/lib.rs +++ b/vaporetto_wasm/src/lib.rs @@ -28,7 +28,7 @@ impl Vaporetto { let mut buff = vec![]; decoder.read_to_end(&mut buff).unwrap(); let model = Model::read(&mut buff.as_slice()).unwrap(); - let predictor = Predictor::new(model); + let predictor = Predictor::new(model).unwrap(); let post_filters: Vec<_> = filters .chars() .map(|c| { From 0e05d795504188dbc5031f913b8a29fedad60e9a Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 2 Dec 2021 12:28:11 +0900 Subject: [PATCH 18/60] Return Error when patterns or weights are invalid (#9) * Return Error when daachorse returns error in initialization * Validate size of weight vector * Refactoring * Apply suggestions from code review Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- vaporetto/src/char_scorer.rs | 64 ++++++++++++++++++++++-------------- vaporetto/src/dict_scorer.rs | 15 +++++---- vaporetto/src/predictor.rs | 4 +-- vaporetto/src/type_scorer.rs | 51 +++++++++++++++++++--------- 4 files changed, 86 insertions(+), 48 deletions(-) diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index a69e7f41..6fb0ae52 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -1,5 +1,6 @@ use daachorse::DoubleArrayAhoCorasick; +use crate::errors::{Result, VaporettoError}; use crate::ngram_model::NgramModel; use crate::sentence::Sentence; @@ -14,18 +15,18 @@ pub enum CharScorer { } impl CharScorer { - pub fn new(model: NgramModel, window_size: usize) -> Self { + pub fn new(model: NgramModel, window_size: usize) -> Result { #[cfg(not(feature = "simd"))] { - Self::Naive(CharScorerNaive::new(model, window_size)) + Ok(Self::Naive(CharScorerNaive::new(model, window_size)?)) } #[cfg(feature = "simd")] - if window_size <= 4 { - Self::Simd(CharScorerSimd::new(model, window_size)) + Ok(if window_size <= 4 { + Self::Simd(CharScorerSimd::new(model, window_size)?) } else { - Self::Naive(CharScorerNaive::new(model, window_size)) - } + Self::Naive(CharScorerNaive::new(model, window_size)?) + }) } pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { @@ -45,13 +46,24 @@ pub struct CharScorerNaive { } impl CharScorerNaive { - pub fn new(mut model: NgramModel, window_size: usize) -> Self { + pub fn new(mut model: NgramModel, window_size: usize) -> Result { model.merge_weights(); - Self { - pma: DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(), - weights: model.data.into_iter().map(|d| d.weights).collect(), - window_size, + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character n-grams"))?; + let mut weights = vec![]; + for d in model.data { + if d.weights.len() <= 2 * window_size - d.ngram.chars().count() { + return Err(VaporettoError::invalid_model( + "invalid size of weight vector", + )); + } + weights.push(d.weights); } + Ok(Self { + pma, + weights, + window_size, + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { @@ -81,23 +93,27 @@ pub struct CharScorerSimd { #[cfg(feature = "simd")] impl CharScorerSimd { - pub fn new(mut model: NgramModel, window_size: usize) -> Self { + pub fn new(mut model: NgramModel, window_size: usize) -> Result { model.merge_weights(); - let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(); - let weights = model - .data - .into_iter() - .map(|d| { - let mut s = [0i32; 8]; - s[..d.weights.len()].copy_from_slice(&d.weights); - i32x8::from_array(s) - }) - .collect(); - Self { + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character n-grams"))?; + let mut weights = vec![]; + for d in model.data { + let mut s = [0i32; 8]; + if let Some(s) = s.get_mut(..d.weights.len()) { + s.copy_from_slice(&d.weights); + } else { + return Err(VaporettoError::invalid_model( + "invalid size of weight vector", + )); + } + weights.push(i32x8::from_array(s)); + } + Ok(Self { pma, weights, window_size, - } + }) } pub fn add_scores(&self, sentence: &Sentence, padding: usize, ys: &mut [i32]) { diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index 8d3ab03d..f6f6106c 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -12,7 +12,7 @@ pub enum DictScorer { impl DictScorer { pub fn new(model: DictModel) -> Result { Ok(match model { - DictModel::Wordwise(model) => Self::Wordwise(DictScorerWordwise::new(model)), + DictModel::Wordwise(model) => Self::Wordwise(DictScorerWordwise::new(model)?), DictModel::Lengthwise(model) => Self::Lengthwise(DictScorerLengthwise::new(model)?), }) } @@ -31,17 +31,16 @@ pub struct DictScorerWordwise { } impl DictScorerWordwise { - pub fn new(model: DictModelWordwise) -> Self { + pub fn new(model: DictModelWordwise) -> Result { let mut words = vec![]; let mut weights = vec![]; for pair in model.data { words.push(pair.word); weights.push(pair.weights); } - Self { - pma: DoubleArrayAhoCorasick::new(words).unwrap(), - weights, - } + let pma = DoubleArrayAhoCorasick::new(words) + .map_err(|_| VaporettoError::invalid_model("invalid dictionary"))?; + Ok(Self { pma, weights }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { @@ -75,8 +74,10 @@ impl DictScorerLengthwise { "dict_word_max_size must be >= 1", )); } + let pma = DoubleArrayAhoCorasick::new(model.words) + .map_err(|_| VaporettoError::invalid_model("invalid dictionary"))?; Ok(Self { - pma: DoubleArrayAhoCorasick::new(model.words).unwrap(), + pma, weights: model.weights, }) } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 66f953a3..298ba9e7 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -41,8 +41,8 @@ impl Predictor { dict_model.merge_dict_weights(&mut char_ngram_model, model.char_window_size); - let char_scorer = CharScorer::new(char_ngram_model, model.char_window_size); - let type_scorer = TypeScorer::new(type_ngram_model, model.type_window_size); + let char_scorer = CharScorer::new(char_ngram_model, model.char_window_size)?; + let type_scorer = TypeScorer::new(type_ngram_model, model.type_window_size)?; let dict_scorer = if dict_model.is_empty() { None } else { diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 5bc9299d..78e0febe 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -1,5 +1,6 @@ use daachorse::DoubleArrayAhoCorasick; +use crate::errors::{Result, VaporettoError}; use crate::ngram_model::NgramModel; use crate::sentence::Sentence; @@ -9,12 +10,12 @@ pub enum TypeScorer { } impl TypeScorer { - pub fn new(model: NgramModel>, window_size: usize) -> Self { - if window_size <= 3 { - Self::Cache(TypeScorerCache::new(model, window_size)) + pub fn new(model: NgramModel>, window_size: usize) -> Result { + Ok(if window_size <= 3 { + Self::Cache(TypeScorerCache::new(model, window_size)?) } else { - Self::Pma(TypeScorerPma::new(model, window_size)) - } + Self::Pma(TypeScorerPma::new(model, window_size)?) + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { @@ -32,13 +33,24 @@ pub struct TypeScorerPma { } impl TypeScorerPma { - pub fn new(mut model: NgramModel>, window_size: usize) -> Self { + pub fn new(mut model: NgramModel>, window_size: usize) -> Result { model.merge_weights(); - Self { - pma: DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(), - weights: model.data.into_iter().map(|d| d.weights).collect(), - window_size, + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; + let mut weights = vec![]; + for d in model.data { + if d.weights.len() <= 2 * window_size - d.ngram.len() { + return Err(VaporettoError::invalid_model( + "invalid size of weight vector", + )); + } + weights.push(d.weights); } + Ok(Self { + pma, + weights, + window_size, + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { @@ -68,10 +80,19 @@ pub struct TypeScorerCache { } impl TypeScorerCache { - pub fn new(mut model: NgramModel>, window_size: usize) -> Self { + pub fn new(mut model: NgramModel>, window_size: usize) -> Result { model.merge_weights(); - let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)).unwrap(); - let weights: Vec> = model.data.into_iter().map(|d| d.weights).collect(); + let pma = DoubleArrayAhoCorasick::new(model.data.iter().map(|d| &d.ngram)) + .map_err(|_| VaporettoError::invalid_model("invalid character type n-grams"))?; + let mut weights = vec![]; + for d in model.data { + if d.weights.len() <= 2 * window_size - d.ngram.len() { + return Err(VaporettoError::invalid_model( + "invalid size of weight vector", + )); + } + weights.push(d.weights); + } let sequence_size = window_size * 2; let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32); @@ -90,11 +111,11 @@ impl TypeScorerCache { *score = y; } - Self { + Ok(Self { scores, window_size, sequence_mask: (1 << (ALPHABET_SHIFT * sequence_size)) - 1, - } + }) } pub fn add_scores(&self, sentence: &Sentence, ys: &mut [i32]) { From 6b342e79c58483b326cbd21445f3f72210efb750 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 2 Dec 2021 19:09:28 +0900 Subject: [PATCH 19/60] Get weights without boundary checking (#10) * Get weights without boundary checking * Fix --- vaporetto/src/char_scorer.rs | 8 ++++++-- vaporetto/src/dict_scorer.rs | 9 ++++++--- vaporetto/src/type_scorer.rs | 4 +++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 6fb0ae52..4cf82351 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -70,7 +70,9 @@ impl CharScorerNaive { for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { let m_end = sentence.str_to_char_pos[m.end()]; let offset = m_end as isize - self.window_size as isize - 1; - let weights = &self.weights[m.pattern()]; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let weights = unsafe { self.weights.get_unchecked(m.pattern()) }; if offset >= 0 { for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { *y += w; @@ -120,7 +122,9 @@ impl CharScorerSimd { for m in self.pma.find_overlapping_no_suffix_iter(&sentence.text) { let m_end = sentence.str_to_char_pos[m.end()]; let offset = padding as isize + m_end as isize - self.window_size as isize - 1; - let weights = &self.weights[m.pattern()]; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let weights = unsafe { self.weights.get_unchecked(m.pattern()) }; let ys_slice = &mut ys[offset as usize..offset as usize + 8]; let mut target = i32x8::from_slice(ys_slice); target += weights; diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index f6f6106c..dcc64502 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -47,8 +47,9 @@ impl DictScorerWordwise { for m in self.pma.find_overlapping_iter(&sentence.text) { let m_start = sentence.str_to_char_pos[m.start()]; let m_end = sentence.str_to_char_pos[m.end()]; - let idx = m.pattern(); - let dict_weight = self.weights[idx]; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let dict_weight = unsafe { self.weights.get_unchecked(m.pattern()) }; if m_start != 0 { ys[m_start - 1] += dict_weight.right; } @@ -87,7 +88,9 @@ impl DictScorerLengthwise { let m_start = sentence.str_to_char_pos[m.start()]; let m_end = sentence.str_to_char_pos[m.end()]; let idx = (m_end - m_start).min(self.weights.len()) - 1; - let dict_weight = self.weights[idx]; + // The upper bound of idx is weights.len() - 1. + // Therefore, the following code is safe. + let dict_weight = unsafe { self.weights.get_unchecked(idx) }; if m_start != 0 { ys[m_start - 1] += dict_weight.right; } diff --git a/vaporetto/src/type_scorer.rs b/vaporetto/src/type_scorer.rs index 78e0febe..21a1c365 100644 --- a/vaporetto/src/type_scorer.rs +++ b/vaporetto/src/type_scorer.rs @@ -59,7 +59,9 @@ impl TypeScorerPma { .find_overlapping_no_suffix_iter(&sentence.char_type) { let offset = m.end() as isize - self.window_size as isize - 1; - let weights = &self.weights[m.pattern()]; + // Both the weights and the PMA always have the same number of items. + // Therefore, the following code is safe. + let weights = unsafe { self.weights.get_unchecked(m.pattern()) }; if offset >= 0 { for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) { *y += w; From 389d1bfb1512bf2715f3557da50b6cb60a2463e8 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 2 Dec 2021 20:04:53 +0900 Subject: [PATCH 20/60] Simplify speed comparison on README (#11) * Simplify speed comparison on README * Update README.md Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- README.md | 54 +++++++++++++----------------------------------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 97a925d1..8c27a863 100644 --- a/README.md +++ b/README.md @@ -86,47 +86,19 @@ You can specify all arguments above multiple times. ## Speed Comparison of Various Tokenizers -You can find the comparison script at [here](https://github.com/legalforce-research/tokenizer-speed-bench). - -### Experimental Setup - -* Document: Japanese training data of Kyoto Free Translation Task -* Models: - * KyTea and Vaporetto: Compact LR model (jp-0.4.7-6) - * MeCab, Kuromoji, and Lindera: IPAdic - * Sudachi and Sudachi.rs: system_core.dic (v20210802) - -### Results - -* VM instance on Google Cloud Platform (c2-standard-16, Debian) - - | Tool Name (version) | Speed (×10^6 chars/s) | σ | - | -------------------------- | ---------------------:|-------| - | KyTea (2020-04-03) | 0.777 | 0.020 | - | Vaporetto (0.1.6) | **4.426** | 0.182 | - | | | | - | MeCab (2020-09-14) | 2.736 | 0.041 | - | | | | - | Kuromoji (Atilika's 0.9.0) | 0.423 | 0.013 | - | Lindera (0.8.0) | 1.002 | 0.014 | - | | | | - | Sudachi (0.5.2) | 0.251 | 0.012 | - | Sudachi.rs (0.6.0-rc1) | 0.644 | 0.012 | - -* MacBook Pro (2017, Processor: 2.3 GHz Intel Core i5, Memory: 8 GB 2133 MHz LPDDR3) - - | Tool Name (version) | Speed (×10^6 chars/s) | σ | - | -------------------------- | ---------------------:|-------| - | KyTea (2020-04-03) | 0.490 | 0.003 | - | Vaporetto (0.1.6) | **3.016** | 0.113 | - | | | | - | MeCab (2020-09-14) | 1.418 | 0.007 | - | | | | - | Kuromoji (Atilika's 0.9.0) | 1.197 | 0.034 | - | Lindera (0.8.0) | 0.542 | 0.010 | - | | | | - | Sudachi (0.5.2) | 0.439 | 0.026 | - | Sudachi.rs (0.6.0-rc1) | 0.427 | 0.009 | +Details can be found [here](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison). + +| Tool Name (version) | Speed [M chars/s] | STD | +| --------------------------------- | -----------------:| -------------:| +| KyTea (2020-04-03) | 1.463 | 0.012 | +| Vaporetto (0.3.0) | **9.716** | 0.115 | +| Vaporetto (0.3.0, `feature=simd`) | **11.035** | 0.144 | +| | | | +| MeCab (2020-09-14) | 4.621 | 0.047 | +| Kuromoji (0.9.0) | 1.470 | 0.074 | +| Lindera (0.8.1) | 1.444 | 0.022 | +| Sudachi (0.5.3) | 0.322 | 0.029 | +| sudachi.rs (0.6.0) | 0.961 | 0.008 | ## Disclaimer From 28f23f1c26497fe4eefaf9c7046b8d55f65de736 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Fri, 3 Dec 2021 14:56:39 +0900 Subject: [PATCH 21/60] Add a figure in README (#12) * Add a figure * Fix * Fix * Update README.md --- README.md | 12 +- figures/comparison.ngp | 1518 ++++++++++++++++++++++++++++++++++++++++ figures/comparison.svg | 179 +++++ figures/comparison.txt | 9 + 4 files changed, 1707 insertions(+), 11 deletions(-) create mode 100644 figures/comparison.ngp create mode 100644 figures/comparison.svg create mode 100644 figures/comparison.txt diff --git a/README.md b/README.md index 8c27a863..da2f46ff 100644 --- a/README.md +++ b/README.md @@ -88,17 +88,7 @@ You can specify all arguments above multiple times. Details can be found [here](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison). -| Tool Name (version) | Speed [M chars/s] | STD | -| --------------------------------- | -----------------:| -------------:| -| KyTea (2020-04-03) | 1.463 | 0.012 | -| Vaporetto (0.3.0) | **9.716** | 0.115 | -| Vaporetto (0.3.0, `feature=simd`) | **11.035** | 0.144 | -| | | | -| MeCab (2020-09-14) | 4.621 | 0.047 | -| Kuromoji (0.9.0) | 1.470 | 0.074 | -| Lindera (0.8.1) | 1.444 | 0.022 | -| Sudachi (0.5.3) | 0.322 | 0.029 | -| sudachi.rs (0.6.0) | 0.961 | 0.008 | +![](./figures/comparison.svg) ## Disclaimer diff --git a/figures/comparison.ngp b/figures/comparison.ngp new file mode 100644 index 00000000..1378083c --- /dev/null +++ b/figures/comparison.ngp @@ -0,0 +1,1518 @@ +#!ngraph +#%creator: Ngraph +#%version: 6.09.03 +new axis name:fX1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=12 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=3200 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=3.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=3200 + axis::direction=9000 + axis::baseline=true + axis::length=3000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=200 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:0' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR1 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=17700 + axis::y=3200 + axis::direction=9000 + axis::baseline=true + axis::length=3000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:1' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 0 1 2 3 + +new axis name:fX2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=12 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=4400 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=1.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=4400 + axis::direction=9000 + axis::baseline=true + axis::length=1000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=3400 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:4' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR2 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=17700 + axis::y=4400 + axis::direction=9000 + axis::baseline=true + axis::length=1000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:5' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 4 5 6 7 + +new axis name:fX3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=12 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=6600 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=2.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=6600 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=4600 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:8' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR3 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=17700 + axis::y=6600 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:9' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 8 9 10 11 + +new axis name:fX4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=12 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=8800 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=left + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=right + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=1200 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fY4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0.5 + axis::max=2.5 + axis::inc=1 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=8800 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference= + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=right + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fU4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=6400 + axis::y=6800 + axis::direction=0 + axis::baseline=true + axis::length=11300 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:12' + axis::gauge=right + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=left + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=center + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +new axis name:fR4 + axis::hidden=false + axis::R=0 + axis::G=0 + axis::B=0 + axis::A=255 + axis::clip=true + axis::redraw_flag=true + axis::min=0 + axis::max=0 + axis::inc=0 + axis::div=0 + axis::type=linear + axis::x=17700 + axis::y=8800 + axis::direction=9000 + axis::baseline=true + axis::length=2000 + axis::width=40 + axis::style= + axis::auto_scale_margin=500 + axis::adjust_axis= + axis::adjust_position=0 + axis::arrow=none + axis::arrow_length=72426 + axis::arrow_width=60000 + axis::wave=none + axis::wave_length=300 + axis::wave_width=40 + axis::reference='axis:13' + axis::gauge=none + axis::gauge_min=0 + axis::gauge_max=0 + axis::gauge_style= + axis::gauge_length1=100 + axis::gauge_width1=40 + axis::gauge_length2=200 + axis::gauge_width2=40 + axis::gauge_length3=300 + axis::gauge_width3=40 + axis::gauge_R=0 + axis::gauge_G=0 + axis::gauge_B=0 + axis::gauge_A=255 + axis::num=none + axis::num_begin=0 + axis::num_step=0 + axis::num_num=-1 + axis::num_auto_norm=5 + axis::num_head= + axis::num_format='%g' + axis::num_tail= + axis::num_log_pow=true + axis::num_pt=2000 + axis::num_space=0 + axis::num_font='Sans-serif' + axis::num_font_style=0 + axis::num_script_size=7000 + axis::num_align=left + axis::num_no_zero=regular + axis::num_direction=horizontal + axis::num_shift_p=0 + axis::num_shift_n=100 + axis::num_R=0 + axis::num_G=0 + axis::num_B=0 + axis::num_A=255 + axis::num_date_format= + axis::num_math= + +axis::grouping 1 12 13 14 15 + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='4-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=1 + data::read_step=1 + data::final_line=4 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:0' + data::axis_y='axis:1' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y= + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=4 + data::read_step=1 + data::final_line=5 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:4' + data::axis_y='axis:5' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='3-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=5 + data::read_step=1 + data::final_line=7 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:8' + data::axis_y='axis:9' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new data + data::hidden=false + data::R=0 + data::G=0 + data::B=0 + data::A=255 + data::clip=true + data::redraw_flag=true + data::source=file + data::save_path=relative + data::x=2 + data::y=0 + data::type=bar_fill_x + data::interpolation=spline + data::fit= + data::math_x= + data::math_y='3-Y' + data::func_f= + data::func_g= + data::func_h= + data::smooth_x=0 + data::smooth_y=0 + data::averaging_type=simple + data::mark_type=0 + data::mark_size=200 + data::line_width=40 + data::line_style= + data::line_join=bevel + data::line_miter_limit=1000 + data::R2=0 + data::G2=0 + data::B2=0 + data::A2=255 + data::remark='#%'\''' + data::ifs=',' + data::csv=false + data::head_skip=7 + data::read_step=1 + data::final_line=9 + data::mask= + data::move_data= + data::move_data_x= + data::move_data_y= + data::axis_x='axis:12' + data::axis_y='axis:13' + data::data_clip=true + data::range_min=1 + data::range_max=10 + data::range_div=512 + data::array= + data::file='./comparison.txt' + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='KyTea (2020-04-03)' + text::x=200 + text::y=800 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Vaporetto (0.3.0)' + text::x=200 + text::y=1800 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Analysis Speed [×10^6@ chars/s]' + text::x=6400 + text::y=10000 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='MeCab (2020-09-14)' + text::x=200 + text::y=4000 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Kuromoji (0.9.0)' + text::x=200 + text::y=5200 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Lindera (0.8.1)' + text::x=200 + text::y=6200 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Sudachi (0.5.3)' + text::x=200 + text::y=7400 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='sudachi.rs (0.6.0)' + text::x=200 + text::y=8400 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new text + text::hidden=false + text::R=0 + text::G=0 + text::B=0 + text::A=255 + text::clip=true + text::redraw_flag=true + text::text='Vaporetto (0.3.0, feature=simd)' + text::x=200 + text::y=2800 + text::pt=1200 + text::font='Sans-serif' + text::style=0 + text::space=0 + text::direction=0 + text::script_size=7000 + text::raw=false + +new gra name:viewer + gra::left_margin=0 + gra::top_margin=0 + gra::zoom=10000 + gra::paper_width=18000 + gra::paper_height=10200 + gra::decimalsign=period + gra::draw_obj='axisgrid axis data merge legend rectangle arc path mark text' diff --git a/figures/comparison.svg b/figures/comparison.svg new file mode 100644 index 00000000..5a75e90e --- /dev/null +++ b/figures/comparison.svg @@ -0,0 +1,179 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/figures/comparison.txt b/figures/comparison.txt new file mode 100644 index 00000000..6a25dd8d --- /dev/null +++ b/figures/comparison.txt @@ -0,0 +1,9 @@ +Tool Name (version),Speed [M chars/s],STD +KyTea (2020-04-03),1.463,0.012 +Vaporetto (0.3.0),9.716,0.115 +Vaporetto (0.3.0+feature=simd),11.035,0.144 +MeCab (2020-09-14),4.621,0.047 +Kuromoji (0.9.0),1.470,0.074 +Lindera (0.8.1),1.444,0.022 +Sudachi (0.5.3),0.322,0.029 +sudachi.rs (0.6.0),0.961,0.008 From 79e7b6e28290f833631a3b057789dfa9cd9a1432 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 7 Dec 2021 11:22:48 +0900 Subject: [PATCH 22/60] Dict manipulation (#13) * wip * Add manipulate_model command * wip * Update doc * Update README.md * Fix * fix * Fix tests * Update README.md * Update README.md * Update Cargo.toml * Update README.md * Update README.md * Update README.md * Update README.md --- Cargo.toml | 1 + README.md | 62 ++++++++++++++++++++++--- manipulate_model/Cargo.toml | 11 +++++ manipulate_model/src/main.rs | 85 ++++++++++++++++++++++++++++++++++ vaporetto/src/dict_model.rs | 89 ++++++++++++++++++++++++++++++++---- vaporetto/src/dict_scorer.rs | 6 +-- vaporetto/src/kytea_model.rs | 10 ++-- vaporetto/src/lib.rs | 1 + vaporetto/src/model.rs | 12 ++++- vaporetto/src/predictor.rs | 26 +++++------ 10 files changed, 264 insertions(+), 39 deletions(-) create mode 100644 manipulate_model/Cargo.toml create mode 100644 manipulate_model/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 09c8e826..3c5b1193 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "vaporetto", "vaporetto_rules", + "manipulate_model", "predict", "train", "evaluate", diff --git a/README.md b/README.md index da2f46ff..655a2b4a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # 🛥 VAporetto: POintwise pREdicTion based TOkenizer Vaporetto is a fast and lightweight pointwise prediction based tokenizer. +This repository includes both a Rust crate that provides APIs for Vaporetto and CLI frontends. [![Crates.io](https://img.shields.io/crates/v/vaporetto)](https://crates.io/crates/vaporetto) [![Documentation](https://docs.rs/vaporetto/badge.svg)](https://docs.rs/vaporetto) @@ -8,9 +9,7 @@ Vaporetto is a fast and lightweight pointwise prediction based tokenizer. [Technical details](https://tech.legalforce.co.jp/entry/2021/09/28/180844) (Japanese) -## Overview - -This repository includes both a Rust crate that provides APIs for Vaporetto and CLI frontends. +## Example Usage ### Try Word Segmentation @@ -36,12 +35,12 @@ Each model is compressed, so you need to decompress the downloaded model file li To convert a KyTea model into a Vaporetto model, run the following command in the Vaporetto root directory. If necessary, the Rust code will be compiled before the conversion process. ``` -% cargo run --release -p convert_kytea_model -- --model-in path/to/jp-0.4.7-5.mod --model-out path/to/jp-0.4.7-5-tokenize.model.zstd +% cargo run --release -p convert_kytea_model -- --model-in path/to/jp-0.4.7-5.mod --model-out path/to/jp-0.4.7-5-tokenize.model.zst ``` Now you can perform tokenization. Run the following command: ``` -% echo '火星猫の生態の調査結果' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zstd +% echo '火星猫の生態の調査結果' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zst ``` The following will be output: @@ -75,7 +74,7 @@ Here is an example: To train a model, use the following command: ``` -% cargo run --release -p train -- --model ./your.model.zstd --tok path/to/full.txt --part path/to/part.txt --dict path/to/dict.txt +% cargo run --release -p train -- --model ./your.model.zst --tok path/to/full.txt --part path/to/part.txt --dict path/to/dict.txt ``` `--tok` argument specifies a fully annotated corpus, and `--part` argument specifies a partially annotated corpus. @@ -84,6 +83,57 @@ A word dictionary is a file with words per line. You can specify all arguments above multiple times. +### Model Manipulation + +For example, `メロンパン` is split into two tokens in the following command: +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zst +朝食 は メロン パン 1 個 だっ た +``` + +Sometimes, the model outputs different results than what you expect. +You can make the `メロンパン` into a single token by manipulating the model following the steps below: + +1. Dump a dictionary by the following command: + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --dump-dict path/to/dictionary.csv + ``` + +2. Edit the dictionary. + + The dictionary is a csv file. Each row contains a word and corresponding weights in the following order: + + * `right_weight` - A weight that is added when the word is found to the right of the boundary. + * `inside_weight` - A weight that is added when the word is overlapped on the boundary. + * `left_weight` - A weight that is added when the word is found to the left of the boundary. + + Vaporetto splits a text when the total weight of the boundary is a positive number, so we add a new entry as follows: + ```diff + メロレオストーシス,6944,-2553,5319 + メロン,8924,-10861,7081 + +メロンパン,0,-100000,0 + メロン果実,4168,-1165,3558 + メロヴィング,6999,-15413,7583 + ``` + + In this case, `-100000` will be added when the boundary is inside of the word `メロンパン`. + + Note that Vaporetto uses 32-bit integers for the total weight, so you have to be careful about overflow. + + In addition, The dictionary cannot contain duplicated words. + When the word is already contained in the dictionary, you have to edit existing weights. + +3. Replaces weight data of a model file + ``` + % cargo run --release -p manipulate_model -- --model-in path/to/jp-0.4.7-5-tokenize.model.zst --replace-dict path/to/dictionary.csv --model-out path/to/jp-0.4.7-5-tokenize-new.model.zst + ``` + +Now `メロンパン` is split into a single token. +``` +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize-new.model.zst +朝食 は メロンパン 1 個 だっ た +``` + ## Speed Comparison of Various Tokenizers Details can be found [here](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison). diff --git a/manipulate_model/Cargo.toml b/manipulate_model/Cargo.toml new file mode 100644 index 00000000..5139cfd9 --- /dev/null +++ b/manipulate_model/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "manipulate_model" +version = "0.1.0" +edition = "2018" + +[dependencies] +csv = "1.1" # Unlicense OR MIT +serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 +structopt = "0.3" # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 +zstd = "0.9" # MIT diff --git a/manipulate_model/src/main.rs b/manipulate_model/src/main.rs new file mode 100644 index 00000000..db4e6e87 --- /dev/null +++ b/manipulate_model/src/main.rs @@ -0,0 +1,85 @@ +use std::fs; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; +use structopt::StructOpt; +use vaporetto::{Model, WordWeightRecord}; + +#[derive(StructOpt, Debug)] +#[structopt( + name = "manipulate_model", + about = "A program to manipulate tarined models." +)] +struct Opt { + /// Input path of the model file + #[structopt(long)] + model_in: PathBuf, + + /// Output path of the model file + #[structopt(long)] + model_out: Option, + + /// Output a dictionary contained in the model. + #[structopt(long)] + dump_dict: Option, + + /// Replace a dictionary if the argument is specified. + #[structopt(long)] + replace_dict: Option, +} + +#[derive(Deserialize, Serialize)] +struct WordWeightRecordFlatten { + word: String, + right: i32, + inside: i32, + left: i32, +} + +fn main() -> Result<(), Box> { + let opt = Opt::from_args(); + + eprintln!("Loading model file..."); + let mut f = zstd::Decoder::new(fs::File::open(opt.model_in)?)?; + let mut model = Model::read(&mut f)?; + + if let Some(path) = opt.dump_dict { + eprintln!("Saving dictionary file..."); + let file = fs::File::create(path)?; + let mut wtr = csv::Writer::from_writer(file); + for data in model.dump_dictionary() { + wtr.serialize(WordWeightRecordFlatten { + word: data.get_word().to_string(), + right: data.get_right_weight(), + inside: data.get_inside_weight(), + left: data.get_left_weight(), + })?; + } + } + + if let Some(path) = opt.replace_dict { + eprintln!("Loading dictionary file..."); + let file = fs::File::open(path)?; + let mut rdr = csv::Reader::from_reader(file); + let mut dict = vec![]; + for result in rdr.deserialize() { + let record: WordWeightRecordFlatten = result?; + dict.push(WordWeightRecord::new( + record.word, + record.right, + record.inside, + record.left, + )); + } + model.replace_dictionary(dict); + } + + if let Some(path) = opt.model_out { + eprintln!("Saving model file..."); + let mut f = zstd::Encoder::new(fs::File::create(path)?, 19)?; + model.write(&mut f)?; + f.finish()?; + } + + Ok(()) +} diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs index b84cd4ac..279a3cd2 100644 --- a/vaporetto/src/dict_model.rs +++ b/vaporetto/src/dict_model.rs @@ -7,7 +7,7 @@ use crate::ngram_model::NgramModel; #[derive(Clone, Copy, Default, Serialize, Deserialize)] pub struct DictWeight { pub right: i32, - pub inner: i32, + pub inside: i32, pub left: i32, } @@ -35,17 +35,70 @@ impl DictModel { Self::Lengthwise(model) => model.is_empty(), } } + + pub fn dump_dictionary(&self) -> Vec { + match self { + Self::Wordwise(model) => model.dump_dictionary(), + Self::Lengthwise(model) => model.dump_dictionary(), + } + } } +/// Record of weights for each word. #[derive(Clone, Serialize, Deserialize)] -pub struct WordwiseDictData { +pub struct WordWeightRecord { pub(crate) word: String, pub(crate) weights: DictWeight, } +impl WordWeightRecord { + /// Creates a new word weight record. + /// + /// # Arguments + /// + /// * `word` - A word. + /// * `right` - A weight of the boundary when the word is found at right. + /// * `inside` - A weight of the boundary when the word is overlapped on the boundary. + /// * `left` - A weight of the boundary when the word is found at left. + /// + /// # Returns + /// + /// A new record. + pub const fn new(word: String, right: i32, inside: i32, left: i32) -> Self { + Self { + word, + weights: DictWeight { + right, + inside, + left, + }, + } + } + + /// Gets a reference to the word. + pub fn get_word(&self) -> &str { + &self.word + } + + /// Gets a `right` weight. + pub const fn get_right_weight(&self) -> i32 { + self.weights.right + } + + /// Gets a `inside` weight. + pub const fn get_inside_weight(&self) -> i32 { + self.weights.inside + } + + /// Gets a `left` weight. + pub const fn get_left_weight(&self) -> i32 { + self.weights.left + } +} + #[derive(Serialize, Deserialize)] pub struct DictModelWordwise { - pub(crate) data: Vec, + pub(crate) dict: Vec, } impl DictModelWordwise { @@ -63,8 +116,8 @@ impl DictModelWordwise { { word_map.insert(word, i); } - let mut new_data = vec![]; - for data in self.data.drain(..) { + let mut new_dict = vec![]; + for data in self.dict.drain(..) { let word_size = data.word.chars().count(); match word_map.get(&data.word) { Some(&idx) if char_window_size >= word_size => { @@ -72,20 +125,24 @@ impl DictModelWordwise { let end = start + word_size; char_ngram_model.data[idx].weights[start] += data.weights.right; for i in start + 1..end { - char_ngram_model.data[idx].weights[i] += data.weights.inner; + char_ngram_model.data[idx].weights[i] += data.weights.inside; } char_ngram_model.data[idx].weights[end] += data.weights.left; } _ => { - new_data.push(data); + new_dict.push(data); } } } - self.data = new_data; + self.dict = new_dict; } pub fn is_empty(&self) -> bool { - self.data.is_empty() + self.dict.is_empty() + } + + pub fn dump_dictionary(&self) -> Vec { + self.dict.clone() } } @@ -121,7 +178,7 @@ impl DictModelLengthwise { let weight = &self.weights[word_size_idx]; char_ngram_model.data[idx].weights[start] += weight.right; for i in start + 1..end { - char_ngram_model.data[idx].weights[i] += weight.inner; + char_ngram_model.data[idx].weights[i] += weight.inside; } char_ngram_model.data[idx].weights[end] += weight.left; } @@ -134,4 +191,16 @@ impl DictModelLengthwise { pub fn is_empty(&self) -> bool { self.words.is_empty() } + + pub fn dump_dictionary(&self) -> Vec { + let mut result = vec![]; + for word in &self.words { + let word = word.clone(); + let word_size = word.chars().count(); + let word_size_idx = word_size.min(self.weights.len()) - 1; + let weights = self.weights[word_size_idx]; + result.push(WordWeightRecord { word, weights }); + } + result + } } diff --git a/vaporetto/src/dict_scorer.rs b/vaporetto/src/dict_scorer.rs index dcc64502..59268538 100644 --- a/vaporetto/src/dict_scorer.rs +++ b/vaporetto/src/dict_scorer.rs @@ -34,7 +34,7 @@ impl DictScorerWordwise { pub fn new(model: DictModelWordwise) -> Result { let mut words = vec![]; let mut weights = vec![]; - for pair in model.data { + for pair in model.dict { words.push(pair.word); weights.push(pair.weights); } @@ -54,7 +54,7 @@ impl DictScorerWordwise { ys[m_start - 1] += dict_weight.right; } for y in &mut ys[m_start..m_end - 1] { - *y += dict_weight.inner; + *y += dict_weight.inside; } if m_end <= ys.len() { ys[m_end - 1] += dict_weight.left; @@ -95,7 +95,7 @@ impl DictScorerLengthwise { ys[m_start - 1] += dict_weight.right; } for y in &mut ys[m_start..m_end - 1] { - *y += dict_weight.inner; + *y += dict_weight.inside; } if m_end <= ys.len() { ys[m_end - 1] += dict_weight.left; diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index a6de598b..95ee1b03 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -3,7 +3,7 @@ use std::io::BufRead; use byteorder::{LittleEndian, ReadBytesExt}; -use crate::dict_model::{DictModel, DictModelWordwise, DictWeight, WordwiseDictData}; +use crate::dict_model::{DictModel, DictModelWordwise, DictWeight, WordWeightRecord}; use crate::errors::{Result, VaporettoError}; use crate::model::Model; use crate::ngram_model::{NgramData, NgramModel}; @@ -433,7 +433,7 @@ impl TryFrom for Model { }); } - let mut dict_data = vec![]; + let mut dict = vec![]; if let Some(kytea_dict) = model.dict { for (w, data) in kytea_dict.dump_items() { let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1; @@ -442,11 +442,11 @@ impl TryFrom for Model { if data.in_dict >> j & 1 == 1 { let offset = 3 * config.dict_n as usize * j + 3 * word_len; weights.right += feature_lookup.dict_vec[offset] as i32; - weights.inner += feature_lookup.dict_vec[offset + 1] as i32; + weights.inside += feature_lookup.dict_vec[offset + 1] as i32; weights.left += feature_lookup.dict_vec[offset + 2] as i32; } } - dict_data.push(WordwiseDictData { + dict.push(WordWeightRecord { word: w.into_iter().collect(), weights, }); @@ -456,7 +456,7 @@ impl TryFrom for Model { Ok(Self { char_ngram_model: NgramModel::new(char_ngrams), type_ngram_model: NgramModel::new(type_ngrams), - dict_model: DictModel::Wordwise(DictModelWordwise { data: dict_data }), + dict_model: DictModel::Wordwise(DictModelWordwise { dict }), quantize_multiplier, diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index 1705d53c..c1214fe7 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -47,6 +47,7 @@ mod trainer; #[cfg(feature = "kytea")] mod kytea_model; +pub use dict_model::WordWeightRecord; pub use model::Model; pub use predictor::Predictor; pub use sentence::{BoundaryType, CharacterType, Sentence}; diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index cbbfa18b..58bdc492 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -2,7 +2,7 @@ use std::io::{Read, Write}; use serde::{Deserialize, Serialize}; -use crate::dict_model::DictModel; +use crate::dict_model::{DictModel, DictModelWordwise, WordWeightRecord}; use crate::ngram_model::NgramModel; #[cfg(feature = "train")] @@ -140,7 +140,7 @@ impl Model { } FeatureContent::DictionaryWord(size) => match feature.rel_position { 0 => dict_weights[size - 1].right = weight as i32, - 1 => dict_weights[size - 1].inner = weight as i32, + 1 => dict_weights[size - 1].inside = weight as i32, 2 => dict_weights[size - 1].left = weight as i32, _ => panic!("Invalid rel_position"), }, @@ -161,4 +161,12 @@ impl Model { type_window_size, } } + + pub fn dump_dictionary(&self) -> Vec { + self.dict_model.dump_dictionary() + } + + pub fn replace_dictionary(&mut self, dict: Vec) { + self.dict_model = DictModel::Wordwise(DictModelWordwise { dict }); + } } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 298ba9e7..e993b03d 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -185,7 +185,7 @@ mod tests { use super::*; use crate::dict_model::{ - DictModel, DictModelLengthwise, DictModelWordwise, DictWeight, WordwiseDictData, + DictModel, DictModelLengthwise, DictModelWordwise, DictWeight, WordWeightRecord, }; use crate::ngram_model::{NgramData, NgramModel}; @@ -262,12 +262,12 @@ mod tests { weights: vec![ DictWeight { right: 40, - inner: 41, + inside: 41, left: 42, }, DictWeight { right: 43, - inner: 44, + inside: 44, left: 45, }, ], @@ -352,17 +352,17 @@ mod tests { weights: vec![ DictWeight { right: 38, - inner: 39, + inside: 39, left: 40, }, DictWeight { right: 41, - inner: 42, + inside: 42, left: 43, }, DictWeight { right: 44, - inner: 45, + inside: 45, left: 46, }, ], @@ -443,28 +443,28 @@ mod tests { }, ]), dict_model: DictModel::Wordwise(DictModelWordwise { - data: vec![ - WordwiseDictData { + dict: vec![ + WordWeightRecord { word: "国民".to_string(), weights: DictWeight { right: 38, - inner: 39, + inside: 39, left: 40, }, }, - WordwiseDictData { + WordWeightRecord { word: "世界".to_string(), weights: DictWeight { right: 41, - inner: 42, + inside: 42, left: 43, }, }, - WordwiseDictData { + WordWeightRecord { word: "世".to_string(), weights: DictWeight { right: 44, - inner: 45, + inside: 45, left: 46, }, }, From 6efa477ac5527626245743ca39147858e2d1adad Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 7 Dec 2021 12:40:15 +0900 Subject: [PATCH 23/60] Add update functions to Sentence (#14) * Add chars field * Separate parsers * Add tests * Use update_raw() in predict command * Fix format and refactoring --- predict/src/main.rs | 40 +- vaporetto/src/sentence.rs | 871 +++++++++++++++++++++++++++++++------- 2 files changed, 751 insertions(+), 160 deletions(-) diff --git a/predict/src/main.rs b/predict/src/main.rs index d8201bce..d8fb3570 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -75,23 +75,31 @@ fn main() -> Result<(), Box> { eprintln!("Start tokenization"); let mut n_boundaries = 0; let start = Instant::now(); - for line in stdin().lock().lines() { - let line = line?; - let s = if opt.no_norm { - let s = Sentence::from_raw(line)?; - predictor.predict(s) - } else { + let mut s = Sentence::from_raw(" ")?; + if opt.no_norm { + for line in stdin().lock().lines() { + let line = line?; + s.update_raw(line)?; + s = predictor.predict(s); + s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); + n_boundaries += s.boundaries().len(); + let toks = s.to_tokenized_string()?; + println!("{}", toks); + } + } else { + let mut s_norm = Sentence::from_raw(" ")?; + for line in stdin().lock().lines() { + let line = line?; let norm = fullwidth_filter.filter(&line); - let mut s_orig = Sentence::from_raw(line)?; - let s = Sentence::from_raw(norm)?; - let s = predictor.predict(s); - s_orig.boundaries_mut().clone_from_slice(s.boundaries()); - s_orig - }; - let s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - n_boundaries += s.boundaries().len(); - let toks = s.to_tokenized_string()?; - println!("{}", toks); + s.update_raw(line)?; + s_norm.update_raw(norm)?; + s_norm = predictor.predict(s_norm); + s.boundaries_mut().clone_from_slice(s_norm.boundaries()); + s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); + n_boundaries += s.boundaries().len(); + let toks = s.to_tokenized_string()?; + println!("{}", toks); + } } let duration = start.elapsed(); eprintln!("Elapsed: {} [sec]", duration.as_secs_f64()); diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index fb32ca4d..c01e9528 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -80,6 +80,7 @@ pub enum BoundaryType { #[derive(Debug, PartialEq, Clone)] pub struct Sentence { pub(crate) text: String, + pub(crate) chars: Vec, pub(crate) str_to_char_pos: Vec, pub(crate) char_to_str_pos: Vec, pub(crate) char_type: Vec, @@ -88,31 +89,205 @@ pub struct Sentence { } impl Sentence { - fn common_info(chars: &[char]) -> (Vec, Vec, Vec) { - let mut char_to_str_pos = Vec::with_capacity(chars.len() + 1); - let mut char_type = Vec::with_capacity(chars.len()); + fn internal_new(text: String, chars: Vec, boundaries: Vec) -> Self { + let mut s = Self { + text, + chars, + str_to_char_pos: Vec::with_capacity(0), + char_to_str_pos: Vec::with_capacity(0), + char_type: Vec::with_capacity(0), + boundaries, + boundary_scores: None, + }; + s.update_common_info(); + s + } + + fn clear(&mut self) { + self.text.clear(); + self.text.push(' '); + self.chars.clear(); + self.chars.push(' '); + self.str_to_char_pos.clear(); + self.str_to_char_pos.push(0); + self.str_to_char_pos.push(1); + self.char_to_str_pos.clear(); + self.char_to_str_pos.push(0); + self.char_to_str_pos.push(1); + self.char_type.clear(); + self.char_type.push(CharacterType::Other as u8); + self.boundaries.clear(); + self.boundary_scores = None; + } + + fn parse_raw_text( + raw_text: &str, + chars: &mut Vec, + boundaries: &mut Vec, + ) -> Result<()> { + if raw_text.is_empty() { + return Err(VaporettoError::invalid_argument("raw_text", "is empty")); + } + + chars.clear(); + + for c in raw_text.chars() { + chars.push(c); + } + boundaries.clear(); + boundaries.resize(chars.len() - 1, BoundaryType::Unknown); + + Ok(()) + } + + fn parse_tokenized_text( + tokenized_text: &str, + text: &mut String, + chars: &mut Vec, + boundaries: &mut Vec, + ) -> Result<()> { + if tokenized_text.is_empty() { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "is empty", + )); + } + + text.clear(); + text.reserve(tokenized_text.len()); + chars.clear(); + boundaries.clear(); + + let mut prev_boundary = false; + let mut escape = false; + for c in tokenized_text.chars() { + match (escape, c) { + (false, '\\') => { + escape = true; + } + (false, ' ') => { + if chars.is_empty() { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "starts with a whitespace", + )); + } else if prev_boundary { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "contains consecutive whitespaces", + )); + } + prev_boundary = true; + } + (_, _) => { + if !chars.is_empty() { + boundaries.push(if prev_boundary { + BoundaryType::WordBoundary + } else { + BoundaryType::NotWordBoundary + }); + } + prev_boundary = false; + escape = false; + text.push(c); + chars.push(c); + } + }; + } + + if prev_boundary { + return Err(VaporettoError::invalid_argument( + "tokenized_text", + "ends with a whitespace", + )); + } + + Ok(()) + } + + fn parse_partial_annotation( + labeled_text: &str, + text: &mut String, + chars: &mut Vec, + boundaries: &mut Vec, + ) -> Result<()> { + if labeled_text.is_empty() { + return Err(VaporettoError::invalid_argument("labeled_text", "is empty")); + } + + let labeled_chars: Vec = labeled_text.chars().collect(); + if labeled_chars.len() % 2 == 0 { + return Err(VaporettoError::invalid_argument( + "labeled_text", + format!("invalid length: {}", labeled_chars.len()), + )); + } + + text.clear(); + text.reserve(labeled_text.len() - labeled_chars.len() / 2); + chars.clear(); + boundaries.clear(); + + for c in labeled_chars.iter().skip(1).step_by(2) { + boundaries.push(match c { + ' ' => BoundaryType::Unknown, + '|' => BoundaryType::WordBoundary, + '-' => BoundaryType::NotWordBoundary, + _ => { + return Err(VaporettoError::invalid_argument( + "labeled_text", + format!("contains invalid boundary character: '{}'", c), + )) + } + }); + } + for c in labeled_chars.into_iter().step_by(2) { + text.push(c); + chars.push(c); + } + + Ok(()) + } + + /// Updates char_to_str_pos, str_to_char_pos, and char_type. + /// + /// This function allocates: + /// + /// * char_to_str_pos: chars.len() + 1 + /// * str_to_char_pos: text.len() + 1 + /// * char_type: chars.len() + /// + /// If these variables already have sufficient spaces, this function reuses them. + fn update_common_info(&mut self) { + self.char_to_str_pos.clear(); + self.str_to_char_pos.clear(); + self.char_type.clear(); + let mut pos = 0; - char_to_str_pos.push(0); - for &c in chars { + self.char_to_str_pos.push(0); + for &c in &self.chars { pos += c.len_utf8(); - char_to_str_pos.push(pos); - char_type.push(CharacterType::get_type(c) as u8) + self.char_to_str_pos.push(pos); + self.char_type.push(CharacterType::get_type(c) as u8) } - let mut str_to_char_pos = vec![0; char_to_str_pos.last().unwrap_or(&0) + 1]; - for (i, &j) in char_to_str_pos.iter().enumerate() { - // j < str_to_char_pos.len() + + debug_assert!(pos == self.text.len()); + + self.str_to_char_pos.fill(0); + self.str_to_char_pos.resize(self.text.len() + 1, 0); + for (i, &j) in self.char_to_str_pos.iter().enumerate() { + // j is always lower than pos + 1, so the following is safe. unsafe { - *str_to_char_pos.get_unchecked_mut(j) = i; + *self.str_to_char_pos.get_unchecked_mut(j) = i; } } - (char_to_str_pos, str_to_char_pos, char_type) } /// Creates a new [`Sentence`] from a given string. /// /// # Arguments /// - /// * `text` - A raw string without any annotation. + /// * `raw_text` - A raw string without any annotation. /// /// # Returns /// @@ -120,7 +295,7 @@ impl Sentence { /// /// # Errors /// - /// If the given `text` is empty, an error variant will be returned. + /// If the given `raw_text` is empty, an error variant will be returned. /// /// # Examples /// @@ -133,29 +308,56 @@ impl Sentence { /// let s = Sentence::from_raw(""); /// assert!(s.is_err()); /// ``` - pub fn from_raw(text: S) -> Result + pub fn from_raw(raw_text: S) -> Result where S: Into, { - let text = text.into(); + let raw_text = raw_text.into(); - if text.is_empty() { - return Err(VaporettoError::invalid_argument("text", "is empty")); - } + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); + Self::parse_raw_text(&raw_text, &mut chars, &mut boundaries)?; - let chars: Vec = text.chars().collect(); - let boundaries = vec![BoundaryType::Unknown; chars.len() - 1]; + Ok(Self::internal_new(raw_text, chars, boundaries)) + } - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); + /// Updates the [`Sentence`] using a given string. + /// + /// # Arguments + /// + /// * `raw_text` - A raw string without any annotation. + /// + /// # Errors + /// + /// If the given `raw_text` is empty, an error variant will be returned. + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_raw("How are you?").unwrap(); + /// s.update_raw("I am file.").unwrap(); + /// assert_eq!("I am file.", s.to_raw_string()); + /// ``` + pub fn update_raw(&mut self, raw_text: S) -> Result<()> + where + S: Into, + { + let raw_text = raw_text.into(); - Ok(Self { - text, - str_to_char_pos, - char_to_str_pos, - char_type, - boundaries, - boundary_scores: None, - }) + match Self::parse_raw_text(&raw_text, &mut self.chars, &mut self.boundaries) { + Ok(_) => { + self.text = raw_text; + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } } /// Gets a string without any annotation. @@ -211,68 +413,61 @@ impl Sentence { { let tokenized_text = tokenized_text.as_ref(); - if tokenized_text.is_empty() { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "is empty", - )); - } + let mut text = String::with_capacity(0); + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); - let tokenized_chars: Vec = tokenized_text.chars().collect(); - let mut chars = Vec::with_capacity(tokenized_chars.len()); - let mut boundaries = Vec::with_capacity(tokenized_chars.len() - 1); + Self::parse_tokenized_text(tokenized_text, &mut text, &mut chars, &mut boundaries)?; - let mut prev_boundary = false; - let mut escape = false; - for c in tokenized_chars { - match (escape, c) { - (false, '\\') => { - escape = true; - } - (false, ' ') => { - if chars.is_empty() { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "starts with a whitespace", - )); - } else if prev_boundary { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "contains consecutive whitespaces", - )); - } - prev_boundary = true; - } - (_, _) => { - if !chars.is_empty() { - boundaries.push(if prev_boundary { - BoundaryType::WordBoundary - } else { - BoundaryType::NotWordBoundary - }); - } - prev_boundary = false; - escape = false; - chars.push(c); - } - }; - } - if prev_boundary { - return Err(VaporettoError::invalid_argument( - "tokenized_text", - "ends with a whitespace", - )); - } + Ok(Self::internal_new(text, chars, boundaries)) + } - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); - Ok(Self { - text: chars.iter().collect(), - char_to_str_pos, - str_to_char_pos, - char_type, - boundaries, - boundary_scores: None, - }) + /// Updates the [`Sentence`] using tokenized string. + /// + /// # Arguments + /// + /// * `tokenized_text` - A tokenized string containing whitespaces for word boundaries. + /// + /// # Errors + /// + /// This function will return an error variant when: + /// + /// * `tokenized_text` is empty. + /// * `tokenized_text` starts/ends with a whitespace. + /// * `tokenized_text` contains consecutive whitespaces. + /// + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_tokenized("How are you?").unwrap(); + /// s.update_tokenized("I am fine").unwrap(); + /// assert_eq!("Iamfine", s.to_raw_string()); + /// ``` + pub fn update_tokenized(&mut self, tokenized_text: S) -> Result<()> + where + S: AsRef, + { + let tokenized_text = tokenized_text.as_ref(); + + match Self::parse_tokenized_text( + tokenized_text, + &mut self.text, + &mut self.chars, + &mut self.boundaries, + ) { + Ok(_) => { + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } } /// Generates a string with whitespaces for word boundaries. @@ -404,46 +599,60 @@ impl Sentence { { let labeled_text = labeled_text.as_ref(); - if labeled_text.is_empty() { - return Err(VaporettoError::invalid_argument("labeled_text", "is empty")); - } + let mut text = String::with_capacity(0); + let mut chars = Vec::with_capacity(0); + let mut boundaries = Vec::with_capacity(0); + Self::parse_partial_annotation(labeled_text, &mut text, &mut chars, &mut boundaries)?; - let labeled_chars: Vec = labeled_text.chars().collect(); - if labeled_chars.len() & 0x01 == 0 { - return Err(VaporettoError::invalid_argument( - "labeled_text", - format!("invalid length: {}", labeled_chars.len()), - )); - } - let mut chars = Vec::with_capacity(labeled_chars.len() / 2 + 1); - let mut boundaries = Vec::with_capacity(labeled_chars.len() / 2); + Ok(Self::internal_new(text, chars, boundaries)) + } - for c in labeled_chars.iter().skip(1).step_by(2) { - boundaries.push(match c { - ' ' => BoundaryType::Unknown, - '|' => BoundaryType::WordBoundary, - '-' => BoundaryType::NotWordBoundary, - _ => { - return Err(VaporettoError::invalid_argument( - "labeled_text", - format!("contains invalid boundary character: '{}'", c), - )) - } - }); - } - for c in labeled_chars.into_iter().step_by(2) { - chars.push(c); - } + /// Updates the [`Sentence`] using a string with partial annotations. + /// + /// # Arguments + /// + /// * `labeled_text` - A string with partial annotations. + /// + /// # Errors + /// + /// This function will return an error variant when: + /// + /// * `labeled_text` is empty. + /// * The length of `lsbeled_text` is even numbers. + /// * `labeled_text` contains invalid boundary characters. + /// + /// When an error is occurred, the sentence will be replaced with a white space. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let mut s = Sentence::from_raw("g-o-o-d|i-d e-a").unwrap(); + /// s.update_partial_annotation("h-e-l-l-o").unwrap(); + /// assert_eq!("hello", s.to_raw_string()); + /// ``` + pub fn update_partial_annotation(&mut self, labeled_text: S) -> Result<()> + where + S: AsRef, + { + let labeled_text = labeled_text.as_ref(); - let (char_to_str_pos, str_to_char_pos, char_type) = Self::common_info(&chars); - Ok(Self { - text: chars.iter().collect(), - char_to_str_pos, - str_to_char_pos, - char_type, - boundaries, - boundary_scores: None, - }) + match Self::parse_partial_annotation( + labeled_text, + &mut self.text, + &mut self.chars, + &mut self.boundaries, + ) { + Ok(_) => { + self.update_common_info(); + Ok(()) + } + Err(e) => { + self.clear(); + Err(e) + } + } } /// Generates a string with partial annotations. @@ -501,9 +710,27 @@ impl Sentence { /// /// # Returns /// - /// A mutable reference to the boundary information. - pub fn boundaries_mut(&mut self) -> &mut [BoundaryType] { - &mut self.boundaries + /// A mutable reference to the boundary information. + pub fn boundaries_mut(&mut self) -> &mut [BoundaryType] { + &mut self.boundaries + } + + /// Gets a reference to the characters. + /// + /// # Returns + /// + /// A reference to the characters. + /// + /// # Examples + /// + /// ``` + /// use vaporetto::Sentence; + /// + /// let s = Sentence::from_raw("A1あエ漢?").unwrap(); + /// assert_eq!(&['A', '1', 'あ', 'エ', '漢', '?'], s.chars()); + /// ``` + pub fn chars(&self) -> &[char] { + &self.chars } /// Gets a reference to the character type information. @@ -576,19 +803,41 @@ mod tests { fn test_sentence_from_raw_empty() { let s = Sentence::from_raw(""); - assert!(s.is_err()); assert_eq!( - "InvalidArgumentError: text: is empty", + "InvalidArgumentError: raw_text: is empty", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_raw_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_raw(""); + + assert_eq!( + "InvalidArgumentError: raw_text: is empty", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: ct2u8vec![Other], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_raw_one() { let s = Sentence::from_raw("あ"); let expected = Sentence { text: "あ".to_string(), + chars: vec!['あ'], str_to_char_pos: vec![0, 0, 0, 1], char_to_str_pos: vec![0, 3], char_type: ct2u8vec![Hiragana], @@ -598,12 +847,33 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_raw_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_raw("あ").unwrap(); + + let expected = Sentence { + text: "あ".to_string(), + chars: vec!['あ'], + str_to_char_pos: vec![0, 0, 0, 1], + char_to_str_pos: vec![0, 3], + char_type: ct2u8vec![Hiragana], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_raw() { let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); let expected = Sentence { text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], str_to_char_pos: vec![ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, @@ -621,6 +891,34 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_raw() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_raw("Rustで良いプログラミング体験を!").unwrap(); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + ], + char_type: ct2u8vec![ + Roman, Roman, Roman, Roman, Hiragana, Kanji, Hiragana, Katakana, Katakana, + Katakana, Katakana, Katakana, Katakana, Katakana, Kanji, Kanji, Hiragana, Other, + ], + boundaries: vec![Unknown; 17], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_raw() { let s = Sentence::from_raw("Rustで良いプログラミング体験を!"); @@ -635,52 +933,137 @@ mod tests { fn test_sentence_from_tokenized_empty() { let s = Sentence::from_tokenized(""); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: tokenized_text: is empty", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized(""); + + assert_eq!( + "InvalidArgumentError: tokenized_text: is empty", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: ct2u8vec![Other], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_start_with_space() { let s = Sentence::from_tokenized(" Rust で 良い プログラミング 体験 を !"); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: tokenized_text: starts with a whitespace", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_start_with_space() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized(" Rust で 良い プログラミング 体験 を !"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: starts with a whitespace", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: ct2u8vec![Other], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_end_with_space() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を ! "); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: tokenized_text: ends with a whitespace", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_end_with_space() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized("Rust で 良い プログラミング 体験 を ! "); + + assert_eq!( + "InvalidArgumentError: tokenized_text: ends with a whitespace", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: ct2u8vec![Other], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_two_spaces() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: tokenized_text: contains consecutive whitespaces", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_tokenized_two_spaces() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_tokenized("Rust で 良い プログラミング 体験 を !"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: contains consecutive whitespaces", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: ct2u8vec![Other], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_one() { let s = Sentence::from_tokenized("あ"); let expected = Sentence { text: "あ".to_string(), + chars: vec!['あ'], str_to_char_pos: vec![0, 0, 0, 1], char_to_str_pos: vec![0, 3], char_type: ct2u8vec![Hiragana], @@ -690,12 +1073,33 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("あ").unwrap(); + + let expected = Sentence { + text: "あ".to_string(), + chars: vec!['あ'], + str_to_char_pos: vec![0, 0, 0, 1], + char_to_str_pos: vec![0, 3], + char_type: ct2u8vec![Hiragana], + boundaries: vec![], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized() { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); let expected = Sentence { text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], str_to_char_pos: vec![ 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, @@ -731,12 +1135,63 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("Rust で 良い プログラミング 体験 を !") + .unwrap(); + + let expected = Sentence { + text: "Rustで良いプログラミング体験を!".to_string(), + chars: vec![ + 'R', 'u', 's', 't', 'で', '良', 'い', 'プ', 'ロ', 'グ', 'ラ', 'ミ', 'ン', 'グ', + '体', '験', 'を', '!', + ], + str_to_char_pos: vec![ + 0, 1, 2, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, + 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 18, + ], + char_to_str_pos: vec![ + 0, 1, 2, 3, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, + ], + char_type: ct2u8vec![ + Roman, Roman, Roman, Roman, Hiragana, Kanji, Hiragana, Katakana, Katakana, + Katakana, Katakana, Katakana, Katakana, Katakana, Kanji, Kanji, Hiragana, Other, + ], + boundaries: vec![ + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + ], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_from_tokenized_with_escape_whitespace() { - let s = Sentence::from_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )"); + let s = Sentence::from_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )").unwrap(); let expected = Sentence { text: "火星猫の生態(M et al.)".to_string(), + chars: vec![ + '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', + ')', + ], str_to_char_pos: vec![ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -767,7 +1222,52 @@ mod tests { ], boundary_scores: None, }; - assert_eq!(expected, s.unwrap()); + assert_eq!(expected, s); + } + + #[test] + fn test_sentence_update_tokenized_escape_whitespace() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("火星 猫 の 生態 ( M \\ et\\ al. )") + .unwrap(); + + let expected = Sentence { + text: "火星猫の生態(M et al.)".to_string(), + chars: vec![ + '火', '星', '猫', 'の', '生', '態', '(', 'M', ' ', 'e', 't', ' ', 'a', 'l', '.', + ')', + ], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, + ], + char_to_str_pos: vec![ + 0, 3, 6, 9, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + ], + char_type: ct2u8vec![ + Kanji, Kanji, Kanji, Hiragana, Kanji, Kanji, Other, Roman, Other, Roman, Roman, + Other, Roman, Roman, Other, Other, + ], + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + NotWordBoundary, + WordBoundary, + ], + boundary_scores: None, + }; + assert_eq!(expected, s); } #[test] @@ -776,6 +1276,7 @@ mod tests { let expected = Sentence { text: "改行に\\nを用いる".to_string(), + chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], str_to_char_pos: vec![ 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, ], @@ -798,12 +1299,41 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_tokenized_with_escape_backslash() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_tokenized("改行 に \\\\n を 用い る").unwrap(); + + let expected = Sentence { + text: "改行に\\nを用いる".to_string(), + chars: vec!['改', '行', 'に', '\\', 'n', 'を', '用', 'い', 'る'], + str_to_char_pos: vec![ + 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 4, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, + ], + char_to_str_pos: vec![0, 3, 6, 9, 10, 11, 14, 17, 20, 23], + char_type: ct2u8vec![ + Kanji, Kanji, Hiragana, Other, Roman, Hiragana, Kanji, Hiragana, Hiragana, + ], + boundaries: vec![ + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + WordBoundary, + NotWordBoundary, + WordBoundary, + ], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_tokenized_string_unknown() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); let result = s.unwrap().to_tokenized_string(); - assert!(result.is_err()); assert_eq!( "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() @@ -835,7 +1365,6 @@ mod tests { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態").unwrap(); let result = s.to_tokenized_vec(); - assert!(result.is_err()); assert_eq!( "InvalidSentenceError: contains an unknown boundary", result.err().unwrap().to_string() @@ -856,21 +1385,41 @@ mod tests { fn test_sentence_from_partial_annotation_empty() { let s = Sentence::from_partial_annotation(""); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: labeled_text: is empty", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_partial_annotation_empty() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation(""); + + assert_eq!( + "InvalidArgumentError: labeled_text: is empty", + &result.err().unwrap().to_string() + ); + } + #[test] fn test_sentence_from_partial_annotation_invalid_length() { - let s = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); + let result = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: labeled_text: invalid length: 12", - &s.err().unwrap().to_string() + &result.err().unwrap().to_string() + ); + } + + #[test] + fn test_sentence_update_partial_annotation_invalid_length() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation("火-星 猫|の|生-態 "); + + assert_eq!( + "InvalidArgumentError: labeled_text: invalid length: 12", + &result.err().unwrap().to_string() ); } @@ -878,19 +1427,30 @@ mod tests { fn test_sentence_from_partial_annotation_invalid_boundary_character() { let s = Sentence::from_partial_annotation("火-星?猫|の|生-態"); - assert!(s.is_err()); assert_eq!( "InvalidArgumentError: labeled_text: contains invalid boundary character: '?'", &s.err().unwrap().to_string() ); } + #[test] + fn test_sentence_update_partial_annotation_invalid_boundary_character() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = s.update_partial_annotation("火-星?猫|の|生-態"); + + assert_eq!( + "InvalidArgumentError: labeled_text: contains invalid boundary character: '?'", + &result.err().unwrap().to_string() + ); + } + #[test] fn test_sentence_from_partial_annotation_one() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); let expected = Sentence { text: "火星猫の生態".to_string(), + chars: vec!['火', '星', '猫', 'の', '生', '態'], str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], char_type: ct2u8vec![Kanji, Kanji, Kanji, Hiragana, Kanji, Kanji], @@ -906,6 +1466,29 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_update_partial_annotation_one() { + let mut s = Sentence::from_raw("12345").unwrap(); + s.update_partial_annotation("火-星 猫|の|生-態").unwrap(); + + let expected = Sentence { + text: "火星猫の生態".to_string(), + chars: vec!['火', '星', '猫', 'の', '生', '態'], + str_to_char_pos: vec![0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 6], + char_to_str_pos: vec![0, 3, 6, 9, 12, 15, 18], + char_type: ct2u8vec![Kanji, Kanji, Kanji, Hiragana, Kanji, Kanji], + boundaries: vec![ + NotWordBoundary, + Unknown, + WordBoundary, + WordBoundary, + NotWordBoundary, + ], + boundary_scores: None, + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_to_partial_annotation_string() { let s = Sentence::from_partial_annotation("火-星 猫|の|生-態"); From d64e05572160bfba2492335ac397de57bed1bedd Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 7 Dec 2021 15:28:11 +0900 Subject: [PATCH 24/60] Update readme (#15) * Update figure * Update README --- README.md | 6 ++-- figures/comparison.svg | 66 +++++++++++++++++++++--------------------- figures/comparison.txt | 18 ++++++------ 3 files changed, 46 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 655a2b4a..89d5383b 100644 --- a/README.md +++ b/README.md @@ -117,9 +117,9 @@ You can make the `メロンパン` into a single token by manipulating the model ``` In this case, `-100000` will be added when the boundary is inside of the word `メロンパン`. - + Note that Vaporetto uses 32-bit integers for the total weight, so you have to be careful about overflow. - + In addition, The dictionary cannot contain duplicated words. When the word is already contained in the dictionary, you have to edit existing weights. @@ -136,6 +136,8 @@ Now `メロンパン` is split into a single token. ## Speed Comparison of Various Tokenizers +Vaporetto is 6.9 times faster than KyTea. With `feature=simd`, it becomes 7.8 times faster. (`simd` option requires Nightly Rust.) + Details can be found [here](https://github.com/legalforce-research/vaporetto/wiki/Speed-Comparison). ![](./figures/comparison.svg) diff --git a/figures/comparison.svg b/figures/comparison.svg index 5a75e90e..5f93598f 100644 --- a/figures/comparison.svg +++ b/figures/comparison.svg @@ -1,6 +1,6 @@ - + @@ -124,45 +124,45 @@ - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + + diff --git a/figures/comparison.txt b/figures/comparison.txt index 6a25dd8d..595c5403 100644 --- a/figures/comparison.txt +++ b/figures/comparison.txt @@ -1,9 +1,9 @@ -Tool Name (version),Speed [M chars/s],STD -KyTea (2020-04-03),1.463,0.012 -Vaporetto (0.3.0),9.716,0.115 -Vaporetto (0.3.0+feature=simd),11.035,0.144 -MeCab (2020-09-14),4.621,0.047 -Kuromoji (0.9.0),1.470,0.074 -Lindera (0.8.1),1.444,0.022 -Sudachi (0.5.3),0.322,0.029 -sudachi.rs (0.6.0),0.961,0.008 +Tool Name (version),Speed [M chars/s] +KyTea (2020-04-03),1.4674450789921388 +Vaporetto (0.3.0),10.07734841348238 +Vaporetto (0.3.0+feature=simd),11.414333204815095 +MeCab (2020-09-14),4.619055018595073 +Kuromoji (0.9.0),1.4837693905013502 +Lindera (0.8.1),1.4499374143314385 +Sudachi (0.5.3),0.3185670881795747 +sudachi.rs (0.6.0),0.9658781319147613 From dc0bff81db7afe4af27cae8446dbe0f59feee3d1 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 13 Dec 2021 12:53:30 +0900 Subject: [PATCH 25/60] Add --scores option to the predict command (#17) * Use u32 instead of f64 for exporting scores * Fix API of vaporetto_rules * Add --scores option * Update README * Fix * Fix doc * Apply suggestions from code review Co-authored-by: Shunsuke Kanda * Fix Co-authored-by: Shunsuke Kanda --- README.md | 39 ++++++-- predict/src/main.rs | 96 +++++++++++++------ train/src/main.rs | 2 +- vaporetto/src/kytea_model.rs | 8 +- vaporetto/src/model.rs | 6 -- vaporetto/src/predictor.rs | 30 ++---- vaporetto/src/sentence.rs | 4 +- vaporetto_rules/README.md | 9 +- vaporetto_rules/src/lib.rs | 14 +-- .../src/string_filters/kytea_fullwidth.rs | 9 +- 10 files changed, 127 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index 89d5383b..167ccfba 100644 --- a/README.md +++ b/README.md @@ -85,14 +85,27 @@ You can specify all arguments above multiple times. ### Model Manipulation -For example, `メロンパン` is split into two tokens in the following command: +Sometimes, your model will output different results than what you expect. +For example, `メロンパン` is split into two tokens in the following command. +We use `--scores` option to show the score of each character boundary: ``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize.model.zst +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize.model.zst 朝食 は メロン パン 1 個 だっ た -``` - -Sometimes, the model outputs different results than what you expect. -You can make the `メロンパン` into a single token by manipulating the model following the steps below: +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -26885 +4:ロン -38896 +5:ンパ 8162 +6:パン -23416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 +``` + +To concatenate `メロンパン` into a single token, manipulate the model in the following steps so that the score of `ンパ` becomes negative: 1. Dump a dictionary by the following command: ``` @@ -130,8 +143,20 @@ You can make the `メロンパン` into a single token by manipulating the model Now `メロンパン` is split into a single token. ``` -% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --model path/to/jp-0.4.7-5-tokenize-new.model.zst +% echo '朝食はメロンパン1個だった' | cargo run --release -p predict -- --scores --model path/to/jp-0.4.7-5-tokenize-new.model.zst 朝食 は メロンパン 1 個 だっ た +0:朝食 -15398 +1:食は 24623 +2:はメ 30261 +3:メロ -126885 +4:ロン -138896 +5:ンパ -91838 +6:パン -123416 +7:ン1 23513 +8:1個 18435 +9:個だ 24964 +10:だっ -15065 +11:った 14178 ``` ## Speed Comparison of Various Tokenizers diff --git a/predict/src/main.rs b/predict/src/main.rs index d8fb3570..7d96a42e 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -1,11 +1,12 @@ use std::fs::File; use std::io::{prelude::*, stdin}; use std::path::PathBuf; +use std::rc::Rc; use std::str::FromStr; use std::time::Instant; use structopt::StructOpt; -use vaporetto::{CharacterType, Model, Predictor, Sentence}; +use vaporetto::{errors::VaporettoError, CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ sentence_filters::{ConcatGraphemeClustersFilter, KyteaWsConstFilter}, string_filters::KyteaFullwidthFilter, @@ -46,15 +47,65 @@ struct Opt { #[structopt(long)] wsconst: Vec, + /// Prints scores. + #[structopt(long)] + scores: bool, + /// Do not normalize input strings before prediction. #[structopt(long)] no_norm: bool, } +fn print_scores(s: &Sentence) { + if let Some(scores) = s.boundary_scores().as_ref() { + for (i, score) in scores.iter().enumerate() { + println!("{}:{}{} {}", i, s.chars()[i], s.chars()[i + 1], score); + } + println!(); + } +} + +fn tokenize( + predictor: &Predictor, + text: impl Into, + mut buf1: Sentence, + mut buf2: Sentence, + pre_filters: &[Box], + post_filters: &[Box], +) -> Result<(String, Sentence, Sentence), VaporettoError> { + let text = text.into(); + if pre_filters.is_empty() { + buf1.update_raw(text)?; + } else { + let text_rc = Rc::new(text); + let filt_text = Rc::try_unwrap( + pre_filters + .iter() + .fold(Rc::clone(&text_rc), |s, filter| Rc::new(filter.filter(&s))), + ) + .unwrap(); + let text = Rc::try_unwrap(text_rc).unwrap(); + buf1.update_raw(filt_text)?; + buf2.update_raw(text)?; + } + buf1 = predictor.predict_with_score(buf1); + buf1 = post_filters.iter().fold(buf1, |s, filter| filter.filter(s)); + let result = if pre_filters.is_empty() { + buf1.to_tokenized_string()? + } else { + buf2.boundaries_mut().copy_from_slice(buf1.boundaries()); + buf2.to_tokenized_string()? + }; + Ok((result, buf1, buf2)) +} + fn main() -> Result<(), Box> { let opt = Opt::from_args(); - let fullwidth_filter = KyteaFullwidthFilter::new(); + let mut pre_filters: Vec> = vec![]; + if !opt.no_norm { + pre_filters.push(Box::new(KyteaFullwidthFilter::new())); + } let mut post_filters: Vec> = vec![]; for wsconst in &opt.wsconst { match wsconst { @@ -73,39 +124,26 @@ fn main() -> Result<(), Box> { let predictor = Predictor::new(model)?; eprintln!("Start tokenization"); - let mut n_boundaries = 0; + let mut n_chars = 0; let start = Instant::now(); - let mut s = Sentence::from_raw(" ")?; - if opt.no_norm { - for line in stdin().lock().lines() { - let line = line?; - s.update_raw(line)?; - s = predictor.predict(s); - s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - n_boundaries += s.boundaries().len(); - let toks = s.to_tokenized_string()?; - println!("{}", toks); - } - } else { - let mut s_norm = Sentence::from_raw(" ")?; - for line in stdin().lock().lines() { - let line = line?; - let norm = fullwidth_filter.filter(&line); - s.update_raw(line)?; - s_norm.update_raw(norm)?; - s_norm = predictor.predict(s_norm); - s.boundaries_mut().clone_from_slice(s_norm.boundaries()); - s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - n_boundaries += s.boundaries().len(); - let toks = s.to_tokenized_string()?; - println!("{}", toks); + let mut buf1 = Sentence::from_raw(" ")?; + let mut buf2 = Sentence::from_raw(" ")?; + for line in stdin().lock().lines() { + let ret = tokenize(&predictor, line?, buf1, buf2, &pre_filters, &post_filters)?; + let result = ret.0; + buf1 = ret.1; + buf2 = ret.2; + println!("{}", result); + if opt.scores { + print_scores(&buf1); } + n_chars += buf1.chars().len(); } let duration = start.elapsed(); eprintln!("Elapsed: {} [sec]", duration.as_secs_f64()); eprintln!( - "Speed: {} [boundaries/sec]", - n_boundaries as f64 / duration.as_secs_f64() + "Speed: {} [chars/sec]", + n_chars as f64 / duration.as_secs_f64() ); Ok(()) diff --git a/train/src/main.rs b/train/src/main.rs index 76c6c590..04b5db2f 100644 --- a/train/src/main.rs +++ b/train/src/main.rs @@ -138,7 +138,7 @@ fn main() -> Result<(), Box> { let line = if opt.no_norm { line } else { - fullwidth_filter.filter(line) + fullwidth_filter.filter(&line) }; dictionary.insert(line); } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 95ee1b03..984bd46c 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -240,7 +240,7 @@ struct LinearModel { _solver_type: u8, _labels: Vec, _bias: bool, - multiplier: f64, + _multiplier: f64, feature_lookup: Option>, } @@ -264,7 +264,7 @@ impl Readable for Option { _solver_type: solver_type, _labels: labels, _bias: bias, - multiplier, + _multiplier: multiplier, feature_lookup, })) } @@ -399,7 +399,6 @@ impl TryFrom for Model { let wordseg_model = model .wordseg_model .ok_or_else(|| VaporettoError::invalid_model("no word segmentation model."))?; - let quantize_multiplier = wordseg_model.multiplier; let feature_lookup = wordseg_model .feature_lookup .ok_or_else(|| VaporettoError::invalid_model("no lookup data."))?; @@ -457,9 +456,6 @@ impl TryFrom for Model { char_ngram_model: NgramModel::new(char_ngrams), type_ngram_model: NgramModel::new(type_ngrams), dict_model: DictModel::Wordwise(DictModelWordwise { dict }), - - quantize_multiplier, - bias, char_window_size: config.char_w as usize, type_window_size: config.type_w as usize, diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index 58bdc492..f5a9a815 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -31,9 +31,6 @@ pub struct Model { pub(crate) char_ngram_model: NgramModel, pub(crate) type_ngram_model: NgramModel>, pub(crate) dict_model: DictModel, - - pub(crate) quantize_multiplier: f64, - pub(crate) bias: i32, pub(crate) char_window_size: usize, pub(crate) type_window_size: usize, @@ -153,9 +150,6 @@ impl Model { words: dict, weights: dict_weights, }), - - quantize_multiplier, - bias, char_window_size, type_window_size, diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index e993b03d..fe84b21e 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -16,8 +16,6 @@ pub struct Predictor { type_scorer: TypeScorer, dict_scorer: Option, - quantize_multiplier: f64, - #[cfg(feature = "simd")] padding: usize, } @@ -56,8 +54,6 @@ impl Predictor { type_scorer, dict_scorer, - quantize_multiplier: model.quantize_multiplier, - #[cfg(feature = "simd")] padding: model.char_window_size.max(model.type_window_size), }) @@ -133,23 +129,14 @@ impl Predictor { if boundaries_size != 0 { let mut ys = vec![0; boundaries_size]; self.predict_impl(&sentence, 0, &mut ys); - let mut scores = sentence - .boundary_scores - .take() - .unwrap_or_else(|| vec![0.; boundaries_size]); - for (y, (b, s)) in ys - .into_iter() - .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) - { + for (&y, b) in ys.iter().zip(sentence.boundaries.iter_mut()) { *b = if y >= 0 { BoundaryType::WordBoundary } else { BoundaryType::NotWordBoundary }; - - *s = y as f64 * self.quantize_multiplier; } - sentence.boundary_scores.replace(scores); + sentence.boundary_scores.replace(ys); } #[cfg(feature = "simd")] @@ -160,7 +147,7 @@ impl Predictor { let mut scores = sentence .boundary_scores .take() - .unwrap_or_else(|| vec![0.; boundaries_size]); + .unwrap_or_else(|| vec![0; boundaries_size]); for (&y, (b, s)) in ys[self.padding..] .into_iter() .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) @@ -171,7 +158,7 @@ impl Predictor { BoundaryType::NotWordBoundary }; - *s = y as f64 * self.quantize_multiplier; + *s = y; } sentence.boundary_scores.replace(scores); } @@ -272,7 +259,6 @@ mod tests { }, ], }), - quantize_multiplier: 0.5, bias: -200, char_window_size: 3, type_window_size: 2, @@ -367,7 +353,6 @@ mod tests { }, ], }), - quantize_multiplier: 0.25, bias: -285, char_window_size: 2, type_window_size: 3, @@ -470,7 +455,6 @@ mod tests { }, ], }), - quantize_multiplier: 0.25, bias: -285, char_window_size: 2, type_window_size: 3, @@ -560,7 +544,7 @@ mod tests { s.boundaries(), ); assert_eq!( - &[-38.5, -2.5, 22.5, 66.0, 66.5, 72.0, 25.0, -16.0], + &[-77, -5, 45, 132, 133, 144, 50, -32], s.boundary_scores().unwrap(), ); } @@ -585,7 +569,7 @@ mod tests { s.boundaries(), ); assert_eq!( - &[-34.5, -27.25, -9.75, 14.25, 26.0, 8.5, -19.75, -28.5], + &[-138, -109, -39, 57, 104, 34, -79, -114], s.boundary_scores().unwrap(), ); } @@ -610,7 +594,7 @@ mod tests { s.boundaries(), ); assert_eq!( - &[-34.5, -27.25, -20.75, 4.5, 16.25, -3.0, -10.25, -18.75], + &[-138, -109, -83, 18, 65, -12, -41, -75], s.boundary_scores().unwrap(), ); } diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index c01e9528..8fd2c800 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -85,7 +85,7 @@ pub struct Sentence { pub(crate) char_to_str_pos: Vec, pub(crate) char_type: Vec, pub(crate) boundaries: Vec, - pub(crate) boundary_scores: Option>, + pub(crate) boundary_scores: Option>, } impl Sentence { @@ -756,7 +756,7 @@ impl Sentence { /// # Returns /// /// If the predictor inserted, the boundary score information is returned. Otherwise, None. - pub fn boundary_scores(&self) -> Option<&[f64]> { + pub fn boundary_scores(&self) -> Option<&[i32]> { self.boundary_scores.as_deref() } diff --git a/vaporetto_rules/README.md b/vaporetto_rules/README.md index 6527833c..f8edbeac 100644 --- a/vaporetto_rules/README.md +++ b/vaporetto_rules/README.md @@ -8,6 +8,7 @@ vaporetto_rules is rule-base filters for Vaporetto. ```rust use std::fs::File; use std::io::BufReader; +use std::rc::Rc; use vaporetto::{CharacterType, Model, Predictor, Sentence}; use vaporetto_rules::{ @@ -18,9 +19,9 @@ use vaporetto_rules::{ let mut f = BufReader::new(File::open("model.bin").unwrap()); let model = Model::read(&mut f).unwrap(); -let mut predictor = Predictor::new(model); +let mut predictor = Predictor::new(model).unwrap(); -let pre_filters: Vec>> = vec![ +let pre_filters: Vec> = vec![ Box::new(KyteaFullwidthFilter::new()), ]; let post_filters: Vec> = vec![ @@ -31,7 +32,9 @@ let post_filters: Vec> = vec![ let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" .to_string(); -let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); +let input = Rc::new(input); +let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); +let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); let sentence = Sentence::from_raw(preproc_input).unwrap(); let sentence = predictor.predict(sentence); diff --git a/vaporetto_rules/src/lib.rs b/vaporetto_rules/src/lib.rs index 9ac6969d..305e0f30 100644 --- a/vaporetto_rules/src/lib.rs +++ b/vaporetto_rules/src/lib.rs @@ -7,6 +7,7 @@ //! ```no_run //! use std::fs::File; //! use std::io::BufReader; +//! use std::rc::Rc; //! //! use vaporetto::{CharacterType, Model, Predictor, Sentence}; //! use vaporetto_rules::{ @@ -19,7 +20,7 @@ //! let model = Model::read(&mut f).unwrap(); //! let mut predictor = Predictor::new(model).unwrap(); //! -//! let pre_filters: Vec>> = vec![ +//! let pre_filters: Vec> = vec![ //! Box::new(KyteaFullwidthFilter::new()), //! ]; //! let post_filters: Vec> = vec![ @@ -30,7 +31,9 @@ //! let input = "Vaporettoは仲良し家族👨‍👨‍👧‍👦を離れ離れにさせません。" //! .to_string(); //! -//! let preproc_input = pre_filters.iter().fold(input, |s, filter| filter.filter(s)); +//! let input = Rc::new(input); +//! let preproc_input = pre_filters.iter().fold(input, |s, filter| Rc::new(filter.filter(&s))); +//! let preproc_input = Rc::try_unwrap(preproc_input).unwrap(); //! //! let sentence = Sentence::from_raw(preproc_input).unwrap(); //! let sentence = predictor.predict(sentence); @@ -62,10 +65,7 @@ pub trait SentenceFilter { fn filter(&self, sentence: Sentence) -> Sentence; } -pub trait StringFilter -where - S: AsRef, -{ +pub trait StringFilter { /// Filter a specified string using rules. /// /// # Arguments: @@ -75,5 +75,5 @@ where /// # Returns /// /// A processed string. - fn filter(&self, string: S) -> String; + fn filter(&self, string: &str) -> String; } diff --git a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs index befead3b..3dc841fc 100644 --- a/vaporetto_rules/src/string_filters/kytea_fullwidth.rs +++ b/vaporetto_rules/src/string_filters/kytea_fullwidth.rs @@ -20,10 +20,7 @@ impl Default for KyteaFullwidthFilter { } } -impl StringFilter for KyteaFullwidthFilter -where - S: AsRef, -{ +impl StringFilter for KyteaFullwidthFilter { /// Replace alphanumerics and symbols to full-width characters. /// /// # Arguments: @@ -33,8 +30,8 @@ where /// # Returns /// /// A processed text. - fn filter(&self, string: S) -> String { - let mut chars: Vec<_> = string.as_ref().chars().collect(); + fn filter(&self, string: &str) -> String { + let mut chars: Vec<_> = string.chars().collect(); for c in &mut chars { *c = match *c { 'a' => 'a', From c749f33cd14e1445275a0e29914e77d46663f1f9 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 16 Dec 2021 11:57:01 +0900 Subject: [PATCH 26/60] Use self contained serializer instead of serde (#18) * Implement R/W * fix * Refactoring * fix * fix * fix --- vaporetto/Cargo.toml | 8 +- vaporetto/src/dict_model.rs | 163 +++++++++++++++++++++++++++++++++-- vaporetto/src/errors.rs | 4 +- vaporetto/src/model.rs | 25 ++++-- vaporetto/src/ngram_model.rs | 115 +++++++++++++++++++++++- 5 files changed, 293 insertions(+), 22 deletions(-) diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index b25d7528..d0be026b 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vaporetto" version = "0.2.0" -edition = "2018" +edition = "2021" authors = ["Koichi Akabe "] description = "Vaporetto: a pointwise prediction based tokenizer" license = "MIT OR Apache-2.0" @@ -13,16 +13,14 @@ categories = ["text-processing"] autotests = false [dependencies] -bincode = "1.3.3" # MIT daachorse = "0.2.0" # MIT or Apache-2.0 -serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 +byteorder = "1.4" # Unlicense or MIT -byteorder = { version = "1.4", optional = true } # Unlicense or MIT liblinear = { version = "1", optional = true } # MIT [features] default = [] -kytea = ["byteorder"] +kytea = [] train = ["liblinear"] simd = [] diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs index 279a3cd2..39b08974 100644 --- a/vaporetto/src/dict_model.rs +++ b/vaporetto/src/dict_model.rs @@ -1,23 +1,51 @@ use std::collections::HashMap; +use std::io::{Read, Write}; +use std::mem; -use serde::{Deserialize, Serialize}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use crate::errors::{Result, VaporettoError}; use crate::ngram_model::NgramModel; -#[derive(Clone, Copy, Default, Serialize, Deserialize)] +#[derive(Clone, Copy, Default)] pub struct DictWeight { pub right: i32, pub inside: i32, pub left: i32, } -#[derive(Serialize, Deserialize)] +impl DictWeight { + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + buf.write_i32::(self.right)?; + buf.write_i32::(self.inside)?; + buf.write_i32::(self.left)?; + Ok(mem::size_of::() * 3) + } + + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + Ok(Self { + right: buf.read_i32::()?, + inside: buf.read_i32::()?, + left: buf.read_i32::()?, + }) + } +} + pub enum DictModel { Wordwise(DictModelWordwise), Lengthwise(DictModelLengthwise), } impl DictModel { + const TYPE_ID_WORDWISE: u8 = 0; + const TYPE_ID_LENGTHWISE: u8 = 1; + pub fn merge_dict_weights( &mut self, char_ngram_model: &mut NgramModel, @@ -42,15 +70,74 @@ impl DictModel { Self::Lengthwise(model) => model.dump_dictionary(), } } + + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let size = match self { + Self::Wordwise(model) => { + buf.write_u8(Self::TYPE_ID_WORDWISE)?; + model.serialize(buf)? + } + Self::Lengthwise(model) => { + buf.write_u8(Self::TYPE_ID_LENGTHWISE)?; + model.serialize(buf)? + } + }; + Ok(mem::size_of::() + size) + } + + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let type_id = buf.read_u8()?; + match type_id { + Self::TYPE_ID_WORDWISE => Ok(Self::Wordwise(DictModelWordwise::deserialize(buf)?)), + Self::TYPE_ID_LENGTHWISE => { + Ok(Self::Lengthwise(DictModelLengthwise::deserialize(buf)?)) + } + _ => Err(VaporettoError::invalid_model( + "invalid type_id of dict_model", + )), + } + } } /// Record of weights for each word. -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone)] pub struct WordWeightRecord { pub(crate) word: String, pub(crate) weights: DictWeight, } +impl WordWeightRecord { + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let word_size = self.word.len(); + buf.write_u32::(word_size.try_into().unwrap())?; + buf.write_all(self.word.as_bytes())?; + let weights_size = self.weights.serialize(&mut buf)?; + Ok(mem::size_of::() + word_size + weights_size) + } + + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let word_size = buf.read_u32::()?; + let mut str_bytes = vec![0; word_size.try_into().unwrap()]; + buf.read_exact(&mut str_bytes)?; + Ok(Self { + word: String::from_utf8(str_bytes)?, + weights: DictWeight::deserialize(&mut buf)?, + }) + } +} + impl WordWeightRecord { /// Creates a new word weight record. /// @@ -96,7 +183,6 @@ impl WordWeightRecord { } } -#[derive(Serialize, Deserialize)] pub struct DictModelWordwise { pub(crate) dict: Vec, } @@ -144,9 +230,33 @@ impl DictModelWordwise { pub fn dump_dictionary(&self) -> Vec { self.dict.clone() } + + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let dict_size = self.dict.len(); + buf.write_u32::(dict_size.try_into().unwrap())?; + let mut total_size = mem::size_of::(); + for entry in &self.dict { + total_size += entry.serialize(&mut buf)?; + } + Ok(total_size) + } + + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let dict_size = buf.read_u32::()?; + let mut dict = Vec::with_capacity(dict_size.try_into().unwrap()); + for _ in 0..dict_size { + dict.push(WordWeightRecord::deserialize(&mut buf)?); + } + Ok(Self { dict }) + } } -#[derive(Serialize, Deserialize)] pub struct DictModelLengthwise { pub(crate) words: Vec, pub(crate) weights: Vec, @@ -203,4 +313,45 @@ impl DictModelLengthwise { } result } + + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let words_size = self.words.len(); + let weights_size = self.weights.len(); + buf.write_u32::(words_size.try_into().unwrap())?; + buf.write_u32::(weights_size.try_into().unwrap())?; + let mut total_size = mem::size_of::() * 2; + for word in &self.words { + let word_size = word.len(); + buf.write_u32::(word_size.try_into().unwrap())?; + buf.write_all(word.as_bytes())?; + total_size += mem::size_of::() + word_size; + } + for weight in &self.weights { + total_size += weight.serialize(&mut buf)?; + } + Ok(total_size) + } + + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let words_size = buf.read_u32::()?; + let weights_size = buf.read_u32::()?; + let mut words = Vec::with_capacity(words_size.try_into().unwrap()); + for _ in 0..words_size { + let word_size = buf.read_u32::()?; + let mut word_bytes = vec![0; word_size.try_into().unwrap()]; + buf.read_exact(&mut word_bytes)?; + words.push(String::from_utf8(word_bytes)?); + } + let mut weights = Vec::with_capacity(weights_size.try_into().unwrap()); + for _ in 0..weights_size { + weights.push(DictWeight::deserialize(&mut buf)?); + } + Ok(Self { words, weights }) + } } diff --git a/vaporetto/src/errors.rs b/vaporetto/src/errors.rs index 863da6cf..5597a8ed 100644 --- a/vaporetto/src/errors.rs +++ b/vaporetto/src/errors.rs @@ -3,6 +3,8 @@ use std::error::Error; use std::fmt; +pub type Result = std::result::Result; + #[derive(Debug)] pub enum VaporettoError { InvalidModel(InvalidModelError), @@ -52,8 +54,6 @@ impl fmt::Display for VaporettoError { impl Error for VaporettoError {} -pub type Result = std::result::Result; - /// Error used when the model is invalid. #[derive(Debug)] pub struct InvalidModelError { diff --git a/vaporetto/src/model.rs b/vaporetto/src/model.rs index f5a9a815..d9508855 100644 --- a/vaporetto/src/model.rs +++ b/vaporetto/src/model.rs @@ -1,8 +1,9 @@ use std::io::{Read, Write}; -use serde::{Deserialize, Serialize}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use crate::dict_model::{DictModel, DictModelWordwise, WordWeightRecord}; +use crate::errors::Result; use crate::ngram_model::NgramModel; #[cfg(feature = "train")] @@ -26,7 +27,6 @@ const EPSILON: f64 = 1e-6; const QUANTIZE_BIT_DEPTH: u8 = 16; /// Model data. -#[derive(Serialize, Deserialize)] pub struct Model { pub(crate) char_ngram_model: NgramModel, pub(crate) type_ngram_model: NgramModel>, @@ -46,11 +46,17 @@ impl Model { /// # Errors /// /// When `wtr` generates an error, it will be returned as is. - pub fn write(&self, wtr: &mut W) -> Result<(), bincode::Error> + pub fn write(&self, mut wtr: W) -> Result<()> where W: Write, { - bincode::serialize_into(wtr, self) + self.char_ngram_model.serialize(&mut wtr)?; + self.type_ngram_model.serialize(&mut wtr)?; + self.dict_model.serialize(&mut wtr)?; + wtr.write_i32::(self.bias)?; + wtr.write_u32::(self.char_window_size.try_into().unwrap())?; + wtr.write_u32::(self.type_window_size.try_into().unwrap())?; + Ok(()) } /// Creates a model from a reader. @@ -66,11 +72,18 @@ impl Model { /// # Errors /// /// When `rdr` generates an error, it will be returned as is. - pub fn read(rdr: &mut R) -> Result + pub fn read(mut rdr: R) -> Result where R: Read, { - bincode::deserialize_from(rdr) + Ok(Self { + char_ngram_model: NgramModel::::deserialize(&mut rdr)?, + type_ngram_model: NgramModel::>::deserialize(&mut rdr)?, + dict_model: DictModel::deserialize(&mut rdr)?, + bias: rdr.read_i32::()?, + char_window_size: rdr.read_u32::()?.try_into().unwrap(), + type_window_size: rdr.read_u32::()?.try_into().unwrap(), + }) } #[cfg(feature = "train")] diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs index 28ce97e6..eceaead6 100644 --- a/vaporetto/src/ngram_model.rs +++ b/vaporetto/src/ngram_model.rs @@ -1,8 +1,12 @@ use std::collections::HashMap; +use std::io::{Read, Write}; +use std::mem; -use serde::{Deserialize, Serialize}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -#[derive(Clone, Serialize, Deserialize)] +use crate::errors::Result; + +#[derive(Clone)] pub struct NgramData where T: Clone, @@ -11,7 +15,62 @@ where pub(crate) weights: Vec, } -#[derive(Serialize, Deserialize)] +impl NgramData +where + T: AsRef<[u8]> + Clone, +{ + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let ngram = self.ngram.as_ref(); + let ngram_size = ngram.len(); + let weights_size = self.weights.len(); + buf.write_u32::(ngram_size.try_into().unwrap())?; + buf.write_u32::(weights_size.try_into().unwrap())?; + buf.write_all(ngram)?; + for &w in &self.weights { + buf.write_i32::(w)?; + } + Ok(mem::size_of::() * 2 + ngram_size + mem::size_of::() * weights_size) + } +} + +impl NgramData { + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let ngram_size = buf.read_u32::()?; + let weights_size = buf.read_u32::()?; + let mut ngram_bytes = vec![0; ngram_size.try_into().unwrap()]; + buf.read_exact(&mut ngram_bytes)?; + let ngram = String::from_utf8(ngram_bytes)?; + let mut weights = vec![]; + for _ in 0..weights_size { + weights.push(buf.read_i32::()?); + } + Ok(Self { ngram, weights }) + } +} + +impl NgramData> { + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let ngram_size = buf.read_u32::()?; + let weights_size = buf.read_u32::()?; + let mut ngram = vec![0; ngram_size.try_into().unwrap()]; + buf.read_exact(&mut ngram)?; + let mut weights = Vec::with_capacity(weights_size.try_into().unwrap()); + for _ in 0..weights_size { + weights.push(buf.read_i32::()?); + } + Ok(Self { ngram, weights }) + } +} + pub struct NgramModel where T: Clone, @@ -60,4 +119,54 @@ where *weights = new_weights.unwrap(); } } + + pub fn serialize(&self, mut buf: W) -> Result + where + W: Write, + { + let data_size = self.data.len(); + buf.write_u32::(data_size.try_into().unwrap())?; + let mut total_size = mem::size_of::(); + for d in &self.data { + total_size += d.serialize(&mut buf)?; + } + buf.write_u8(self.merged.into())?; + Ok(total_size + mem::size_of::()) + } +} + +impl NgramModel { + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let data_size = buf.read_u32::()?; + let mut data = Vec::with_capacity(data_size.try_into().unwrap()); + for _ in 0..data_size { + data.push(NgramData::::deserialize(&mut buf)?); + } + let merged_u8 = buf.read_u8()?; + Ok(Self { + data, + merged: merged_u8 != 0, + }) + } +} + +impl NgramModel> { + pub fn deserialize(mut buf: R) -> Result + where + R: Read, + { + let data_size = buf.read_u32::()?; + let mut data = Vec::with_capacity(data_size.try_into().unwrap()); + for _ in 0..data_size { + data.push(NgramData::>::deserialize(&mut buf)?); + } + let merged_u8 = buf.read_u8()?; + Ok(Self { + data, + merged: merged_u8 != 0, + }) + } } From c4e2d1949957fc8c93c2fce0c7436e865e126155 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 16 Dec 2021 18:19:46 +0900 Subject: [PATCH 27/60] Support comment in the dictionary (#19) * Support inserting comments to words in the dictionary * Update README * fmt --- README.md | 13 +++++++------ manipulate_model/src/main.rs | 3 +++ vaporetto/src/dict_model.rs | 31 +++++++++++++++++++++++++------ vaporetto/src/kytea_model.rs | 1 + vaporetto/src/predictor.rs | 3 +++ 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 167ccfba..a32a3673 100644 --- a/README.md +++ b/README.md @@ -114,19 +114,20 @@ To concatenate `メロンパン` into a single token, manipulate the model in th 2. Edit the dictionary. - The dictionary is a csv file. Each row contains a word and corresponding weights in the following order: + The dictionary is a csv file. Each row contains a word, corresponding weights, and a comment in the following order: * `right_weight` - A weight that is added when the word is found to the right of the boundary. * `inside_weight` - A weight that is added when the word is overlapped on the boundary. * `left_weight` - A weight that is added when the word is found to the left of the boundary. + * `comment` - A comment that does not affect the behaviour. Vaporetto splits a text when the total weight of the boundary is a positive number, so we add a new entry as follows: ```diff - メロレオストーシス,6944,-2553,5319 - メロン,8924,-10861,7081 - +メロンパン,0,-100000,0 - メロン果実,4168,-1165,3558 - メロヴィング,6999,-15413,7583 + メロレオストーシス,6944,-2553,5319, + メロン,8924,-10861,7081, + +メロンパン,0,-100000,0,melon🍈 in English. + メロン果実,4168,-1165,3558, + メロヴィング,6999,-15413,7583, ``` In this case, `-100000` will be added when the boundary is inside of the word `メロンパン`. diff --git a/manipulate_model/src/main.rs b/manipulate_model/src/main.rs index db4e6e87..f074ca85 100644 --- a/manipulate_model/src/main.rs +++ b/manipulate_model/src/main.rs @@ -34,6 +34,7 @@ struct WordWeightRecordFlatten { right: i32, inside: i32, left: i32, + comment: String, } fn main() -> Result<(), Box> { @@ -53,6 +54,7 @@ fn main() -> Result<(), Box> { right: data.get_right_weight(), inside: data.get_inside_weight(), left: data.get_left_weight(), + comment: data.get_comment().to_string(), })?; } } @@ -69,6 +71,7 @@ fn main() -> Result<(), Box> { record.right, record.inside, record.left, + record.comment, )); } model.replace_dictionary(dict); diff --git a/vaporetto/src/dict_model.rs b/vaporetto/src/dict_model.rs index 39b08974..1affe071 100644 --- a/vaporetto/src/dict_model.rs +++ b/vaporetto/src/dict_model.rs @@ -110,6 +110,7 @@ impl DictModel { pub struct WordWeightRecord { pub(crate) word: String, pub(crate) weights: DictWeight, + pub(crate) comment: String, } impl WordWeightRecord { @@ -118,10 +119,13 @@ impl WordWeightRecord { W: Write, { let word_size = self.word.len(); + let comment_size = self.comment.len(); buf.write_u32::(word_size.try_into().unwrap())?; + buf.write_u32::(comment_size.try_into().unwrap())?; buf.write_all(self.word.as_bytes())?; + buf.write_all(self.comment.as_bytes())?; let weights_size = self.weights.serialize(&mut buf)?; - Ok(mem::size_of::() + word_size + weights_size) + Ok(mem::size_of::() * 2 + word_size + weights_size + comment_size) } pub fn deserialize(mut buf: R) -> Result @@ -129,11 +133,15 @@ impl WordWeightRecord { R: Read, { let word_size = buf.read_u32::()?; - let mut str_bytes = vec![0; word_size.try_into().unwrap()]; - buf.read_exact(&mut str_bytes)?; + let comment_size = buf.read_u32::()?; + let mut word_bytes = vec![0; word_size.try_into().unwrap()]; + buf.read_exact(&mut word_bytes)?; + let mut comment_bytes = vec![0; comment_size.try_into().unwrap()]; + buf.read_exact(&mut comment_bytes)?; Ok(Self { - word: String::from_utf8(str_bytes)?, + word: String::from_utf8(word_bytes)?, weights: DictWeight::deserialize(&mut buf)?, + comment: String::from_utf8(comment_bytes)?, }) } } @@ -147,11 +155,12 @@ impl WordWeightRecord { /// * `right` - A weight of the boundary when the word is found at right. /// * `inside` - A weight of the boundary when the word is overlapped on the boundary. /// * `left` - A weight of the boundary when the word is found at left. + /// * `comment` - A comment that does not affect the behaviour. /// /// # Returns /// /// A new record. - pub const fn new(word: String, right: i32, inside: i32, left: i32) -> Self { + pub const fn new(word: String, right: i32, inside: i32, left: i32, comment: String) -> Self { Self { word, weights: DictWeight { @@ -159,6 +168,7 @@ impl WordWeightRecord { inside, left, }, + comment, } } @@ -181,6 +191,11 @@ impl WordWeightRecord { pub const fn get_left_weight(&self) -> i32 { self.weights.left } + + /// Gets a reference to the comment. + pub fn get_comment(&self) -> &str { + &self.comment + } } pub struct DictModelWordwise { @@ -309,7 +324,11 @@ impl DictModelLengthwise { let word_size = word.chars().count(); let word_size_idx = word_size.min(self.weights.len()) - 1; let weights = self.weights[word_size_idx]; - result.push(WordWeightRecord { word, weights }); + result.push(WordWeightRecord { + word, + weights, + comment: "".to_string(), + }); } result } diff --git a/vaporetto/src/kytea_model.rs b/vaporetto/src/kytea_model.rs index 984bd46c..e18575ac 100644 --- a/vaporetto/src/kytea_model.rs +++ b/vaporetto/src/kytea_model.rs @@ -448,6 +448,7 @@ impl TryFrom for Model { dict.push(WordWeightRecord { word: w.into_iter().collect(), weights, + comment: "".to_string(), }); } } diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index fe84b21e..89ea4971 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -436,6 +436,7 @@ mod tests { inside: 39, left: 40, }, + comment: "".to_string(), }, WordWeightRecord { word: "世界".to_string(), @@ -444,6 +445,7 @@ mod tests { inside: 42, left: 43, }, + comment: "".to_string(), }, WordWeightRecord { word: "世".to_string(), @@ -452,6 +454,7 @@ mod tests { inside: 45, left: 46, }, + comment: "".to_string(), }, ], }), From f899ae73bb644361f5bf3e2d27c2361ac08f4064 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 20 Dec 2021 11:12:30 +0900 Subject: [PATCH 28/60] Reimplement merge_weights() using a stack (#20) * Reimplement merging * fix --- vaporetto/src/ngram_model.rs | 43 +++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/vaporetto/src/ngram_model.rs b/vaporetto/src/ngram_model.rs index eceaead6..4fa30179 100644 --- a/vaporetto/src/ngram_model.rs +++ b/vaporetto/src/ngram_model.rs @@ -96,27 +96,40 @@ where return; } self.merged = true; - let ngrams = self + let mut check = vec![false; self.data.len()]; + let ngram_ids: HashMap<_, _> = self .data .iter() .cloned() - .map(|d| (d.ngram.as_ref().to_vec(), d.weights)) - .collect::>(); - for NgramData { ngram, weights } in &mut self.data { - let ngram = ngram.as_ref(); - let mut new_weights: Option> = None; - for st in (0..ngram.len()).rev() { - if let Some(weights) = ngrams.get(&ngram[st..]) { - if let Some(new_weights) = new_weights.as_mut() { - for (w_new, w) in new_weights.iter_mut().zip(weights) { - *w_new += *w; - } - } else { - new_weights.replace(weights.clone()); + .enumerate() + .map(|(i, d)| (d.ngram.as_ref().to_vec(), i)) + .collect(); + let mut stack = vec![]; + for i in 0..self.data.len() { + if check[i] { + continue; + } + stack.push(i); + let ngram = self.data[i].ngram.as_ref(); + for j in 1..ngram.len() { + if let Some(&k) = ngram_ids.get(&ngram[j..]) { + stack.push(k); + if check[k] { + break; } } } - *weights = new_weights.unwrap(); + let mut idx_from = stack.pop().unwrap(); + check[idx_from] = true; + while let Some(idx_to) = stack.pop() { + let mut new_weights = self.data[idx_from].weights.clone(); + for (w1, w2) in new_weights.iter_mut().zip(&self.data[idx_to].weights) { + *w1 += w2; + } + self.data[idx_to].weights = new_weights; + idx_from = idx_to; + check[idx_to] = true; + } } } From 9c43a4e77cc1adaa6b67f2f1d517c2e255ddca21 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 20 Dec 2021 12:45:53 +0900 Subject: [PATCH 29/60] Add --metric option to `evaluate` command (#21) * Add --metric option * format --- evaluate/src/main.rs | 113 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 18 deletions(-) diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index 3e96ecf0..da279869 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -33,6 +33,23 @@ impl FromStr for WsConst { } } +#[derive(Debug)] +enum EvaluationMetric { + CharBoundaryAccuracy, + WordAccuracy, +} + +impl FromStr for EvaluationMetric { + type Err = &'static str; + fn from_str(metric: &str) -> Result { + match metric { + "char" => Ok(Self::CharBoundaryAccuracy), + "word" => Ok(Self::WordAccuracy), + _ => Err("Could not parse a metric value"), + } + } +} + #[derive(StructOpt, Debug)] #[structopt( name = "evaluate", @@ -51,6 +68,12 @@ struct Opt { /// Do not normalize input strings before prediction. #[structopt(long)] no_norm: bool, + + /// Evaluation metric: {char, word}. + /// char: evaluates each charactor boundary. + /// word: evaluates each word using Nagata's method. + #[structopt(long, default_value = "char")] + metric: EvaluationMetric, } fn main() -> Result<(), Box> { @@ -75,10 +98,8 @@ fn main() -> Result<(), Box> { let predictor = Predictor::new(model)?; eprintln!("Start tokenization"); - let mut n_true_positive = 0; - let mut n_false_positive = 0; - let mut n_false_negative = 0; + let mut results = vec![]; for line in stdin().lock().lines() { let s = Sentence::from_tokenized(line?)?; let s = if opt.no_norm { @@ -92,25 +113,81 @@ fn main() -> Result<(), Box> { let reference = s.boundaries().to_vec(); let s = predictor.predict(s); let s = post_filters.iter().fold(s, |s, filter| filter.filter(s)); - for (&r, &h) in reference.iter().zip(s.boundaries()) { - if r == h { - if h == BoundaryType::WordBoundary { - n_true_positive += 1; + results.push((reference, s.boundaries().to_vec())); + } + + match opt.metric { + EvaluationMetric::CharBoundaryAccuracy => { + let mut n_tp = 0; + let mut n_tn = 0; + let mut n_fp = 0; + let mut n_fn = 0; + for (rs, hs) in results { + for (r, h) in rs.into_iter().zip(hs) { + if r == h { + if h == BoundaryType::WordBoundary { + n_tp += 1; + } else { + n_tn += 1; + } + } else if h == BoundaryType::WordBoundary { + n_fp += 1; + } else { + n_fn += 1; + } + } + } + let precision = n_tp as f64 / (n_tp + n_fp) as f64; + let recall = n_tp as f64 / (n_tp + n_fn) as f64; + let f1 = 2. * precision * recall / (precision + recall); + println!("Precision: {}", precision); + println!("Recall: {}", recall); + println!("F1: {}", f1); + println!("TP: {}, TN: {}, FP: {}, FN: {}", n_tp, n_tn, n_fp, n_fn); + } + EvaluationMetric::WordAccuracy => { + // Reference: + // Masaaki Nagata. 1994. A stochastic Japanese morphological analyzer using a forward-DP + // backward-A* n-best search algorithm. In COLING 1994 Volume 1: The 15th International + // Conference on Computational Linguistics. + let mut n_sys = 0; + let mut n_ref = 0; + let mut n_cor = 0; + let mut matched = true; + for (rs, hs) in results { + for (r, h) in rs.into_iter().zip(hs) { + if r == h { + if h == BoundaryType::WordBoundary { + if matched { + n_cor += 1; + } + matched = true; + n_ref += 1; + n_sys += 1; + } + } else { + if h == BoundaryType::WordBoundary { + n_sys += 1; + } else { + n_ref += 1; + } + matched = false; + } } - } else if h == BoundaryType::WordBoundary { - n_false_positive += 1; - } else { - n_false_negative += 1; } + if matched { + n_cor += 1; + } + n_sys += 1; + n_ref += 1; + let precision = n_cor as f64 / n_sys as f64; + let recall = n_cor as f64 / n_ref as f64; + let f1 = 2. * precision * recall / (precision + recall); + println!("Precision: {}", precision); + println!("Recall: {}", recall); + println!("F1: {}", f1); } } - let precision = n_true_positive as f64 / (n_true_positive + n_false_positive) as f64; - let recall = n_true_positive as f64 / (n_true_positive + n_false_negative) as f64; - let f1 = 2. * precision * recall / (precision + recall); - println!("Precision: {}", precision); - println!("Recall: {}", recall); - println!("F1: {}", f1); - Ok(()) } From e067b7d7a70053b39686a79a69ce248d0167676e Mon Sep 17 00:00:00 2001 From: Shunsuke Kanda Date: Tue, 21 Dec 2021 13:36:30 +0900 Subject: [PATCH 30/60] handle empty line (#22) --- evaluate/src/main.rs | 6 +++++- predict/src/main.rs | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/evaluate/src/main.rs b/evaluate/src/main.rs index da279869..81f7d690 100644 --- a/evaluate/src/main.rs +++ b/evaluate/src/main.rs @@ -101,7 +101,11 @@ fn main() -> Result<(), Box> { let mut results = vec![]; for line in stdin().lock().lines() { - let s = Sentence::from_tokenized(line?)?; + let line = line?; + if line.is_empty() { + continue; + } + let s = Sentence::from_tokenized(line)?; let s = if opt.no_norm { s } else { diff --git a/predict/src/main.rs b/predict/src/main.rs index 7d96a42e..6edc29a7 100644 --- a/predict/src/main.rs +++ b/predict/src/main.rs @@ -129,7 +129,12 @@ fn main() -> Result<(), Box> { let mut buf1 = Sentence::from_raw(" ")?; let mut buf2 = Sentence::from_raw(" ")?; for line in stdin().lock().lines() { - let ret = tokenize(&predictor, line?, buf1, buf2, &pre_filters, &post_filters)?; + let line = line?; + if line.is_empty() { + println!(); + continue; + } + let ret = tokenize(&predictor, line, buf1, buf2, &pre_filters, &post_filters)?; let result = ret.0; buf1 = ret.1; buf2 = ret.2; From 1718315062b77d878af6007616f393e134637d21 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Tue, 21 Dec 2021 18:23:40 +0900 Subject: [PATCH 31/60] Fix error message (#23) * Fix error messages * Update README --- README.md | 3 +++ vaporetto/src/sentence.rs | 54 ++++++++++++++++++++++----------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index a32a3673..d10e2426 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,9 @@ To train a model, use the following command: You can also specify a word dictionary with `--dict` argument. A word dictionary is a file with words per line. +The trainer does not accept empty lines. +Therefore, remove all empty lines from the corpus before training. + You can specify all arguments above multiple times. ### Model Manipulation diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index 8fd2c800..837ea50f 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -126,7 +126,10 @@ impl Sentence { boundaries: &mut Vec, ) -> Result<()> { if raw_text.is_empty() { - return Err(VaporettoError::invalid_argument("raw_text", "is empty")); + return Err(VaporettoError::invalid_argument( + "raw_text", + "must contain at least one character", + )); } chars.clear(); @@ -149,7 +152,7 @@ impl Sentence { if tokenized_text.is_empty() { return Err(VaporettoError::invalid_argument( "tokenized_text", - "is empty", + "must contain at least one character", )); } @@ -169,12 +172,12 @@ impl Sentence { if chars.is_empty() { return Err(VaporettoError::invalid_argument( "tokenized_text", - "starts with a whitespace", + "must not start with a whitespace", )); } else if prev_boundary { return Err(VaporettoError::invalid_argument( "tokenized_text", - "contains consecutive whitespaces", + "must not contain consecutive whitespaces", )); } prev_boundary = true; @@ -198,7 +201,7 @@ impl Sentence { if prev_boundary { return Err(VaporettoError::invalid_argument( "tokenized_text", - "ends with a whitespace", + "must not end with a whitespace", )); } @@ -212,14 +215,17 @@ impl Sentence { boundaries: &mut Vec, ) -> Result<()> { if labeled_text.is_empty() { - return Err(VaporettoError::invalid_argument("labeled_text", "is empty")); + return Err(VaporettoError::invalid_argument( + "labeled_text", + "must contain at least one character", + )); } let labeled_chars: Vec = labeled_text.chars().collect(); if labeled_chars.len() % 2 == 0 { return Err(VaporettoError::invalid_argument( "labeled_text", - format!("invalid length: {}", labeled_chars.len()), + "must contain odd number of characters", )); } @@ -236,7 +242,7 @@ impl Sentence { _ => { return Err(VaporettoError::invalid_argument( "labeled_text", - format!("contains invalid boundary character: '{}'", c), + format!("contains an invalid boundary character: '{}'", c), )) } }); @@ -804,7 +810,7 @@ mod tests { let s = Sentence::from_raw(""); assert_eq!( - "InvalidArgumentError: raw_text: is empty", + "InvalidArgumentError: raw_text: must contain at least one character", &s.err().unwrap().to_string() ); } @@ -815,7 +821,7 @@ mod tests { let result = s.update_raw(""); assert_eq!( - "InvalidArgumentError: raw_text: is empty", + "InvalidArgumentError: raw_text: must contain at least one character", &result.err().unwrap().to_string() ); @@ -934,7 +940,7 @@ mod tests { let s = Sentence::from_tokenized(""); assert_eq!( - "InvalidArgumentError: tokenized_text: is empty", + "InvalidArgumentError: tokenized_text: must contain at least one character", &s.err().unwrap().to_string() ); } @@ -945,7 +951,7 @@ mod tests { let result = s.update_tokenized(""); assert_eq!( - "InvalidArgumentError: tokenized_text: is empty", + "InvalidArgumentError: tokenized_text: must contain at least one character", &result.err().unwrap().to_string() ); @@ -966,7 +972,7 @@ mod tests { let s = Sentence::from_tokenized(" Rust で 良い プログラミング 体験 を !"); assert_eq!( - "InvalidArgumentError: tokenized_text: starts with a whitespace", + "InvalidArgumentError: tokenized_text: must not start with a whitespace", &s.err().unwrap().to_string() ); } @@ -977,7 +983,7 @@ mod tests { let result = s.update_tokenized(" Rust で 良い プログラミング 体験 を !"); assert_eq!( - "InvalidArgumentError: tokenized_text: starts with a whitespace", + "InvalidArgumentError: tokenized_text: must not start with a whitespace", &result.err().unwrap().to_string() ); @@ -998,7 +1004,7 @@ mod tests { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を ! "); assert_eq!( - "InvalidArgumentError: tokenized_text: ends with a whitespace", + "InvalidArgumentError: tokenized_text: must not end with a whitespace", &s.err().unwrap().to_string() ); } @@ -1009,7 +1015,7 @@ mod tests { let result = s.update_tokenized("Rust で 良い プログラミング 体験 を ! "); assert_eq!( - "InvalidArgumentError: tokenized_text: ends with a whitespace", + "InvalidArgumentError: tokenized_text: must not end with a whitespace", &result.err().unwrap().to_string() ); @@ -1030,7 +1036,7 @@ mod tests { let s = Sentence::from_tokenized("Rust で 良い プログラミング 体験 を !"); assert_eq!( - "InvalidArgumentError: tokenized_text: contains consecutive whitespaces", + "InvalidArgumentError: tokenized_text: must not contain consecutive whitespaces", &s.err().unwrap().to_string() ); } @@ -1041,7 +1047,7 @@ mod tests { let result = s.update_tokenized("Rust で 良い プログラミング 体験 を !"); assert_eq!( - "InvalidArgumentError: tokenized_text: contains consecutive whitespaces", + "InvalidArgumentError: tokenized_text: must not contain consecutive whitespaces", &result.err().unwrap().to_string() ); @@ -1386,7 +1392,7 @@ mod tests { let s = Sentence::from_partial_annotation(""); assert_eq!( - "InvalidArgumentError: labeled_text: is empty", + "InvalidArgumentError: labeled_text: must contain at least one character", &s.err().unwrap().to_string() ); } @@ -1397,7 +1403,7 @@ mod tests { let result = s.update_partial_annotation(""); assert_eq!( - "InvalidArgumentError: labeled_text: is empty", + "InvalidArgumentError: labeled_text: must contain at least one character", &result.err().unwrap().to_string() ); } @@ -1407,7 +1413,7 @@ mod tests { let result = Sentence::from_partial_annotation("火-星 猫|の|生-態 "); assert_eq!( - "InvalidArgumentError: labeled_text: invalid length: 12", + "InvalidArgumentError: labeled_text: must contain odd number of characters", &result.err().unwrap().to_string() ); } @@ -1418,7 +1424,7 @@ mod tests { let result = s.update_partial_annotation("火-星 猫|の|生-態 "); assert_eq!( - "InvalidArgumentError: labeled_text: invalid length: 12", + "InvalidArgumentError: labeled_text: must contain odd number of characters", &result.err().unwrap().to_string() ); } @@ -1428,7 +1434,7 @@ mod tests { let s = Sentence::from_partial_annotation("火-星?猫|の|生-態"); assert_eq!( - "InvalidArgumentError: labeled_text: contains invalid boundary character: '?'", + "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", &s.err().unwrap().to_string() ); } @@ -1439,7 +1445,7 @@ mod tests { let result = s.update_partial_annotation("火-星?猫|の|生-態"); assert_eq!( - "InvalidArgumentError: labeled_text: contains invalid boundary character: '?'", + "InvalidArgumentError: labeled_text: contains an invalid boundary character: '?'", &result.err().unwrap().to_string() ); } From 006a904cbbdb86643b665a8f423b90b7406b2771 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Wed, 22 Dec 2021 15:48:13 +0900 Subject: [PATCH 32/60] Portable simd feature (#24) * Add portable-simd feature * Update README * fmt * fix * Update README.md * Enable simd when portable-simd is specified * Fix README * fix --- evaluate/Cargo.toml | 2 +- predict/Cargo.toml | 2 +- vaporetto/Cargo.toml | 1 + vaporetto/README.md | 4 +++- vaporetto/src/char_scorer.rs | 24 ++++++++++++++++++++---- vaporetto/src/lib.rs | 5 ++++- vaporetto/src/predictor.rs | 5 +++-- vaporetto/src/trainer.rs | 6 +----- 8 files changed, 34 insertions(+), 15 deletions(-) diff --git a/evaluate/Cargo.toml b/evaluate/Cargo.toml index a4d7b1eb..a05e29f4 100644 --- a/evaluate/Cargo.toml +++ b/evaluate/Cargo.toml @@ -5,6 +5,6 @@ edition = "2018" [dependencies] structopt = "0.3" # MIT or Apache-2.0 -vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto", features = ["simd"] } # MIT or Apache-2.0 vaporetto_rules = { path = "../vaporetto_rules" } # MIT or Apache-2.0 zstd = "0.9" # MIT diff --git a/predict/Cargo.toml b/predict/Cargo.toml index 66040c0b..5817a39f 100644 --- a/predict/Cargo.toml +++ b/predict/Cargo.toml @@ -5,6 +5,6 @@ edition = "2018" [dependencies] structopt = "0.3" # MIT or Apache-2.0 -vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 +vaporetto = { path = "../vaporetto", features = ["simd"] } # MIT or Apache-2.0 vaporetto_rules = { path = "../vaporetto_rules" } # MIT or Apache-2.0 zstd = "0.9" # MIT diff --git a/vaporetto/Cargo.toml b/vaporetto/Cargo.toml index d0be026b..69f57f22 100644 --- a/vaporetto/Cargo.toml +++ b/vaporetto/Cargo.toml @@ -23,6 +23,7 @@ default = [] kytea = [] train = ["liblinear"] simd = [] +portable-simd = ["simd"] [package.metadata.docs.rs] all-features = true diff --git a/vaporetto/README.md b/vaporetto/README.md index 6b774112..4d0038e3 100644 --- a/vaporetto/README.md +++ b/vaporetto/README.md @@ -25,7 +25,9 @@ println!("{:?}", s.to_tokenized_vec().unwrap()); * `kytea` - Enables the reader for models generated by KyTea. * `train` - Enables the trainer. -* `simd` - Use the SIMD operations for prediction. (Nightly version of Rust is required.) +* `simd` - Uses a SIMD-conscious data layout expecting your compiler enables SIMD optimization. +* `portable-simd` - Uses the [portable SIMD API](https://github.com/rust-lang/portable-simd) instead + of our SIMD-conscious data layout. (Nightly Rust is required.) ## License diff --git a/vaporetto/src/char_scorer.rs b/vaporetto/src/char_scorer.rs index 4cf82351..27e14b55 100644 --- a/vaporetto/src/char_scorer.rs +++ b/vaporetto/src/char_scorer.rs @@ -4,7 +4,7 @@ use crate::errors::{Result, VaporettoError}; use crate::ngram_model::NgramModel; use crate::sentence::Sentence; -#[cfg(feature = "simd")] +#[cfg(all(feature = "simd", feature = "portable-simd"))] use std::simd::i32x8; pub enum CharScorer { @@ -89,7 +89,12 @@ impl CharScorerNaive { #[cfg(feature = "simd")] pub struct CharScorerSimd { pma: DoubleArrayAhoCorasick, + + #[cfg(feature = "portable-simd")] weights: Vec, + #[cfg(not(feature = "portable-simd"))] + weights: Vec<[i32; 8]>, + window_size: usize, } @@ -109,7 +114,10 @@ impl CharScorerSimd { "invalid size of weight vector", )); } + #[cfg(feature = "portable-simd")] weights.push(i32x8::from_array(s)); + #[cfg(not(feature = "portable-simd"))] + weights.push(s); } Ok(Self { pma, @@ -126,9 +134,17 @@ impl CharScorerSimd { // Therefore, the following code is safe. let weights = unsafe { self.weights.get_unchecked(m.pattern()) }; let ys_slice = &mut ys[offset as usize..offset as usize + 8]; - let mut target = i32x8::from_slice(ys_slice); - target += weights; - ys_slice.copy_from_slice(target.as_array()); + + #[cfg(feature = "portable-simd")] + { + let mut target = i32x8::from_slice(ys_slice); + target += weights; + ys_slice.copy_from_slice(target.as_array()); + } + #[cfg(not(feature = "portable-simd"))] + for (y, w) in ys_slice.iter_mut().zip(weights) { + *y += w; + } } } diff --git a/vaporetto/src/lib.rs b/vaporetto/src/lib.rs index c1214fe7..db0e803a 100644 --- a/vaporetto/src/lib.rs +++ b/vaporetto/src/lib.rs @@ -1,5 +1,8 @@ #![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr( + all(feature = "simd", feature = "portable-simd"), + feature(portable_simd) +)] //! # Vaporetto //! diff --git a/vaporetto/src/predictor.rs b/vaporetto/src/predictor.rs index 89ea4971..bf43db50 100644 --- a/vaporetto/src/predictor.rs +++ b/vaporetto/src/predictor.rs @@ -99,7 +99,7 @@ impl Predictor { let mut ys = vec![0; ys_size]; self.predict_impl(&sentence, self.padding, &mut ys); for (&y, b) in ys[self.padding..] - .into_iter() + .iter() .zip(sentence.boundaries.iter_mut()) { *b = if y >= 0 { @@ -148,8 +148,9 @@ impl Predictor { .boundary_scores .take() .unwrap_or_else(|| vec![0; boundaries_size]); + scores.resize(boundaries_size, 0); for (&y, (b, s)) in ys[self.padding..] - .into_iter() + .iter() .zip(sentence.boundaries.iter_mut().zip(scores.iter_mut())) { *b = if y >= 0 { diff --git a/vaporetto/src/trainer.rs b/vaporetto/src/trainer.rs index c67d1cf2..055df7b4 100644 --- a/vaporetto/src/trainer.rs +++ b/vaporetto/src/trainer.rs @@ -147,11 +147,7 @@ impl<'a> Dataset<'a> { let mut feature_ids = BTreeMap::new(); for f in example.features { let fid = self.fid_manager.get_id(f) + 1; - if let Some(v) = feature_ids.get_mut(&fid) { - *v += 1.0; - } else { - feature_ids.insert(fid, 1.0); - } + *feature_ids.entry(fid).or_insert(0.0) += 1.0; } self.xs.push(feature_ids.into_iter().collect()); self.ys.push(example.label as u8 as f64); From 3430df841a2c1c26b5a019077636f18f186858e8 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Thu, 23 Dec 2021 12:15:17 +0900 Subject: [PATCH 33/60] Reimplement the portable JS builder in Python (#25) * Use Python instead of Bash * Add py * fix * fix * fix * fix * Update vaporetto_wasm/README.md Co-authored-by: Shunsuke Kanda Co-authored-by: Shunsuke Kanda --- vaporetto_wasm/README.md | 9 ++--- vaporetto_wasm/build_portable_js.py | 55 +++++++++++++++++++++++++++++ vaporetto_wasm/build_portable_js.sh | 16 --------- vaporetto_wasm/src/lib.rs | 4 +-- vaporetto_wasm/www/index.js | 2 +- 5 files changed, 63 insertions(+), 23 deletions(-) create mode 100755 vaporetto_wasm/build_portable_js.py delete mode 100755 vaporetto_wasm/build_portable_js.sh diff --git a/vaporetto_wasm/README.md b/vaporetto_wasm/README.md index 7507e68d..9d7f4fde 100644 --- a/vaporetto_wasm/README.md +++ b/vaporetto_wasm/README.md @@ -2,14 +2,14 @@ ## How to build? -1. Build a model file refering the [documentation](../README.md). +1. Build a model file following the [documentation](../README.md). -2. Build a JS file containing a web assembly using `build_portable_js.sh`. +2. Build a JS file containing a web assembly using `build_portable_js.py`. This script requires a model file, an identifier, and an output path. - + The identifier must consist of alphanumeric characters and underscores. ``` - ./build_portable_js.sh + ./build_portable_js.py --model --identifier --output ``` 3. You can use the generated JS file like the follwing code: @@ -17,6 +17,7 @@ +