Skip to content

Commit

Permalink
Add cache array for type weights (#9)
Browse files Browse the repository at this point in the history
* make type_predictor

* add score caching

* fix bug for partial predict

* fix with clippy

* move mod declarative

* rm redundant add_type_ngram_scores

* rename

* fix error

Co-authored-by: Koichi Akabe <[email protected]>
  • Loading branch information
kampersanda and vbkaisetsu authored Nov 1, 2021
1 parent 40ab049 commit 514564b
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 34 deletions.
1 change: 1 addition & 0 deletions vaporetto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod utils;
mod model;
mod predictor;
mod sentence;
mod type_scorer;

#[cfg(feature = "train")]
mod feature;
Expand Down
44 changes: 10 additions & 34 deletions vaporetto/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ use crossbeam_channel::{Receiver, Sender};

use crate::model::{Model, ScoreValue, WeightValue};
use crate::sentence::{BoundaryType, Sentence};
use crate::type_scorer::TypeScorer;

use daachorse::DoubleArrayAhoCorasick;

/// Predictor.
pub struct Predictor {
word_pma: DoubleArrayAhoCorasick,
type_pma: DoubleArrayAhoCorasick,
dict_pma: DoubleArrayAhoCorasick,
word_weights: Vec<Vec<ScoreValue>>,
type_weights: Vec<Vec<ScoreValue>>,
dict_weights: Vec<[ScoreValue; 3]>,
dict_word_wise: bool,
bias: ScoreValue,
char_window_size: usize,
type_window_size: usize,
dict_window_size: usize,

type_scorer: TypeScorer,

#[cfg(feature = "model-quantize")]
quantize_multiplier: f64,
}
Expand Down Expand Up @@ -60,19 +61,21 @@ impl Predictor {
let word_pma = DoubleArrayAhoCorasick::new(model.words).unwrap();
let type_pma = DoubleArrayAhoCorasick::new(model.types).unwrap();
let dict_pma = DoubleArrayAhoCorasick::new(model.dict).unwrap();

let type_scorer = TypeScorer::new(type_pma, type_weights, model.type_window_size);

Self {
word_pma,
type_pma,
dict_pma,
word_weights,
type_weights,
dict_weights,
dict_word_wise: model.dict_word_wise,
bias,
char_window_size: model.char_window_size,
type_window_size: model.type_window_size,
dict_window_size: 1,

type_scorer,

#[cfg(feature = "model-quantize")]
quantize_multiplier: model.quantize_multiplier,
}
Expand Down Expand Up @@ -135,33 +138,6 @@ impl Predictor {
}
}

fn add_type_ngram_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) {
let type_start = if start >= self.type_window_size {
start + 1 - self.type_window_size
} else {
0
};
let type_end = std::cmp::min(
start + ys.len() + self.type_window_size,
sentence.char_type.len(),
);
let char_type = &sentence.char_type[type_start..type_end];
let padding = start - type_start + 1;
for m in self.type_pma.find_overlapping_no_suffix_iter(&char_type) {
let offset = m.end() as isize - self.type_window_size as isize - padding as isize;
let weights = &self.type_weights[m.pattern()];
if offset >= 0 {
for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) {
*y += w;
}
} else {
for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) {
*y += w;
}
}
}
}

fn add_dict_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) {
let char_start = if start >= self.dict_window_size {
start + 1 - self.dict_window_size
Expand Down Expand Up @@ -209,7 +185,7 @@ impl Predictor {
) {
ys.fill(self.bias);
self.add_word_ngram_scores(sentence, range.start, ys);
self.add_type_ngram_scores(sentence, range.start, ys);
self.type_scorer.add_scores(sentence, range.start, ys);
self.add_dict_scores(sentence, range.start, ys);
}

Expand Down
205 changes: 205 additions & 0 deletions vaporetto/src/type_scorer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use crate::model::ScoreValue;
use crate::sentence::Sentence;
use daachorse::DoubleArrayAhoCorasick;

pub enum TypeScorer {
Pma(TypeScorerPma),
Cache(TypeScorerCache),
}

impl TypeScorer {
pub fn new(
pma: DoubleArrayAhoCorasick,
weights: Vec<Vec<ScoreValue>>,
window_size: usize,
) -> Self {
if window_size <= 3 {
Self::Cache(TypeScorerCache::new(pma, weights, window_size))
} else {
Self::Pma(TypeScorerPma::new(pma, weights, window_size))
}
}

pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) {
match self {
TypeScorer::Pma(pma) => pma.add_scores(sentence, start, ys),
TypeScorer::Cache(cache) => cache.add_scores(sentence, start, ys),
}
}
}

pub struct TypeScorerPma {
pma: DoubleArrayAhoCorasick,
weights: Vec<Vec<ScoreValue>>,
window_size: usize,
}

impl TypeScorerPma {
pub fn new(
pma: DoubleArrayAhoCorasick,
weights: Vec<Vec<ScoreValue>>,
window_size: usize,
) -> Self {
Self {
pma,
weights,
window_size,
}
}

pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) {
let type_start = if start >= self.window_size {
start + 1 - self.window_size
} else {
0
};
let type_end = std::cmp::min(
start + ys.len() + self.window_size,
sentence.char_type.len(),
);
let char_type = &sentence.char_type[type_start..type_end];
let padding = start - type_start + 1;
for m in self.pma.find_overlapping_no_suffix_iter(&char_type) {
let offset = m.end() as isize - self.window_size as isize - padding as isize;
let weights = &self.weights[m.pattern()];
if offset >= 0 {
for (w, y) in weights.iter().zip(&mut ys[offset as usize..]) {
*y += w;
}
} else {
for (w, y) in weights[-offset as usize..].iter().zip(ys.iter_mut()) {
*y += w;
}
}
}
}
}

pub struct TypeScorerCache {
scores: Vec<ScoreValue>,
window_size: usize,
sequence_mask: usize,
}

impl TypeScorerCache {
pub fn new(
pma: DoubleArrayAhoCorasick,
weights: Vec<Vec<ScoreValue>>,
window_size: usize,
) -> Self {
let sequence_size = window_size * 2;
let all_sequences = ALPHABET_SIZE.pow(sequence_size as u32);

let mut sequence = vec![0u8; sequence_size];
let mut scores = vec![0 as ScoreValue; all_sequences];

for (i, score) in scores.iter_mut().enumerate() {
if !Self::seqid_to_seq(i, &mut sequence) {
continue;
}
let mut y = ScoreValue::default();
for m in pma.find_overlapping_no_suffix_iter(&sequence) {
y += weights[m.pattern()][sequence_size - m.end()];
}
*score = y;
}

Self {
scores,
window_size,
sequence_mask: (1 << (ALPHABET_SHIFT * sequence_size)) - 1,
}
}

pub fn add_scores(&self, sentence: &Sentence, start: usize, ys: &mut [ScoreValue]) {
let type_start = if start >= self.window_size {
start + 1 - self.window_size
} else {
0
};
let type_end = std::cmp::min(
start + ys.len() + self.window_size,
sentence.char_type.len(),
);
let char_type = &sentence.char_type[type_start..type_end];
let offset = self.window_size + start;
let mut seqid = 0;
for i in 0..offset {
if let Some(ct) = char_type.get(i) {
seqid = self.increment_seqid(seqid, *ct);
} else {
seqid = self.increment_seqid_without_char(seqid);
};
}
for (i, y) in ys.iter_mut().enumerate() {
if let Some(ct) = char_type.get(i + offset) {
seqid = self.increment_seqid(seqid, *ct);
} else {
seqid = self.increment_seqid_without_char(seqid);
};
*y += self.get_score(seqid);
}
}

fn seqid_to_seq(mut seqid: usize, sequence: &mut [u8]) -> bool {
for i in (0..sequence.len()).rev() {
let x = seqid & ALPHABET_MASK;
if x == ALPHABET_MASK {
return false; // invalid
}
sequence[i] = ID_TO_TYPE[x];
seqid >>= ALPHABET_SHIFT;
}
assert_eq!(seqid, 0);
true
}

#[inline(always)]
fn get_score(&self, seqid: usize) -> ScoreValue {
self.scores[seqid]
}

#[inline(always)]
fn increment_seqid(&self, seqid: usize, char_type: u8) -> usize {
let char_id = TYPE_TO_ID[char_type as usize] as usize;
debug_assert!((1..=6).contains(&char_id));
((seqid << ALPHABET_SHIFT) | char_id) & self.sequence_mask
}

#[inline(always)]
const fn increment_seqid_without_char(&self, seqid: usize) -> usize {
(seqid << ALPHABET_SHIFT) & self.sequence_mask
}
}

const ALPHABET_SIZE: usize = 8;
const ALPHABET_MASK: usize = ALPHABET_SIZE - 1;
const ALPHABET_SHIFT: usize = 3;
const TYPE_TO_ID: [u32; 256] = make_type_to_id();
const ID_TO_TYPE: [u8; 256] = make_id_to_type();

const fn make_type_to_id() -> [u32; 256] {
use crate::sentence::CharacterType::*;

let mut type_to_id = [0u32; 256];
type_to_id[Digit as usize] = 1;
type_to_id[Roman as usize] = 2;
type_to_id[Hiragana as usize] = 3;
type_to_id[Katakana as usize] = 4;
type_to_id[Kanji as usize] = 5;
type_to_id[Other as usize] = 6;
type_to_id
}

const fn make_id_to_type() -> [u8; 256] {
use crate::sentence::CharacterType::*;

let mut id_to_type = [0u8; 256];
id_to_type[1] = Digit as u8;
id_to_type[2] = Roman as u8;
id_to_type[3] = Hiragana as u8;
id_to_type[4] = Katakana as u8;
id_to_type[5] = Kanji as u8;
id_to_type[6] = Other as u8;
id_to_type
}

0 comments on commit 514564b

Please sign in to comment.