Skip to content

Commit 01affef

Browse files
committed
feat(user_dictionary): predict word
1 parent 729aa62 commit 01affef

File tree

3 files changed

+70
-11
lines changed

3 files changed

+70
-11
lines changed

src/rime/dict/user_dictionary.cc

+60-8
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
#include <rime/ticket.h>
1717
#include <rime/algo/dynamics.h>
1818
#include <rime/algo/syllabifier.h>
19+
#include <rime/algo/strings.h>
1920
#include <rime/dict/db.h>
2021
#include <rime/dict/table.h>
2122
#include <rime/dict/user_dictionary.h>
23+
#include <rime/dict/vocabulary.h>
2224

2325
namespace rime {
2426

2527
struct DfsState {
2628
size_t depth_limit;
29+
size_t predict_word_from_depth;
2730
TickCount present_tick;
2831
Code code;
2932
vector<double> credibility;
@@ -32,13 +35,15 @@ struct DfsState {
3235
string key;
3336
string value;
3437

38+
size_t depth() const { return code.size(); }
39+
3540
bool IsExactMatch(const string& prefix) {
3641
return boost::starts_with(key, prefix + '\t');
3742
}
3843
bool IsPrefixMatch(const string& prefix) {
3944
return boost::starts_with(key, prefix);
4045
}
41-
void RecruitEntry(size_t pos);
46+
void RecruitEntry(size_t pos, map<string, SyllableId>* syllabary = nullptr);
4247
bool NextEntry() {
4348
if (!accessor->GetNextRecord(&key, &value)) {
4449
key.clear();
@@ -63,11 +68,30 @@ struct DfsState {
6368
}
6469
};
6570

66-
void DfsState::RecruitEntry(size_t pos) {
71+
void DfsState::RecruitEntry(size_t pos, map<string, SyllableId>* syllabary) {
72+
string full_code;
6773
auto e = UserDictionary::CreateDictEntry(key, value, present_tick,
68-
credibility.back());
74+
credibility.back(),
75+
syllabary ? &full_code : nullptr);
6976
if (e) {
70-
e->code = code;
77+
if (syllabary) {
78+
vector<string> syllables =
79+
strings::split(full_code, " ", strings::SplitBehavior::SkipToken);
80+
Code numeric_code;
81+
for (auto s = syllables.begin(); s != syllables.end(); ++s) {
82+
auto found = syllabary->find(*s);
83+
if (found == syllabary->end()) {
84+
LOG(ERROR) << "failed to recruit dict entry '" << e->text
85+
<< "', unrecognized syllable: " << *s;
86+
return;
87+
}
88+
numeric_code.push_back(found->second);
89+
}
90+
e->code = numeric_code;
91+
e->matching_code_size = code.size();
92+
} else {
93+
e->code = code;
94+
}
7195
DLOG(INFO) << "add entry at pos " << pos;
7296
query_result[pos].push_back(e);
7397
}
@@ -230,10 +254,36 @@ void UserDictionary::DfsLookup(const SyllableGraph& syll_graph,
230254
if (!state->NextEntry()) // reached the end of db
231255
break;
232256
}
233-
// the caller can limit the number of syllables to look up
234-
if ((!state->depth_limit || state->code.size() < state->depth_limit) &&
235-
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
236-
DfsLookup(syll_graph, end_pos, prefix, state);
257+
auto next_index = syll_graph.indices.find(end_pos);
258+
if (next_index == syll_graph.indices.end()) {
259+
// reached the end of input, predict word if requested
260+
if (state->predict_word_from_depth != 0 &&
261+
state->depth() >= state->predict_word_from_depth) {
262+
while (state->IsPrefixMatch(prefix)) {
263+
DLOG(INFO) << "prefix match found for '" << prefix << "'.";
264+
if (syllabary_.empty()) {
265+
Syllabary syllabary;
266+
if (!table_->GetSyllabary(&syllabary)) {
267+
LOG(ERROR) << "failed to get syllabary for user dict: "
268+
<< name();
269+
break;
270+
}
271+
SyllableId syllable_id = 0;
272+
for (auto s = syllabary.begin(); s != syllabary.end(); ++s) {
273+
syllabary_[*s] = syllable_id++;
274+
}
275+
}
276+
state->RecruitEntry(end_pos, &syllabary_);
277+
if (!state->NextEntry()) // reached the end of db
278+
break;
279+
}
280+
}
281+
} else {
282+
// the caller can limit the number of syllables to look up
283+
if ((!state->depth_limit || state->depth() < state->depth_limit) &&
284+
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
285+
DfsLookup(syll_graph, end_pos, prefix, state);
286+
}
237287
}
238288
}
239289
if (!state->IsPrefixMatch(current_prefix)) // 'b |' vs. 'g o \tGo'
@@ -254,12 +304,14 @@ an<UserDictEntryCollector> UserDictionary::Lookup(
254304
const SyllableGraph& syll_graph,
255305
size_t start_pos,
256306
size_t depth_limit,
307+
size_t predict_word_from_depth,
257308
double initial_credibility) {
258309
if (!table_ || !prism_ || !loaded() ||
259310
start_pos >= syll_graph.interpreted_length)
260311
return nullptr;
261312
DfsState state;
262313
state.depth_limit = depth_limit;
314+
state.predict_word_from_depth = predict_word_from_depth;
263315
FetchTickCount();
264316
state.present_tick = tick_ + 1;
265317
state.credibility.push_back(initial_credibility);

src/rime/dict/user_dictionary.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
5959
an<UserDictEntryCollector> Lookup(const SyllableGraph& syllable_graph,
6060
size_t start_pos,
6161
size_t depth_limit = 0,
62+
size_t predict_word_from_depth = 0,
6263
double initial_credibility = 0.0);
6364
size_t LookupWords(UserDictEntryIterator* result,
6465
const string& input,
@@ -82,7 +83,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
8283
const string& value,
8384
TickCount present_tick,
8485
double credibility = 0.0,
85-
string* full_code = NULL);
86+
string* full_code = nullptr);
8687

8788
protected:
8889
bool Initialize();
@@ -98,6 +99,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
9899
an<Db> db_;
99100
an<Table> table_;
100101
an<Prism> prism_;
102+
map<string, SyllableId> syllabary_;
101103
TickCount tick_ = 0;
102104
time_t transaction_time_ = 0;
103105
};

src/rime/gear/script_translator.cc

+7-2
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,11 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
356356

357357
phrase_ = dict->Lookup(syllable_graph, 0, predict_word);
358358
if (user_dict) {
359-
user_phrase_ = user_dict->Lookup(syllable_graph, 0);
359+
const size_t kUnlimitedDepth = 0;
360+
const size_t kNumSyllablesToPredictWord = 4;
361+
user_phrase_ =
362+
user_dict->Lookup(syllable_graph, 0, kUnlimitedDepth,
363+
predict_word ? kNumSyllablesToPredictWord : 0);
360364
}
361365
if (!phrase_ && !user_phrase_)
362366
return false;
@@ -371,7 +375,8 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) {
371375
phrase_ && phrase_iter_->first == consumed &&
372376
is_exact_match_phrase(phrase_iter_->second.Peek());
373377
bool has_exact_match_user_phrase =
374-
user_phrase_ && user_phrase_iter_->first == consumed;
378+
user_phrase_ && user_phrase_iter_->first == consumed &&
379+
is_exact_match_phrase(user_phrase_iter_->second.Peek());
375380
bool has_at_least_two_syllables = syllable_graph.edges.size() >= 2;
376381
if (!has_exact_match_phrase && !has_exact_match_user_phrase &&
377382
has_at_least_two_syllables) {

0 commit comments

Comments
 (0)