Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoJoiner 개선 #152

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/kiwi/Joiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace kiwi
template<class LmState> friend struct Candidate;
const CompiledRule* cr = nullptr;
KString stack;
std::vector<std::pair<uint32_t, uint32_t>> ranges;
size_t activeStart = 0;
POSTag lastTag = POSTag::unknown, anteLastTag = POSTag::unknown;

Expand All @@ -45,8 +46,8 @@ namespace kiwi
void add(const std::u16string& form, POSTag tag, Space space = Space::none);
void add(const char16_t* form, POSTag tag, Space space = Space::none);

std::u16string getU16() const;
std::string getU8() const;
std::u16string getU16(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
std::string getU8(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
};

template<class LmState>
Expand Down Expand Up @@ -115,8 +116,8 @@ namespace kiwi
void add(const std::u16string& form, POSTag tag, bool inferRegularity = true, Space space = Space::none);
void add(const char16_t* form, POSTag tag, bool inferRegularity = true, Space space = Space::none);

std::u16string getU16() const;
std::string getU8() const;
std::u16string getU16(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
std::string getU8(std::vector<std::pair<uint32_t, uint32_t>>* rangesOut = nullptr) const;
};
}
}
25 changes: 25 additions & 0 deletions include/kiwi/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@ namespace kiwi
return ret;
}

template<class It, class Ty, class Alloc>
inline std::u16string joinHangul(It first, It last, std::vector<Ty, Alloc>& positionOut)
{
std::u16string ret;
ret.reserve(std::distance(first, last));
positionOut.clear();
positionOut.reserve(std::distance(first, last));
for (; first != last; ++first)
{
auto c = *first;
if (isHangulCoda(c) && !ret.empty() && isHangulSyllable(ret.back()))
{
if ((ret.back() - 0xAC00) % 28) ret.push_back(c);
else ret.back() += c - 0x11A7;
positionOut.emplace_back(ret.size() - 1);
}
else
{
ret.push_back(c);
positionOut.emplace_back(ret.size() - 1);
}
}
return ret;
}

inline std::u16string joinHangul(const KString& hangul)
{
return joinHangul(hangul.begin(), hangul.end());
Expand Down
12 changes: 6 additions & 6 deletions src/Combiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ Vector<KString> CompiledRule::combineImpl(
return ret;
}

pair<KString, size_t> CompiledRule::combineOneImpl(
tuple<KString, size_t, size_t> CompiledRule::combineOneImpl(
U16StringView leftForm, POSTag leftTag,
U16StringView rightForm, POSTag rightTag,
CondVowel cv, CondPolarity cp
Expand All @@ -1163,12 +1163,12 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
{
for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second]))
{
if(p.score >= 0) return make_pair(p.str, p.rightBegin);
if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin);
KString ret;
ret.reserve(leftForm.size() + rightForm.size());
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.insert(ret.end(), rightForm.begin(), rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}
}

Expand All @@ -1183,7 +1183,7 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
{
for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second]))
{
return make_pair(p.str, p.rightBegin);
return make_tuple(p.str, p.leftEnd, p.rightBegin);
}
}
}
Expand All @@ -1198,14 +1198,14 @@ pair<KString, size_t> CompiledRule::combineOneImpl(
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.push_back(u'아'); // `어`를 `아`로 교체하여 삽입
ret.insert(ret.end(), rightForm.begin() + 1, rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}
}
KString ret;
ret.reserve(leftForm.size() + rightForm.size());
ret.insert(ret.end(), leftForm.begin(), leftForm.end());
ret.insert(ret.end(), rightForm.begin(), rightForm.end());
return make_pair(ret, leftForm.size());
return make_tuple(ret, leftForm.size(), leftForm.size());
}

Vector<tuple<size_t, size_t, CondPolarity>> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const
Expand Down
5 changes: 4 additions & 1 deletion src/Combiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,10 @@ namespace kiwi
CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none
) const;

std::pair<KString, size_t> combineOneImpl(
/**
* @return tuple(combinedForm, leftFormBoundary, rightFormBoundary)
*/
std::tuple<KString, size_t, size_t> combineOneImpl(
U16StringView leftForm, POSTag leftTag,
U16StringView rightForm, POSTag rightTag,
CondVowel cv = CondVowel::none, CondPolarity cp = CondPolarity::none
Expand Down
122 changes: 102 additions & 20 deletions src/Joiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace kiwi
if (l == POSTag::sn && r == POSTag::nr) return false;
if (l == POSTag::sso || l == POSTag::ssc) return false;
if (r == POSTag::sso) return true;
if ((isJClass(l) || isEClass(l)) && r == POSTag::ss) return true;

if (r == POSTag::vx && rform.size() == 1 && (rform[0] == u'하' || rform[0] == u'지')) return false;

Expand Down Expand Up @@ -79,9 +80,11 @@ namespace kiwi

void Joiner::add(U16StringView form, POSTag tag, Space space)
{
KString normForm = normalizeHangul(form);
if (stack.size() == activeStart)
{
stack += normalizeHangul(form);
ranges.emplace_back(stack.size(), stack.size() + normForm.size());
stack += normForm;
lastTag = tag;
return;
}
Expand All @@ -90,7 +93,8 @@ namespace kiwi
{
if (stack.empty() || !isSpace(stack.back())) stack.push_back(u' ');
activeStart = stack.size();
stack += normalizeHangul(form);
ranges.emplace_back(stack.size(), stack.size() + normForm.size());
stack += normForm;
}
else
{
Expand All @@ -100,8 +104,6 @@ namespace kiwi
cv = isHangulSyllable(stack[activeStart - 1]) ? CondVowel::vowel : CondVowel::non_vowel;
}

KString normForm = normalizeHangul(form);

if (!stack.empty() && (isJClass(tag) || isEClass(tag)))
{
if (isEClass(tag) && normForm[0] == u'아') normForm[0] = u'어';
Expand Down Expand Up @@ -148,8 +150,10 @@ namespace kiwi
}
auto r = cr->combineOneImpl({ stack.data() + activeStart, stack.size() - activeStart }, lastTag, normForm, tag, cv);
stack.erase(stack.begin() + activeStart, stack.end());
stack += r.first;
activeStart += r.second;
ranges.back().second = activeStart + get<1>(r);
ranges.emplace_back(activeStart + get<2>(r), activeStart + get<0>(r).size());
stack += get<0>(r);
activeStart += get<2>(r);
}
anteLastTag = lastTag;
lastTag = tag;
Expand All @@ -165,14 +169,46 @@ namespace kiwi
return add(U16StringView{ form }, tag, space);
}

u16string Joiner::getU16() const
u16string Joiner::getU16(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return joinHangul(stack);
if (rangesOut)
{
rangesOut->clear();
rangesOut->reserve(ranges.size());
Vector<uint32_t> u16pos;
auto ret = joinHangul(stack.begin(), stack.end(), u16pos);
u16pos.emplace_back(ret.size());
for (auto& r : ranges)
{
auto endOffset = u16pos[r.second] + (r.second > 0 && u16pos[r.second - 1] == u16pos[r.second] ? 1 : 0);
rangesOut->emplace_back(u16pos[r.first], endOffset);
}
return ret;
}
else
{
return joinHangul(stack);
}
}

string Joiner::getU8() const
string Joiner::getU8(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return utf16To8(joinHangul(stack));
auto u16 = getU16(rangesOut);
if (rangesOut)
{
Vector<uint32_t> positions;
auto ret = utf16To8(u16, positions);
for (auto& r : *rangesOut)
{
r.first = positions[r.first];
r.second = positions[r.second];
}
return ret;
}
else
{
return utf16To8(u16);
}
}

AutoJoiner::~AutoJoiner()
Expand Down Expand Up @@ -264,24 +300,31 @@ namespace kiwi
if (!node) break;
}

// prevent unknown or partial tag
POSTag fixedTag = tag;
if (tag == POSTag::unknown || tag == POSTag::p)
{
fixedTag = POSTag::nnp;
}

if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie)))
{
Vector<const Morpheme*> cands;
foreachMorpheme(formHead, [&](const Morpheme* m)
{
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(tag))
if (inferRegularity && clearIrregular(m->tag) == clearIrregular(fixedTag))
{
cands.emplace_back(m);
}
else if (!inferRegularity && m->tag == tag)
else if (!inferRegularity && m->tag == fixedTag)
{
cands.emplace_back(m);
}
});

if (cands.size() <= 1)
{
auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(tag)) : cands[0]->lmMorphemeId;
auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId;
if (!cands.empty()) tag = cands[0]->tag;
for (auto& cand : candidates)
{
Expand All @@ -308,11 +351,36 @@ namespace kiwi
n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId);
n.joiner.add(form, cands[0]->tag, space);
}

UnorderedMap<LmState, pair<float, uint32_t>> bestScoreByState;
for (size_t i = 0; i < candidates.size(); ++i)
{
auto& c = candidates[i];
auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i));
if (!inserted.second)
{
if (inserted.first->second.first < c.score)
{
inserted.first->second = make_pair(c.score, i);
}
}
}

if (bestScoreByState.size() < candidates.size())
{
Vector<Candidate<LmState>> newCandidates;
newCandidates.reserve(bestScoreByState.size());
for (auto& p : bestScoreByState)
{
newCandidates.emplace_back(move(candidates[p.second.second]));
}
candidates = move(newCandidates);
}
}
}
else
{
auto lmId = getDefaultMorphemeId(clearIrregular(tag));
auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag));
for (auto& cand : candidates)
{
cand.score += cand.lmState.next(kiwi->langMdl, lmId);
Expand Down Expand Up @@ -422,19 +490,33 @@ namespace kiwi

struct GetU16Visitor
{
vector<pair<uint32_t, uint32_t>>* rangesOut;

GetU16Visitor(vector<pair<uint32_t, uint32_t>>* _rangesOut)
: rangesOut{ _rangesOut }
{
}

template<class LmState>
u16string operator()(const Vector<Candidate<LmState>>& o) const
{
return o[0].joiner.getU16();
return o[0].joiner.getU16(rangesOut);
}
};

struct GetU8Visitor
{
vector<pair<uint32_t, uint32_t>>* rangesOut;

GetU8Visitor(vector<pair<uint32_t, uint32_t>>* _rangesOut)
: rangesOut{ _rangesOut }
{
}

template<class LmState>
string operator()(const Vector<Candidate<LmState>>& o) const
{
return o[0].joiner.getU8();
return o[0].joiner.getU8(rangesOut);
}
};

Expand All @@ -458,14 +540,14 @@ namespace kiwi
return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast<CandVector&>(candBuf));
}

u16string AutoJoiner::getU16() const
u16string AutoJoiner::getU16(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return mapbox::util::apply_visitor(GetU16Visitor{}, reinterpret_cast<const CandVector&>(candBuf));
return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast<const CandVector&>(candBuf));
}

string AutoJoiner::getU8() const
string AutoJoiner::getU8(vector<pair<uint32_t, uint32_t>>* rangesOut) const
{
return mapbox::util::apply_visitor(GetU8Visitor{}, reinterpret_cast<const CandVector&>(candBuf));
return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast<const CandVector&>(candBuf));
}
}
}
Loading
Loading