diff --git a/include/kiwi/SubstringExtractor.h b/include/kiwi/SubstringExtractor.h index 48296dce..ec71a48e 100644 --- a/include/kiwi/SubstringExtractor.h +++ b/include/kiwi/SubstringExtractor.h @@ -24,22 +24,42 @@ namespace kiwi UnorderedMap token2id; Vector id2Token; Vector buf; + Vector tokenClusters; + Vector tokenCnts; std::shared_ptr threadPool; template void _addArray(It first, It last); + + Vector> computeClusterScore() const; + public: - PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers); + PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers, + const std::vector>& clusters = {} + ); void addArray(const uint16_t* first, const uint16_t* last); void addArray(const uint32_t* first, const uint32_t* last); void addArray(const uint64_t* first, const uint64_t* last); utils::FrozenTrie count() const; std::unique_ptr buildLM( - size_t lastMinCf, + const std::vector& minCfByOrder, size_t bosTokenId, size_t eosTokenId, size_t unkTokenId, ArchType archType = ArchType::none ) const; }; + + class ClusterData + { + const std::pair* clusterScores = nullptr; + size_t clusterSize = 0; + public: + ClusterData(); + ClusterData(const void* _ptr, size_t _size); + + size_t size() const; + size_t cluster(size_t i) const; + float score(size_t i) const; + }; } diff --git a/src/SubstringExtractor.cpp b/src/SubstringExtractor.cpp index b437255e..cc41541b 100644 --- a/src/SubstringExtractor.cpp +++ b/src/SubstringExtractor.cpp @@ -173,14 +173,45 @@ namespace kiwi template using PrefixTrieNode = utils::TrieNodeEx>>; - PrefixCounter::PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers) + PrefixCounter::PrefixCounter( + size_t _prefixSize, + size_t _minCf, + size_t _numWorkers, + const std::vector>& clusters + ) : prefixSize(_prefixSize), minCf(_minCf), id2Token(2), buf(1) { - if (_numWorkers == 0) _numWorkers = min(thread::hardware_concurrency(), 8u); + if (_numWorkers == (size_t)-1) _numWorkers = min(thread::hardware_concurrency(), 8u); if (_numWorkers > 1) { threadPool = make_unique(_numWorkers); } + + if (clusters.empty()) return; + + unordered_set alreadyAllocated; + for (auto cs : clusters) + { + if (cs.empty()) continue; + sort(cs.begin(), cs.end()); + const auto cid = cs[0]; + for (auto c : cs) + { + if (alreadyAllocated.find(c) != alreadyAllocated.end()) + { + throw runtime_error("Duplicated cluster id"); + } + alreadyAllocated.insert(c); + + if (c >= tokenClusters.size()) + { + const auto e = c + 1; + tokenClusters.resize(e, -1); + tokenCnts.resize(e); + } + tokenClusters[c] = cid; + } + } } template @@ -188,7 +219,16 @@ namespace kiwi { for (; first != last; ++first) { - const auto token = *first; + auto token = *first; + if (token < tokenClusters.size()) + { + if (tokenClusters[token] != (size_t)-1) + { + tokenCnts[token]++; + token = tokenClusters[token]; + } + } + auto it = token2id.find(token); if (it == token2id.end()) { @@ -295,18 +335,89 @@ namespace kiwi return utils::freezeTrie(move(trie), ArchType::balanced); } + Vector> PrefixCounter::computeClusterScore() const + { + UnorderedMap clusterCnts; + for (size_t i = 0; i < tokenClusters.size(); ++i) + { + if (tokenClusters[i] != (size_t)-1) + { + clusterCnts[tokenClusters[i]] += tokenCnts[i]; + } + } + + Vector> ret; + ret.reserve(tokenClusters.size()); + for (size_t i = 0; i < tokenClusters.size(); ++i) + { + if (tokenClusters[i] == (size_t)-1) + { + ret.emplace_back(-1, 0); + } + else + { + ret.emplace_back(tokenClusters[i], (float)log((double)tokenCnts[i] / clusterCnts[tokenClusters[i]])); + } + } + return ret; + } + unique_ptr PrefixCounter::buildLM( - size_t lastMinCf, + const std::vector& minCfByOrder, size_t bosTokenId, size_t eosTokenId, size_t unkTokenId, ArchType archType) const { + Vector extraBuf; + if (!tokenClusters.empty()) + { + auto clusterScore = computeClusterScore(); + extraBuf.resize(clusterScore.size() * sizeof(uint64_t) + sizeof(uint64_t) * 2); + memcpy(extraBuf.data(), "UNIGRAM\0", sizeof(uint64_t)); + uint64_t size = clusterScore.size(); + memcpy(extraBuf.data() + sizeof(uint64_t), &size, sizeof(uint64_t)); + memcpy(extraBuf.data() + sizeof(uint64_t) * 2, clusterScore.data(), clusterScore.size() * sizeof(uint64_t)); + } + utils::MemoryOwner mem; { auto trie = count(); - mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCf, lastMinCf, unkTokenId, bosTokenId, eosTokenId, 1e-5f, 0, false); + mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, + 1e-5f, 0, false, nullptr, (const Vector*)nullptr, + extraBuf.data(), extraBuf.size()); } return lm::KnLangModelBase::create(move(mem), archType); } + + ClusterData::ClusterData() = default; + + ClusterData::ClusterData(const void* _ptr, size_t _size) + { + if (!_ptr || !_size) return; + if (_size < sizeof(uint64_t) * 2) throw runtime_error("Invalid cluster data"); + auto ptr = (const uint64_t*)_ptr; + if (memcmp(ptr, "UNIGRAM\0", sizeof(uint64_t)) != 0) throw runtime_error("Invalid cluster data"); + const auto size = ptr[1]; + if (_size < sizeof(uint64_t) * 2 + size * sizeof(uint64_t)) throw runtime_error("Invalid cluster data"); + clusterScores = (const pair*)(ptr + 2); + clusterSize = size; + } + + size_t ClusterData::size() const + { + return clusterSize; + } + + size_t ClusterData::cluster(size_t i) const + { + if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return i; + return clusterScores[i].first; + } + + float ClusterData::score(size_t i) const + { + if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return 0; + return clusterScores[i].second; + } }