From ffee8e4d72a4d2ecd859575007877d12acbee5b3 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 28 Nov 2021 17:08:13 +0000 Subject: [PATCH] Add string_view to dictionary for fast lookup --- CMakeLists.txt | 2 ++ src/dictionary.cc | 66 +++++++++++++++++++++++++++++++++++++++++------ src/dictionary.h | 17 +++++++----- 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a93e06f4e..17e7f43b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,8 @@ cmake_minimum_required(VERSION 2.8.9) project(fasttext) +set(CMAKE_CXX_STANDARD 17) + # The version number. set (fasttext_VERSION_MAJOR 0) set (fasttext_VERSION_MINOR 1) diff --git a/src/dictionary.cc b/src/dictionary.cc index cb396cd14..c23e35f6f 100644 --- a/src/dictionary.cc +++ b/src/dictionary.cc @@ -42,11 +42,11 @@ Dictionary::Dictionary(std::shared_ptr args, std::istream& in) load(in); } -int32_t Dictionary::find(const std::string& w) const { +int32_t Dictionary::find(const std::string_view w) const { return find(w, hash(w)); } -int32_t Dictionary::find(const std::string& w, uint32_t h) const { +int32_t Dictionary::find(const std::string_view w, uint32_t h) const { int32_t word2intsize = word2int_.size(); int32_t id = h % word2intsize; while (word2int_[id] != -1 && words_[word2int_[id]].word != w) { @@ -126,12 +126,12 @@ bool Dictionary::discard(int32_t id, real rand) const { return rand > pdiscard_[id]; } -int32_t Dictionary::getId(const std::string& w, uint32_t h) const { +int32_t Dictionary::getId(const std::string_view w, uint32_t h) const { int32_t id = find(w, h); return word2int_[id]; } -int32_t Dictionary::getId(const std::string& w) const { +int32_t Dictionary::getId(const std::string_view w) const { int32_t h = find(w); return word2int_[h]; } @@ -142,7 +142,7 @@ entry_type Dictionary::getType(int32_t id) const { return words_[id].type; } -entry_type Dictionary::getType(const std::string& w) const { +entry_type Dictionary::getType(const std::string_view w) const { return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word; } @@ -160,7 +160,7 @@ std::string Dictionary::getWord(int32_t id) const { // Since all fasttext models that were already released were trained // using signed char, we fixed the hash function to make models // compatible whatever compiler is used. -uint32_t Dictionary::hash(const std::string& str) const { +uint32_t Dictionary::hash(const std::string_view str) const { uint32_t h = 2166136261; for (size_t i = 0; i < str.size(); i++) { h = h ^ uint32_t(int8_t(str[i])); @@ -324,11 +324,16 @@ void Dictionary::addWordNgrams( void Dictionary::addSubwords( std::vector& line, - const std::string& token, + const std::string_view token, int32_t wid) const { if (wid < 0) { // out of vocab if (token != EOS) { - computeSubwords(BOW + token + EOW, line); + std::string concat; + concat.reserve(BOW.size() + token.size() + EOW.size()); + concat += BOW; + concat.append(token.data(), token.size()); + concat += EOW; + computeSubwords(concat, line); } } else { if (args_->maxn <= 0) { // in vocab w/o subwords @@ -406,6 +411,51 @@ int32_t Dictionary::getLine( return ntokens; } +namespace { +bool readWordNoNewline(std::string_view& in, std::string_view& word) { + const std::string_view spaces(" \n\r\t\v\f\0"); + std::string_view::size_type begin = in.find_first_not_of(spaces); + if (begin == std::string_view::npos) { + in.remove_prefix(in.size()); + return false; + } + in.remove_prefix(begin); + word = in.substr(0, in.find_first_of(spaces)); + in.remove_prefix(word.size()); + return true; +} +} // namespace + +int32_t Dictionary::getStringNoNewline( + std::string_view in, + std::vector& words, + std::vector& labels) const { + std::vector word_hashes; + std::string_view token; + int32_t ntokens = 0; + + words.clear(); + labels.clear(); + while (readWordNoNewline(in, token)) { + uint32_t h = hash(token); + int32_t wid = getId(token, h); + entry_type type = wid < 0 ? getType(token) : getType(wid); + + ntokens++; + if (type == entry_type::word) { + addSubwords(words, token, wid); + word_hashes.push_back(h); + } else if (type == entry_type::label && wid >= 0) { + labels.push_back(wid - nwords_); + } + if (token == EOS) { + break; + } + } + addWordNgrams(words, word_hashes, args_->wordNgrams); + return ntokens; +} + void Dictionary::pushHash(std::vector& hashes, int32_t id) const { if (pruneidx_size_ == 0 || id < 0) { return; diff --git a/src/dictionary.h b/src/dictionary.h index aa57989c4..e25e8011f 100644 --- a/src/dictionary.h +++ b/src/dictionary.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -36,13 +37,13 @@ class Dictionary { static const int32_t MAX_VOCAB_SIZE = 30000000; static const int32_t MAX_LINE_SIZE = 1024; - int32_t find(const std::string&) const; - int32_t find(const std::string&, uint32_t h) const; + int32_t find(const std::string_view) const; + int32_t find(const std::string_view, uint32_t h) const; void initTableDiscard(); void initNgrams(); void reset(std::istream&) const; void pushHash(std::vector&, int32_t) const; - void addSubwords(std::vector&, const std::string&, int32_t) const; + void addSubwords(std::vector&, const std::string_view, int32_t) const; std::shared_ptr args_; std::vector word2int_; @@ -71,10 +72,10 @@ class Dictionary { int32_t nwords() const; int32_t nlabels() const; int64_t ntokens() const; - int32_t getId(const std::string&) const; - int32_t getId(const std::string&, uint32_t h) const; + int32_t getId(const std::string_view) const; + int32_t getId(const std::string_view, uint32_t h) const; entry_type getType(int32_t) const; - entry_type getType(const std::string&) const; + entry_type getType(const std::string_view) const; bool discard(int32_t, real) const; std::string getWord(int32_t) const; const std::vector& getSubwords(int32_t) const; @@ -87,7 +88,7 @@ class Dictionary { const std::string&, std::vector&, std::vector* substrings = nullptr) const; - uint32_t hash(const std::string& str) const; + uint32_t hash(const std::string_view str) const; void add(const std::string&); bool readWord(std::istream&, std::string&) const; void readFromFile(std::istream&); @@ -99,6 +100,8 @@ class Dictionary { const; int32_t getLine(std::istream&, std::vector&, std::minstd_rand&) const; + int32_t getStringNoNewline(std::string_view, std::vector&, + std::vector&) const; void threshold(int64_t, int64_t); void prune(std::vector&); bool isPruned() {