diff --git a/src/translator.rs b/src/translator.rs index 8e3c650..ead3bfa 100644 --- a/src/translator.rs +++ b/src/translator.rs @@ -648,7 +648,6 @@ mod tests { } #[test] - #[ignore = "Not implemented properly"] fn begnum_test() { let rules = vec![ parse_rule("digit 1 1"), @@ -662,7 +661,6 @@ mod tests { } #[test] - #[ignore = "Not implemented properly"] fn endnum_test() { let rules = vec![ parse_rule("digit 1 1"), diff --git a/src/translator/trie.rs b/src/translator/trie.rs index e41ac07..ca251ee 100644 --- a/src/translator/trie.rs +++ b/src/translator/trie.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::translator::boundaries::{word_end, word_start}; +use crate::translator::boundaries::{word_end, word_start, word_number, number_word}; use super::Translation; @@ -56,6 +56,12 @@ impl TrieNode { fn not_word_end_transition(&self) -> Option<&TrieNode> { self.transitions.get(&Transition::End(Boundary::NotWord)) } + fn word_num_transition(&self) -> Option<&TrieNode> { + self.transitions.get(&Transition::End(Boundary::WordNumber)) + } + fn num_word_transition(&self) -> Option<&TrieNode> { + self.transitions.get(&Transition::Start(Boundary::NumberWord)) + } } #[derive(Default, Debug)] @@ -109,81 +115,74 @@ impl Trie { } fn find_translations_from_node<'a>( - &'a self, + &self, input: &str, + prev: Option, node: &'a TrieNode, ) -> Vec<&'a Translation> { - let mut current_node = node; let mut matching_rules = Vec::new(); - let mut prev: Option = None; let mut chars = input.chars(); - while let Some(c) = chars.next() { - if let Some(node) = current_node.char_transition(c) { - current_node = node; - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } - } else if let Some(node) = current_node.char_case_insensitive_transition(c) { - current_node = node; - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } - } else if let Some(node) = current_node.word_end_transition() { - current_node = node; - if word_end(prev, Some(c)) { - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } - } - } else if let Some(node) = current_node.not_word_end_transition() { - current_node = node; - if !word_end(prev, Some(c)) { - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } - } - } else { - prev = Some(c); - break; + // if this node has a translation add it to the list of matching rules + if let Some(ref translation) = node.translation { + matching_rules.push(translation) + } + let c = chars.next(); + if let Some(c) = c { + let bytes = c.len_utf8(); + if let Some(node) = node.char_transition(c) { + matching_rules.extend(self.find_translations_from_node( + &input[bytes..], + Some(c), + node, + )); + } + if let Some(node) = node.char_case_insensitive_transition(c) { + matching_rules.extend(self.find_translations_from_node( + &input[bytes..], + Some(c), + node, + )); } - prev = Some(c); } - // at this point we have either - // - exhausted the input (chars.next() is None) or - // - exhausted the trie (current_node has no applicable transitions) - // TODO: assert this invariant (how can we do this without the - // side-effecting chars.next()?) - if let Some(node) = current_node.word_end_transition() { - if word_end(prev, chars.next()) { - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } + if let Some(node) = node.word_start_transition() { + if word_start(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); } - } else if let Some(node) = current_node.not_word_end_transition() { - if !word_end(prev, chars.next()) { - if let Some(ref translation) = node.translation { - matching_rules.push(translation) - } + } + if let Some(node) = node.not_word_start_transition() { + if !word_start(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); + } + } + if let Some(node) = node.word_end_transition() { + if word_end(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); + } + } + if let Some(node) = node.not_word_end_transition() { + if !word_end(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); + } + } + if let Some(node) = node.word_num_transition() { + if word_number(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); + } + } + if let Some(node) = node.num_word_transition() { + dbg!(prev, c); + if number_word(prev, c) { + matching_rules.extend(self.find_translations_from_node(&input[..], prev, node)); } } matching_rules } - pub fn find_translations(&self, input: &str, before: Option) -> Vec<&Translation> { + pub fn find_translations(&self, input: &str, prev: Option) -> Vec<&Translation> { let mut matching_rules = Vec::new(); - if word_start(before, input.chars().next()) { - if let Some(node) = self.root.word_start_transition() { - matching_rules = self.find_translations_from_node(input, node); - } - } else { - if let Some(node) = self.root.not_word_start_transition() { - matching_rules = self.find_translations_from_node(input, node); - } - } - - matching_rules.extend(self.find_translations_from_node(input, &self.root)); + matching_rules.extend(self.find_translations_from_node(input, prev, dbg!(&self.root))); matching_rules.sort_by_key(|translation| translation.weight); matching_rules } @@ -300,6 +299,41 @@ mod tests { assert_eq!(trie.find_translations("foobar", Some('c')), vec![&foo]); } + #[test] + fn find_translations_with_word_num_boundary() { + let mut trie = Trie::new(); + let empty = Vec::<&Translation>::new(); + let foo = Translation::new("aaa".into(), "A".into(), 5); + trie.insert( + "aaa".into(), + "A".into(), + Boundary::Word, + Boundary::WordNumber, + ); + assert_eq!(trie.find_translations("aaa", None), empty); + assert_eq!(trie.find_translations("aaa1", Some(' ')), vec![&foo]); + assert_eq!(trie.find_translations("aaa1", Some('.')), vec![&foo]); + assert_eq!(trie.find_translations("aaa1", Some('c')), empty); + } + + #[test] + fn find_translations_with_num_word_boundary() { + let mut trie = Trie::new(); + let empty = Vec::<&Translation>::new(); + let foo = Translation::new("st".into(), "S".into(), 4); + trie.insert( + "st".into(), + "S".into(), + Boundary::NumberWord, + Boundary::Word, + ); + assert_eq!(trie.find_translations("st", None), empty); + assert_eq!(trie.find_translations("st", Some(' ')), empty); + assert_eq!(trie.find_translations("st", Some('.')), empty); + assert_eq!(trie.find_translations("st", Some('1')), vec![&foo]); + assert_eq!(trie.find_translations("sta", Some('1')), empty); + } + #[test] fn find_translations_case_insensitive_test() { let mut trie = Trie::new();