diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index bb789004..304f85b8 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -55,6 +55,11 @@ namespace kiwi } }; + template + struct WordLL; + + using Wid = uint32_t; + class PathEvaluator { public: @@ -115,19 +120,45 @@ namespace kiwi const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, - size_t topN, + const size_t topN, bool openEnd, bool splitComplex = false, const std::unordered_set* blocklist = nullptr ); - template - static float evalPath(const Kiwi* kw, const KGraphNode* startNode, const KGraphNode* node, - CacheTy& cache, const Vector& ownFormList, - size_t i, size_t ownFormId, CandTy&& cands, bool unknownForm, + template + static float evalPath(const Kiwi* kw, + const KGraphNode* startNode, + const KGraphNode* node, + const size_t topN, + Vector>>& cache, + const Vector& ownFormList, + size_t i, + size_t ownFormId, + CandTy&& cands, + bool unknownForm, bool splitComplex = false, const std::unordered_set* blocklist = nullptr ); + + template + static void evalSingleMorpheme( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + array seq, + array oseq, + size_t chSize, + uint8_t combSocket, + size_t ownFormId, + const Morpheme* curMorph, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const float ignoreCondScore, + const float nodeLevelDiscount + ); }; using FnFindBestPath = decltype(&PathEvaluator::findBestPath>); @@ -598,36 +629,6 @@ namespace kiwi return ret; } - template - void emplaceMaxCnt(Map& dest, Vector& vector, Key&& key, CompareKey ckey, Value&& value, size_t maxCnt, Comp comparator) - { - auto p = dest.emplace(std::piecewise_construct, std::forward_as_tuple(key), std::forward_as_tuple()); - auto itp = p.first; - if (p.second) - { - itp->second.reserve(maxCnt); - } - - if (itp->second.size() < maxCnt) - { - itp->second.emplace_back(ckey, vector.size()); - vector.emplace_back(std::forward(value)); - push_heap(itp->second.begin(), itp->second.end(), comparator); - } - else - { - if (comparator(ckey, itp->second.front().first)) - { - pop_heap(itp->second.begin(), itp->second.end(), comparator); - itp->second.back().first = ckey; - vector[itp->second.back().second] = value; - push_heap(itp->second.begin(), itp->second.end(), comparator); - } - } - } - - using Wid = uint32_t; - template struct WordLL { @@ -638,31 +639,25 @@ namespace kiwi Wid wid = 0; uint16_t ownFormId = 0; uint8_t combineSocket = 0; - SpecialState spState; + SpecialState spState, rootSpState; WordLL() = default; WordLL(const Morpheme* _morph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) - : morpheme{ _morph }, accScore{ _accScore }, accTypoCost{ _accTypoCost }, parent{ _parent }, lmState{ _lmState }, spState(_spState) + : morpheme{ _morph }, + accScore{ _accScore }, + accTypoCost{ _accTypoCost }, + parent{ _parent }, + lmState{ _lmState }, + spState{ _spState }, + rootSpState{ parent ? parent->rootSpState : spState } { } - }; - - template - struct WordLLP - { - const Morpheme* lastMorpheme = nullptr; - float accScore = 0, accTypoCost = 0; - const WordLL* parent = nullptr; - LmState lmState; - SpecialState spState; - WordLLP() = default; - - WordLLP(const Morpheme* _lastMorph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) - : lastMorpheme{ _lastMorph }, accScore{ _accScore }, accTypoCost{ _accTypoCost }, - parent{ _parent }, lmState{ _lmState }, spState(_spState) + const WordLL* root() const { + if (parent) return parent->root(); + else return this; } }; @@ -670,22 +665,21 @@ namespace kiwi struct PathHash { LmState lmState; - uint16_t lastMorpheme; uint8_t spState; - PathHash(LmState _lmState = {}, uint16_t _lastMorpheme = 0, SpecialState _spState = {}) - : lmState{ _lmState }, lastMorpheme{ _lastMorpheme }, spState{ _spState } + PathHash(LmState _lmState = {}, SpecialState _spState = {}) + : lmState{ _lmState }, spState{ _spState } { } PathHash(const WordLL& wordLl, const Morpheme* morphBase) - : lmState{ wordLl.lmState }, lastMorpheme{ (uint16_t)(wordLl.morpheme - morphBase) }, spState{ wordLl.spState } + : PathHash{ wordLl.lmState, wordLl.root()->spState } { } bool operator==(const PathHash& o) const { - return lmState == o.lmState && lastMorpheme == o.lastMorpheme && spState == o.spState; + return lmState == o.lmState && spState == o.spState; } }; @@ -694,18 +688,25 @@ namespace kiwi { using LmState = SbgState; + KnLMState<_arch, VocabTy> lmState; array lastMorphemes; + uint8_t spState; + + PathHash(LmState _lmState = {}, SpecialState _spState = {}) + : lmState{ _lmState }, spState{ _spState } + { + _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); + } + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : PathHash{ wordLl.lmState, wordLl.root()->spState } { - wordLl.lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); - lastMorphemes.back() = wordLl.morpheme - morphBase; - lastMorphemes[0] = ((lastMorphemes[0] << 8) >> 8) | ((VocabTy)wordLl.spState) << ((sizeof(VocabTy) - 1) * 8); } bool operator==(const PathHash& o) const { - return lastMorphemes == o.lastMorphemes; + return lmState == o.lmState && lastMorphemes == o.lastMorphemes && spState == o.spState; } }; @@ -744,34 +745,209 @@ namespace kiwi } }; - template - void evalTrigram(const LangModel& langMdl, - const Morpheme* morphBase, - const Vector& ownForms, + inline bool hasLeftBoundary(const KGraphNode* node) + { + // 시작 지점은 항상 왼쪽 경계로 처리 + if (node->getPrev()->endPos == 0) return true; + + // 이전 노드의 끝지점이 현재 노드보다 작은 경우 왼쪽 경계로 처리 + if (node->getPrev()->endPos < node->startPos) return true; + + // 이전 노드가 구두점이나 특수 문자인 경우 + if (!node->getPrev()->uform.empty()) + { + // 닫는 괄호는 왼쪽 경계로 처리하지 않음 + auto c = node->getPrev()->uform.back(); + auto tag = identifySpecialChr(c); + if (tag == POSTag::ssc || c == u'"' || c == u'\'') return false; + + // 나머지 특수문자는 왼쪽 경계로 처리 + if (POSTag::sf <= tag && tag <= POSTag::sb) return true; + } + return false; + } + + inline bool isInflectendaNP(const Morpheme* morph) + { + return morph->tag == POSTag::np && morph->kform->size() == 1 && ( + (*morph->kform)[0] == u'나' || (*morph->kform)[0] == u'너' || (*morph->kform)[0] == u'저'); + } + + inline bool isInflectendaJ(const Morpheme* morph) + { + return (morph->tag == POSTag::jks || morph->tag == POSTag::jkc) && morph->kform->size() == 1 && (*morph->kform)[0] == u'가'; + } + + inline bool isVerbL(const Morpheme* morph) + { + return isVerbClass(morph->tag) && morph->kform && !morph->kform->empty() && morph->kform->back() == u'ᆯ'; + } + + inline bool isBadPairOfVerbL(const Morpheme* morph) + { + auto onset = (morph->kform && !morph->kform->empty()) ? morph->kform->front() : 0; + return onset == u'으' || onset == u'느' || (u'사' <= onset && onset <= u'시'); + } + + inline bool isPositiveVerb(const Morpheme* morph) + { + return isVerbClass(morph->tag) && FeatureTestor::isMatched(morph->kform, CondPolarity::positive); + } + + inline bool isNegativeVerb(const Morpheme* morph) + { + return isVerbClass(morph->tag) && FeatureTestor::isMatched(morph->kform, CondPolarity::negative); + } + + inline bool isVerbVowel(const Morpheme* morph) + { + return isVerbClass(morph->tag) && morph->kform && !morph->kform->empty() && !isHangulCoda(morph->kform->back()); + } + + inline uint8_t hashSbTypeOrder(uint8_t type, uint8_t order) + { + return ((type << 1) ^ (type >> 7) ^ order) % 63 + 1; + } + + struct RuleBasedScorer + { + Kiwi::SpecialMorph curMorphSpecialType; + size_t curMorphSbType; + int curMorphSbOrder; + bool vowelE, infJ, badPairOfL, positiveE, contractableE; + CondPolarity condP; + + RuleBasedScorer(const Kiwi* kw, const Morpheme* curMorph, const KGraphNode* node) + : + curMorphSpecialType{ kw->determineSpecialMorphType(kw->morphToId(curMorph)) }, + curMorphSbType{ curMorph->tag == POSTag::sb ? getSBType(joinHangul(*curMorph->kform)) : 0 }, + curMorphSbOrder{ curMorphSbType ? curMorph->senseId : 0 }, + vowelE{ isEClass(curMorph->tag) && curMorph->kform && hasNoOnset(*curMorph->kform) }, + infJ{ isInflectendaJ(curMorph) }, + badPairOfL{ isBadPairOfVerbL(curMorph) }, + positiveE{ isEClass(curMorph->tag) && node->form && node->form->form[0] == u'아' }, + contractableE{ isEClass(curMorph->tag) && curMorph->kform && !curMorph->kform->empty() && (*curMorph->kform)[0] == u'어' }, + condP{ curMorph->polar } + { + } + + float operator()(const Morpheme* prevMorpheme, const SpecialState prevSpState) const + { + float accScore = 0; + + // 불규칙 활용 형태소 뒤에 모음 어미가 붙는 경우 벌점 부여 + if (vowelE && isIrregular(prevMorpheme->tag)) + { + accScore -= 10; + } + // 나/너/저 뒤에 주격 조사 '가'가 붙는 경우 벌점 부여 + if (infJ && isInflectendaNP(prevMorpheme)) + { + accScore -= 5; + } + // ㄹ 받침 용언 뒤에 으/느/ㅅ으로 시작하는 형태소가 올 경우 벌점 부여 + if (badPairOfL && isVerbL(prevMorpheme)) + { + accScore -= 7; + } + // 동사 뒤가 아니거나, 앞의 동사가 양성이 아닌데, 양성모음용 어미가 등장한 경우 벌점 부여 + if (positiveE && !isPositiveVerb(prevMorpheme)) + { + accScore -= 100; + } + // 아/어로 시작하는 어미가 받침 없는 동사 뒤에서 축약되지 않은 경우 벌점 부여 + if (contractableE && isVerbVowel(prevMorpheme)) + { + accScore -= 3; + } + // 형용사 사용 불가 어미인데 형용사 뒤에 등장 + if (condP == CondPolarity::non_adj && (prevMorpheme->tag == POSTag::va || prevMorpheme->tag == POSTag::xsa)) + { + accScore -= 10; + } + if (curMorphSpecialType <= Kiwi::SpecialMorph::singleQuoteNA) + { + if (static_cast(curMorphSpecialType) != prevSpState.singleQuote) + { + accScore -= 2; + } + } + else if (curMorphSpecialType <= Kiwi::SpecialMorph::doubleQuoteNA) + { + if ((static_cast(curMorphSpecialType) - 3) != prevSpState.doubleQuote) + { + accScore -= 2; + } + } + + // discount for SB in form "[가-하]." + if (curMorphSbType == 5) + { + accScore -= 5; + } + + if (curMorphSbType && isEClass(prevMorpheme->tag) && prevMorpheme->tag != POSTag::ef) + { + accScore -= 10; + } + + if (curMorphSbType && prevSpState.bulletHash == hashSbTypeOrder(curMorphSbType, curMorphSbOrder)) + { + accScore += 3; + } + + return accScore; + } + }; + + template + void PathEvaluator::evalSingleMorpheme( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, const Vector>>& cache, - array seq, - size_t chSize, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - _Map& maxWidLL, - Vector>& nextPredCands, - float ignoreCondScore, - float spacePenalty, - bool allowedSpaceBetweenChunk + array seq, + array oseq, + size_t chSize, + uint8_t combSocket, + size_t ownFormId, + const Morpheme* curMorph, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const float ignoreCondScore, + const float nodeLevelDiscount ) { - size_t vocabSize = langMdl.knlm->getHeader().vocab_size; + // pair: [index, size] + thread_local UnorderedMap, pair> bestPathIndex; + thread_local Vector> bestPathValues; + bestPathIndex.clear(); + bestPathValues.clear(); + + const LangModel& langMdl = kw->langMdl; + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + + float additionalScore = curMorph->userScore + nodeLevelDiscount; + additionalScore += kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + + float discountForCombining = curMorph->combineSocket ? -15 : 0; + + const size_t vocabSize = langMdl.knlm->getHeader().vocab_size; for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { assert(prev != node); - for (auto& p : cache[prev - startNode]) + for (auto& prevPath : cache[prev - startNode]) { - float candScore = p.accScore; - if (p.combineSocket) + float candScore = prevPath.accScore + additionalScore; + if (prevPath.combineSocket) { // merge with only the same socket - if (p.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) + if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) { continue; } @@ -780,19 +956,19 @@ namespace kiwi if (allowedSpaceBetweenChunk) candScore -= spacePenalty; else continue; } - seq[0] = morphBase[p.wid].getCombined()->lmMorphemeId; + seq[0] = morphBase[prevPath.wid].getCombined()->lmMorphemeId; } const kchar_t* leftFormFirst, * leftFormLast; - if (p.ownFormId) + if (prevPath.ownFormId) { - leftFormFirst = ownForms[p.ownFormId - 1].data(); - leftFormLast = ownForms[p.ownFormId - 1].data() + ownForms[0].size(); + leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); + leftFormLast = ownForms[prevPath.ownFormId - 1].data() + ownForms[0].size(); } - else if (morphBase[p.wid].kform) + else if (morphBase[prevPath.wid].kform) { - leftFormFirst = morphBase[p.wid].kform->data(); - leftFormLast = morphBase[p.wid].kform->data() + morphBase[p.wid].kform->size(); + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = morphBase[prevPath.wid].kform->data() + morphBase[prevPath.wid].kform->size(); } else { @@ -802,7 +978,7 @@ namespace kiwi CondVowel cvowel = curMorph->vowel; CondPolarity cpolar = curMorph->polar; - if (p.morpheme->tag == POSTag::ssc) + if (prevPath.morpheme->tag == POSTag::ssc) { // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 } @@ -815,11 +991,11 @@ namespace kiwi if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) continue; } - auto cLmState = p.lmState; + auto cLmState = prevPath.lmState; Wid lSeq = 0; if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex)) { - lSeq = p.wid; + lSeq = prevPath.wid; } else { @@ -835,89 +1011,70 @@ namespace kiwi candScore += ll; } } - emplaceMaxCnt( - maxWidLL, - nextPredCands, - lSeq, - candScore, - WordLLP{ &morphBase[p.wid], candScore, p.accTypoCost + node->typoCost, &p, move(cLmState), p.spState }, - 3, - GenericGreater{} - ); + + { + const auto* prevMorpheme = &morphBase[prevPath.wid]; + const auto prevSpState = prevPath.spState; + candScore += ruleBasedScorer(prevMorpheme, prevSpState); + + PathHash ph{ cLmState, prevPath.rootSpState }; + auto inserted = bestPathIndex.emplace(ph, make_pair(bestPathValues.size(), 1)); + if (inserted.second) + { + bestPathValues.emplace_back(curMorph, candScore, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), prevPath.spState); + } + else + { + auto& target = bestPathValues[inserted.first->second.first]; + if (candScore > target.accScore) + { + target = WordLL{ curMorph, candScore, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), prevPath.spState }; + } + ++inserted.first->second.second; + } + } + continueFor:; } } - } - inline bool hasLeftBoundary(const KGraphNode* node) - { - // 시작 지점은 항상 왼쪽 경계로 처리 - if (node->getPrev()->endPos == 0) return true; - - // 이전 노드의 끝지점이 현재 노드보다 작은 경우 왼쪽 경계로 처리 - if (node->getPrev()->endPos < node->startPos) return true; - - // 이전 노드가 구두점이나 특수 문자인 경우 - if (!node->getPrev()->uform.empty()) + for (auto& p : bestPathIndex) { - // 닫는 괄호는 왼쪽 경계로 처리하지 않음 - auto c = node->getPrev()->uform.back(); - auto tag = identifySpecialChr(c); - if (tag == POSTag::ssc || c == u'"' || c == u'\'') return false; - - // 나머지 특수문자는 왼쪽 경계로 처리 - if (POSTag::sf <= tag && tag <= POSTag::sb) return true; - } - return false; - } - - inline bool isInflectendaNP(const Morpheme* morph) - { - return morph->tag == POSTag::np && morph->kform->size() == 1 && ( - (*morph->kform)[0] == u'나' || (*morph->kform)[0] == u'너' || (*morph->kform)[0] == u'저'); - } - - inline bool isInflectendaJ(const Morpheme* morph) - { - return (morph->tag == POSTag::jks || morph->tag == POSTag::jkc) && morph->kform->size() == 1 && (*morph->kform)[0] == u'가'; - } + const auto index = p.second.first; + const auto size = p.second.second; + resultOut.emplace_back(move(bestPathValues[index])); + auto& newPath = resultOut.back(); - inline bool isVerbL(const Morpheme* morph) - { - return isVerbClass(morph->tag) && morph->kform && !morph->kform->empty() && morph->kform->back() == u'ᆯ'; - } - - inline bool isBadPairOfVerbL(const Morpheme* morph) - { - auto onset = (morph->kform && !morph->kform->empty()) ? morph->kform->front() : 0; - return onset == u'으' || onset == u'느' || (u'사' <= onset && onset <= u'시'); - } - - inline bool isPositiveVerb(const Morpheme* morph) - { - return isVerbClass(morph->tag) && FeatureTestor::isMatched(morph->kform, CondPolarity::positive); - } - - inline bool isNegativeVerb(const Morpheme* morph) - { - return isVerbClass(morph->tag) && FeatureTestor::isMatched(morph->kform, CondPolarity::negative); - } - - inline bool isVerbVowel(const Morpheme* morph) - { - return isVerbClass(morph->tag) && morph->kform && !morph->kform->empty() && !isHangulCoda(morph->kform->back()); - } + // fill the rest information of resultOut + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) newPath.spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) newPath.spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) newPath.spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) newPath.spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) + { + newPath.spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + } - inline uint8_t hashSbTypeOrder(uint8_t type, uint8_t order) - { - return ((type << 1) ^ (type >> 7) ^ order) % 63 + 1; + if (curMorph->chunks.empty() || curMorph->complex) + { + newPath.wid = oseq[0]; + newPath.combineSocket = combSocket; + newPath.ownFormId = ownFormId; + } + else + { + newPath.wid = oseq[chSize - 1]; + } + } + return; } - template + template float PathEvaluator::evalPath(const Kiwi* kw, const KGraphNode* startNode, const KGraphNode* node, - CacheTy& cache, + const size_t topN, + Vector>>& cache, const Vector& ownFormList, size_t i, size_t ownFormId, @@ -929,12 +1086,22 @@ namespace kiwi { const size_t langVocabSize = kw->langMdl.knlm->getHeader().vocab_size; auto& nCache = cache[i]; + Vector> refCache; float whitespaceDiscount = 0; if (node->uform.empty() && node->endPos - node->startPos > node->form->form.size()) { whitespaceDiscount = -kw->spacePenalty * (node->endPos - node->startPos - node->form->form.size()); } + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) + { + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); + } + + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; float tMax = -INFINITY; for (bool ignoreCond : {false, true}) @@ -954,11 +1121,12 @@ namespace kiwi auto lastTag = kw->morphemes[p.wid].tag; if (!isJClass(lastTag) && !isEClass(lastTag)) continue; nCache.emplace_back(p); - nCache.back().accScore += curMorph->userScore * kw->typoCostWeight; - nCache.back().accTypoCost -= curMorph->userScore; - nCache.back().parent = &p; - nCache.back().morpheme = &kw->morphemes[curMorph->lmMorphemeId]; - nCache.back().wid = curMorph->lmMorphemeId; + auto& newPath = nCache.back(); + newPath.accScore += curMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= curMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; + newPath.wid = curMorph->lmMorphemeId; } } continue; @@ -967,10 +1135,9 @@ namespace kiwi array seq = { 0, }; array oseq = { 0, }; uint8_t combSocket = 0; - CondVowel condV = CondVowel::none; - CondPolarity condP = CondPolarity::none; + CondVowel condV = curMorph->vowel; + CondPolarity condP = curMorph->polar; size_t chSize = 1; - bool isUserWord = false; // if the morpheme has chunk set if (!curMorph->chunks.empty() && !curMorph->complex) { @@ -1005,7 +1172,6 @@ namespace kiwi seq[0] = curMorph->lmMorphemeId; if (within(curMorph->getCombined() ? curMorph->getCombined() : curMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) { - isUserWord = true; oseq[0] = curMorph - kw->morphemes.data(); } else @@ -1014,141 +1180,27 @@ namespace kiwi } combSocket = curMorph->combineSocket; } - condV = curMorph->vowel; - condP = curMorph->polar; - - thread_local BMap, 3>> maxWidLL; - thread_local Vector> nextPredCands; - maxWidLL.clear(); - nextPredCands.clear(); - evalTrigram(kw->langMdl, kw->morphemes.data(), ownFormList, cache, seq, chSize, curMorph, node, startNode, maxWidLL, nextPredCands, ignoreCond ? -10 : 0, kw->spacePenalty, kw->spaceTolerance > 0); - if (maxWidLL.empty()) continue; - - float estimatedLL = curMorph->userScore + whitespaceDiscount - node->typoCost * kw->typoCostWeight; - // if a form of the node is unknown, calculate log poisson distribution for word-tag - if (unknownForm) - { - size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); - estimatedLL -= unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias; - } - - float discountForCombining = curMorph->combineSocket ? -15 : 0; - estimatedLL += kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); - - auto curMorphSpecialType = kw->determineSpecialMorphType(kw->morphToId(curMorph)); - auto curMorphSbType = curMorph->tag == POSTag::sb ? getSBType(joinHangul(*curMorph->kform)) : 0; - auto curMorphSbOrder = curMorphSbType ? curMorph->senseId : 0; - - bool vowelE = isEClass(curMorph->tag) && curMorph->kform && hasNoOnset(*curMorph->kform); - bool infJ = isInflectendaJ(curMorph); - bool badPairOfL = isBadPairOfVerbL(curMorph); - bool positiveE = isEClass(curMorph->tag) && node->form && node->form->form[0] == u'아'; - bool contractableE = isEClass(curMorph->tag) && curMorph->kform && !curMorph->kform->empty() && (*curMorph->kform)[0] == u'어'; - - for (auto& p : maxWidLL) - { - for (auto& qp : p.second) - { - auto& q = nextPredCands[qp.second]; - q.accScore += estimatedLL; - // 불규칙 활용 형태소 뒤에 모음 어미가 붙는 경우 벌점 부여 - if (vowelE && isIrregular(q.lastMorpheme->tag)) - { - q.accScore -= 10; - } - // 나/너/저 뒤에 주격 조사 '가'가 붙는 경우 벌점 부여 - if (infJ && isInflectendaNP(q.lastMorpheme)) - { - q.accScore -= 5; - } - // ㄹ 받침 용언 뒤에 으/느/ㅅ으로 시작하는 형태소가 올 경우 벌점 부여 - if (badPairOfL && isVerbL(q.lastMorpheme)) - { - q.accScore -= 7; - } - // 동사 뒤가 아니거나, 앞의 동사가 양성이 아닌데, 양성모음용 어미가 등장한 경우 벌점 부여 - if (positiveE && !isPositiveVerb(q.lastMorpheme)) - { - q.accScore -= 100; - } - // 아/어로 시작하는 어미가 받침 없는 동사 뒤에서 축약되지 않은 경우 벌점 부여 - if (contractableE && isVerbVowel(q.lastMorpheme)) - { - q.accScore -= 3; - } - // 형용사 사용 불가 어미인데 형용사 뒤에 등장 - if (condP == CondPolarity::non_adj && (q.lastMorpheme->tag == POSTag::va || q.lastMorpheme->tag == POSTag::xsa)) - { - q.accScore -= 10; - } - if (curMorphSpecialType <= Kiwi::SpecialMorph::singleQuoteNA) - { - if (static_cast(curMorphSpecialType) != q.spState.singleQuote) - { - q.accScore -= 2; - } - } - else if (curMorphSpecialType <= Kiwi::SpecialMorph::doubleQuoteNA) - { - if ((static_cast(curMorphSpecialType) - 3) != q.spState.doubleQuote) - { - q.accScore -= 2; - } - } - - // discount for SB in form "[가-하]." - if (curMorphSbType == 5) - { - q.accScore -= 5; - } - - if (curMorphSbType && isEClass(q.lastMorpheme->tag) && q.lastMorpheme->tag != POSTag::ef) - { - q.accScore -= 10; - } - - if (curMorphSbType && q.spState.bulletHash == hashSbTypeOrder(curMorphSbType, curMorphSbOrder)) - { - q.accScore += 3; - } - - tMax = max(tMax, q.accScore + discountForCombining); - } - } - - for (auto& p : maxWidLL) - { - for (auto& qp : p.second) - { - auto& q = nextPredCands[qp.second]; - if (q.accScore <= tMax - kw->cutOffThreshold) continue; - nCache.emplace_back(curMorph, q.accScore, q.accTypoCost, q.parent, q.lmState, q.spState); - - if (curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) nCache.back().spState.singleQuote = 1; - else if (curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) nCache.back().spState.singleQuote = 0; - else if (curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) nCache.back().spState.doubleQuote = 1; - else if (curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) nCache.back().spState.doubleQuote = 0; - if (curMorphSbType) - { - nCache.back().spState.bulletHash = hashSbTypeOrder(curMorphSbType, curMorphSbOrder + 1); - } - auto& back = nCache.back(); - if (curMorph->chunks.empty() || curMorph->complex) - { - back.wid = oseq[0]; - back.combineSocket = combSocket; - back.ownFormId = ownFormId; - } - else - { - back.wid = oseq[chSize - 1]; - } - } - } + evalSingleMorpheme(nCache, kw, ownFormList, cache, seq, oseq, chSize, combSocket, ownFormId, curMorph, node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount); } if (!nCache.empty()) break; } + + tMax = -INFINITY; + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + tMax = max(tMax, c.accScore); + } + + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) + { + if (nCache[i].accScore + kw->cutOffThreshold < tMax) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; + } + nCache.resize(validCount); return tMax; } @@ -1250,7 +1302,7 @@ namespace kiwi const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, - size_t topN, + const size_t topN, bool openEnd, bool splitComplex, const std::unordered_set* blocklist @@ -1274,7 +1326,7 @@ namespace kiwi // start node if (prevSpStates.empty()) { - cache.front().emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, SpecialState{}); + cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, SpecialState{}); } else { @@ -1283,7 +1335,7 @@ namespace kiwi uniqStates.erase(unique(uniqStates.begin(), uniqStates.end()), uniqStates.end()); for (auto& spState : uniqStates) { - cache.front().emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, spState); + cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, spState); } } @@ -1301,7 +1353,7 @@ namespace kiwi if (node->form) { - tMax = evalPath(kw, startNode, node, cache, ownFormList, i, ownFormId, node->form->candidate, false, splitComplex, blocklist); + tMax = evalPath(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, node->form->candidate, false, splitComplex, blocklist); if (all_of(node->form->candidate.begin(), node->form->candidate.end(), [](const Morpheme* m) { return m->combineSocket || (!m->chunks.empty() && !m->complex); @@ -1309,50 +1361,12 @@ namespace kiwi { ownFormList.emplace_back(node->form->form); ownFormId = ownFormList.size(); - tMax = min(tMax, evalPath(kw, startNode, node, cache, ownFormList, i, ownFormId, unknownNodeLCands, true, splitComplex, blocklist)); + tMax = min(tMax, evalPath(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, unknownNodeLCands, true, splitComplex, blocklist)); }; } else { - tMax = evalPath(kw, startNode, node, cache, ownFormList, i, ownFormId, unknownNodeCands, true, splitComplex, blocklist); - } - - // heuristically remove cands having lower ll to speed up - if (cache[i].size() > topN) - { - UnorderedMap, pair*, float>> bestPathes; - float cutoffScore = -INFINITY, cutoffScoreWithCombined = -INFINITY; - for (auto& c : cache[i]) - { - PathHash ph{ c, kw->morphemes.data() }; - auto insertResult = bestPathes.emplace(ph, make_pair(&c, c.accScore)); - if (!insertResult.second) - { - if (c.accScore > insertResult.first->second.second) - { - insertResult.first->second = make_pair(&c, c.accScore); - } - } - if (c.combineSocket) - { - cutoffScoreWithCombined = max(cutoffScoreWithCombined, c.accScore); - } - else - { - cutoffScore = max(cutoffScore, c.accScore); - } - } - cutoffScore -= kw->cutOffThreshold; - cutoffScoreWithCombined -= kw->cutOffThreshold; - - Vector> reduced; - for (auto& p : bestPathes) - { - auto& c = *p.second.first; - float cutoff = (c.combineSocket) ? cutoffScoreWithCombined : cutoffScore; - if (reduced.size() < topN || c.accScore >= cutoff) reduced.emplace_back(move(c)); - } - cache[i] = move(reduced); + tMax = evalPath(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, unknownNodeCands, true, splitComplex, blocklist); } #ifdef DEBUG_PRINT diff --git a/src/Knlm.hpp b/src/Knlm.hpp index db69c0bc..1f9f9baf 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -341,7 +341,8 @@ namespace kiwi } } - float progress(ptrdiff_t& node_idx, KeyType next) const + template + float progress(IdxType& node_idx, KeyType next) const { float acc = 0; while (1) @@ -358,7 +359,7 @@ namespace kiwi { if (htx_data) { - ptrdiff_t lv; + IdxType lv; if (nst::search( &key_data[0], value_data, @@ -413,7 +414,7 @@ namespace kiwi } if (htx_data) { - ptrdiff_t lv; + IdxType lv; if (nst::search( &key_data[0], value_data, diff --git a/src/LmState.hpp b/src/LmState.hpp index 185e9124..ea225886 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -30,12 +30,12 @@ namespace kiwi template class KnLMState { - ptrdiff_t node = 0; + int32_t node = 0; public: static constexpr ArchType arch = _arch; KnLMState() = default; - KnLMState(const LangModel& lm) : node{ static_cast&>(*lm.knlm).getBosNodeIdx() } {} + KnLMState(const LangModel& lm) : node{ (int32_t)static_cast&>(*lm.knlm).getBosNodeIdx() } {} bool operator==(const KnLMState& other) const {