Skip to content

Commit

Permalink
improved PathHash impl
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Oct 28, 2023
1 parent c2548bd commit 2d95d7f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ namespace kiwi
{
}

PathHash(const WordLL<LmState>& wordLl, const Morpheme* morphBase)
: lmState{ wordLl.lmState }, lastMorpheme{ (uint16_t)(wordLl.morpheme - morphBase) }, spState{ wordLl.spState }
{
}

bool operator==(const PathHash& o) const
{
return lmState == o.lmState && lastMorpheme == o.lastMorpheme && spState == o.spState;
Expand All @@ -687,13 +692,15 @@ namespace kiwi
template<size_t windowSize, ArchType _arch, class VocabTy>
struct PathHash<SbgState<windowSize, _arch, VocabTy>>
{
using LmState = SbgState<windowSize, _arch, VocabTy>;

array<VocabTy, 4> lastMorphemes;

PathHash(const SbgState<windowSize, _arch, VocabTy>& _lmState = {}, uint16_t _lastMorpheme = 0, SpecialState _spState = {})
PathHash(const WordLL<LmState>& wordLl, const Morpheme* morphBase)
{
_lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size());
lastMorphemes.back() = _lastMorpheme;
lastMorphemes[0] = ((lastMorphemes[0] << 8) >> 8) | ((uint8_t)_spState) << ((sizeof(VocabTy) - 1) * 8);
wordLl.lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size());
lastMorphemes.back() = wordLl.morpheme - morphBase;
lastMorphemes[0] = ((lastMorphemes[0] << 8) >> 8) | ((VocabTy)wordLl.spState) << ((sizeof(VocabTy) - 1) * 8);
}

bool operator==(const PathHash& o) const
Expand Down Expand Up @@ -1317,7 +1324,7 @@ namespace kiwi
float cutoffScore = -INFINITY, cutoffScoreWithCombined = -INFINITY;
for (auto& c : cache[i])
{
PathHash<LmState> ph{ c.lmState, (uint16_t)(c.morpheme - kw->morphemes.data()), c.spState };
PathHash<LmState> ph{ c, kw->morphemes.data() };
auto insertResult = bestPathes.emplace(ph, make_pair(&c, c.accScore));
if (!insertResult.second)
{
Expand Down
2 changes: 1 addition & 1 deletion src/LmState.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace kiwi
};

template<size_t windowSize, ArchType _arch, class VocabTy>
class SbgState : KnLMState<_arch, VocabTy>
class SbgState : public KnLMState<_arch, VocabTy>
{
size_t historyPos = 0;
std::array<VocabTy, windowSize> history = { {0,} };
Expand Down

0 comments on commit 2d95d7f

Please sign in to comment.