diff --git a/CHANGELOG.md b/CHANGELOG.md index 56ede4e55..cc9f179cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special - Better suppression of unwanted output symbols, specifically "\n" from SentencePiece with byte-fallback. Can be deactivated with --allow-special - Display decoder time statistics with marian-decoder --stat-freq 10 ... - Support for MS-internal binary shortlist diff --git a/src/data/default_vocab.cpp b/src/data/default_vocab.cpp index 7706a1c11..2d92f4f64 100644 --- a/src/data/default_vocab.cpp +++ b/src/data/default_vocab.cpp @@ -28,6 +28,9 @@ class DefaultVocab : public IVocab { std::vector suffixes_ = { ".yml", ".yaml", ".json" }; + // Contains control characters added to vocab, possibly due to byte-fallback + std::vector controlChars_; + class VocabFreqOrderer { private: const std::unordered_map& counter_; @@ -71,6 +74,16 @@ class DefaultVocab : public IVocab { return decode(sentence, /*ignoreEOS=*/true); } + // SentencePiece with byte-fallback may generate control symbols with output sampling. + // Let's mark them as special and suppress them later on output. This is generally safe + // for UTF-8 since control chars are not used as partial bytes in multi-byte sequences. + // They only appear in single-byte chars as themselves and this is what we suppress. + void addSpecialWords(std::vector& special) const override { + special.reserve(special.size() + controlChars_.size()); + for(auto c : controlChars_) + special.push_back(c); + } + virtual std::string type() const override { return "DefaultVocab"; } virtual Word getEosId() const override { return eosId_; } @@ -130,6 +143,8 @@ class DefaultVocab : public IVocab { } ABORT_IF(id2str_.empty(), "Empty vocabulary: ", vocabPath); + populateControlChars(); + addRequiredVocabulary(vocabPath, isJson); return std::max(id2str_.size(), maxSize); @@ -172,6 +187,17 @@ class DefaultVocab : public IVocab { private: + // Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab. + // This makes sure that we do not waste computational effort on suppression if they don't actually appear. + void populateControlChars() { + for(int i = 0; i < 32; ++i) { + std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x + auto id = (*this)[bytePiece]; + if(id != unkId_) + controlChars_.push_back(id); + } + } + virtual void addRequiredVocabulary(const std::string& vocabPath, bool isJson) { // look up ids for and , which are required // The name backCompatStr is alternatively accepted for Yaml vocabs if id