Skip to content

Commit

Permalink
Construct weights of character n-grams with dictionary features (#8)
Browse files Browse the repository at this point in the history
* wip

* Fix

* wip

* wip

* Use struct instead of slice for dict weights

* Remove unnecessary pub(crate)

* Fix for not model-quantize

* Fix format

* Use derive(Default)
  • Loading branch information
vbkaisetsu authored Nov 1, 2021
1 parent 514564b commit 9c8318e
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 37 deletions.
10 changes: 5 additions & 5 deletions vaporetto/src/kytea_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::io::BufRead;
use anyhow::{anyhow, Result};
use byteorder::{LittleEndian, ReadBytesExt};

use crate::model::Model;
use crate::model::{DictWeight, Model};

struct KyteaConfig {
_model_tag: String,
Expand Down Expand Up @@ -430,13 +430,13 @@ impl TryFrom<KyteaModel> for Model {
if let Some(kytea_dict) = model.dict {
for (w, data) in kytea_dict.dump_items() {
let word_len = std::cmp::min(w.len(), config.dict_n as usize) - 1;
let mut weights = [0i32; 3];
let mut weights = DictWeight::default();
for j in 0..kytea_dict.n_dicts as usize {
if data.in_dict >> j & 1 == 1 {
let offset = 3 * config.dict_n as usize * j + 3 * word_len;
weights[0] += feature_lookup.dict_vec[offset] as i32;
weights[1] += feature_lookup.dict_vec[offset + 1] as i32;
weights[2] += feature_lookup.dict_vec[offset + 2] as i32;
weights.right += feature_lookup.dict_vec[offset] as i32;
weights.inner += feature_lookup.dict_vec[offset + 1] as i32;
weights.left += feature_lookup.dict_vec[offset + 2] as i32;
}
}
dict_weights.push(weights);
Expand Down
22 changes: 15 additions & 7 deletions vaporetto/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ pub type ScoreValue = f64;
#[cfg(feature = "model-quantize")]
pub type ScoreValue = i32;

#[derive(Clone, Copy, Default, Serialize, Deserialize)]
pub struct DictWeight {
pub right: ScoreValue,
pub inner: ScoreValue,
pub left: ScoreValue,
}

/// Model data.
#[derive(Serialize, Deserialize)]
pub struct Model {
Expand All @@ -32,7 +39,7 @@ pub struct Model {

pub(crate) word_weights: Vec<Vec<WeightValue>>,
pub(crate) type_weights: Vec<Vec<WeightValue>>,
pub(crate) dict_weights: Vec<[ScoreValue; 3]>,
pub(crate) dict_weights: Vec<DictWeight>,

#[cfg(feature = "model-quantize")]
pub(crate) quantize_multiplier: f64,
Expand Down Expand Up @@ -102,9 +109,7 @@ impl Model {
let mut types = vec![];
let mut word_weights = vec![];
let mut type_weights = vec![];
let mut dict_weights: Vec<[_; 3]> = (0..dict_word_max_size)
.map(|_| [ScoreValue::default(); 3])
.collect();
let mut dict_weights = vec![DictWeight::default(); dict_word_max_size];
let mut word_ids = StringIdManager::new();
let mut type_ids = StringIdManager::new();

Expand Down Expand Up @@ -155,9 +160,12 @@ impl Model {
}
type_weights[id][feature.rel_position] = weight as WeightValue;
}
FeatureContent::DictionaryWord(size) => {
dict_weights[size - 1][feature.rel_position] = weight as ScoreValue;
}
FeatureContent::DictionaryWord(size) => match feature.rel_position {
0 => dict_weights[size - 1].right = weight as ScoreValue,
1 => dict_weights[size - 1].inner = weight as ScoreValue,
2 => dict_weights[size - 1].left = weight as ScoreValue,
_ => panic!("Invalid rel_position"),
},
};
}
Self {
Expand Down
210 changes: 185 additions & 25 deletions vaporetto/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::thread;
#[cfg(feature = "multithreading")]
use crossbeam_channel::{Receiver, Sender};

use crate::model::{Model, ScoreValue, WeightValue};
use crate::model::{DictWeight, Model, ScoreValue};
use crate::sentence::{BoundaryType, Sentence};
use crate::type_scorer::TypeScorer;

Expand All @@ -22,7 +22,7 @@ pub struct Predictor {
word_pma: DoubleArrayAhoCorasick,
dict_pma: DoubleArrayAhoCorasick,
word_weights: Vec<Vec<ScoreValue>>,
dict_weights: Vec<[ScoreValue; 3]>,
dict_weights: Vec<DictWeight>,
dict_word_wise: bool,
bias: ScoreValue,
char_window_size: usize,
Expand All @@ -46,21 +46,40 @@ impl Predictor {
/// A new predictor.
pub fn new(model: Model) -> Self {
let bias = model.bias;
let word_weights = Self::merge_weights(&model.words, &model.word_weights);
let type_weights = Self::merge_weights(&model.types, &model.type_weights);

let words = model.words;
let dict = model.dict;
let dict_weights = model.dict_weights;

let mut word_weights: Vec<_> = model
.word_weights
.into_iter()
.map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect())
.collect();
let type_weights: Vec<_> = model
.type_weights
.into_iter()
.map(|ws| ws.into_iter().map(|w| w as ScoreValue).collect())
.collect();

let (dict, dict_weights) = Self::merge_dict_weights(
dict,
dict_weights,
&words,
&mut word_weights,
model.char_window_size,
model.dict_word_wise,
);

let word_weights = Self::merge_weights(&words, &word_weights);
let type_weights = Self::merge_weights(&model.types, &type_weights);

#[cfg(feature = "model-quantize")]
let bias = bias as i32;
#[cfg(feature = "model-quantize")]
let dict_weights = dict_weights
.iter()
.map(|ws| [ws[0] as i32, ws[1] as i32, ws[2] as i32])
.collect();

let word_pma = DoubleArrayAhoCorasick::new(model.words).unwrap();
let word_pma = DoubleArrayAhoCorasick::new(words).unwrap();
let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap();
let dict_pma = DoubleArrayAhoCorasick::new(model.dict).unwrap();
let dict_pma = DoubleArrayAhoCorasick::new(dict).unwrap();

let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size);

Expand All @@ -81,7 +100,63 @@ impl Predictor {
}
}

fn merge_weights(words: &[Vec<u8>], weights: &[Vec<WeightValue>]) -> Vec<Vec<ScoreValue>> {
fn merge_dict_weights(
dict: Vec<Vec<u8>>,
dict_weights: Vec<DictWeight>,
words: &[Vec<u8>],
word_weights: &mut Vec<Vec<ScoreValue>>,
char_window_size: usize,
dict_word_wise: bool,
) -> (Vec<Vec<u8>>, Vec<DictWeight>) {
let mut word_map = HashMap::new();
for (i, word) in words.iter().cloned().enumerate() {
word_map.insert(word, i);
}
let mut new_dict = vec![];
if dict_word_wise {
let mut new_dict_weights = vec![];
for (word, weight) in dict.into_iter().zip(dict_weights) {
let word_size = std::str::from_utf8(&word).unwrap().chars().count();
match word_map.get(&word) {
Some(&idx) if char_window_size >= word_size => {
let start = char_window_size - word_size;
let end = start + word_size;
word_weights[idx][start] += weight.right;
for i in start + 1..end {
word_weights[idx][i] += weight.inner;
}
word_weights[idx][end] += weight.left;
}
_ => {
new_dict.push(word);
new_dict_weights.push(weight);
}
}
}
(new_dict, new_dict_weights)
} else {
for word in dict {
let word_size = std::str::from_utf8(&word).unwrap().chars().count();
match word_map.get(&word) {
Some(&idx) if char_window_size >= word_size => {
let start = char_window_size - word_size;
let end = start + word_size;
let word_size_idx = std::cmp::min(word_size, dict_weights.len()) - 1;
let weight = &dict_weights[word_size_idx];
word_weights[idx][start] += weight.right;
for i in start + 1..end {
word_weights[idx][i] += weight.inner;
}
word_weights[idx][end] += weight.left;
}
_ => new_dict.push(word),
}
}
(new_dict, dict_weights)
}
}

fn merge_weights(words: &[Vec<u8>], weights: &[Vec<ScoreValue>]) -> Vec<Vec<ScoreValue>> {
let mut result = vec![];
let word_ids = words
.iter()
Expand All @@ -95,11 +170,10 @@ impl Predictor {
if let Some(&idx) = word_ids.get(&seq[st..]) {
if let Some(new_weights) = new_weights.as_mut() {
for (w_new, w) in new_weights.iter_mut().zip(&weights[idx]) {
*w_new += *w as ScoreValue;
*w_new += *w;
}
} else {
new_weights
.replace(weights[idx].iter().map(|&w| w as ScoreValue).collect());
new_weights.replace(weights[idx].clone());
}
}
}
Expand Down Expand Up @@ -160,19 +234,19 @@ impl Predictor {
} else {
std::cmp::min(m_end - m_start, self.dict_weights.len()) - 1
};
let [w_right, w_center, w_left] = self.dict_weights[idx];
let dict_weight = self.dict_weights[idx];
if m_start >= padding && m_start < padding + ys.len() {
ys[m_start - padding] += w_right;
ys[m_start - padding] += dict_weight.right;
}
let range_start = std::cmp::max(0, m_start as isize - padding as isize + 1);
let range_end = std::cmp::min(m_end as isize - padding as isize, ys.len() as isize);
if range_start < range_end {
for y in &mut ys[range_start as usize..range_end as usize] {
*y += w_center;
*y += dict_weight.inner;
}
}
if m_end >= padding && m_end < ys.len() + padding {
ys[m_end - padding] += w_left;
ys[m_end - padding] += dict_weight.left;
}
}
}
Expand Down Expand Up @@ -553,9 +627,31 @@ mod tests {
vec![37, 38, 39],
],
#[cfg(not(feature = "model-quantize"))]
dict_weights: vec![[20.0, 20.5, 21.0], [21.5, 22.0, 22.5]],
dict_weights: vec![
DictWeight {
right: 20.0,
inner: 20.5,
left: 21.0,
},
DictWeight {
right: 21.5,
inner: 22.0,
left: 22.5,
},
],
#[cfg(feature = "model-quantize")]
dict_weights: vec![[40, 41, 42], [43, 44, 45]],
dict_weights: vec![
DictWeight {
right: 40,
inner: 41,
left: 42,
},
DictWeight {
right: 43,
inner: 44,
left: 45,
},
],
#[cfg(feature = "model-quantize")]
quantize_multiplier: 0.5,
dict_word_wise: false,
Expand Down Expand Up @@ -640,9 +736,41 @@ mod tests {
vec![33, 34, 35, 36, 37],
],
#[cfg(not(feature = "model-quantize"))]
dict_weights: vec![[9.5, 9.75, 10.0], [10.25, 10.5, 10.75], [11.0, 11.25, 11.5]],
dict_weights: vec![
DictWeight {
right: 9.5,
inner: 9.75,
left: 10.0,
},
DictWeight {
right: 10.25,
inner: 10.5,
left: 10.75,
},
DictWeight {
right: 11.0,
inner: 11.25,
left: 11.5,
},
],
#[cfg(feature = "model-quantize")]
dict_weights: vec![[38, 39, 40], [41, 42, 43], [44, 45, 46]],
dict_weights: vec![
DictWeight {
right: 38,
inner: 39,
left: 40,
},
DictWeight {
right: 41,
inner: 42,
left: 43,
},
DictWeight {
right: 44,
inner: 45,
left: 46,
},
],
#[cfg(feature = "model-quantize")]
quantize_multiplier: 0.25,
dict_word_wise: false,
Expand Down Expand Up @@ -727,9 +855,41 @@ mod tests {
vec![33, 34, 35, 36, 37],
],
#[cfg(not(feature = "model-quantize"))]
dict_weights: vec![[9.5, 9.75, 11.0], [10.25, 10.5, 10.75], [11.0, 11.25, 11.5]],
dict_weights: vec![
DictWeight {
right: 9.5,
inner: 9.75,
left: 11.0,
},
DictWeight {
right: 10.25,
inner: 10.5,
left: 10.75,
},
DictWeight {
right: 11.0,
inner: 11.25,
left: 11.5,
},
],
#[cfg(feature = "model-quantize")]
dict_weights: vec![[38, 39, 40], [41, 42, 43], [44, 45, 46]],
dict_weights: vec![
DictWeight {
right: 38,
inner: 39,
left: 40,
},
DictWeight {
right: 41,
inner: 42,
left: 43,
},
DictWeight {
right: 44,
inner: 45,
left: 46,
},
],
#[cfg(feature = "model-quantize")]
quantize_multiplier: 0.25,
dict_word_wise: true,
Expand Down
1 change: 1 addition & 0 deletions vaporetto/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl StringIdManager {
}

#[cfg(test)]
#[allow(unused_macros)]
macro_rules! ct2u8 {
( $( $v:path ),* ) => {
ct2u8!( $( $v, )* )
Expand Down

0 comments on commit 9c8318e

Please sign in to comment.