Skip to content

Commit

Permalink
feat(user_dictionary): predict word
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Mar 3, 2024
1 parent 729aa62 commit 01affef
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 11 deletions.
68 changes: 60 additions & 8 deletions src/rime/dict/user_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
#include <rime/ticket.h>
#include <rime/algo/dynamics.h>
#include <rime/algo/syllabifier.h>
#include <rime/algo/strings.h>
#include <rime/dict/db.h>
#include <rime/dict/table.h>
#include <rime/dict/user_dictionary.h>
#include <rime/dict/vocabulary.h>

namespace rime {

struct DfsState {
size_t depth_limit;
size_t predict_word_from_depth;
TickCount present_tick;
Code code;
vector<double> credibility;
Expand All @@ -32,13 +35,15 @@ struct DfsState {
string key;
string value;

size_t depth() const { return code.size(); }

bool IsExactMatch(const string& prefix) {
return boost::starts_with(key, prefix + '\t');
}
bool IsPrefixMatch(const string& prefix) {
return boost::starts_with(key, prefix);
}
void RecruitEntry(size_t pos);
void RecruitEntry(size_t pos, map<string, SyllableId>* syllabary = nullptr);
bool NextEntry() {
if (!accessor->GetNextRecord(&key, &value)) {
key.clear();
Expand All @@ -63,11 +68,30 @@ struct DfsState {
}
};

void DfsState::RecruitEntry(size_t pos) {
void DfsState::RecruitEntry(size_t pos, map<string, SyllableId>* syllabary) {
string full_code;
auto e = UserDictionary::CreateDictEntry(key, value, present_tick,
credibility.back());
credibility.back(),
syllabary ? &full_code : nullptr);
if (e) {
e->code = code;
if (syllabary) {
vector<string> syllables =
strings::split(full_code, " ", strings::SplitBehavior::SkipToken);
Code numeric_code;
for (auto s = syllables.begin(); s != syllables.end(); ++s) {
auto found = syllabary->find(*s);
if (found == syllabary->end()) {
LOG(ERROR) << "failed to recruit dict entry '" << e->text
<< "', unrecognized syllable: " << *s;
return;
}
numeric_code.push_back(found->second);
}
e->code = numeric_code;
e->matching_code_size = code.size();
} else {
e->code = code;
}
DLOG(INFO) << "add entry at pos " << pos;
query_result[pos].push_back(e);
}
Expand Down Expand Up @@ -230,10 +254,36 @@ void UserDictionary::DfsLookup(const SyllableGraph& syll_graph,
if (!state->NextEntry()) // reached the end of db
break;
}
// the caller can limit the number of syllables to look up
if ((!state->depth_limit || state->code.size() < state->depth_limit) &&
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
DfsLookup(syll_graph, end_pos, prefix, state);
auto next_index = syll_graph.indices.find(end_pos);
if (next_index == syll_graph.indices.end()) {
// reached the end of input, predict word if requested
if (state->predict_word_from_depth != 0 &&
state->depth() >= state->predict_word_from_depth) {
while (state->IsPrefixMatch(prefix)) {
DLOG(INFO) << "prefix match found for '" << prefix << "'.";
if (syllabary_.empty()) {
Syllabary syllabary;
if (!table_->GetSyllabary(&syllabary)) {
LOG(ERROR) << "failed to get syllabary for user dict: "
<< name();
break;
}
SyllableId syllable_id = 0;
for (auto s = syllabary.begin(); s != syllabary.end(); ++s) {
syllabary_[*s] = syllable_id++;
}
}
state->RecruitEntry(end_pos, &syllabary_);
if (!state->NextEntry()) // reached the end of db
break;
}
}
} else {
// the caller can limit the number of syllables to look up
if ((!state->depth_limit || state->depth() < state->depth_limit) &&
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
DfsLookup(syll_graph, end_pos, prefix, state);
}
}
}
if (!state->IsPrefixMatch(current_prefix)) // 'b |' vs. 'g o \tGo'
Expand All @@ -254,12 +304,14 @@ an<UserDictEntryCollector> UserDictionary::Lookup(
const SyllableGraph& syll_graph,
size_t start_pos,
size_t depth_limit,
size_t predict_word_from_depth,
double initial_credibility) {
if (!table_ || !prism_ || !loaded() ||
start_pos >= syll_graph.interpreted_length)
return nullptr;
DfsState state;
state.depth_limit = depth_limit;
state.predict_word_from_depth = predict_word_from_depth;
FetchTickCount();
state.present_tick = tick_ + 1;
state.credibility.push_back(initial_credibility);
Expand Down
4 changes: 3 additions & 1 deletion src/rime/dict/user_dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
an<UserDictEntryCollector> Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
size_t depth_limit = 0,
size_t predict_word_from_depth = 0,
double initial_credibility = 0.0);
size_t LookupWords(UserDictEntryIterator* result,
const string& input,
Expand All @@ -82,7 +83,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
const string& value,
TickCount present_tick,
double credibility = 0.0,
string* full_code = NULL);
string* full_code = nullptr);

protected:
bool Initialize();
Expand All @@ -98,6 +99,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
an<Db> db_;
an<Table> table_;
an<Prism> prism_;
map<string, SyllableId> syllabary_;
TickCount tick_ = 0;
time_t transaction_time_ = 0;
};
Expand Down
9 changes: 7 additions & 2 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,11 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {

phrase_ = dict->Lookup(syllable_graph, 0, predict_word);
if (user_dict) {
user_phrase_ = user_dict->Lookup(syllable_graph, 0);
const size_t kUnlimitedDepth = 0;
const size_t kNumSyllablesToPredictWord = 4;
user_phrase_ =
user_dict->Lookup(syllable_graph, 0, kUnlimitedDepth,
predict_word ? kNumSyllablesToPredictWord : 0);
}
if (!phrase_ && !user_phrase_)
return false;
Expand All @@ -371,7 +375,8 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
phrase_ && phrase_iter_->first == consumed &&
is_exact_match_phrase(phrase_iter_->second.Peek());
bool has_exact_match_user_phrase =
user_phrase_ && user_phrase_iter_->first == consumed;
user_phrase_ && user_phrase_iter_->first == consumed &&
is_exact_match_phrase(user_phrase_iter_->second.Peek());
bool has_at_least_two_syllables = syllable_graph.edges.size() >= 2;
if (!has_exact_match_phrase && !has_exact_match_user_phrase &&
has_at_least_two_syllables) {
Expand Down

0 comments on commit 01affef

Please sign in to comment.