Skip to content

Commit

Permalink
Move back to a simple recursive graph traversal to find translations
Browse files Browse the repository at this point in the history
That way begnum and endnum finally work

OTOH simple benchmarks seem to indicate that the traversal is quite a
bit slower. Running all yaml tests now takes 6 instead of 4 secs
  • Loading branch information
egli committed Dec 6, 2024
1 parent 27aea84 commit 16c57a3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 62 deletions.
2 changes: 0 additions & 2 deletions src/translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ mod tests {
}

#[test]
#[ignore = "Not implemented properly"]
fn begnum_test() {
let rules = vec![
parse_rule("digit 1 1"),
Expand All @@ -662,7 +661,6 @@ mod tests {
}

#[test]
#[ignore = "Not implemented properly"]
fn endnum_test() {
let rules = vec![
parse_rule("digit 1 1"),
Expand Down
154 changes: 94 additions & 60 deletions src/translator/trie.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use crate::translator::boundaries::{word_end, word_start};
use crate::translator::boundaries::{word_end, word_start, word_number, number_word};

use super::Translation;

Expand Down Expand Up @@ -56,6 +56,12 @@ impl TrieNode {
fn not_word_end_transition(&self) -> Option<&TrieNode> {
self.transitions.get(&Transition::End(Boundary::NotWord))
}
fn word_num_transition(&self) -> Option<&TrieNode> {
self.transitions.get(&Transition::End(Boundary::WordNumber))
}
fn num_word_transition(&self) -> Option<&TrieNode> {
self.transitions.get(&Transition::Start(Boundary::NumberWord))
}
}

#[derive(Default, Debug)]
Expand Down Expand Up @@ -109,81 +115,74 @@ impl Trie {
}

fn find_translations_from_node<'a>(
&'a self,
&self,
input: &str,
prev: Option<char>,
node: &'a TrieNode,
) -> Vec<&'a Translation> {
let mut current_node = node;
let mut matching_rules = Vec::new();
let mut prev: Option<char> = None;
let mut chars = input.chars();

while let Some(c) = chars.next() {
if let Some(node) = current_node.char_transition(c) {
current_node = node;
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
} else if let Some(node) = current_node.char_case_insensitive_transition(c) {
current_node = node;
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
} else if let Some(node) = current_node.word_end_transition() {
current_node = node;
if word_end(prev, Some(c)) {
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
}
} else if let Some(node) = current_node.not_word_end_transition() {
current_node = node;
if !word_end(prev, Some(c)) {
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
}
} else {
prev = Some(c);
break;
// if this node has a translation add it to the list of matching rules
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
let c = chars.next();
if let Some(c) = c {
let bytes = c.len_utf8();
if let Some(node) = node.char_transition(c) {
matching_rules.extend(self.find_translations_from_node(
&input[bytes..],
Some(c),
node,
));
}
if let Some(node) = node.char_case_insensitive_transition(c) {
matching_rules.extend(self.find_translations_from_node(
&input[bytes..],
Some(c),
node,
));
}
prev = Some(c);
}
// at this point we have either
// - exhausted the input (chars.next() is None) or
// - exhausted the trie (current_node has no applicable transitions)
// TODO: assert this invariant (how can we do this without the
// side-effecting chars.next()?)
if let Some(node) = current_node.word_end_transition() {
if word_end(prev, chars.next()) {
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
if let Some(node) = node.word_start_transition() {
if word_start(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
} else if let Some(node) = current_node.not_word_end_transition() {
if !word_end(prev, chars.next()) {
if let Some(ref translation) = node.translation {
matching_rules.push(translation)
}
}
if let Some(node) = node.not_word_start_transition() {
if !word_start(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
}
if let Some(node) = node.word_end_transition() {
if word_end(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
}
if let Some(node) = node.not_word_end_transition() {
if !word_end(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
}
if let Some(node) = node.word_num_transition() {
if word_number(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
}
if let Some(node) = node.num_word_transition() {
dbg!(prev, c);
if number_word(prev, c) {
matching_rules.extend(self.find_translations_from_node(&input[..], prev, node));
}
}
matching_rules
}

pub fn find_translations(&self, input: &str, before: Option<char>) -> Vec<&Translation> {
pub fn find_translations(&self, input: &str, prev: Option<char>) -> Vec<&Translation> {
let mut matching_rules = Vec::new();

if word_start(before, input.chars().next()) {
if let Some(node) = self.root.word_start_transition() {
matching_rules = self.find_translations_from_node(input, node);
}
} else {
if let Some(node) = self.root.not_word_start_transition() {
matching_rules = self.find_translations_from_node(input, node);
}
}

matching_rules.extend(self.find_translations_from_node(input, &self.root));
matching_rules.extend(self.find_translations_from_node(input, prev, dbg!(&self.root)));
matching_rules.sort_by_key(|translation| translation.weight);
matching_rules
}
Expand Down Expand Up @@ -300,6 +299,41 @@ mod tests {
assert_eq!(trie.find_translations("foobar", Some('c')), vec![&foo]);
}

#[test]
fn find_translations_with_word_num_boundary() {
let mut trie = Trie::new();
let empty = Vec::<&Translation>::new();
let foo = Translation::new("aaa".into(), "A".into(), 5);
trie.insert(
"aaa".into(),
"A".into(),
Boundary::Word,
Boundary::WordNumber,
);
assert_eq!(trie.find_translations("aaa", None), empty);
assert_eq!(trie.find_translations("aaa1", Some(' ')), vec![&foo]);
assert_eq!(trie.find_translations("aaa1", Some('.')), vec![&foo]);
assert_eq!(trie.find_translations("aaa1", Some('c')), empty);
}

#[test]
fn find_translations_with_num_word_boundary() {
let mut trie = Trie::new();
let empty = Vec::<&Translation>::new();
let foo = Translation::new("st".into(), "S".into(), 4);
trie.insert(
"st".into(),
"S".into(),
Boundary::NumberWord,
Boundary::Word,
);
assert_eq!(trie.find_translations("st", None), empty);
assert_eq!(trie.find_translations("st", Some(' ')), empty);
assert_eq!(trie.find_translations("st", Some('.')), empty);
assert_eq!(trie.find_translations("st", Some('1')), vec![&foo]);
assert_eq!(trie.find_translations("sta", Some('1')), empty);
}

#[test]
fn find_translations_case_insensitive_test() {
let mut trie = Trie::new();
Expand Down

0 comments on commit 16c57a3

Please sign in to comment.