From 2d95d7fdebc0ed435627350775688c6846ea233d Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 28 Oct 2023 20:26:11 +0900 Subject: [PATCH] improved PathHash impl --- src/Kiwi.cpp | 17 ++++++++++++----- src/LmState.hpp | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 319a3673..bb789004 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -678,6 +678,11 @@ namespace kiwi { } + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : lmState{ wordLl.lmState }, lastMorpheme{ (uint16_t)(wordLl.morpheme - morphBase) }, spState{ wordLl.spState } + { + } + bool operator==(const PathHash& o) const { return lmState == o.lmState && lastMorpheme == o.lastMorpheme && spState == o.spState; @@ -687,13 +692,15 @@ namespace kiwi template struct PathHash> { + using LmState = SbgState; + array lastMorphemes; - PathHash(const SbgState& _lmState = {}, uint16_t _lastMorpheme = 0, SpecialState _spState = {}) + PathHash(const WordLL& wordLl, const Morpheme* morphBase) { - _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); - lastMorphemes.back() = _lastMorpheme; - lastMorphemes[0] = ((lastMorphemes[0] << 8) >> 8) | ((uint8_t)_spState) << ((sizeof(VocabTy) - 1) * 8); + wordLl.lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); + lastMorphemes.back() = wordLl.morpheme - morphBase; + lastMorphemes[0] = ((lastMorphemes[0] << 8) >> 8) | ((VocabTy)wordLl.spState) << ((sizeof(VocabTy) - 1) * 8); } bool operator==(const PathHash& o) const @@ -1317,7 +1324,7 @@ namespace kiwi float cutoffScore = -INFINITY, cutoffScoreWithCombined = -INFINITY; for (auto& c : cache[i]) { - PathHash ph{ c.lmState, (uint16_t)(c.morpheme - kw->morphemes.data()), c.spState }; + PathHash ph{ c, kw->morphemes.data() }; auto insertResult = bestPathes.emplace(ph, make_pair(&c, c.accScore)); if (!insertResult.second) { diff --git a/src/LmState.hpp b/src/LmState.hpp index 3d8c1fd6..185e9124 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -54,7 +54,7 @@ namespace kiwi }; template - class SbgState : KnLMState<_arch, VocabTy> + class SbgState : public KnLMState<_arch, VocabTy> { size_t historyPos = 0; std::array history = { {0,} };