Skip to content

Commit

Permalink
Merged PR 18366: Fix generation of special control characters for def…
Browse files Browse the repository at this point in the history
…ault vocabulary

This PR extends the --allow-special feature to default vocabulary items as well. If the default vocabulary is provided with symbols ostensibly generated from the SentencePiece Byte Fallback mechanism, we suppress the control characters from that list.
  • Loading branch information
rjai authored and emjotde committed Mar 30, 2021
1 parent 7d1f941 commit 4408e88
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
- Better suppression of unwanted output symbols, specifically "\n" from SentencePiece with byte-fallback. Can be deactivated with --allow-special
- Display decoder time statistics with marian-decoder --stat-freq 10 ...
- Support for MS-internal binary shortlist
Expand Down
26 changes: 26 additions & 0 deletions src/data/default_vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class DefaultVocab : public IVocab {

std::vector<std::string> suffixes_ = { ".yml", ".yaml", ".json" };

// Contains control characters added to vocab, possibly due to byte-fallback
std::vector<Word> controlChars_;

class VocabFreqOrderer {
private:
const std::unordered_map<std::string, size_t>& counter_;
Expand Down Expand Up @@ -71,6 +74,16 @@ class DefaultVocab : public IVocab {
return decode(sentence, /*ignoreEOS=*/true);
}

// SentencePiece with byte-fallback may generate control symbols with output sampling.
// Let's mark them as special and suppress them later on output. This is generally safe
// for UTF-8 since control chars are not used as partial bytes in multi-byte sequences.
// They only appear in single-byte chars as themselves and this is what we suppress.
void addSpecialWords(std::vector<Word>& special) const override {
special.reserve(special.size() + controlChars_.size());
for(auto c : controlChars_)
special.push_back(c);
}

virtual std::string type() const override { return "DefaultVocab"; }

virtual Word getEosId() const override { return eosId_; }
Expand Down Expand Up @@ -130,6 +143,8 @@ class DefaultVocab : public IVocab {
}
ABORT_IF(id2str_.empty(), "Empty vocabulary: ", vocabPath);

populateControlChars();

addRequiredVocabulary(vocabPath, isJson);

return std::max(id2str_.size(), maxSize);
Expand Down Expand Up @@ -172,6 +187,17 @@ class DefaultVocab : public IVocab {

private:

// Creates the first 32 control characters as done in byte-fallback and checks if they exist in the vocab.
// This makes sure that we do not waste computational effort on suppression if they don't actually appear.
void populateControlChars() {
for(int i = 0; i < 32; ++i) {
std::string bytePiece = fmt::format("<0x{:02X}>", i); // 0 becomes <0x00>, 10 becomes <0x0A>, note uppercase A and lowercase x
auto id = (*this)[bytePiece];
if(id != unkId_)
controlChars_.push_back(id);
}
}

virtual void addRequiredVocabulary(const std::string& vocabPath, bool isJson) {
// look up ids for </s> and <unk>, which are required
// The name backCompatStr is alternatively accepted for Yaml vocabs if id
Expand Down

0 comments on commit 4408e88

Please sign in to comment.