From 96ed0baf5a8a0c39508fb53ff13e0ab309fd7cea Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 02:24:37 +0000 Subject: [PATCH 01/14] chmod -x --- src/common/config_parser.cpp | 0 src/common/definitions.h | 0 src/common/file_stream.cpp | 0 src/common/io_item.h | 0 src/common/options.h | 0 src/common/utils.cpp | 0 src/data/batch.h | 0 src/data/corpus.cpp | 0 src/data/corpus_base.cpp | 0 src/data/factored_vocab.cpp | 0 src/data/factored_vocab.h | 0 src/data/vocab.cpp | 0 src/data/vocab.h | 0 src/data/vocab_base.h | 0 src/functional/operators.h | 0 src/functional/shape.h | 0 src/functional/tensor.h | 0 src/functional/tmp.h | 0 src/graph/auto_tuner.h | 0 src/graph/expression_operators.h | 0 src/graph/node.cpp | 0 src/graph/node_initializers.cpp | 0 src/graph/node_initializers.h | 0 src/layers/constructors.h | 0 src/layers/factory.h | 0 src/layers/generic.cpp | 0 src/layers/generic.h | 0 src/layers/guided_alignment.h | 0 src/layers/loss.cpp | 0 src/layers/loss.h | 0 src/microsoft/quicksand.cpp | 0 src/microsoft/quicksand.h | 0 src/models/amun.h | 0 src/models/bert.h | 0 src/models/char_s2s.h | 0 src/models/classifier.h | 0 src/models/costs.h | 0 src/models/encoder_decoder.cpp | 0 src/models/encoder_decoder.h | 0 src/models/model_factory.cpp | 0 src/models/model_factory.h | 0 src/models/nematus.h | 0 src/models/s2s.h | 0 src/models/states.h | 0 src/models/transformer.h | 0 src/models/transformer_factory.h | 0 src/models/transformer_stub.cpp | 0 src/optimizers/exponential_smoothing.cpp | 0 src/optimizers/exponential_smoothing.h | 0 src/rnn/attention.h | 0 src/rnn/cells.h | 0 src/rnn/constructors.h | 0 src/tensors/rand.cpp | 0 src/tensors/tensor.cpp | 0 src/tensors/tensor.h | 0 src/training/graph_group_sync.cpp | 0 src/training/graph_group_sync.h | 0 src/training/scheduler.h | 0 src/training/validator.h | 0 src/translator/beam_search.cpp | 0 src/translator/output_printer.h | 0 src/translator/scorers.h | 0 src/translator/translator.h | 0 63 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 src/common/config_parser.cpp mode change 100755 => 100644 src/common/definitions.h mode change 100755 => 100644 src/common/file_stream.cpp mode change 100755 => 100644 src/common/io_item.h mode change 100755 => 100644 src/common/options.h mode change 100755 => 100644 src/common/utils.cpp mode change 100755 => 100644 src/data/batch.h mode change 100755 => 100644 src/data/corpus.cpp mode change 100755 => 100644 src/data/corpus_base.cpp mode change 100755 => 100644 src/data/factored_vocab.cpp mode change 100755 => 100644 src/data/factored_vocab.h mode change 100755 => 100644 src/data/vocab.cpp mode change 100755 => 100644 src/data/vocab.h mode change 100755 => 100644 src/data/vocab_base.h mode change 100755 => 100644 src/functional/operators.h mode change 100755 => 100644 src/functional/shape.h mode change 100755 => 100644 src/functional/tensor.h mode change 100755 => 100644 src/functional/tmp.h mode change 100755 => 100644 src/graph/auto_tuner.h mode change 100755 => 100644 src/graph/expression_operators.h mode change 100755 => 100644 src/graph/node.cpp mode change 100755 => 100644 src/graph/node_initializers.cpp mode change 100755 => 100644 src/graph/node_initializers.h mode change 100755 => 100644 src/layers/constructors.h mode change 100755 => 100644 src/layers/factory.h mode change 100755 => 100644 src/layers/generic.cpp mode change 100755 => 100644 src/layers/generic.h mode change 100755 => 100644 src/layers/guided_alignment.h mode change 100755 => 100644 src/layers/loss.cpp mode change 100755 => 100644 src/layers/loss.h mode change 100755 => 100644 src/microsoft/quicksand.cpp mode change 100755 => 100644 src/microsoft/quicksand.h mode change 100755 => 100644 src/models/amun.h mode change 100755 => 100644 src/models/bert.h mode change 100755 => 100644 src/models/char_s2s.h mode change 100755 => 100644 src/models/classifier.h mode change 100755 => 100644 src/models/costs.h mode change 100755 => 100644 src/models/encoder_decoder.cpp mode change 100755 => 100644 src/models/encoder_decoder.h mode change 100755 => 100644 src/models/model_factory.cpp mode change 100755 => 100644 src/models/model_factory.h mode change 100755 => 100644 src/models/nematus.h mode change 100755 => 100644 src/models/s2s.h mode change 100755 => 100644 src/models/states.h mode change 100755 => 100644 src/models/transformer.h mode change 100755 => 100644 src/models/transformer_factory.h mode change 100755 => 100644 src/models/transformer_stub.cpp mode change 100755 => 100644 src/optimizers/exponential_smoothing.cpp mode change 100755 => 100644 src/optimizers/exponential_smoothing.h mode change 100755 => 100644 src/rnn/attention.h mode change 100755 => 100644 src/rnn/cells.h mode change 100755 => 100644 src/rnn/constructors.h mode change 100755 => 100644 src/tensors/rand.cpp mode change 100755 => 100644 src/tensors/tensor.cpp mode change 100755 => 100644 src/tensors/tensor.h mode change 100755 => 100644 src/training/graph_group_sync.cpp mode change 100755 => 100644 src/training/graph_group_sync.h mode change 100755 => 100644 src/training/scheduler.h mode change 100755 => 100644 src/training/validator.h mode change 100755 => 100644 src/translator/beam_search.cpp mode change 100755 => 100644 src/translator/output_printer.h mode change 100755 => 100644 src/translator/scorers.h mode change 100755 => 100644 src/translator/translator.h diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp old mode 100755 new mode 100644 diff --git a/src/common/definitions.h b/src/common/definitions.h old mode 100755 new mode 100644 diff --git a/src/common/file_stream.cpp b/src/common/file_stream.cpp old mode 100755 new mode 100644 diff --git a/src/common/io_item.h b/src/common/io_item.h old mode 100755 new mode 100644 diff --git a/src/common/options.h b/src/common/options.h old mode 100755 new mode 100644 diff --git a/src/common/utils.cpp b/src/common/utils.cpp old mode 100755 new mode 100644 diff --git a/src/data/batch.h b/src/data/batch.h old mode 100755 new mode 100644 diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp old mode 100755 new mode 100644 diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp old mode 100755 new mode 100644 diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp old mode 100755 new mode 100644 diff --git a/src/data/factored_vocab.h b/src/data/factored_vocab.h old mode 100755 new mode 100644 diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp old mode 100755 new mode 100644 diff --git a/src/data/vocab.h b/src/data/vocab.h old mode 100755 new mode 100644 diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h old mode 100755 new mode 100644 diff --git a/src/functional/operators.h b/src/functional/operators.h old mode 100755 new mode 100644 diff --git a/src/functional/shape.h b/src/functional/shape.h old mode 100755 new mode 100644 diff --git a/src/functional/tensor.h b/src/functional/tensor.h old mode 100755 new mode 100644 diff --git a/src/functional/tmp.h b/src/functional/tmp.h old mode 100755 new mode 100644 diff --git a/src/graph/auto_tuner.h b/src/graph/auto_tuner.h old mode 100755 new mode 100644 diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h old mode 100755 new mode 100644 diff --git a/src/graph/node.cpp b/src/graph/node.cpp old mode 100755 new mode 100644 diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp old mode 100755 new mode 100644 diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h old mode 100755 new mode 100644 diff --git a/src/layers/constructors.h b/src/layers/constructors.h old mode 100755 new mode 100644 diff --git a/src/layers/factory.h b/src/layers/factory.h old mode 100755 new mode 100644 diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp old mode 100755 new mode 100644 diff --git a/src/layers/generic.h b/src/layers/generic.h old mode 100755 new mode 100644 diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h old mode 100755 new mode 100644 diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp old mode 100755 new mode 100644 diff --git a/src/layers/loss.h b/src/layers/loss.h old mode 100755 new mode 100644 diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp old mode 100755 new mode 100644 diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h old mode 100755 new mode 100644 diff --git a/src/models/amun.h b/src/models/amun.h old mode 100755 new mode 100644 diff --git a/src/models/bert.h b/src/models/bert.h old mode 100755 new mode 100644 diff --git a/src/models/char_s2s.h b/src/models/char_s2s.h old mode 100755 new mode 100644 diff --git a/src/models/classifier.h b/src/models/classifier.h old mode 100755 new mode 100644 diff --git a/src/models/costs.h b/src/models/costs.h old mode 100755 new mode 100644 diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp old mode 100755 new mode 100644 diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h old mode 100755 new mode 100644 diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp old mode 100755 new mode 100644 diff --git a/src/models/model_factory.h b/src/models/model_factory.h old mode 100755 new mode 100644 diff --git a/src/models/nematus.h b/src/models/nematus.h old mode 100755 new mode 100644 diff --git a/src/models/s2s.h b/src/models/s2s.h old mode 100755 new mode 100644 diff --git a/src/models/states.h b/src/models/states.h old mode 100755 new mode 100644 diff --git a/src/models/transformer.h b/src/models/transformer.h old mode 100755 new mode 100644 diff --git a/src/models/transformer_factory.h b/src/models/transformer_factory.h old mode 100755 new mode 100644 diff --git a/src/models/transformer_stub.cpp b/src/models/transformer_stub.cpp old mode 100755 new mode 100644 diff --git a/src/optimizers/exponential_smoothing.cpp b/src/optimizers/exponential_smoothing.cpp old mode 100755 new mode 100644 diff --git a/src/optimizers/exponential_smoothing.h b/src/optimizers/exponential_smoothing.h old mode 100755 new mode 100644 diff --git a/src/rnn/attention.h b/src/rnn/attention.h old mode 100755 new mode 100644 diff --git a/src/rnn/cells.h b/src/rnn/cells.h old mode 100755 new mode 100644 diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h old mode 100755 new mode 100644 diff --git a/src/tensors/rand.cpp b/src/tensors/rand.cpp old mode 100755 new mode 100644 diff --git a/src/tensors/tensor.cpp b/src/tensors/tensor.cpp old mode 100755 new mode 100644 diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h old mode 100755 new mode 100644 diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp old mode 100755 new mode 100644 diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h old mode 100755 new mode 100644 diff --git a/src/training/scheduler.h b/src/training/scheduler.h old mode 100755 new mode 100644 diff --git a/src/training/validator.h b/src/training/validator.h old mode 100755 new mode 100644 diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp old mode 100755 new mode 100644 diff --git a/src/translator/output_printer.h b/src/translator/output_printer.h old mode 100755 new mode 100644 diff --git a/src/translator/scorers.h b/src/translator/scorers.h old mode 100755 new mode 100644 diff --git a/src/translator/translator.h b/src/translator/translator.h old mode 100755 new mode 100644 From 0d8372c590b290dc9b10200ea068c407954229a8 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 02:46:19 +0000 Subject: [PATCH 02/14] move embedding to its own file --- src/CMakeLists.txt | 1 + src/layers/constructors.h | 1 + src/layers/embedding.cpp | 175 ++++++++++++++++++++++++++++++++++++++ src/layers/embedding.h | 148 ++++++++++++++++++++++++++++++++ src/layers/generic.cpp | 168 ------------------------------------ src/layers/generic.h | 140 ------------------------------ 6 files changed, 325 insertions(+), 308 deletions(-) create mode 100644 src/layers/embedding.cpp create mode 100644 src/layers/embedding.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b47663b4e..e71a7b711 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -72,6 +72,7 @@ set(MARIAN_SOURCES layers/loss.cpp layers/weight.cpp layers/lsh.cpp + layers/embedding.cpp rnn/cells.cpp rnn/attention.cpp diff --git a/src/layers/constructors.h b/src/layers/constructors.h index a2c38197f..7771919d0 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -2,6 +2,7 @@ #include "layers/factory.h" #include "layers/generic.h" +#include "layers/embedding.h" namespace marian { namespace mlp { diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp new file mode 100644 index 000000000..488fbb8be --- /dev/null +++ b/src/layers/embedding.cpp @@ -0,0 +1,175 @@ +#include "embedding.h" +#include "data/factored_vocab.h" + +namespace marian { + +Embedding::Embedding(Ptr graph, Ptr options) +: LayerBase(graph, options), inference_(opt("inference")) { +std::string name = opt("prefix"); +int dimVoc = opt("dimVocab"); +int dimEmb = opt("dimEmb"); + +bool fixed = opt("fixed", false); + +factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); +if (factoredVocab_) { + dimVoc = (int)factoredVocab_->factorVocabSize(); + LOG_ONCE(info, "[embedding] Factored embeddings enabled"); +} + +// Embedding layer initialization should depend only on embedding size, hence fanIn=false +auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + +if (options_->has("embFile")) { + std::string file = opt("embFile"); + if (!file.empty()) { + bool norm = opt("normalization", false); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + } +} + +E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); +} + +// helper to embed a sequence of words (given as indices) via factored embeddings +Expr Embedding::multiRows(const Words& data, float dropProb) const { +auto graph = E_->graph(); +auto factoredData = factoredVocab_->csr_rows(data); +// multi-hot factor vectors are represented as a sparse CSR matrix +// [row index = word position index] -> set of factor indices for word at this position +ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); +// the CSR matrix is passed in pieces +auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); +auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); +auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); +// apply dropout +// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. +if(!inference_) + weights = dropout(weights, dropProb); +// perform the product +return csr_dot(factoredData.shape, weights, indices, offsets, E_); +} + +std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { +auto graph = E_->graph(); +int dimBatch = (int)subBatch->batchSize(); +int dimEmb = E_->shape()[-1]; +int dimWidth = (int)subBatch->batchWidth(); + +// factored embeddings: +// - regular: +// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] +// - factored: +// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) +// - each row of M contains the set of factors for one word => we want a CSR matrix +// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] +// - first compute x @ M on the CPU +// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): +// - shape (U, specifically) not actually needed here +// - foreach input x[i] +// - locate row M[i,*] +// - copy through its index values (std::vector) +// - create a matching ones vector (we can keep growing) +// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) +// - CSR matrix product with E +// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) +// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). +// - weighting: +// - core factors' gradients are sums over all words that use the factors; +// - core factors' embeddings move very fast +// - words will need to make up for the move; rare words cannot +// - so, we multiply each factor with 1/refCount +// - core factors get weighed down a lot +// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before +// - but forward pass weighs them down, so that all factors are in a similar numeric range +// - if it is required to be in a different range, the embeddings can still learn that, but more slowly + +auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); +#if 1 +auto batchMask = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->mask())); +#else // @TODO: this is dead code now, get rid of it +// experimental: hide inline-fix source tokens from cross attention +auto batchMask = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); +#endif +// give the graph inputs readable names for debugging and ONNX +batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); + +return std::make_tuple(batchEmbeddings, batchMask); +} + +Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { +if (factoredVocab_) { + Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + //selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout + return selectedEmbs; +} +else + return applyIndices(toWordIndexVector(words), shape); +} + +Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { +ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); +auto embIdxExpr = E_->graph()->indices(embIdx); +embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? +auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] +selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] +// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) +if(!inference_) + selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); +return selectedEmbs; +} + +// standard encoder word embeddings +/*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { +auto options = New( + "dimVocab", opt>("dim-vocabs")[batchIndex_], + "dimEmb", opt("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", + "fixed", embeddingFix_, + "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings +if(options_->hasAndNotEmpty("embedding-vectors")) { + auto embFiles = opt>("embedding-vectors"); + options->set( + "embFile", embFiles[batchIndex_], + "normalization", opt("embedding-normalization")); +} +return New(graph_, options); +} + +// ULR word embeddings +/*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { +return New(graph_, New( + "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", opt("ulr-dim-emb"), + "dimEmb", opt("dim-emb"), + "ulr-dropout", opt("ulr-dropout"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "ulrTrainTransform", opt("ulr-trainable-transformation"), + "ulrQueryFile", opt("ulr-query-vectors"), + "ulrKeysFile", opt("ulr-keys-vectors"))); +} + +// get embedding layer for this encoder or decoder +// This is lazy mostly because the constructors of the consuming objects are not +// guaranteed presently to have access to their graph. +Ptr EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { +if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy + if (embeddingLayers_.size() <= batchIndex_) + embeddingLayers_.resize(batchIndex_ + 1); + if (ulr) + embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR + else + embeddingLayers_[batchIndex_] = createEmbeddingLayer(); +} +return embeddingLayers_[batchIndex_]; +} + +} + diff --git a/src/layers/embedding.h b/src/layers/embedding.h new file mode 100644 index 000000000..91fd0b9d7 --- /dev/null +++ b/src/layers/embedding.h @@ -0,0 +1,148 @@ +#pragma once +#include "marian.h" +#include "generic.h" + +namespace marian { + +// A regular embedding layer. +// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). +// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in +// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. +class Embedding : public LayerBase, public IEmbeddingLayer { + Expr E_; + Ptr factoredVocab_; + Expr multiRows(const Words& data, float dropProb) const; + bool inference_{false}; + +public: + Embedding(Ptr graph, Ptr options); + + std::tuple apply(Ptr subBatch) const override final; + + Expr apply(const Words& words, const Shape& shape) const override final; + + Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final; +}; + +class ULREmbedding : public LayerBase, public IEmbeddingLayer { + std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + bool inference_{false}; + +public: + ULREmbedding(Ptr graph, Ptr options) + : LayerBase(graph, options), inference_(opt("inference")) { + std::string name = "url_embed"; //opt("prefix"); + int dimKeys = opt("dimTgtVoc"); + int dimQueries = opt("dimSrcVoc"); + int dimEmb = opt("dimEmb"); + int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size + bool fixed = opt("fixed", false); + + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); + + std::string queryFile = opt("ulrQueryFile"); + std::string keyFile = opt("ulrKeysFile"); + bool trainTrans = opt("ulrTrainTransform", false); + if (!queryFile.empty() && !keyFile.empty()) { + initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); + name = "ulr_query"; + fixed = true; + auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(query_embed); + // keys embeds + initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); + name = "ulr_keys"; + fixed = true; + auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(key_embed); + // actual trainable embedding + initFunc = inits::glorotUniform(); + name = "ulr_embed"; + fixed = false; + auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim + ulrEmbeddings_.push_back(ulr_embed); + // init trainable src embedding + name = "ulr_src_embed"; + auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(ulr_src_embed); + // ulr transformation matrix + //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only + if (trainTrans) { + initFunc = inits::glorotUniform(); + fixed = false; + } + else + { + initFunc = inits::eye(); // identity matrix + fixed = true; + } + name = "ulr_transform"; + auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(ulrTransform); + + initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only + fixed = true; + name = "ulr_shared"; + auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); + ulrEmbeddings_.push_back(share_embed); + } + } + + std::tuple apply(Ptr subBatch) const override final { + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = uniEmbed->shape()[-1]; + int dimWords = (int)subBatch->batchWidth(); + // D = K.A.QT + // dimm(K) = univ_tok_vocab*uni_embed_size + // dim A = uni_embed_size*uni_embed_size + // dim Q: uni_embed_size * total_merged_vocab_size + // dim D = univ_tok_vocab * total_merged_vocab_size + // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) + // here we need to handle the mini-batch + // extract raws corresponding to Xs in this minibatch from Q + auto embIdx = toWordIndexVector(subBatch->data()); + auto queryEmbeddings = rows(queryEmbed, embIdx); + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); + qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes + auto z = dot(qt, keyEmbed, false, true); // query-key similarity + float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout + if(!inference_) + z = dropout(z, dropProb); + + float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature + // temperature in softmax is to control randomness of predictions + // high temperature Softmax outputs are more close to each other + // low temperatures the softmax become more similar to "hardmax" + auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE + auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); + auto graph = ulrEmbeddings_.front()->graph(); + auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, + inits::fromVector(subBatch->mask())); + if(!inference_) + batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); + return std::make_tuple(batchEmbeddings, batchMask); + } + + Expr apply(const Words& words, const Shape& shape) const override final { + return applyIndices(toWordIndexVector(words), shape); + } + + Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { + embIdx; shape; + ABORT("not implemented"); // @TODO: implement me + } +}; + +} diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index d44f40206..36612577c 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -438,172 +438,4 @@ namespace marian { } } } - - Embedding::Embedding(Ptr graph, Ptr options) - : LayerBase(graph, options), inference_(opt("inference")) { - std::string name = opt("prefix"); - int dimVoc = opt("dimVocab"); - int dimEmb = opt("dimEmb"); - - bool fixed = opt("fixed", false); - - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); - if (factoredVocab_) { - dimVoc = (int)factoredVocab_->factorVocabSize(); - LOG_ONCE(info, "[embedding] Factored embeddings enabled"); - } - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length - - if (options_->has("embFile")) { - std::string file = opt("embFile"); - if (!file.empty()) { - bool norm = opt("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); - } - } - - E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); - } - - // helper to embed a sequence of words (given as indices) via factored embeddings - Expr Embedding::multiRows(const Words& data, float dropProb) const { - auto graph = E_->graph(); - auto factoredData = factoredVocab_->csr_rows(data); - // multi-hot factor vectors are represented as a sparse CSR matrix - // [row index = word position index] -> set of factor indices for word at this position - ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); - // the CSR matrix is passed in pieces - auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); - auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); - auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); - // apply dropout - // We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. - if(!inference_) - weights = dropout(weights, dropProb); - // perform the product - return csr_dot(factoredData.shape, weights, indices, offsets, E_); - } - - std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { - auto graph = E_->graph(); - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; - int dimWidth = (int)subBatch->batchWidth(); - - // factored embeddings: - // - regular: - // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] - // - factored: - // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) - // - each row of M contains the set of factors for one word => we want a CSR matrix - // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] - // - first compute x @ M on the CPU - // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): - // - shape (U, specifically) not actually needed here - // - foreach input x[i] - // - locate row M[i,*] - // - copy through its index values (std::vector) - // - create a matching ones vector (we can keep growing) - // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) - // - CSR matrix product with E - // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) - // - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). - // - weighting: - // - core factors' gradients are sums over all words that use the factors; - // - core factors' embeddings move very fast - // - words will need to make up for the move; rare words cannot - // - so, we multiply each factor with 1/refCount - // - core factors get weighed down a lot - // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before - // - but forward pass weighs them down, so that all factors are in a similar numeric range - // - if it is required to be in a different range, the embeddings can still learn that, but more slowly - - auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); -#if 1 - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it - // experimental: hide inline-fix source tokens from cross attention - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); -#endif - // give the graph inputs readable names for debugging and ONNX - batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); - - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { - if (factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - //selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout - return selectedEmbs; - } - else - return applyIndices(toWordIndexVector(words), shape); - } - - Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { - ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); - auto embIdxExpr = E_->graph()->indices(embIdx); - embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? - auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) - if(!inference_) - selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); - return selectedEmbs; - } - - // standard encoder word embeddings - /*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { - auto options = New( - "dimVocab", opt>("dim-vocabs")[batchIndex_], - "dimEmb", opt("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, - "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings - if(options_->hasAndNotEmpty("embedding-vectors")) { - auto embFiles = opt>("embedding-vectors"); - options->set( - "embFile", embFiles[batchIndex_], - "normalization", opt("embedding-normalization")); - } - return New(graph_, options); - } - - // ULR word embeddings - /*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { - return New(graph_, New( - "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", opt("ulr-dim-emb"), - "dimEmb", opt("dim-emb"), - "ulr-dropout", opt("ulr-dropout"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "ulrTrainTransform", opt("ulr-trainable-transformation"), - "ulrQueryFile", opt("ulr-query-vectors"), - "ulrKeysFile", opt("ulr-keys-vectors"))); - } - - // get embedding layer for this encoder or decoder - // This is lazy mostly because the constructors of the consuming objects are not - // guaranteed presently to have access to their graph. - Ptr EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { - if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy - if (embeddingLayers_.size() <= batchIndex_) - embeddingLayers_.resize(batchIndex_ + 1); - if (ulr) - embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR - else - embeddingLayers_[batchIndex_] = createEmbeddingLayer(); - } - return embeddingLayers_[batchIndex_]; - } } // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index f47bb45e2..9fbaea7c7 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -295,146 +295,6 @@ class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { } // namespace mlp -// A regular embedding layer. -// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). -// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in -// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. -class Embedding : public LayerBase, public IEmbeddingLayer { - Expr E_; - Ptr factoredVocab_; - Expr multiRows(const Words& data, float dropProb) const; - bool inference_{false}; - -public: - Embedding(Ptr graph, Ptr options); - - std::tuple apply(Ptr subBatch) const override final; - - Expr apply(const Words& words, const Shape& shape) const override final; - - Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final; -}; - -class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members - bool inference_{false}; - -public: - ULREmbedding(Ptr graph, Ptr options) - : LayerBase(graph, options), inference_(opt("inference")) { - std::string name = "url_embed"; //opt("prefix"); - int dimKeys = opt("dimTgtVoc"); - int dimQueries = opt("dimSrcVoc"); - int dimEmb = opt("dimEmb"); - int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size - bool fixed = opt("fixed", false); - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); - - std::string queryFile = opt("ulrQueryFile"); - std::string keyFile = opt("ulrKeysFile"); - bool trainTrans = opt("ulrTrainTransform", false); - if (!queryFile.empty() && !keyFile.empty()) { - initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); - name = "ulr_query"; - fixed = true; - auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(query_embed); - // keys embeds - initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); - name = "ulr_keys"; - fixed = true; - auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(key_embed); - // actual trainable embedding - initFunc = inits::glorotUniform(); - name = "ulr_embed"; - fixed = false; - auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim - ulrEmbeddings_.push_back(ulr_embed); - // init trainable src embedding - name = "ulr_src_embed"; - auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulr_src_embed); - // ulr transformation matrix - //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only - if (trainTrans) { - initFunc = inits::glorotUniform(); - fixed = false; - } - else - { - initFunc = inits::eye(); // identity matrix - fixed = true; - } - name = "ulr_transform"; - auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulrTransform); - - initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only - fixed = true; - name = "ulr_shared"; - auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); - ulrEmbeddings_.push_back(share_embed); - } - } - - std::tuple apply(Ptr subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb - auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = uniEmbed->shape()[-1]; - int dimWords = (int)subBatch->batchWidth(); - // D = K.A.QT - // dimm(K) = univ_tok_vocab*uni_embed_size - // dim A = uni_embed_size*uni_embed_size - // dim Q: uni_embed_size * total_merged_vocab_size - // dim D = univ_tok_vocab * total_merged_vocab_size - // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) - // here we need to handle the mini-batch - // extract raws corresponding to Xs in this minibatch from Q - auto embIdx = toWordIndexVector(subBatch->data()); - auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); - qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity - float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout - if(!inference_) - z = dropout(z, dropProb); - - float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature - // temperature in softmax is to control randomness of predictions - // high temperature Softmax outputs are more close to each other - // low temperatures the softmax become more similar to "hardmax" - auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? - auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast - auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); - auto graph = ulrEmbeddings_.front()->graph(); - auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, - inits::fromVector(subBatch->mask())); - if(!inference_) - batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr apply(const Words& words, const Shape& shape) const override final { - return applyIndices(toWordIndexVector(words), shape); - } - - Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { - embIdx; shape; - ABORT("not implemented"); // @TODO: implement me - } -}; // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object From ca47eabca5cb9bb11a3e4fe45afa77501128b4b9 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 03:24:25 +0000 Subject: [PATCH 03/14] move output to its own file --- src/CMakeLists.txt | 1 + src/layers/constructors.h | 1 + src/layers/generic.cpp | 223 ------------------------------------ src/layers/generic.h | 62 ---------- src/layers/output.cpp | 233 ++++++++++++++++++++++++++++++++++++++ src/layers/output.h | 73 ++++++++++++ 6 files changed, 308 insertions(+), 285 deletions(-) create mode 100644 src/layers/output.cpp create mode 100644 src/layers/output.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e71a7b711..170aafd10 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -73,6 +73,7 @@ set(MARIAN_SOURCES layers/weight.cpp layers/lsh.cpp layers/embedding.cpp + layers/output.cpp rnn/cells.cpp rnn/attention.cpp diff --git a/src/layers/constructors.h b/src/layers/constructors.h index 7771919d0..e25449aa4 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -3,6 +3,7 @@ #include "layers/factory.h" #include "layers/generic.h" #include "layers/embedding.h" +#include "layers/output.h" namespace marian { namespace mlp { diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 36612577c..d5baf6da1 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -215,227 +215,4 @@ namespace marian { return Logits(std::move(newLogits), factoredVocab_); } - namespace mlp { - /*private*/ void Output::lazyConstruct(int inputDim) { - // We must construct lazily since we won't know tying nor input dim in constructor. - if (Wt_) - return; - - // this option is only set in the decoder - if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { - auto k = opt>("output-approx-knn")[0]; - auto nbits = opt>("output-approx-knn")[1]; - lsh_ = New(k, nbits); - } - - auto name = options_->get("prefix"); - auto numOutputClasses = options_->get("dim"); - - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); - if (factoredVocab_) { - numOutputClasses = (int)factoredVocab_->factorVocabSize(); - LOG_ONCE(info, "[embedding] Factored outputs enabled"); - } - - if(tiedParam_) { - Wt_ = tiedParam_; - } else { - if (graph_->get(name + "_W")) { // support of legacy models that did not transpose - Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); - isLegacyUntransposedW = true; - } - else // this is the regular case: - Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); - } - - if(hasBias_) - b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); - - /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); - ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); - if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix -#define HARDMAX_HACK -#ifdef HARDMAX_HACK - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number -#endif - auto range = factoredVocab_->getGroupRange(0); - auto lemmaVocabDim = (int)(range.second - range.first); - auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length - lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed - } - } - - Logits Output::applyAsLogits(Expr input) /*override final*/ { - lazyConstruct(input->shape()[-1]); - - auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { - if(b) - return affine(x, W, b, transA, transB); - else - return dot(x, W, transA, transB); - }; - - auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { - if(lsh_) { - ABORT_IF( transA, "Transposed query not supported for LSH"); - ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); - return lsh_->apply(x, W, b); // knows how to deal with undefined bias - } else { - return affineOrDot(x, W, b, transA, transB); - } - }; - - if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed - cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); - if(hasBias_) - cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); - } - - if (factoredVocab_) { - auto graph = input->graph(); - - // project each factor separately - auto numGroups = factoredVocab_->getNumGroups(); - std::vector> allLogits(numGroups, nullptr); // (note: null entries for absent factors) - Expr input1 = input; // [B... x D] - Expr Plemma = nullptr; // used for lemmaDimEmb=-1 - Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 - for (size_t g = 0; g < numGroups; g++) { - auto range = factoredVocab_->getGroupRange(g); - if (g > 0 && range.first == range.second) // empty entry - continue; - ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g); - // slice this group's section out of W_ - Expr factorWt, factorB; - if (g == 0 && shortlist_) { - factorWt = cachedShortWt_; - factorB = cachedShortb_; - } - else { - factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); - if(hasBias_) - factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); - } - /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); - if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) - LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); - // this mimics one transformer layer - // - attention over two inputs: - // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas. - // - input = hidden state FF(h_enc+h_dec) - // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention) - // - multi-head to allow for multiple conditions to be modeled - // - add & norm, for gradient flow and scaling - // - FF layer --this is expensive; it is per-factor - // multi-head attention - int inputDim = input->shape()[-1]; - int heads = 8; - auto name = options_->get("prefix") + "_factor" + std::to_string(g); - auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform()); - auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform()); - auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform()); - auto toMultiHead = [&](Expr x, int heads) { - const auto& shape = x->shape(); - int inputDim = shape[-1]; - int otherDim = shape.elements() / inputDim; - ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads); - return reshape(x, { otherDim, heads, 1, inputDim / heads }); - }; - input1 = inputLemma; - auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query - auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax. - auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values - auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other - auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] - auto sm = sigmoid(zm); // [B... x H x 1] - auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] - auto r = reshape(rm, input->shape()); // [B... x D] - // add & norm - input1 = r + input1; - input1 = layerNorm(input1, name + "_att"); - // FF layer - auto ffnDropProb = 0.1f; // @TODO: get as a parameter - auto ffnDim = inputDim * 2; // @TODO: get as a parameter - auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb); - f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); - // add & norm - input1 = f + input1; - input1 = layerNorm(input1, name + "_ffn"); - } - // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix - Expr factorLogits; - if(g == 0) - factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - else - factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - - // optionally add lemma-dependent bias - if (Plemma) { // [B... x U0] - int lemmaVocabDim = Plemma->shape()[-1]; - int factorVocabDim = factorLogits->shape()[-1]; - auto name = options_->get("prefix"); - Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma - auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] - factorLogits = factorLogits + b; - } - allLogits[g] = New(factorLogits, nullptr); - // optionally add a soft embedding of lemma back to create some lemma dependency - // @TODO: if this works, move it into lazyConstruct - if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure - LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); - // get expected lemma embedding vector - auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set - auto factorSoftmax = exp(factorLogSoftmax); - inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max - LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); - // get max-lemma embedding vector - auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set - auto factorHardmax = eq(factorLogits, maxVal); - inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias - ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); - LOG_ONCE(info, "[embedding] using lemma-dependent bias"); - auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) - auto z = /*stopGradient*/(factorLogSoftmax); - Plemma = exp(z); // [B... x U] - } - else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix - LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); - // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE - auto factorLogSoftmax = logsoftmax(factorLogits); - auto factorSoftmax = exp(factorLogSoftmax); -#ifdef HARDMAX_HACK - bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) - if (hardmax) { - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; - LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); - auto maxVal = max(factorSoftmax, -1); - factorSoftmax = eq(factorSoftmax, maxVal); - } -#endif - // re-embedding lookup, soft-indexed by softmax - if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix - cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); - auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] - // project it back to regular hidden dim - int inputDim = input1->shape()[-1]; - auto name = options_->get("prefix"); - // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1 - Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension - auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] - // augment the original hidden vector with this additional information - input1 = input1 + f; - } - } - return Logits(std::move(allLogits), factoredVocab_); - } else if (shortlist_) { - return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } else { - return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } - } - } } // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index 9fbaea7c7..6d953fd8c 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -233,68 +233,6 @@ class Dense : public LayerBase, public IUnaryLayer { } // namespace mlp -class LSH; - -namespace mlp { - -class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { -private: - // parameters held by this layer - Expr Wt_; // weight matrix is stored transposed for efficiency - Expr b_; - Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] - bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form - bool hasBias_{true}; - - Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) - Expr cachedShortb_; // these match the current value of shortlist_ - Expr cachedShortLemmaEt_; - Ptr factoredVocab_; - - // optional parameters set/updated after construction - Expr tiedParam_; - Ptr shortlist_; - Ptr lsh_; - - void lazyConstruct(int inputDim); -public: - Output(Ptr graph, Ptr options) - : LayerBase(graph, options), - hasBias_{!options->get("output-omit-bias", false)} { - clear(); - } - - void tieTransposed(Expr tied) { - if (Wt_) - ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created"); - else - tiedParam_ = tied; - } - - void setShortlist(Ptr shortlist) override final { - if (shortlist_) - ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); - else { - ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??"); - shortlist_ = shortlist; - } - // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() - } - - // this is expected to be called in sync with graph->clear(), which invalidates - // cachedShortWt_ etc. in the graph's short-term cache - void clear() override final { - shortlist_ = nullptr; - cachedShortWt_ = nullptr; - cachedShortb_ = nullptr; - cachedShortLemmaEt_ = nullptr; - } - - Logits applyAsLogits(Expr input) override final; -}; - -} // namespace mlp - // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object diff --git a/src/layers/output.cpp b/src/layers/output.cpp new file mode 100644 index 000000000..bf8fa5886 --- /dev/null +++ b/src/layers/output.cpp @@ -0,0 +1,233 @@ +#include "output.h" +#include "data/factored_vocab.h" +#include "common/timer.h" +#include "layers/lsh.h" +#include "layers/loss.h" + +namespace marian { +namespace mlp { + +/*private*/ void Output::lazyConstruct(int inputDim) { + // We must construct lazily since we won't know tying nor input dim in constructor. + if (Wt_) + return; + + // this option is only set in the decoder + if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { + auto k = opt>("output-approx-knn")[0]; + auto nbits = opt>("output-approx-knn")[1]; + lsh_ = New(k, nbits); + } + + auto name = options_->get("prefix"); + auto numOutputClasses = options_->get("dim"); + + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); + if (factoredVocab_) { + numOutputClasses = (int)factoredVocab_->factorVocabSize(); + LOG_ONCE(info, "[embedding] Factored outputs enabled"); + } + + if(tiedParam_) { + Wt_ = tiedParam_; + } else { + if (graph_->get(name + "_W")) { // support of legacy models that did not transpose + Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); + isLegacyUntransposedW = true; + } + else // this is the regular case: + Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); + } + + if(hasBias_) + b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); + + /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); + ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); + if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix +#define HARDMAX_HACK +#ifdef HARDMAX_HACK + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number +#endif + auto range = factoredVocab_->getGroupRange(0); + auto lemmaVocabDim = (int)(range.second - range.first); + auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length + lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed + } +} + +Logits Output::applyAsLogits(Expr input) /*override final*/ { + lazyConstruct(input->shape()[-1]); + + auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + if(b) + return affine(x, W, b, transA, transB); + else + return dot(x, W, transA, transB); + }; + + auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { + if(lsh_) { + ABORT_IF( transA, "Transposed query not supported for LSH"); + ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); + return lsh_->apply(x, W, b); // knows how to deal with undefined bias + } else { + return affineOrDot(x, W, b, transA, transB); + } + }; + + if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed + cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); + if(hasBias_) + cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); + } + + if (factoredVocab_) { + auto graph = input->graph(); + + // project each factor separately + auto numGroups = factoredVocab_->getNumGroups(); + std::vector> allLogits(numGroups, nullptr); // (note: null entries for absent factors) + Expr input1 = input; // [B... x D] + Expr Plemma = nullptr; // used for lemmaDimEmb=-1 + Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 + for (size_t g = 0; g < numGroups; g++) { + auto range = factoredVocab_->getGroupRange(g); + if (g > 0 && range.first == range.second) // empty entry + continue; + ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g); + // slice this group's section out of W_ + Expr factorWt, factorB; + if (g == 0 && shortlist_) { + factorWt = cachedShortWt_; + factorB = cachedShortb_; + } + else { + factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); + if(hasBias_) + factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); + } + /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); + if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) + LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); + // this mimics one transformer layer + // - attention over two inputs: + // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas. + // - input = hidden state FF(h_enc+h_dec) + // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention) + // - multi-head to allow for multiple conditions to be modeled + // - add & norm, for gradient flow and scaling + // - FF layer --this is expensive; it is per-factor + // multi-head attention + int inputDim = input->shape()[-1]; + int heads = 8; + auto name = options_->get("prefix") + "_factor" + std::to_string(g); + auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform()); + auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform()); + auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform()); + auto toMultiHead = [&](Expr x, int heads) { + const auto& shape = x->shape(); + int inputDim = shape[-1]; + int otherDim = shape.elements() / inputDim; + ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads); + return reshape(x, { otherDim, heads, 1, inputDim / heads }); + }; + input1 = inputLemma; + auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query + auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax. + auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values + auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other + auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] + auto sm = sigmoid(zm); // [B... x H x 1] + auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] + auto r = reshape(rm, input->shape()); // [B... x D] + // add & norm + input1 = r + input1; + input1 = layerNorm(input1, name + "_att"); + // FF layer + auto ffnDropProb = 0.1f; // @TODO: get as a parameter + auto ffnDim = inputDim * 2; // @TODO: get as a parameter + auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb); + f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); + // add & norm + input1 = f + input1; + input1 = layerNorm(input1, name + "_ffn"); + } + // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix + Expr factorLogits; + if(g == 0) + factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + else + factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + + // optionally add lemma-dependent bias + if (Plemma) { // [B... x U0] + int lemmaVocabDim = Plemma->shape()[-1]; + int factorVocabDim = factorLogits->shape()[-1]; + auto name = options_->get("prefix"); + Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma + auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] + factorLogits = factorLogits + b; + } + allLogits[g] = New(factorLogits, nullptr); + // optionally add a soft embedding of lemma back to create some lemma dependency + // @TODO: if this works, move it into lazyConstruct + if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure + LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); + // get expected lemma embedding vector + auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorSoftmax = exp(factorLogSoftmax); + inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } + else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max + LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); + // get max-lemma embedding vector + auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorHardmax = eq(factorLogits, maxVal); + inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } + else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias + ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); + LOG_ONCE(info, "[embedding] using lemma-dependent bias"); + auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) + auto z = /*stopGradient*/(factorLogSoftmax); + Plemma = exp(z); // [B... x U] + } + else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix + LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); + // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE + auto factorLogSoftmax = logsoftmax(factorLogits); + auto factorSoftmax = exp(factorLogSoftmax); +#ifdef HARDMAX_HACK + bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) + if (hardmax) { + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; + LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); + auto maxVal = max(factorSoftmax, -1); + factorSoftmax = eq(factorSoftmax, maxVal); + } +#endif + // re-embedding lookup, soft-indexed by softmax + if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix + cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); + auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] + // project it back to regular hidden dim + int inputDim = input1->shape()[-1]; + auto name = options_->get("prefix"); + // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1 + Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension + auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] + // augment the original hidden vector with this additional information + input1 = input1 + f; + } + } + return Logits(std::move(allLogits), factoredVocab_); + } else if (shortlist_) { + return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } else { + return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } +} + +} +} \ No newline at end of file diff --git a/src/layers/output.h b/src/layers/output.h new file mode 100644 index 000000000..d091556a4 --- /dev/null +++ b/src/layers/output.h @@ -0,0 +1,73 @@ +#pragma once + +#include "marian.h" +#include "generic.h" +#include "data/shortlist.h" +#include "layers/factory.h" + +namespace marian { +class LSH; + +namespace mlp { + +class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { +private: + // parameters held by this layer + Expr Wt_; // weight matrix is stored transposed for efficiency + Expr b_; + Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] + bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form + bool hasBias_{true}; + + Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) + Expr cachedShortb_; // these match the current value of shortlist_ + Expr cachedShortLemmaEt_; + Ptr factoredVocab_; + + // optional parameters set/updated after construction + Expr tiedParam_; + Ptr shortlist_; + Ptr lsh_; + + void lazyConstruct(int inputDim); +public: + Output(Ptr graph, Ptr options) + : LayerBase(graph, options), + hasBias_{!options->get("output-omit-bias", false)} { + clear(); + } + + void tieTransposed(Expr tied) { + if (Wt_) + ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created"); + else + tiedParam_ = tied; + } + + void setShortlist(Ptr shortlist) override final { + if (shortlist_) + ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); + else { + ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??"); + shortlist_ = shortlist; + } + // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() + } + + // this is expected to be called in sync with graph->clear(), which invalidates + // cachedShortWt_ etc. in the graph's short-term cache + void clear() override final { + shortlist_ = nullptr; + cachedShortWt_ = nullptr; + cachedShortb_ = nullptr; + cachedShortLemmaEt_ = nullptr; + } + + Logits applyAsLogits(Expr input) override final; +}; + +} // namespace mlp + +} + + From f7266886f0d478d802a88f2ce82b71f27c37bf07 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 04:18:19 +0000 Subject: [PATCH 04/14] move logits to its own file --- src/layers/embedding.h | 2 ++ src/layers/generic.h | 66 ------------------------------------------ src/layers/loss.h | 2 +- src/layers/output.h | 1 + src/models/states.h | 2 +- 5 files changed, 5 insertions(+), 68 deletions(-) diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 91fd0b9d7..b7898c76e 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -4,6 +4,8 @@ namespace marian { +class FactoredVocab; + // A regular embedding layer. // Note that this also applies dropout if the option is passed (pass 0 when in inference mode). // It is best to not use Embedding directly, but rather via getEmbeddingLayer() in diff --git a/src/layers/generic.h b/src/layers/generic.h index 6d953fd8c..eddd597e8 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -97,72 +97,6 @@ class EncoderDecoderLayerBase : public LayerBase { Ptr getEmbeddingLayer(bool ulr = false) const; }; -class FactoredVocab; - -// To support factors, any output projection (that is followed by a softmax) must -// retain multiple outputs, one for each factor. Such layer returns not a single Expr, -// but a Logits object that contains multiple. -// This allows to compute softmax values in a factored manner, where we never create -// a fully expanded list of all factor combinations. -class RationalLoss; -class Logits { -public: - Logits() {} - explicit Logits(Ptr logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector>&& logits, Ptr embeddingFactorMapping) // factored-output constructor - : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr shortlist = nullptr, const std::vector& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function& lossFn) const; - Logits applyUnaryFunction(const std::function& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function& f1, const std::function& fother) const; // clone this but apply f1 to first and fother to to all other values - - struct MaskedFactorIndices { - std::vector indices; // factor index, or 0 if masked - std::vector masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ -private: - // helper functions - Ptr graph() const; - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template Expr constant(const std::vector& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector getFactorMasks(size_t factorGroup, const std::vector& indices) const; -private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector> logits_; // [group id][B..., num factors in group] - Ptr factoredVocab_; -}; - -// Unary function that returns a Logits object -// Also implements IUnaryLayer, since Logits can be cast to Expr. -// This interface is implemented by all layers that are of the form of a unary function -// that returns multiple logits, to support factors. -struct IUnaryLogitLayer : public IUnaryLayer { - virtual Logits applyAsLogits(Expr) = 0; - virtual Logits applyAsLogits(const std::vector& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub - return applyAsLogits(es.front()); - } - virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } - virtual Expr apply(const std::vector& es) override { return applyAsLogits(es).getLogits(); } -}; - namespace mlp { class Dense : public LayerBase, public IUnaryLayer { diff --git a/src/layers/loss.h b/src/layers/loss.h index d7bc19e4a..ba93cdac7 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,7 +1,7 @@ #pragma once #include "graph/expression_operators.h" -#include "layers/generic.h" // for Logits (Frank's factor hack) +#include "layers/logits.h" // for Logits (Frank's factor hack) #include "data/types.h" namespace marian { diff --git a/src/layers/output.h b/src/layers/output.h index d091556a4..92e7eb25e 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -2,6 +2,7 @@ #include "marian.h" #include "generic.h" +#include "logits.h" #include "data/shortlist.h" #include "layers/factory.h" diff --git a/src/models/states.h b/src/models/states.h index c2f9ee05a..cfb6fd1b8 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once #include "marian.h" -#include "layers/generic.h" // @HACK: for factored embeddings only so far +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian { From 42406cc715a301db2c3087d0def6ff70332d1074 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 04:23:35 +0000 Subject: [PATCH 05/14] move logits to its own file --- src/CMakeLists.txt | 1 + src/layers/generic.cpp | 71 ------------------------------------------ 2 files changed, 1 insertion(+), 71 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 170aafd10..a3155d68e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -74,6 +74,7 @@ set(MARIAN_SOURCES layers/lsh.cpp layers/embedding.cpp layers/output.cpp + layers/logits.cpp rnn/cells.cpp rnn/attention.cpp diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index d5baf6da1..237336226 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -143,76 +143,5 @@ namespace marian { #endif } - void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { - bool isValid = FactoredVocab::isFactorValid(factorIndex); - indices.push_back(isValid ? (WordIndex)factorIndex : 0); - masks.push_back((float)isValid); - } - - std::vector Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return {MaskedFactorIndices(words)}; - } - auto numGroups = factoredVocab_->getNumGroups(); - std::vector res(numGroups); - for (size_t g = 0; g < numGroups; g++) { - auto& resg = res[g]; - resg.reserve(words.size()); - for (const auto& word : words) - resg.push_back(factoredVocab_->getFactor(word, g)); - } - return res; - } - - //// use first factor of each word to determine whether it has a specific factor - //std::vector Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 - // std::vector res; - // res.reserve(words.size()); - // for (const auto& word : words) { - // auto lemma = factoredVocab_->getFactor(word, 0); - // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - // } - // return res; - //} - - // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor - // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas - std::vector Logits::getFactorMasks(size_t factorGroup, const std::vector& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 - size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); - std::vector res; - res.reserve(n); - // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab - for (size_t i = 0; i < n; i++) { - auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); - res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - } - return res; - } - - Logits Logits::applyUnaryFunction(const std::function& f) const { // clone this but apply f to all loss values - std::vector> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New(f(l->loss()), l->count())); - return Logits(std::move(newLogits), factoredVocab_); - } - - Logits Logits::applyUnaryFunctions(const std::function& f1, const std::function& fother) const { - std::vector> newLogits; - bool first = true; - for (const auto& l : logits_) { - newLogits.emplace_back(New((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others - first = false; - } - return Logits(std::move(newLogits), factoredVocab_); - } - - // @TODO: code dup with above; we can merge it into applyToRationalLoss() - Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ - std::vector> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New(l->loss(), count)); - return Logits(std::move(newLogits), factoredVocab_); - } } // namespace marian From b88c3fcb71ad06326f60cc930fd283d77cd41a6c Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 04:35:00 +0000 Subject: [PATCH 06/14] costs.cpp --- src/CMakeLists.txt | 1 + src/common/timer.cpp | 0 src/models/costs.cpp | 16 ++++++++++++++++ src/models/costs.h | 7 +------ 4 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 src/common/timer.cpp create mode 100644 src/models/costs.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a3155d68e..c59d8bf61 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,6 +87,7 @@ set(MARIAN_SOURCES models/model_factory.cpp models/encoder_decoder.cpp models/transformer_stub.cpp + models/costs.cpp rescorer/score_collector.cpp embedder/vector_collector.cpp diff --git a/src/common/timer.cpp b/src/common/timer.cpp new file mode 100644 index 000000000..e69de29bb diff --git a/src/models/costs.cpp b/src/models/costs.cpp new file mode 100644 index 000000000..5105f5904 --- /dev/null +++ b/src/models/costs.cpp @@ -0,0 +1,16 @@ +#include "costs.h" + +namespace marian { +namespace models { + +Ptr LogSoftmaxStep::apply(Ptr state) { +// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) +state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); +// @TODO: This is becoming more and more opaque ^^. Can we simplify this? +return state; +} + + +} +} + diff --git a/src/models/costs.h b/src/models/costs.h index 3d8f2c515..2d34c53a9 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -282,12 +282,7 @@ class ILogProbStep { class LogSoftmaxStep : public ILogProbStep { public: virtual ~LogSoftmaxStep() {} - virtual Ptr apply(Ptr state) override { - // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) - state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); - // @TODO: This is becoming more and more opaque ^^. Can we simplify this? - return state; - } + virtual Ptr apply(Ptr state) override; }; // Gumbel-max noising for sampling during beam-search From 085c8a7a98a462ea809dc4efcff7ba0f5b49066d Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 07:57:10 +0000 Subject: [PATCH 07/14] more code from .h -> .cpp --- src/layers/generic.cpp | 134 ----------------------------------------- 1 file changed, 134 deletions(-) diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 237336226..02e820e57 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -4,144 +4,10 @@ #include "layers/constructors.h" #include "layers/loss.h" #include "data/factored_vocab.h" -#include "rnn/types.h" // for State::select() #include "models/states.h" // for EncoderState #include "layers/lsh.h" namespace marian { - Logits::Logits(Expr logits) : Logits(New(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) - - Ptr Logits::graph() const { - ABORT_IF(logits_.empty(), "Empty logits object??"); - return logits_.front()->loss()->graph(); - } - - // This function assumes that the object holds one or more factor logits. - // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. - Expr Logits::applyLossFunction(const Words& labels, const std::function& lossFn) const { - LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto firstLogits = logits_.front()->loss(); - ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), - "Labels not matching logits shape ({} != {}, {})??", - labels.size() * firstLogits->shape()[-1], - firstLogits->shape().elements(), - firstLogits->shape()); - - // base case (no factors) - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return lossFn(firstLogits, indices(toWordIndexVector(labels))); - } - - auto numGroups = factoredVocab_->getNumGroups(); - - // split labels into individual factor labels - auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] - - //Expr indices = this->indices(toWordIndexVector(labels)); - // accumulate all CEs for all words that have the factor - // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. - Expr loss; - for (size_t g = 0; g < numGroups; g++) { - if (!logits_[g]) - continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] - if(loss) - factorLoss = cast(factorLoss, loss->value_type()); - factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor - loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] - } - return loss; - } - - // This function assumes this object holds a single factor that represents a rational loss (with count). - //Ptr Logits::getRationalLoss() const { - // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); - // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); - // return logits_.front(); - //} - - // get logits for one factor group - // For groupIndex == 0, the function also requires the shortlist if there is one. - Expr Logits::getFactoredLogits(size_t groupIndex, Ptr shortlist /*= nullptr*/, const std::vector& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] - - // normalize for decoding: - // - all secondary factors: subtract their max - // - lemma: add all maxes of applicable factors - if (groupIndex > 0) { - sel = sel - max(sel, -1); - } - else { - auto numGroups = getNumFactorGroups(); - for (size_t g = 1; g < numGroups; g++) { - auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 - auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector())); - sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 - } - } - - // if selIdx are given, then we must reshuffle accordingly - if (!hypIndices.empty()) // use the same function that shuffles decoder state - sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); - - return sel; - } - - // used for breakDown() only - // Index is flattened - Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - return logits_[groupIndex]->loss()->val(); - } - - // This function assumes that the object holds one or more factor logits, which are summed up - // into output-vocab logits according to the factored model (with correct normalization of factors). - // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. - // @TODO: remove altogether - Expr Logits::getLogits() const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return getFactoredLogits(0); - } - -#ifdef FACTOR_FULL_EXPANSION - // compute normalized factor log probs - std::vector logProbs(logits_.size()); - for (size_t g = 0; g < logits_.size(); g++) - logProbs[g] = logsoftmax(logits_[g]->loss()); - auto y = concatenate(logProbs, /*axis=*/ -1); - - // sum up the unit logits across factors for each target word - auto graph = y->graph(); - auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] - y = dot_csr( - y, // [B x U] - factorMatrix.shape, - graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), - /*transB=*/ true); // -> [B x V] - - // mask out gaps - auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] - y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); - - return y; -#else - ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible -#endif - } } // namespace marian From 7c1cb8462a2adc6540ee78f111a75ed4fbdd66ad Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 07:59:54 +0000 Subject: [PATCH 08/14] add logits.cpp --- src/layers/logits.cpp | 212 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 src/layers/logits.cpp diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp new file mode 100644 index 000000000..cd2203e4d --- /dev/null +++ b/src/layers/logits.cpp @@ -0,0 +1,212 @@ +#include "logits.h" +#include "loss.h" +#include "data/factored_vocab.h" +#include "rnn/types.h" // for State::select() + +namespace marian { + Logits::Logits(Expr logits) : Logits(New(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) + + Ptr Logits::graph() const { + ABORT_IF(logits_.empty(), "Empty logits object??"); + return logits_.front()->loss()->graph(); + } + + // This function assumes that the object holds one or more factor logits. + // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. + Expr Logits::applyLossFunction(const Words& labels, const std::function& lossFn) const { + LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto firstLogits = logits_.front()->loss(); + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + "Labels not matching logits shape ({} != {}, {})??", + labels.size() * firstLogits->shape()[-1], + firstLogits->shape().elements(), + firstLogits->shape()); + + // base case (no factors) + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return lossFn(firstLogits, indices(toWordIndexVector(labels))); + } + + auto numGroups = factoredVocab_->getNumGroups(); + + // split labels into individual factor labels + auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] + + //Expr indices = this->indices(toWordIndexVector(labels)); + // accumulate all CEs for all words that have the factor + // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. + Expr loss; + for (size_t g = 0; g < numGroups; g++) { + if (!logits_[g]) + continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + if(loss) + factorLoss = cast(factorLoss, loss->value_type()); + factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor + loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] + } + return loss; + } + + // This function assumes this object holds a single factor that represents a rational loss (with count). + //Ptr Logits::getRationalLoss() const { + // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); + // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); + // return logits_.front(); + //} + + // get logits for one factor group + // For groupIndex == 0, the function also requires the shortlist if there is one. + Expr Logits::getFactoredLogits(size_t groupIndex, Ptr shortlist /*= nullptr*/, const std::vector& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + + // normalize for decoding: + // - all secondary factors: subtract their max + // - lemma: add all maxes of applicable factors + if (groupIndex > 0) { + sel = sel - max(sel, -1); + } + else { + auto numGroups = getNumFactorGroups(); + for (size_t g = 1; g < numGroups; g++) { + auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 + auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector())); + sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 + } + } + + // if selIdx are given, then we must reshuffle accordingly + if (!hypIndices.empty()) // use the same function that shuffles decoder state + sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + + return sel; + } + + // used for breakDown() only + // Index is flattened + Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + return logits_[groupIndex]->loss()->val(); + } + + // This function assumes that the object holds one or more factor logits, which are summed up + // into output-vocab logits according to the factored model (with correct normalization of factors). + // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. + // @TODO: remove altogether + Expr Logits::getLogits() const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return getFactoredLogits(0); + } + +#ifdef FACTOR_FULL_EXPANSION + // compute normalized factor log probs + std::vector logProbs(logits_.size()); + for (size_t g = 0; g < logits_.size(); g++) + logProbs[g] = logsoftmax(logits_[g]->loss()); + auto y = concatenate(logProbs, /*axis=*/ -1); + + // sum up the unit logits across factors for each target word + auto graph = y->graph(); + auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] + y = dot_csr( + y, // [B x U] + factorMatrix.shape, + graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), + graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), + /*transB=*/ true); // -> [B x V] + + // mask out gaps + auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] + y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); + + return y; +#else + ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible +#endif + } + + void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { + bool isValid = FactoredVocab::isFactorValid(factorIndex); + indices.push_back(isValid ? (WordIndex)factorIndex : 0); + masks.push_back((float)isValid); + } + + std::vector Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return {MaskedFactorIndices(words)}; + } + auto numGroups = factoredVocab_->getNumGroups(); + std::vector res(numGroups); + for (size_t g = 0; g < numGroups; g++) { + auto& resg = res[g]; + resg.reserve(words.size()); + for (const auto& word : words) + resg.push_back(factoredVocab_->getFactor(word, g)); + } + return res; + } + + //// use first factor of each word to determine whether it has a specific factor + //std::vector Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 + // std::vector res; + // res.reserve(words.size()); + // for (const auto& word : words) { + // auto lemma = factoredVocab_->getFactor(word, 0); + // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + // } + // return res; + //} + + // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor + // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas + std::vector Logits::getFactorMasks(size_t factorGroup, const std::vector& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); + std::vector res; + res.reserve(n); + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab + for (size_t i = 0; i < n; i++) { + auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + } + return res; + } + + Logits Logits::applyUnaryFunction(const std::function& f) const { // clone this but apply f to all loss values + std::vector> newLogits; + for (const auto& l : logits_) + newLogits.emplace_back(New(f(l->loss()), l->count())); + return Logits(std::move(newLogits), factoredVocab_); + } + + Logits Logits::applyUnaryFunctions(const std::function& f1, const std::function& fother) const { + std::vector> newLogits; + bool first = true; + for (const auto& l : logits_) { + newLogits.emplace_back(New((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others + first = false; + } + return Logits(std::move(newLogits), factoredVocab_); + } + + // @TODO: code dup with above; we can merge it into applyToRationalLoss() + Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ + std::vector> newLogits; + for (const auto& l : logits_) + newLogits.emplace_back(New(l->loss(), count)); + return Logits(std::move(newLogits), factoredVocab_); + } +} \ No newline at end of file From 55f4216552bca148091f15b72c5c2e5b486d4c79 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 5 Mar 2021 06:12:28 +0000 Subject: [PATCH 09/14] add .h --- src/layers/logits.h | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/layers/logits.h diff --git a/src/layers/logits.h b/src/layers/logits.h new file mode 100644 index 000000000..4196e0d0a --- /dev/null +++ b/src/layers/logits.h @@ -0,0 +1,76 @@ +#pragma once + +#include "marian.h" +#include "data/shortlist.h" +#include "generic.h" + +namespace marian { + +class FactoredVocab; + +// To support factors, any output projection (that is followed by a softmax) must +// retain multiple outputs, one for each factor. Such layer returns not a single Expr, +// but a Logits object that contains multiple. +// This allows to compute softmax values in a factored manner, where we never create +// a fully expanded list of all factor combinations. +class RationalLoss; +class Logits { +public: + Logits() {} + explicit Logits(Ptr logits) { // single-output constructor + logits_.push_back(logits); + } + explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) + Logits(std::vector>&& logits, Ptr embeddingFactorMapping) // factored-output constructor + : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} + Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + Expr getFactoredLogits(size_t groupIndex, Ptr shortlist = nullptr, const std::vector& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle + //Ptr getRationalLoss() const; // assume it holds a loss: get that + Expr applyLossFunction(const Words& labels, const std::function& lossFn) const; + Logits applyUnaryFunction(const std::function& f) const; // clone this but apply f to all loss values + Logits applyUnaryFunctions(const std::function& f1, const std::function& fother) const; // clone this but apply f1 to first and fother to to all other values + + struct MaskedFactorIndices { + std::vector indices; // factor index, or 0 if masked + std::vector masks; + void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } + void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries + MaskedFactorIndices() {} + MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case + }; + std::vector factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices + Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only + size_t getNumFactorGroups() const { return logits_.size(); } + bool empty() const { return logits_.empty(); } + Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ +private: + // helper functions + Ptr graph() const; + Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } + Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } + template Expr constant(const std::vector& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector + Expr indices(const std::vector& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type + std::vector getFactorMasks(size_t factorGroup, const std::vector& indices) const; +private: + // members + // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr + std::vector> logits_; // [group id][B..., num factors in group] + Ptr factoredVocab_; +}; + +// Unary function that returns a Logits object +// Also implements IUnaryLayer, since Logits can be cast to Expr. +// This interface is implemented by all layers that are of the form of a unary function +// that returns multiple logits, to support factors. +struct IUnaryLogitLayer : public IUnaryLayer { + virtual Logits applyAsLogits(Expr) = 0; + virtual Logits applyAsLogits(const std::vector& es) { + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + return applyAsLogits(es.front()); + } + virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } + virtual Expr apply(const std::vector& es) override { return applyAsLogits(es).getLogits(); } +}; + +} + From ba196637847c50c76d5d0edfcfe39b9cedb0d1d0 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 5 Mar 2021 22:54:05 -0700 Subject: [PATCH 10/14] clang-format -i --- src/layers/constructors.h | 70 ++++--- src/layers/embedding.cpp | 282 ++++++++++++++----------- src/layers/embedding.h | 108 +++++----- src/layers/generic.cpp | 11 +- src/layers/generic.h | 98 +++++---- src/layers/logits.cpp | 424 +++++++++++++++++++++----------------- src/layers/logits.h | 110 ++++++---- src/layers/loss.cpp | 32 +-- src/layers/loss.h | 181 ++++++++-------- src/layers/output.cpp | 336 +++++++++++++++++------------- src/layers/output.h | 37 ++-- src/models/costs.cpp | 14 +- src/models/costs.h | 158 +++++++------- src/models/states.h | 70 ++++--- 14 files changed, 1068 insertions(+), 863 deletions(-) diff --git a/src/layers/constructors.h b/src/layers/constructors.h index e25449aa4..9e9de2077 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -1,8 +1,8 @@ #pragma once +#include "layers/embedding.h" #include "layers/factory.h" #include "layers/generic.h" -#include "layers/embedding.h" #include "layers/output.h" namespace marian { @@ -45,6 +45,7 @@ struct LogitLayerFactory : public Factory { // @TODO: In the long run, I hope we can get rid of the abstract factories altogether. class OutputFactory : public LogitLayerFactory { using LogitLayerFactory::LogitLayerFactory; + protected: std::string tiedTransposedName_; Ptr shortlist_; @@ -55,9 +56,7 @@ class OutputFactory : public LogitLayerFactory { return Accumulator(*this); } - void setShortlist(Ptr shortlist) { - shortlist_ = shortlist; - } + void setShortlist(Ptr shortlist) { shortlist_ = shortlist; } Ptr construct(Ptr graph) override { auto output = New(graph, options_); @@ -89,8 +88,7 @@ class MLP : public IUnaryLogitLayer, public IHasShortList { std::vector> layers_; public: - MLP(Ptr graph, Ptr options) - : graph_(graph), options_(options) {} + MLP(Ptr graph, Ptr options) : graph_(graph), options_(options) {} Expr apply(const std::vector& av) override { Expr output; @@ -106,46 +104,53 @@ class MLP : public IUnaryLogitLayer, public IHasShortList { } Logits applyAsLogits(const std::vector& av) override { - // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type + // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different + // return type auto lastLayer = std::dynamic_pointer_cast(layers_.back()); - ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); - if (layers_.size() == 1) { - if (av.size() == 1) + ABORT_IF( + !lastLayer, + "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); + if(layers_.size() == 1) { + if(av.size() == 1) return lastLayer->applyAsLogits(av[0]); else return lastLayer->applyAsLogits(av); - } - else { + } else { Expr output; - if (av.size() == 1) + if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); - for (size_t i = 1; i < layers_.size() - 1; ++i) + for(size_t i = 1; i < layers_.size() - 1; ++i) output = layers_[i]->apply(output); return lastLayer->applyAsLogits(output); } } - Expr apply(Expr e) override { return apply(std::vector{ e }); } - Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector{ e }); } + Expr apply(Expr e) override { return apply(std::vector{e}); } + Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector{e}); } void push_back(Ptr layer) { layers_.push_back(layer); } void push_back(Ptr layer) { layers_.push_back(layer); } void setShortlist(Ptr shortlist) override final { auto p = tryAsHasShortlist(); - ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists"); + ABORT_IF( + !p, + "setShortlist() called on an MLP with an output layer that does not support short lists"); p->setShortlist(shortlist); } void clear() override final { auto p = tryAsHasShortlist(); - if (p) + if(p) p->clear(); } + private: - Ptr tryAsHasShortlist() const { return std::dynamic_pointer_cast(layers_.back()); } + Ptr tryAsHasShortlist() const { + return std::dynamic_pointer_cast(layers_.back()); + } }; /** @@ -154,6 +159,7 @@ class MLP : public IUnaryLogitLayer, public IHasShortList { */ class MLPFactory : public Factory { using Factory::Factory; + private: std::vector> layers_; @@ -177,23 +183,27 @@ class MLPFactory : public Factory { // which will go away if we get rid of the abstract factories, and instead just construct // all layers immediately, which is my long-term goal for Marian. private: - template + template class AsLayerFactory : public LayerFactory { - WrappedFactory us; + WrappedFactory us; + public: - AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} - Ptr construct(Ptr graph) override final { - auto p = std::static_pointer_cast(us.construct(graph)); - ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); - return p; - } + AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} + Ptr construct(Ptr graph) override final { + auto p = std::static_pointer_cast(us.construct(graph)); + ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); + return p; + } }; - template - static inline AsLayerFactory asLayerFactory(const WrappedFactory& wrapped) { return wrapped; } + template + static inline AsLayerFactory asLayerFactory(const WrappedFactory& wrapped) { + return wrapped; + } + public: Accumulator push_back(const Accumulator& lf) { push_back(AsLayerFactory(lf)); - //layers_.push_back(New>(asLayerFactory((OutputFactory&)lf))); + // layers_.push_back(New>(asLayerFactory((OutputFactory&)lf))); return Accumulator(*this); } }; diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp index 488fbb8be..5a448f611 100644 --- a/src/layers/embedding.cpp +++ b/src/layers/embedding.cpp @@ -3,173 +3,205 @@ namespace marian { -Embedding::Embedding(Ptr graph, Ptr options) -: LayerBase(graph, options), inference_(opt("inference")) { -std::string name = opt("prefix"); -int dimVoc = opt("dimVocab"); -int dimEmb = opt("dimEmb"); +Embedding::Embedding(Ptr graph, Ptr options) + : LayerBase(graph, options), inference_(opt("inference")) { + std::string name = opt("prefix"); + int dimVoc = opt("dimVocab"); + int dimEmb = opt("dimEmb"); -bool fixed = opt("fixed", false); + bool fixed = opt("fixed", false); -factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); -if (factoredVocab_) { + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); + if(factoredVocab_) { dimVoc = (int)factoredVocab_->factorVocabSize(); LOG_ONCE(info, "[embedding] Factored embeddings enabled"); -} + } -// Embedding layer initialization should depend only on embedding size, hence fanIn=false -auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform( + /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length -if (options_->has("embFile")) { + if(options_->has("embFile")) { std::string file = opt("embFile"); - if (!file.empty()) { - bool norm = opt("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + if(!file.empty()) { + bool norm = opt("normalization", false); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); } -} + } -E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); + E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); } // helper to embed a sequence of words (given as indices) via factored embeddings Expr Embedding::multiRows(const Words& data, float dropProb) const { -auto graph = E_->graph(); -auto factoredData = factoredVocab_->csr_rows(data); -// multi-hot factor vectors are represented as a sparse CSR matrix -// [row index = word position index] -> set of factor indices for word at this position -ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); -// the CSR matrix is passed in pieces -auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); -auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); -auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); -// apply dropout -// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. -if(!inference_) + auto graph = E_->graph(); + auto factoredData = factoredVocab_->csr_rows(data); + // multi-hot factor vectors are represented as a sparse CSR matrix + // [row index = word position index] -> set of factor indices for word at this position + ABORT_IF(factoredData.shape + != Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}), + "shape mismatch??"); + // the CSR matrix is passed in pieces + auto weights = graph->constant({(int)factoredData.weights.size()}, + inits::fromVector(factoredData.weights)); + auto indices = graph->constant( + {(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32); + auto offsets = graph->constant( + {(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32); + // apply dropout + // We apply it to the weights, i.e. factors get dropped out separately, but always as entire + // vectors. + if(!inference_) weights = dropout(weights, dropProb); -// perform the product -return csr_dot(factoredData.shape, weights, indices, offsets, E_); + // perform the product + return csr_dot(factoredData.shape, weights, indices, offsets, E_); } -std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { -auto graph = E_->graph(); -int dimBatch = (int)subBatch->batchSize(); -int dimEmb = E_->shape()[-1]; -int dimWidth = (int)subBatch->batchWidth(); - -// factored embeddings: -// - regular: -// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] -// - factored: -// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) -// - each row of M contains the set of factors for one word => we want a CSR matrix -// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] -// - first compute x @ M on the CPU -// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): -// - shape (U, specifically) not actually needed here -// - foreach input x[i] -// - locate row M[i,*] -// - copy through its index values (std::vector) -// - create a matching ones vector (we can keep growing) -// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) -// - CSR matrix product with E -// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) -// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). -// - weighting: -// - core factors' gradients are sums over all words that use the factors; -// - core factors' embeddings move very fast -// - words will need to make up for the move; rare words cannot -// - so, we multiply each factor with 1/refCount -// - core factors get weighed down a lot -// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before -// - but forward pass weighs them down, so that all factors are in a similar numeric range -// - if it is required to be in a different range, the embeddings can still learn that, but more slowly - -auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); +std::tuple Embedding::apply(Ptr subBatch) const +/*override final*/ { + auto graph = E_->graph(); + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = E_->shape()[-1]; + int dimWidth = (int)subBatch->batchWidth(); + + // factored embeddings: + // - regular: + // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] + // - factored: + // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) + // - each row of M contains the set of factors for one word => we want a CSR matrix + // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] + // - first compute x @ M on the CPU + // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): + // - shape (U, specifically) not actually needed here + // - foreach input x[i] + // - locate row M[i,*] + // - copy through its index values (std::vector) + // - create a matching ones vector (we can keep growing) + // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) + // - CSR matrix product with E + // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) + // - double-check if all dimensions are specified. Probably not for transpose (which would + // be like csc_dot()). + // - weighting: + // - core factors' gradients are sums over all words that use the factors; + // - core factors' embeddings move very fast + // - words will need to make up for the move; rare words cannot + // - so, we multiply each factor with 1/refCount + // - core factors get weighed down a lot + // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as + // before + // - but forward pass weighs them down, so that all factors are in a similar numeric range + // - if it is required to be in a different range, the embeddings can still learn that, but + // more slowly + + auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); #if 1 -auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it -// experimental: hide inline-fix source tokens from cross attention -auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); + auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask())); +#else // @TODO: this is dead code now, get rid of it + // experimental: hide inline-fix source tokens from cross attention + auto batchMask + = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); #endif -// give the graph inputs readable names for debugging and ONNX -batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); + // give the graph inputs readable names for debugging and ONNX + batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); -return std::make_tuple(batchEmbeddings, batchMask); + return std::make_tuple(batchEmbeddings, batchMask); } Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { -if (factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - //selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout + if(factoredVocab_) { + Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { + // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout return selectedEmbs; -} -else + } else return applyIndices(toWordIndexVector(words), shape); } -Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { -ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); -auto embIdxExpr = E_->graph()->indices(embIdx); -embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? -auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] -selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] -// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) -if(!inference_) - selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); -return selectedEmbs; +Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const +/*override final*/ { + ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); + auto embIdxExpr = E_->graph()->indices(embIdx); + embIdxExpr->set_name("data_" + + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? + auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() + // (test that separately) + if(!inference_) + selectedEmbs = dropout( + selectedEmbs, options_->get("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1}); + return selectedEmbs; } // standard encoder word embeddings /*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { -auto options = New( - "dimVocab", opt>("dim-vocabs")[batchIndex_], - "dimEmb", opt("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, - "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings -if(options_->hasAndNotEmpty("embedding-vectors")) { + auto options = New( + "dimVocab", + opt>("dim-vocabs")[batchIndex_], + "dimEmb", + opt("dim-emb"), + "dropout", + dropoutEmbeddings_, + "inference", + inference_, + "prefix", + (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" + : prefix_ + "_Wemb", + "fixed", + embeddingFix_, + "vocab", + opt>("vocabs")[batchIndex_]); // for factored embeddings + if(options_->hasAndNotEmpty("embedding-vectors")) { auto embFiles = opt>("embedding-vectors"); options->set( - "embFile", embFiles[batchIndex_], - "normalization", opt("embedding-normalization")); -} -return New(graph_, options); + "embFile", embFiles[batchIndex_], "normalization", opt("embedding-normalization")); + } + return New(graph_, options); } // ULR word embeddings /*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { -return New(graph_, New( - "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", opt("ulr-dim-emb"), - "dimEmb", opt("dim-emb"), - "ulr-dropout", opt("ulr-dropout"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "ulrTrainTransform", opt("ulr-trainable-transformation"), - "ulrQueryFile", opt("ulr-query-vectors"), - "ulrKeysFile", opt("ulr-keys-vectors"))); + return New( + graph_, + New("dimSrcVoc", + opt>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", + opt>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", + opt("ulr-dim-emb"), + "dimEmb", + opt("dim-emb"), + "ulr-dropout", + opt("ulr-dropout"), + "dropout", + dropoutEmbeddings_, + "inference", + inference_, + "ulrTrainTransform", + opt("ulr-trainable-transformation"), + "ulrQueryFile", + opt("ulr-query-vectors"), + "ulrKeysFile", + opt("ulr-keys-vectors"))); } // get embedding layer for this encoder or decoder // This is lazy mostly because the constructors of the consuming objects are not // guaranteed presently to have access to their graph. Ptr EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { -if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy - if (embeddingLayers_.size() <= batchIndex_) - embeddingLayers_.resize(batchIndex_ + 1); - if (ulr) - embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR + if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy + if(embeddingLayers_.size() <= batchIndex_) + embeddingLayers_.resize(batchIndex_ + 1); + if(ulr) + embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR else - embeddingLayers_[batchIndex_] = createEmbeddingLayer(); -} -return embeddingLayers_[batchIndex_]; -} - + embeddingLayers_[batchIndex_] = createEmbeddingLayer(); + } + return embeddingLayers_[batchIndex_]; } +} // namespace marian diff --git a/src/layers/embedding.h b/src/layers/embedding.h index b7898c76e..6edb31409 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -1,6 +1,6 @@ #pragma once -#include "marian.h" #include "generic.h" +#include "marian.h" namespace marian { @@ -19,7 +19,8 @@ class Embedding : public LayerBase, public IEmbeddingLayer { public: Embedding(Ptr graph, Ptr options); - std::tuple apply(Ptr subBatch) const override final; + std::tuple apply( + Ptr subBatch) const override final; Expr apply(const Words& words, const Shape& shape) const override final; @@ -27,17 +28,18 @@ class Embedding : public LayerBase, public IEmbeddingLayer { }; class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + std::vector + ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members bool inference_{false}; public: - ULREmbedding(Ptr graph, Ptr options) - : LayerBase(graph, options), inference_(opt("inference")) { - std::string name = "url_embed"; //opt("prefix"); + ULREmbedding(Ptr graph, Ptr options) + : LayerBase(graph, options), inference_(opt("inference")) { + std::string name = "url_embed"; // opt("prefix"); int dimKeys = opt("dimTgtVoc"); int dimQueries = opt("dimSrcVoc"); int dimEmb = opt("dimEmb"); - int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size + int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size bool fixed = opt("fixed", false); // Embedding layer initialization should depend only on embedding size, hence fanIn=false @@ -46,58 +48,61 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { std::string queryFile = opt("ulrQueryFile"); std::string keyFile = opt("ulrKeysFile"); bool trainTrans = opt("ulrTrainTransform", false); - if (!queryFile.empty() && !keyFile.empty()) { + if(!queryFile.empty() && !keyFile.empty()) { initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); name = "ulr_query"; fixed = true; - auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); + auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(query_embed); // keys embeds initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); name = "ulr_keys"; fixed = true; - auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); + auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(key_embed); // actual trainable embedding initFunc = inits::glorotUniform(); name = "ulr_embed"; fixed = false; - auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim + auto ulr_embed + = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim ulrEmbeddings_.push_back(ulr_embed); // init trainable src embedding name = "ulr_src_embed"; - auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); + auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulr_src_embed); // ulr transformation matrix - //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only - if (trainTrans) { + // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall + // we make this to the fixed case only + if(trainTrans) { initFunc = inits::glorotUniform(); fixed = false; - } - else - { - initFunc = inits::eye(); // identity matrix + } else { + initFunc = inits::eye(); // identity matrix fixed = true; } name = "ulr_transform"; - auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); + auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulrTransform); - initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only + initFunc = inits::fromValue( + 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no + // universal embeddings - should be zero for top freq only fixed = true; name = "ulr_shared"; - auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); + auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); ulrEmbeddings_.push_back(share_embed); } } - std::tuple apply(Ptr subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb - auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + std::tuple apply( + Ptr subBatch) const override final { + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 int dimBatch = (int)subBatch->batchSize(); int dimEmb = uniEmbed->shape()[-1]; int dimWords = (int)subBatch->batchWidth(); @@ -106,34 +111,42 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { // dim A = uni_embed_size*uni_embed_size // dim Q: uni_embed_size * total_merged_vocab_size // dim D = univ_tok_vocab * total_merged_vocab_size - // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) - // here we need to handle the mini-batch - // extract raws corresponding to Xs in this minibatch from Q + // note all above can be precombuted and serialized if A is not trainiable and during decoding + // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this + // minibatch from Q auto embIdx = toWordIndexVector(subBatch->data()); auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); - qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, + ulrTransform, + false, + false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); + qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in + // magnitude with larger embeds sizes + auto z = dot(qt, keyEmbed, false, true); // query-key similarity float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout if(!inference_) z = dropout(z, dropProb); - float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature + float tau + = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature // temperature in softmax is to control randomness of predictions // high temperature Softmax outputs are more close to each other // low temperatures the softmax become more similar to "hardmax" - auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto weights + = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast - auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); + auto chosenEmbeddings_mix + = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); auto graph = ulrEmbeddings_.front()->graph(); - auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, - inits::fromVector(subBatch->mask())); + auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); if(!inference_) - batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); + batchEmbeddings = dropout(batchEmbeddings, + options_->get("dropout-embeddings", 0.0f), + {batchEmbeddings->shape()[-3], 1, 1}); return std::make_tuple(batchEmbeddings, batchMask); } @@ -142,9 +155,10 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { } Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { - embIdx; shape; - ABORT("not implemented"); // @TODO: implement me + embIdx; + shape; + ABORT("not implemented"); // @TODO: implement me } }; -} +} // namespace marian diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 02e820e57..8e2ecfd79 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -1,13 +1,10 @@ #include "marian.h" -#include "layers/generic.h" +#include "data/factored_vocab.h" #include "layers/constructors.h" +#include "layers/generic.h" #include "layers/loss.h" -#include "data/factored_vocab.h" -#include "models/states.h" // for EncoderState #include "layers/lsh.h" +#include "models/states.h" // for EncoderState -namespace marian { - - -} // namespace marian +namespace marian {} // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index eddd597e8..89f5c1e9d 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -5,12 +5,14 @@ #include "data/shortlist.h" #include "layers/factory.h" -namespace marian { namespace mlp { - /** - * @brief Activation functions - */ - enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; -}} +namespace marian { +namespace mlp { +/** + * @brief Activation functions + */ +enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; +} // namespace mlp +} // namespace marian namespace marian { @@ -23,8 +25,7 @@ class LayerBase { Ptr options_; public: - LayerBase(Ptr graph, Ptr options) - : graph_(graph), options_(options) {} + LayerBase(Ptr graph, Ptr options) : graph_(graph), options_(options) {} template T opt(const std::string key) const { @@ -42,7 +43,7 @@ struct IUnaryLayer { virtual ~IUnaryLayer() {} virtual Expr apply(Expr) = 0; virtual Expr apply(const std::vector& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub return apply(es.front()); } }; @@ -54,7 +55,8 @@ struct IHasShortList { // Embedding from corpus sub-batch to (emb, mask) struct IEmbeddingLayer { - virtual std::tuple apply(Ptr subBatch) const = 0; + virtual std::tuple apply( + Ptr subBatch) const = 0; virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0; @@ -63,28 +65,29 @@ struct IEmbeddingLayer { virtual ~IEmbeddingLayer() {} }; -// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index) +// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream +// index) class EncoderDecoderLayerBase : public LayerBase { protected: const std::string prefix_; const bool embeddingFix_; - const float dropoutEmbeddings_; // this drops out full embedding vectors + const float dropoutEmbeddings_; // this drops out full embedding vectors const bool inference_; const size_t batchIndex_; - mutable std::vector> embeddingLayers_; // (lazily created) + mutable std::vector> embeddingLayers_; // (lazily created) - EncoderDecoderLayerBase(Ptr graph, - Ptr options, - const std::string& prefix, + EncoderDecoderLayerBase(Ptr graph, + Ptr options, + const std::string& prefix, size_t batchIndex, float dropoutEmbeddings, - bool embeddingFix) : - LayerBase(graph, options), - prefix_(options->get("prefix", prefix)), - embeddingFix_(embeddingFix), - dropoutEmbeddings_(dropoutEmbeddings), - inference_(options->get("inference", false)), - batchIndex_(options->get("index", batchIndex)) {} + bool embeddingFix) + : LayerBase(graph, options), + prefix_(options->get("prefix", prefix)), + embeddingFix_(embeddingFix), + dropoutEmbeddings_(dropoutEmbeddings), + inference_(options->get("inference", false)), + batchIndex_(options->get("index", batchIndex)) {} virtual ~EncoderDecoderLayerBase() {} @@ -101,8 +104,7 @@ namespace mlp { class Dense : public LayerBase, public IUnaryLayer { public: - Dense(Ptr graph, Ptr options) - : LayerBase(graph, options) {} + Dense(Ptr graph, Ptr options) : LayerBase(graph, options) {} Expr apply(const std::vector& inputs) override { ABORT_IF(inputs.empty(), "No inputs"); @@ -124,21 +126,17 @@ class Dense : public LayerBase, public IUnaryLayer { if(inputs.size() > 1) num = std::to_string(i); - Expr W = g->param( - name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); + Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros()); if(useLayerNorm) { if(useNematusNorm) { - auto ln_s = g->param( - name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); + auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros()); - outputs.push_back( - layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); + outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); } else { - auto gamma = g->param( - name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); + auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); outputs.push_back(layerNorm(dot(in, W), gamma, b)); } @@ -165,39 +163,35 @@ class Dense : public LayerBase, public IUnaryLayer { Expr apply(Expr input) override { return apply(std::vector({input})); } }; -} // namespace mlp - +} // namespace mlp // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object // like affine() but with built-in parameters, activation, and dropout -static inline -Expr denseInline(Expr x, - std::string prefix, - std::string suffix, - int outDim, - Ptr initFn = inits::glorotUniform(), - const std::function& actFn = nullptr, - float dropProb = 0.0f) -{ +static inline Expr denseInline(Expr x, + std::string prefix, + std::string suffix, + int outDim, + Ptr initFn = inits::glorotUniform(), + const std::function& actFn = nullptr, + float dropProb = 0.0f) { auto graph = x->graph(); - auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform()); - auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros()); + auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform()); + auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros()); x = affine(x, W, b); - if (actFn) + if(actFn) x = actFn(x); - x = dropout(x, dropProb); // @TODO: check for infernce? + x = dropout(x, dropProb); // @TODO: check for infernce? return x; } -static inline -Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { +static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { int dimModel = x->shape()[-1]; - auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones()); - auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros()); + auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones()); + auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros()); return marian::layerNorm(x, scale, bias, 1e-6f); } diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index cd2203e4d..772c57150 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -1,212 +1,250 @@ #include "logits.h" -#include "loss.h" #include "data/factored_vocab.h" -#include "rnn/types.h" // for State::select() +#include "loss.h" +#include "rnn/types.h" // for State::select() namespace marian { - Logits::Logits(Expr logits) : Logits(New(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) - - Ptr Logits::graph() const { - ABORT_IF(logits_.empty(), "Empty logits object??"); - return logits_.front()->loss()->graph(); +Logits::Logits(Expr logits) + : Logits(New(logits, nullptr)) { +} // single-output constructor from Expr only (RationalLoss has no count) + +Ptr Logits::graph() const { + ABORT_IF(logits_.empty(), "Empty logits object??"); + return logits_.front()->loss()->graph(); +} + +// This function assumes that the object holds one or more factor logits. +// It applies the supplied loss function to each, and then returns the aggregate loss over all +// factors. +Expr Logits::applyLossFunction( + const Words& labels, + const std::function& lossFn) const { + LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto firstLogits = logits_.front()->loss(); + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + "Labels not matching logits shape ({} != {}, {})??", + labels.size() * firstLogits->shape()[-1], + firstLogits->shape().elements(), + firstLogits->shape()); + + // base case (no factors) + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return lossFn(firstLogits, indices(toWordIndexVector(labels))); } - // This function assumes that the object holds one or more factor logits. - // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. - Expr Logits::applyLossFunction(const Words& labels, const std::function& lossFn) const { - LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto firstLogits = logits_.front()->loss(); - ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), - "Labels not matching logits shape ({} != {}, {})??", - labels.size() * firstLogits->shape()[-1], - firstLogits->shape().elements(), - firstLogits->shape()); - - // base case (no factors) - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return lossFn(firstLogits, indices(toWordIndexVector(labels))); - } - - auto numGroups = factoredVocab_->getNumGroups(); - - // split labels into individual factor labels - auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] - - //Expr indices = this->indices(toWordIndexVector(labels)); - // accumulate all CEs for all words that have the factor - // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. - Expr loss; - for (size_t g = 0; g < numGroups; g++) { - if (!logits_[g]) - continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] - if(loss) - factorLoss = cast(factorLoss, loss->value_type()); - factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor - loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] - } - return loss; + auto numGroups = factoredVocab_->getNumGroups(); + + // split labels into individual factor labels + auto allMaskedFactoredLabels + = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] + + // Expr indices = this->indices(toWordIndexVector(labels)); + // accumulate all CEs for all words that have the factor + // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. + Expr loss; + for(size_t g = 0; g < numGroups; g++) { + if(!logits_[g]) + continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices( + maskedFactoredLabels + .indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask + = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with + // 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask + // it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + if(loss) + factorLoss = cast(factorLoss, loss->value_type()); + factorLoss + = factorLoss + * cast( + reshape(factorMask, factorLoss->shape()), + factorLoss->value_type()); // mask out factor for words that do not have that factor + loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] } - - // This function assumes this object holds a single factor that represents a rational loss (with count). - //Ptr Logits::getRationalLoss() const { - // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); - // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); - // return logits_.front(); - //} - - // get logits for one factor group - // For groupIndex == 0, the function also requires the shortlist if there is one. - Expr Logits::getFactoredLogits(size_t groupIndex, Ptr shortlist /*= nullptr*/, const std::vector& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] - - // normalize for decoding: - // - all secondary factors: subtract their max - // - lemma: add all maxes of applicable factors - if (groupIndex > 0) { - sel = sel - max(sel, -1); - } - else { - auto numGroups = getNumFactorGroups(); - for (size_t g = 1; g < numGroups; g++) { - auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 - auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector())); - sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 - } + return loss; +} + +// This function assumes this object holds a single factor that represents a rational loss (with +// count). +// Ptr Logits::getRationalLoss() const { +// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on +// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational +// loss without count"); return logits_.front(); +//} + +// get logits for one factor group +// For groupIndex == 0, the function also requires the shortlist if there is one. +Expr Logits::getFactoredLogits(size_t groupIndex, + Ptr shortlist /*= nullptr*/, + const std::vector& hypIndices /*= {}*/, + size_t beamSize /*= 0*/) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + + // normalize for decoding: + // - all secondary factors: subtract their max + // - lemma: add all maxes of applicable factors + if(groupIndex > 0) { + sel = sel - max(sel, -1); + } else { + auto numGroups = getNumFactorGroups(); + for(size_t g = 1; g < numGroups; g++) { + auto factorMaxima = max(logits_[g]->loss(), + -1); // we cast since loss is likely ce-loss which has type float32 + auto factorMasks = constant( + getFactorMasks(g, shortlist ? shortlist->indices() : std::vector())); + sel = sel + + cast(factorMaxima, sel->value_type()) + * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor + // get multiplied with 0 } - - // if selIdx are given, then we must reshuffle accordingly - if (!hypIndices.empty()) // use the same function that shuffles decoder state - sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); - - return sel; } - // used for breakDown() only - // Index is flattened - Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - return logits_[groupIndex]->loss()->val(); + // if selIdx are given, then we must reshuffle accordingly + if(!hypIndices.empty()) // use the same function that shuffles decoder state + sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + + return sel; +} + +// used for breakDown() only +// Index is flattened +Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + return logits_[groupIndex]->loss()->val(); +} + +// This function assumes that the object holds one or more factor logits, which are summed up +// into output-vocab logits according to the factored model (with correct normalization of factors). +// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. +// @TODO: remove altogether +Expr Logits::getLogits() const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return getFactoredLogits(0); } - // This function assumes that the object holds one or more factor logits, which are summed up - // into output-vocab logits according to the factored model (with correct normalization of factors). - // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. - // @TODO: remove altogether - Expr Logits::getLogits() const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return getFactoredLogits(0); - } - #ifdef FACTOR_FULL_EXPANSION - // compute normalized factor log probs - std::vector logProbs(logits_.size()); - for (size_t g = 0; g < logits_.size(); g++) - logProbs[g] = logsoftmax(logits_[g]->loss()); - auto y = concatenate(logProbs, /*axis=*/ -1); - - // sum up the unit logits across factors for each target word - auto graph = y->graph(); - auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] - y = dot_csr( - y, // [B x U] - factorMatrix.shape, - graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), - /*transB=*/ true); // -> [B x V] - - // mask out gaps - auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] - y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); - - return y; + // compute normalized factor log probs + std::vector logProbs(logits_.size()); + for(size_t g = 0; g < logits_.size(); g++) + logProbs[g] = logsoftmax(logits_[g]->loss()); + auto y = concatenate(logProbs, /*axis=*/-1); + + // sum up the unit logits across factors for each target word + auto graph = y->graph(); + auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] + y = dot_csr( + y, // [B x U] + factorMatrix.shape, + graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), + graph->constant({(int)factorMatrix.indices.size()}, + inits::fromVector(factorMatrix.indices), + Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, + inits::fromVector(factorMatrix.offsets), + Type::uint32), + /*transB=*/true); // -> [B x V] + + // mask out gaps + auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] + y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask)); + + return y; #else - ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible + ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible #endif +} + +void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { + bool isValid = FactoredVocab::isFactorValid(factorIndex); + indices.push_back(isValid ? (WordIndex)factorIndex : 0); + masks.push_back((float)isValid); +} + +std::vector Logits::factorizeWords(const Words& words) + const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return {MaskedFactorIndices(words)}; } - - void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { - bool isValid = FactoredVocab::isFactorValid(factorIndex); - indices.push_back(isValid ? (WordIndex)factorIndex : 0); - masks.push_back((float)isValid); + auto numGroups = factoredVocab_->getNumGroups(); + std::vector res(numGroups); + for(size_t g = 0; g < numGroups; g++) { + auto& resg = res[g]; + resg.reserve(words.size()); + for(const auto& word : words) + resg.push_back(factoredVocab_->getFactor(word, g)); } - - std::vector Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return {MaskedFactorIndices(words)}; - } - auto numGroups = factoredVocab_->getNumGroups(); - std::vector res(numGroups); - for (size_t g = 0; g < numGroups; g++) { - auto& resg = res[g]; - resg.reserve(words.size()); - for (const auto& word : words) - resg.push_back(factoredVocab_->getFactor(word, g)); - } - return res; - } - - //// use first factor of each word to determine whether it has a specific factor - //std::vector Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 - // std::vector res; - // res.reserve(words.size()); - // for (const auto& word : words) { - // auto lemma = factoredVocab_->getFactor(word, 0); - // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - // } - // return res; - //} - - // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor - // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas - std::vector Logits::getFactorMasks(size_t factorGroup, const std::vector& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 - size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); - std::vector res; - res.reserve(n); - // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab - for (size_t i = 0; i < n; i++) { - auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); - res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - } - return res; + return res; +} + +//// use first factor of each word to determine whether it has a specific factor +// std::vector Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 +// for words that do have this factor; else 0 +// std::vector res; +// res.reserve(words.size()); +// for (const auto& word : words) { +// auto lemma = factoredVocab_->getFactor(word, 0); +// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); +// } +// return res; +//} + +// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor +// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas +std::vector Logits::getFactorMasks(size_t factorGroup, const std::vector& indices) + const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n + = indices.empty() + ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) + : indices.size(); + std::vector res; + res.reserve(n); + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this + // into FactoredVocab + for(size_t i = 0; i < n; i++) { + auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); } - - Logits Logits::applyUnaryFunction(const std::function& f) const { // clone this but apply f to all loss values - std::vector> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New(f(l->loss()), l->count())); - return Logits(std::move(newLogits), factoredVocab_); - } - - Logits Logits::applyUnaryFunctions(const std::function& f1, const std::function& fother) const { - std::vector> newLogits; - bool first = true; - for (const auto& l : logits_) { - newLogits.emplace_back(New((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others - first = false; - } - return Logits(std::move(newLogits), factoredVocab_); - } - - // @TODO: code dup with above; we can merge it into applyToRationalLoss() - Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ - std::vector> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New(l->loss(), count)); - return Logits(std::move(newLogits), factoredVocab_); + return res; +} + +Logits Logits::applyUnaryFunction( + const std::function& f) const { // clone this but apply f to all loss values + std::vector> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New(f(l->loss()), l->count())); + return Logits(std::move(newLogits), factoredVocab_); +} + +Logits Logits::applyUnaryFunctions(const std::function& f1, + const std::function& fother) const { + std::vector> newLogits; + bool first = true; + for(const auto& l : logits_) { + newLogits.emplace_back(New((first ? f1 : fother)(l->loss()), + l->count())); // f1 for first, fother for all others + first = false; } -} \ No newline at end of file + return Logits(std::move(newLogits), factoredVocab_); +} + +// @TODO: code dup with above; we can merge it into applyToRationalLoss() +Logits Logits::withCounts( + const Expr& count) const { // create new Logits with 'count' implanted into all logits_ + std::vector> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New(l->loss(), count)); + return Logits(std::move(newLogits), factoredVocab_); +} +} // namespace marian \ No newline at end of file diff --git a/src/layers/logits.h b/src/layers/logits.h index 4196e0d0a..c61a9e742 100644 --- a/src/layers/logits.h +++ b/src/layers/logits.h @@ -1,8 +1,8 @@ #pragma once -#include "marian.h" #include "data/shortlist.h" #include "generic.h" +#include "marian.h" namespace marian { @@ -16,46 +16,77 @@ class FactoredVocab; class RationalLoss; class Logits { public: - Logits() {} - explicit Logits(Ptr logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector>&& logits, Ptr embeddingFactorMapping) // factored-output constructor + Logits() {} + explicit Logits(Ptr logits) { // single-output constructor + logits_.push_back(logits); + } + explicit Logits( + Expr logits); // single-output constructor from Expr only (RationalLoss has no count) + Logits(std::vector>&& logits, + Ptr embeddingFactorMapping) // factored-output constructor : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr shortlist = nullptr, const std::vector& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function& lossFn) const; - Logits applyUnaryFunction(const std::function& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function& f1, const std::function& fother) const; // clone this but apply f1 to first and fother to to all other values + Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + Expr getFactoredLogits( + size_t groupIndex, + Ptr shortlist = nullptr, + const std::vector& hypIndices = {}, + size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle + // Ptr getRationalLoss() const; // assume it holds a loss: get that + Expr applyLossFunction( + const Words& labels, + const std::function& lossFn) const; + Logits applyUnaryFunction( + const std::function& f) const; // clone this but apply f to all loss values + Logits applyUnaryFunctions(const std::function& f1, + const std::function& fother) + const; // clone this but apply f1 to first and fother to to all other values - struct MaskedFactorIndices { - std::vector indices; // factor index, or 0 if masked - std::vector masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ + struct MaskedFactorIndices { + std::vector indices; // factor index, or 0 if masked + std::vector masks; + void reserve(size_t n) { + indices.reserve(n); + masks.reserve(n); + } + void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 + // for invalid entries + MaskedFactorIndices() {} + MaskedFactorIndices(const Words& words) { + indices = toWordIndexVector(words); + } // we can leave masks uninitialized for this special use case + }; + std::vector factorizeWords( + const Words& words) const; // breaks encoded Word into individual factor indices + Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only + size_t getNumFactorGroups() const { return logits_.size(); } + bool empty() const { return logits_.empty(); } + Logits withCounts( + const Expr& count) const; // create new Logits with 'count' implanted into all logits_ private: - // helper functions - Ptr graph() const; - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template Expr constant(const std::vector& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector getFactorMasks(size_t factorGroup, const std::vector& indices) const; + // helper functions + Ptr graph() const; + Expr constant(const Shape& shape, const std::vector& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + Expr constant(const Shape& shape, const std::vector& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + template + Expr constant(const std::vector& data) const { + return constant(Shape{(int)data.size()}, data); + } // same as constant() but assuming vector + Expr indices(const std::vector& data) const { + return graph()->indices(data); + } // actually the same as constant(data) for this data type + std::vector getFactorMasks(size_t factorGroup, + const std::vector& indices) const; + private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector> logits_; // [group id][B..., num factors in group] - Ptr factoredVocab_; + // members + // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just + // by the Expr + std::vector> logits_; // [group id][B..., num factors in group] + Ptr factoredVocab_; }; // Unary function that returns a Logits object @@ -65,12 +96,11 @@ class Logits { struct IUnaryLogitLayer : public IUnaryLayer { virtual Logits applyAsLogits(Expr) = 0; virtual Logits applyAsLogits(const std::vector& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub return applyAsLogits(es.front()); } virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } virtual Expr apply(const std::vector& es) override { return applyAsLogits(es).getLogits(); } }; -} - +} // namespace marian diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp index 67d388326..695276af8 100644 --- a/src/layers/loss.cpp +++ b/src/layers/loss.cpp @@ -13,26 +13,30 @@ Ptr newLoss(Ptr options, bool inference) { bool wordScores = options->get("word-scores", false); return New(wordScores); } else if(unlikelihood) { - ABORT_IF(!options->hasAndNotEmpty("data-weighting") - && options->get("data-weighting-type") != "word", - "Unlikelihood loss training requires error annotation in form of per-target-label scores"); - return New(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting - } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum? + ABORT_IF( + !options->hasAndNotEmpty("data-weighting") + && options->get("data-weighting-type") != "word", + "Unlikelihood loss training requires error annotation in form of per-target-label scores"); + return New( + smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on + // values given for data-weighting + } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. + // E.g. what about ce-sum? return New(smoothing, factorWeight); } } // see loss.h for detailed explanations of each class Ptr newMultiLoss(Ptr options) { - std::string multiLossType = options->get("multi-loss-type", "sum"); - if(multiLossType == "sum") // sum of sums - return New(); - else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale - return New(); - else if(multiLossType == "mean") // sum of means - return New(); - else - ABORT("Unknown multi-loss-type {}", multiLossType); + std::string multiLossType = options->get("multi-loss-type", "sum"); + if(multiLossType == "sum") // sum of sums + return New(); + else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale + return New(); + else if(multiLossType == "mean") // sum of means + return New(); + else + ABORT("Unknown multi-loss-type {}", multiLossType); } } // namespace marian diff --git a/src/layers/loss.h b/src/layers/loss.h index ba93cdac7..c662f9911 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,8 +1,8 @@ #pragma once -#include "graph/expression_operators.h" -#include "layers/logits.h" // for Logits (Frank's factor hack) #include "data/types.h" +#include "graph/expression_operators.h" +#include "layers/logits.h" // for Logits (Frank's factor hack) namespace marian { @@ -22,21 +22,18 @@ namespace marian { */ class RationalLoss { protected: - Expr loss_; // numerator - Expr count_; // denominator + Expr loss_; // numerator + Expr count_; // denominator - RationalLoss() = default; // protected + RationalLoss() = default; // protected public: - RationalLoss(Expr loss, Expr count) - : loss_(loss), count_(count) {} + RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {} RationalLoss(Expr loss, float count) - : loss_(loss), - count_(constant_like(loss, inits::fromValue(count))) {} + : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {} - RationalLoss(const RationalLoss& other) - : loss_(other.loss_), count_(other.count_) {} + RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {} virtual ~RationalLoss() = default; @@ -50,7 +47,7 @@ class RationalLoss { } template - T loss() const { // this will fail if loss is not a single value + T loss() const { // this will fail if loss is not a single value ABORT_IF(!loss_, "Loss has not been defined"); return loss_->val()->scalar(); } @@ -65,7 +62,7 @@ class RationalLoss { } template - T count() const { // this will fail if loss is not a single value + T count() const { // this will fail if loss is not a single value ABORT_IF(!count_, "Labels have not been defined"); return count_->val()->scalar(); } @@ -85,21 +82,21 @@ class RationalLoss { * RationalLoss object. */ struct StaticLoss { - float loss; // numerator - float count; // denominator + float loss; // numerator + float count; // denominator StaticLoss() : loss(0.f), count(0.f) {} StaticLoss(const RationalLoss& dynamic) - : loss(dynamic.loss()), count(dynamic.count()) {} + : loss(dynamic.loss()), count(dynamic.count()) {} - StaticLoss operator +(const StaticLoss& other) const { + StaticLoss operator+(const StaticLoss& other) const { StaticLoss res(*this); res += other; return res; } - StaticLoss& operator +=(const StaticLoss& other) { + StaticLoss& operator+=(const StaticLoss& other) { loss = loss + other.loss; count = count + other.count; return *this; @@ -139,32 +136,21 @@ class MultiRationalLoss : public RationalLoss { public: MultiRationalLoss() : RationalLoss() {} - MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { - push_back(rl); - } + MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); } virtual void push_back(const RationalLoss& current) { - loss_ = accumulateLoss(current); - count_ = accumulateCount(current); + loss_ = accumulateLoss(current); + count_ = accumulateCount(current); partialLosses_.push_back(current); } - const RationalLoss& operator[](size_t i) { - return partialLosses_[i]; - } + const RationalLoss& operator[](size_t i) { return partialLosses_[i]; } - auto begin() -> decltype(partialLosses_.begin()) const { - return partialLosses_.begin(); - } + auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); } - auto end() -> decltype(partialLosses_.end()) const { - return partialLosses_.end(); - } - - size_t size() const { - return partialLosses_.size(); - } + auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); } + size_t size() const { return partialLosses_.size(); } }; /** @@ -212,17 +198,19 @@ class ScaledMultiRationalLoss : public MultiRationalLoss { virtual Expr accumulateLoss(const RationalLoss& current) override { if(loss_) { const auto& first = partialLosses_.front(); - return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss + return loss_ + + current.loss() * first.count() + / current.count(); // scale up/down to match scale of first loss } else { - return current.loss(); // first reference loss, keeps to scale with this one + return current.loss(); // first reference loss, keeps to scale with this one } } virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) { - return count_; // Keep first label count // or: count_ + first.count() / current.count(); + return count_; // Keep first label count // or: count_ + first.count() / current.count(); } else { - return current.count(); // This is the first loss + return current.count(); // This is the first loss } } @@ -253,9 +241,10 @@ class MeanMultiRationalLoss : public MultiRationalLoss { virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) - return count_; // keep the existing '1' + return count_; // keep the existing '1' else - return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ + return current.count()->graph()->ones( + {1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ } public: @@ -279,18 +268,21 @@ class LabelwiseLoss { protected: std::vector axes_; - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) = 0; + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) + = 0; // label counts are available, reduce together with loss to obtain counts RationalLoss reduce(Expr loss, Expr labels) { ABORT_IF(!loss, "Loss has not been computed"); ABORT_IF(!labels, "Labels have not been computed"); - Expr lossSum = cast(loss, Type::float32); // accumulate in float32 - Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 + Expr lossSum = cast(loss, Type::float32); // accumulate in float32 + Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 for(int i = 0; i < axes_.size(); ++i) { - lossSum = sum(lossSum, axes_[i]); + lossSum = sum(lossSum, axes_[i]); labelsSum = sum(labelsSum, axes_[i]); } @@ -301,7 +293,7 @@ class LabelwiseLoss { RationalLoss reduce(Expr loss) { ABORT_IF(!loss, "Loss has not been computed"); - Expr lossSum = cast(loss, Type::float32); + Expr lossSum = cast(loss, Type::float32); for(int i = 0; i < axes_.size(); ++i) lossSum = sum(lossSum, axes_[i]); @@ -311,17 +303,18 @@ class LabelwiseLoss { } public: - LabelwiseLoss(const std::vector& axes) - : axes_(axes) { } + LabelwiseLoss(const std::vector& axes) : axes_(axes) {} - virtual RationalLoss apply(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) { + virtual RationalLoss apply(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) { Expr loss = compute(logits, labels, mask, labelWeights); if(mask) - return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting + return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting else - return reduce(loss); // we have no mask, assume all items are labels + return reduce(loss); // we have no mask, assume all items are labels } }; @@ -331,28 +324,34 @@ class LabelwiseLoss { class CrossEntropyLoss : public LabelwiseLoss { public: CrossEntropyLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 CrossEntropyLoss(const std::vector& axes, float labelSmoothing, float factorWeight) - : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 - labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {} + : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 + labelSmoothing_(labelSmoothing), + factorWeight_(factorWeight) {} virtual ~CrossEntropyLoss() {} -protected: - float labelSmoothing_; // interpolation factor for label smoothing, see below - float factorWeight_; // give extra weight to factors - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up +protected: + float labelSmoothing_; // interpolation factor for label smoothing, see below + float factorWeight_; // give extra weight to factors + + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + // logits may be factored; in that case, the getLoss() function computes one loss for each, and + // sums them up int inFactor = false; auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { - logits = atleast_3d(logits); // we always assume a time and batch dimension exists. + logits = atleast_3d(logits); // we always assume a time and batch dimension exists. // for bert training or classification the time dimension is lost. // Here safeguard against 2d classifier output, adds 1 on the left, non-op. - + Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32); - if (inFactor && factorWeight_ != 1.0f) { + if(inFactor && factorWeight_ != 1.0f) { LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_); ce = ce * factorWeight_; } @@ -365,8 +364,10 @@ class CrossEntropyLoss : public LabelwiseLoss { if(labelWeights) { // We currently do not know how to use target factors and word-level label weights together - bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights. - ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors"); + bool wordlevel = labelWeights->shape()[-3] + > 1; // Time-dimension is not trivially 1, hence we have word-level weights. + ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, + "CE loss with word-level label weights is not implemented for factors"); ce = ce * cast(labelWeights, Type::float32); } @@ -374,13 +375,12 @@ class CrossEntropyLoss : public LabelwiseLoss { } }; - /** * @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an * implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319. - * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not - * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going - * to flip over to use SUL for that sentence to penalize the selected word. + * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are + * not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it + * is going to flip over to use SUL for that sentence to penalize the selected word. * * SUL is implemented as: * -log(gather(1 - softmax(logits), -1, indices)) @@ -390,35 +390,45 @@ class CrossEntropyLoss : public LabelwiseLoss { class SequenceUnlikelihoodLoss : public CrossEntropyLoss { public: SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 SequenceUnlikelihoodLoss(const std::vector& axes, float labelSmoothing, float factorWeight) - : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} + : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} protected: - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + auto ce = CrossEntropyLoss::compute( + logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE if(!labelWeights) - return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? + return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? // We currently do not know how to use target factors and word-level label weights together ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors"); - ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete. - // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete) + ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by + // default 1, which would make this obsolete. + // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask + // again to eliminate padding (might be obsolete) auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32); auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { return cast(unlikelihood(logits, indices), Type::float32); }); - // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training - // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL. - auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present - ceUl = errorMask * ceUl; // don't use for correct label or padding + // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only + // on_ the errors with UL. This is the "mixed" training schedule from + // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily + // switch between CE and UL. + auto onlyCe = eq(sum(errorMask, /*axis=*/-3), + 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present + ceUl = errorMask * ceUl; // don't use for correct label or padding - auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry + auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never + // simultanously used as cost per batch entry return cost; } @@ -463,7 +473,6 @@ class RescorerLoss : public CrossEntropyLoss { } }; - /** * @brief Factory for label-wise loss functions */ diff --git a/src/layers/output.cpp b/src/layers/output.cpp index bf8fa5886..1d9c7b4b0 100644 --- a/src/layers/output.cpp +++ b/src/layers/output.cpp @@ -1,120 +1,131 @@ #include "output.h" -#include "data/factored_vocab.h" #include "common/timer.h" -#include "layers/lsh.h" +#include "data/factored_vocab.h" #include "layers/loss.h" +#include "layers/lsh.h" namespace marian { namespace mlp { /*private*/ void Output::lazyConstruct(int inputDim) { - // We must construct lazily since we won't know tying nor input dim in constructor. - if (Wt_) + // We must construct lazily since we won't know tying nor input dim in constructor. + if(Wt_) return; - // this option is only set in the decoder - if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { - auto k = opt>("output-approx-knn")[0]; + // this option is only set in the decoder + if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { + auto k = opt>("output-approx-knn")[0]; auto nbits = opt>("output-approx-knn")[1]; lsh_ = New(k, nbits); - } + } - auto name = options_->get("prefix"); - auto numOutputClasses = options_->get("dim"); + auto name = options_->get("prefix"); + auto numOutputClasses = options_->get("dim"); - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); - if (factoredVocab_) { + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); + if(factoredVocab_) { numOutputClasses = (int)factoredVocab_->factorVocabSize(); LOG_ONCE(info, "[embedding] Factored outputs enabled"); - } + } - if(tiedParam_) { + if(tiedParam_) { Wt_ = tiedParam_; - } else { - if (graph_->get(name + "_W")) { // support of legacy models that did not transpose - Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); - isLegacyUntransposedW = true; - } - else // this is the regular case: - Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); - } + } else { + if(graph_->get(name + "_W")) { // support of legacy models that did not transpose + Wt_ = graph_->param( + name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); + isLegacyUntransposedW = true; + } else // this is the regular case: + Wt_ = graph_->param( + name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); + } - if(hasBias_) + if(hasBias_) b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); - /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); - ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); - if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix + /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); + ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); + if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix #define HARDMAX_HACK #ifdef HARDMAX_HACK - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number #endif auto range = factoredVocab_->getGroupRange(0); auto lemmaVocabDim = (int)(range.second - range.first); - auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length - lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed - } + auto initFunc = inits::glorotUniform( + /*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length + lemmaEt_ = graph_->param(name + "_lemmaEt", + {lemmaDimEmb, lemmaVocabDim}, + initFunc); // [L x U] L=lemmaDimEmb; transposed for speed + } } Logits Output::applyAsLogits(Expr input) /*override final*/ { - lazyConstruct(input->shape()[-1]); + lazyConstruct(input->shape()[-1]); - auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { if(b) - return affine(x, W, b, transA, transB); + return affine(x, W, b, transA, transB); else - return dot(x, W, transA, transB); - }; + return dot(x, W, transA, transB); + }; - auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { + auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { if(lsh_) { - ABORT_IF( transA, "Transposed query not supported for LSH"); - ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); - return lsh_->apply(x, W, b); // knows how to deal with undefined bias + ABORT_IF(transA, "Transposed query not supported for LSH"); + ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); + return lsh_->apply(x, W, b); // knows how to deal with undefined bias } else { - return affineOrDot(x, W, b, transA, transB); + return affineOrDot(x, W, b, transA, transB); } - }; + }; - if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed - cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); + if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one + // batch, then clear()ed + cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); if(hasBias_) - cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); - } + cachedShortb_ = index_select(b_, -1, shortlist_->indices()); + } - if (factoredVocab_) { + if(factoredVocab_) { auto graph = input->graph(); // project each factor separately auto numGroups = factoredVocab_->getNumGroups(); - std::vector> allLogits(numGroups, nullptr); // (note: null entries for absent factors) - Expr input1 = input; // [B... x D] - Expr Plemma = nullptr; // used for lemmaDimEmb=-1 - Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 - for (size_t g = 0; g < numGroups; g++) { - auto range = factoredVocab_->getGroupRange(g); - if (g > 0 && range.first == range.second) // empty entry + std::vector> allLogits(numGroups, + nullptr); // (note: null entries for absent factors) + Expr input1 = input; // [B... x D] + Expr Plemma = nullptr; // used for lemmaDimEmb=-1 + Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 + for(size_t g = 0; g < numGroups; g++) { + auto range = factoredVocab_->getGroupRange(g); + if(g > 0 && range.first == range.second) // empty entry continue; - ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g); - // slice this group's section out of W_ - Expr factorWt, factorB; - if (g == 0 && shortlist_) { + ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second, + "Factor groups must be consecutive (group {} vs predecessor)", + g); + // slice this group's section out of W_ + Expr factorWt, factorB; + if(g == 0 && shortlist_) { factorWt = cachedShortWt_; - factorB = cachedShortb_; - } - else { - factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); + factorB = cachedShortb_; + } else { + factorWt = slice( + Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); if(hasBias_) - factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); - } - /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); - if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) + factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); + } + /*const*/ int lemmaDimEmb = options_->get("lemma-dim-emb", 0); + if((lemmaDimEmb == -2 || lemmaDimEmb == -3) + && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); // this mimics one transformer layer // - attention over two inputs: - // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas. + // - e = current lemma. We use the original embedding vector; specifically, expectation + // over all lemmas. // - input = hidden state FF(h_enc+h_dec) - // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention) + // - dot-prod attention to allow both sides to influence (unlike our recurrent + // self-attention) // - multi-head to allow for multiple conditions to be modeled // - add & norm, for gradient flow and scaling // - FF layer --this is expensive; it is per-factor @@ -122,112 +133,161 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ { int inputDim = input->shape()[-1]; int heads = 8; auto name = options_->get("prefix") + "_factor" + std::to_string(g); - auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform()); - auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform()); - auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform()); + auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform()); + auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform()); + auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform()); auto toMultiHead = [&](Expr x, int heads) { - const auto& shape = x->shape(); - int inputDim = shape[-1]; - int otherDim = shape.elements() / inputDim; - ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads); - return reshape(x, { otherDim, heads, 1, inputDim / heads }); + const auto& shape = x->shape(); + int inputDim = shape[-1]; + int otherDim = shape.elements() / inputDim; + ABORT_IF(inputDim / heads * heads != inputDim, + "inputDim ({}) must be multiple of number of heads ({})", + inputDim, + heads); + return reshape(x, {otherDim, heads, 1, inputDim / heads}); }; input1 = inputLemma; - auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query - auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax. - auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values - auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other - auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] - auto sm = sigmoid(zm); // [B... x H x 1] - auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] - auto r = reshape(rm, input->shape()); // [B... x D] + auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query + auto kdm = toMultiHead(dot(input1 - input, Wk), + heads); // [B... x H x D/H] the two data vectors projected as keys. + // Use diff and sigmoid, instead of softmax. + auto vem = toMultiHead( + dot(input1, Wv), + heads); // [B... x H x D/H] one of the two data vectors projected as values + auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other + auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] + auto sm = sigmoid(zm); // [B... x H x 1] + auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] + auto r = reshape(rm, input->shape()); // [B... x D] // add & norm input1 = r + input1; input1 = layerNorm(input1, name + "_att"); // FF layer - auto ffnDropProb = 0.1f; // @TODO: get as a parameter - auto ffnDim = inputDim * 2; // @TODO: get as a parameter - auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb); - f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); + auto ffnDropProb = 0.1f; // @TODO: get as a parameter + auto ffnDim = inputDim * 2; // @TODO: get as a parameter + auto f = denseInline(input1, + name + "_ffn", + /*suffix=*/"1", + ffnDim, + inits::glorotUniform(), + (ActivationFunction*)relu, + ffnDropProb); + f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); // add & norm input1 = f + input1; input1 = layerNorm(input1, name + "_ffn"); - } - // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix - Expr factorLogits; - if(g == 0) - factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - else - factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - - // optionally add lemma-dependent bias - if (Plemma) { // [B... x U0] + } + // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a + // matrix + Expr factorLogits; + if(g == 0) + factorLogits = affineOrLSH( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + else + factorLogits = affineOrDot( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + + // optionally add lemma-dependent bias + if(Plemma) { // [B... x U0] int lemmaVocabDim = Plemma->shape()[-1]; int factorVocabDim = factorLogits->shape()[-1]; auto name = options_->get("prefix"); - Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma - auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] + Expr lemmaBt + = graph_->param(name + "_lemmaBt_" + std::to_string(g), + {factorVocabDim, lemmaVocabDim}, + inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma + auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] factorLogits = factorLogits + b; - } - allLogits[g] = New(factorLogits, nullptr); - // optionally add a soft embedding of lemma back to create some lemma dependency - // @TODO: if this works, move it into lazyConstruct - if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure + } + allLogits[g] = New(factorLogits, nullptr); + // optionally add a soft embedding of lemma back to create some lemma dependency + // @TODO: if this works, move it into lazyConstruct + if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); // get expected lemma embedding vector - auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorLogSoftmax = logsoftmax( + factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set auto factorSoftmax = exp(factorLogSoftmax); - inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max + inputLemma = dot(factorSoftmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); // get max-lemma embedding vector - auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set + auto maxVal = max(factorLogits, + -1); // [B... x U] note: with shortlist, this is not the full lemma set auto factorHardmax = eq(factorLogits, maxVal); - inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias + inputLemma = dot(factorHardmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); LOG_ONCE(info, "[embedding] using lemma-dependent bias"); - auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) - auto z = /*stopGradient*/(factorLogSoftmax); - Plemma = exp(z); // [B... x U] - } - else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix + auto factorLogSoftmax + = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) + auto z = /*stopGradient*/ (factorLogSoftmax); + Plemma = exp(z); // [B... x U] + } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); - // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE + // compute softmax. We compute logsoftmax() separately because this way, computation will be + // reused later via CSE auto factorLogSoftmax = logsoftmax(factorLogits); auto factorSoftmax = exp(factorLogSoftmax); #ifdef HARDMAX_HACK - bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) - if (hardmax) { - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; - LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); - auto maxVal = max(factorSoftmax, -1); - factorSoftmax = eq(factorSoftmax, maxVal); + bool hardmax = (lemmaDimEmb & 1) + != 0; // odd value triggers hardmax for now (for quick experimentation) + if(hardmax) { + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; + LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); + auto maxVal = max(factorSoftmax, -1); + factorSoftmax = eq(factorSoftmax, maxVal); } #endif // re-embedding lookup, soft-indexed by softmax - if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix - cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); - auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] + if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix + cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); + auto e = dot(factorSoftmax, + cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, + false, + true); // [B... x L] // project it back to regular hidden dim int inputDim = input1->shape()[-1]; auto name = options_->get("prefix"); - // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1 - Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension - auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] + // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also + // length 1 + Expr lemmaWt + = inputDim == lemmaDimEmb + ? nullptr + : graph_->param(name + "_lemmaWt", + {inputDim, lemmaDimEmb}, + inits::glorotUniform()); // [D x L] D=hidden-vector dimension + auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] // augment the original hidden vector with this additional information input1 = input1 + f; - } + } } return Logits(std::move(allLogits), factoredVocab_); - } else if (shortlist_) { - return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } else { - return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } + } else if(shortlist_) { + return Logits(affineOrLSH(input, + cachedShortWt_, + cachedShortb_, + false, + /*transB=*/isLegacyUntransposedW ? false : true)); + } else { + return Logits( + affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } } -} -} \ No newline at end of file +} // namespace mlp +} // namespace marian \ No newline at end of file diff --git a/src/layers/output.h b/src/layers/output.h index 92e7eb25e..2b6f49861 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -1,10 +1,10 @@ #pragma once -#include "marian.h" -#include "generic.h" -#include "logits.h" #include "data/shortlist.h" +#include "generic.h" #include "layers/factory.h" +#include "logits.h" +#include "marian.h" namespace marian { class LSH; @@ -14,42 +14,45 @@ namespace mlp { class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { private: // parameters held by this layer - Expr Wt_; // weight matrix is stored transposed for efficiency + Expr Wt_; // weight matrix is stored transposed for efficiency Expr b_; - Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] - bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form + Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] + bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form bool hasBias_{true}; Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; Ptr factoredVocab_; - + // optional parameters set/updated after construction Expr tiedParam_; Ptr shortlist_; Ptr lsh_; void lazyConstruct(int inputDim); + public: Output(Ptr graph, Ptr options) - : LayerBase(graph, options), - hasBias_{!options->get("output-omit-bias", false)} { + : LayerBase(graph, options), hasBias_{!options->get("output-omit-bias", false)} { clear(); } void tieTransposed(Expr tied) { - if (Wt_) - ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created"); + if(Wt_) + ABORT_IF(tiedParam_.get() != tied.get(), + "Tied output projection cannot be changed once weights have been created"); else tiedParam_ = tied; } void setShortlist(Ptr shortlist) override final { - if (shortlist_) - ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); + if(shortlist_) + ABORT_IF(shortlist.get() != shortlist_.get(), + "Output shortlist cannot be changed except after clear()"); else { - ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??"); + ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, + "No shortlist but cached parameters??"); shortlist_ = shortlist; } // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() @@ -60,7 +63,7 @@ class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { void clear() override final { shortlist_ = nullptr; cachedShortWt_ = nullptr; - cachedShortb_ = nullptr; + cachedShortb_ = nullptr; cachedShortLemmaEt_ = nullptr; } @@ -69,6 +72,4 @@ class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { } // namespace mlp -} - - +} // namespace marian diff --git a/src/models/costs.cpp b/src/models/costs.cpp index 5105f5904..c688b2119 100644 --- a/src/models/costs.cpp +++ b/src/models/costs.cpp @@ -4,13 +4,11 @@ namespace marian { namespace models { Ptr LogSoftmaxStep::apply(Ptr state) { -// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) -state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); -// @TODO: This is becoming more and more opaque ^^. Can we simplify this? -return state; -} - - -} + // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) + state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); + // @TODO: This is becoming more and more opaque ^^. Can we simplify this? + return state; } +} // namespace models +} // namespace marian diff --git a/src/models/costs.h b/src/models/costs.h index 2d34c53a9..e5463bfd0 100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -4,8 +4,8 @@ #include "layers/guided_alignment.h" #include "layers/loss.h" #include "layers/weight.h" -#include "models/encoder_decoder.h" #include "models/encoder_classifier.h" +#include "models/encoder_decoder.h" #include "models/encoder_pooler.h" namespace marian { @@ -22,10 +22,12 @@ namespace models { class ICost { public: - virtual Ptr apply(Ptr model, - Ptr graph, // @TODO: why needed? Can it be gotten from model? - Ptr batch, - bool clearGraph = true) = 0; + virtual Ptr apply( + Ptr model, + Ptr graph, // @TODO: why needed? Can it be gotten from model? + Ptr batch, + bool clearGraph = true) + = 0; virtual ~ICost() {} }; @@ -45,10 +47,9 @@ class EncoderDecoderCECost : public ICost { : options_(options), inference_(options->get("inference", false)) { loss_ = newLoss(options_, inference_); - toBeWeighted_ - = (options_->hasAndNotEmpty("data-weighting") && !inference_) - || (options_->has("dynamic-weighting") && options_->get("dynamic-weighting") - && !inference_); + toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_) + || (options_->has("dynamic-weighting") + && options_->get("dynamic-weighting") && !inference_); if(toBeWeighted_) weighter_ = WeightingFactory(options_); } @@ -56,9 +57,9 @@ class EncoderDecoderCECost : public ICost { virtual ~EncoderDecoderCECost() {} Ptr apply(Ptr model, - Ptr graph, - Ptr batch, - bool clearGraph = true) override { + Ptr graph, + Ptr batch, + bool clearGraph = true) override { auto encdec = std::static_pointer_cast(model); auto corpusBatch = std::static_pointer_cast(batch); @@ -72,17 +73,17 @@ class EncoderDecoderCECost : public ICost { Ptr multiLoss = newMultiLoss(options_); // @TODO: adapt to multi-objective training with multiple decoders - auto partialLoss = loss_->apply(state->getLogProbs(), - state->getTargetWords(), - state->getTargetMask(), - weights); + auto partialLoss = loss_->apply( + state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights); multiLoss->push_back(partialLoss); if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) { - auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] + auto attentionVectors + = encdec->getDecoders()[0] + ->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments"); - auto attention = concatenate(attentionVectors, /*axis =*/ -1); + auto attention = concatenate(attentionVectors, /*axis =*/-1); auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention); multiLoss->push_back(alignmentLoss); @@ -109,10 +110,9 @@ class EncoderClassifierCECost : public ICost { } Ptr apply(Ptr model, - Ptr graph, - Ptr batch, - bool clearGraph = true) override { - + Ptr graph, + Ptr batch, + bool clearGraph = true) override { auto enccls = std::static_pointer_cast(model); auto corpusBatch = std::static_pointer_cast(batch); @@ -141,21 +141,20 @@ class EncoderPoolerRankCost : public ICost { public: EncoderPoolerRankCost(Ptr options) - : options_(options), - inference_(options->get("inference", false)) { - auto trainEmbedderRank = options->get>("train-embedder-rank", {}); - ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set"); - - margin_ = std::stof(trainEmbedderRank[0]); - if(trainEmbedderRank.size() > 1) - normalizer_ = std::stof(trainEmbedderRank[1]); + : options_(options), inference_(options->get("inference", false)) { + auto trainEmbedderRank = options->get>("train-embedder-rank", {}); + ABORT_IF(trainEmbedderRank.empty(), + "EncoderPoolerRankCost expects train-embedder-rank to be set"); + + margin_ = std::stof(trainEmbedderRank[0]); + if(trainEmbedderRank.size() > 1) + normalizer_ = std::stof(trainEmbedderRank[1]); } Ptr apply(Ptr model, Ptr graph, Ptr batch, bool clearGraph = true) override { - auto encpool = std::static_pointer_cast(model); auto corpusBatch = std::static_pointer_cast(batch); std::vector dotProducts = encpool->apply(graph, corpusBatch, clearGraph); @@ -167,28 +166,41 @@ class EncoderPoolerRankCost : public ICost { ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss"); // multi-objective training - auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability - auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product + auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability + auto exponent + = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product auto dp = exp(exponent); Expr dn1, dn2; - if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator. - dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + if(normalizer_ + != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the + // magnitude of the sum of negative examples in the denominator. + dn1 = normalizer_ + * mean(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = normalizer_ + * mean(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } else { - dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + dn1 = sum(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = sum(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } // We rewrite the loss so it looks more like a log-softmax, presumably more stable? - // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) - auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples - auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples - auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); - + // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + // + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) + auto marginLoss1 + = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples + auto marginLoss2 + = log(dp + dn2) + - exponent; // symmetric version of the above with positive example vs negative examples + auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); + RationalLoss loss(marginLoss, (float)dimBatch); multiLoss->push_back(loss); - + return multiLoss; } }; @@ -199,8 +211,7 @@ class Trainer : public ICriterionFunction { Ptr cost_; public: - Trainer(Ptr model, Ptr cost) - : model_(model), cost_(cost) {} + Trainer(Ptr model, Ptr cost) : model_(model), cost_(cost) {} virtual ~Trainer() {} @@ -219,8 +230,8 @@ class Trainer : public ICriterionFunction { } virtual Ptr build(Ptr graph, - Ptr batch, - bool clearGraph = true) override { + Ptr batch, + bool clearGraph = true) override { return cost_->apply(model_, graph, batch, clearGraph); }; @@ -230,24 +241,25 @@ class Trainer : public ICriterionFunction { class ILogProb { public: virtual Logits apply(Ptr model, - Ptr graph, - Ptr batch, - bool clearGraph = true) = 0; + Ptr graph, + Ptr batch, + bool clearGraph = true) + = 0; }; -// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth? -// Beam search uses it for the former meaning, while 'marian score' and validation in the latter. -// This class is for the former use. The latter is done using Trainer. +// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for +// the ground truth? +// Beam search uses it for the former meaning, while 'marian score' and validation in the +// latter. This class is for the former use. The latter is done using Trainer. class Scorer : public IModel { protected: Ptr model_; Ptr logProb_; public: - Scorer(Ptr model, Ptr cost) - : model_(model), logProb_(cost) {} + Scorer(Ptr model, Ptr cost) : model_(model), logProb_(cost) {} - virtual ~Scorer(){} + virtual ~Scorer() {} Ptr getModel() { return model_; } @@ -264,8 +276,8 @@ class Scorer : public IModel { } virtual Logits build(Ptr graph, - Ptr batch, - bool clearGraph = true) override { + Ptr batch, + bool clearGraph = true) override { return logProb_->apply(model_, graph, batch, clearGraph); }; @@ -293,10 +305,10 @@ class GumbelSoftmaxStep : public ILogProbStep { virtual ~GumbelSoftmaxStep() {} virtual Ptr apply(Ptr state) override { state->setLogProbs(state->getLogProbs().applyUnaryFunctions( - [](Expr logits){ // lemma gets gumbelled - return logsoftmax(logits + constant_like(logits, inits::gumbel())); - }, - logsoftmax)); // factors don't + [](Expr logits) { // lemma gets gumbelled + return logsoftmax(logits + constant_like(logits, inits::gumbel())); + }, + logsoftmax)); // factors don't return state; } }; @@ -311,8 +323,7 @@ class Stepwise : public IEncoderDecoder { Ptr cost_; public: - Stepwise(Ptr encdec, Ptr cost) - : encdec_(encdec), cost_(cost) {} + Stepwise(Ptr encdec, Ptr cost) : encdec_(encdec), cost_(cost) {} virtual void load(Ptr graph, const std::string& name, @@ -346,12 +357,13 @@ class Stepwise : public IEncoderDecoder { return encdec_->startState(graph, batch); } - virtual Ptr step(Ptr graph, - Ptr state, - const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const Words& words, // [beamIndex * activeBatchSize + batchIndex] - const std::vector& batchIndices, // [batchIndex] - int beamSize) override { + virtual Ptr step( + Ptr graph, + Ptr state, + const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const Words& words, // [beamIndex * activeBatchSize + batchIndex] + const std::vector& batchIndices, // [batchIndex] + int beamSize) override { auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize); return cost_->apply(nextState); } @@ -369,9 +381,7 @@ class Stepwise : public IEncoderDecoder { encdec_->setShortlistGenerator(shortlistGenerator); }; - virtual Ptr getShortlist() override { - return encdec_->getShortlist(); - }; + virtual Ptr getShortlist() override { return encdec_->getShortlist(); }; virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); } }; diff --git a/src/models/states.h b/src/models/states.h index cfb6fd1b8..20dd59c95 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "marian.h" -#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian { @@ -9,7 +9,7 @@ namespace marian { class EncoderState { private: Expr context_; - Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask + Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask Ptr batch_; public: @@ -19,31 +19,34 @@ class EncoderState { EncoderState() {} virtual ~EncoderState() {} - virtual Expr getContext() const { return context_; } - virtual Expr getAttended() const { return context_; } - virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed + virtual Expr getContext() const { return context_; } + virtual Expr getAttended() const { return context_; } + virtual Expr getMask() const { + return mask_; + } // source batch mask; may have additional positions suppressed - virtual const Words& getSourceWords() { - return batch_->front()->data(); - } + virtual const Words& getSourceWords() { return batch_->front()->data(); } // Sub-select active batch entries from encoder context and context mask - Ptr select(const std::vector& batchIndices) { // [batchIndex] indices of active batch entries - // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout - return New(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); + Ptr select( + const std::vector& batchIndices) { // [batchIndex] indices of active batch entries + // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer + // gets transposed to the same dimension layout + return New( + index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); } }; class DecoderState { protected: - rnn::States states_; // states of individual decoder layers + rnn::States states_; // states of individual decoder layers Logits logProbs_; std::vector> encStates_; Ptr batch_; - Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded + Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded Expr targetMask_; - Words targetWords_; // target labels + Words targetWords_; // target labels // Keep track of current target token position during translation size_t position_{0}; @@ -57,26 +60,30 @@ class DecoderState { virtual ~DecoderState() {} // @TODO: Do we need all these to be virtual? - virtual const std::vector>& getEncoderStates() const { - return encStates_; - } + virtual const std::vector>& getEncoderStates() const { return encStates_; } virtual Logits getLogProbs() const { return logProbs_; } virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; } - // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop - virtual Ptr select(const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const std::vector& batchIndices, // [batchIndex] - int beamSize) const { - + // @TODO: should this be a constructor? Then derived classes can call this without the New<> in + // the loop + virtual Ptr select( + const std::vector& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const std::vector& batchIndices, // [batchIndex] + int beamSize) const { std::vector> newEncStates; for(auto& es : encStates_) - // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries - newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); + // If the size of the batch dimension of the encoder state context changed, subselect the + // correct batch entries + newEncStates.push_back( + es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices - auto selectedState = New( - states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_); + auto selectedState + = New(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), + logProbs_, + newEncStates, + batch_); // Set positon of new state based on the target token position of current state selectedState->setPosition(getPosition()); @@ -86,7 +93,9 @@ class DecoderState { virtual const rnn::States& getStates() const { return states_; } virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; }; - virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; } + virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { + targetHistoryEmbeddings_ = targetHistoryEmbeddings; + } virtual const Words& getTargetWords() const { return targetWords_; }; virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; } @@ -94,9 +103,7 @@ class DecoderState { virtual Expr getTargetMask() const { return targetMask_; }; virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; } - virtual const Words& getSourceWords() const { - return getEncoderStates()[0]->getSourceWords(); - } + virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); } Ptr getBatch() const { return batch_; } @@ -111,7 +118,8 @@ class DecoderState { /** * Classifier output based on DecoderState - * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output. + * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have + * stateful output. */ class ClassifierState { private: From cd018e8d0404687c0bd13f64962bd22617b80331 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Mon, 8 Mar 2021 03:09:03 -0800 Subject: [PATCH 11/14] Update formatting --- .clang-format | 4 +- src/layers/embedding.cpp | 77 +++++++++++++++-------------------- src/layers/embedding.h | 87 ++++++++++++++++++---------------------- src/layers/logits.cpp | 31 ++++++-------- 4 files changed, 87 insertions(+), 112 deletions(-) diff --git a/.clang-format b/.clang-format index 670df0753..bda0b0e0f 100644 --- a/.clang-format +++ b/.clang-format @@ -3,7 +3,7 @@ Language: Cpp # BasedOnStyle: Google AccessModifierOffset: -2 AlignAfterOpenBracket: Align -AlignConsecutiveAssignments: false +AlignConsecutiveAssignments: true AlignConsecutiveDeclarations: false AlignEscapedNewlinesLeft: true AlignOperands: true @@ -71,7 +71,7 @@ PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 PointerAlignment: Left -ReflowComments: true +ReflowComments: false SortIncludes: true SpaceAfterCStyleCast: false SpaceBeforeAssignmentOperators: true diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp index 5a448f611..92c4ad6d2 100644 --- a/src/layers/embedding.cpp +++ b/src/layers/embedding.cpp @@ -6,8 +6,8 @@ namespace marian { Embedding::Embedding(Ptr graph, Ptr options) : LayerBase(graph, options), inference_(opt("inference")) { std::string name = opt("prefix"); - int dimVoc = opt("dimVocab"); - int dimEmb = opt("dimEmb"); + int dimVoc = opt("dimVocab"); + int dimEmb = opt("dimEmb"); bool fixed = opt("fixed", false); @@ -25,7 +25,7 @@ Embedding::Embedding(Ptr graph, Ptr options) std::string file = opt("embFile"); if(!file.empty()) { bool norm = opt("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); } } @@ -34,7 +34,7 @@ Embedding::Embedding(Ptr graph, Ptr options) // helper to embed a sequence of words (given as indices) via factored embeddings Expr Embedding::multiRows(const Words& data, float dropProb) const { - auto graph = E_->graph(); + auto graph = E_->graph(); auto factoredData = factoredVocab_->csr_rows(data); // multi-hot factor vectors are represented as a sparse CSR matrix // [row index = word position index] -> set of factor indices for word at this position @@ -59,9 +59,9 @@ Expr Embedding::multiRows(const Words& data, float dropProb) const { std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { - auto graph = E_->graph(); + auto graph = E_->graph(); int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; + int dimEmb = E_->shape()[-1]; int dimWidth = (int)subBatch->batchWidth(); // factored embeddings: @@ -113,7 +113,7 @@ std::tuple Embedding::apply(Ptrget("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout return selectedEmbs; @@ -128,7 +128,7 @@ Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() // (test that separately) if(!inference_) @@ -139,22 +139,17 @@ Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& // standard encoder word embeddings /*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { + // clang-format off auto options = New( - "dimVocab", - opt>("dim-vocabs")[batchIndex_], - "dimEmb", - opt("dim-emb"), - "dropout", - dropoutEmbeddings_, - "inference", - inference_, - "prefix", - (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" - : prefix_ + "_Wemb", - "fixed", - embeddingFix_, - "vocab", - opt>("vocabs")[batchIndex_]); // for factored embeddings + "dimVocab", opt>("dim-vocabs")[batchIndex_], + "dimEmb", opt("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" + : prefix_ + "_Wemb", + "fixed", embeddingFix_, + "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings + // clang-format on if(options_->hasAndNotEmpty("embedding-vectors")) { auto embFiles = opt>("embedding-vectors"); options->set( @@ -165,28 +160,20 @@ Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& // ULR word embeddings /*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { - return New( - graph_, - New("dimSrcVoc", - opt>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", - opt>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", - opt("ulr-dim-emb"), - "dimEmb", - opt("dim-emb"), - "ulr-dropout", - opt("ulr-dropout"), - "dropout", - dropoutEmbeddings_, - "inference", - inference_, - "ulrTrainTransform", - opt("ulr-trainable-transformation"), - "ulrQueryFile", - opt("ulr-query-vectors"), - "ulrKeysFile", - opt("ulr-keys-vectors"))); + // clang-format off + return New(graph_, New( + "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", opt("ulr-dim-emb"), + "dimEmb", opt("dim-emb"), + "ulr-dropout", opt("ulr-dropout"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "ulrTrainTransform", opt("ulr-trainable-transformation"), + "ulrQueryFile", opt("ulr-query-vectors"), + "ulrKeysFile", opt("ulr-keys-vectors") + )); + // clang-format on } // get embedding layer for this encoder or decoder diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 6edb31409..2fa7b78d0 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -28,47 +28,45 @@ class Embedding : public LayerBase, public IEmbeddingLayer { }; class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector - ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members bool inference_{false}; public: ULREmbedding(Ptr graph, Ptr options) : LayerBase(graph, options), inference_(opt("inference")) { std::string name = "url_embed"; // opt("prefix"); - int dimKeys = opt("dimTgtVoc"); - int dimQueries = opt("dimSrcVoc"); - int dimEmb = opt("dimEmb"); - int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size - bool fixed = opt("fixed", false); + int dimKeys = opt("dimTgtVoc"); + int dimQueries = opt("dimSrcVoc"); + int dimEmb = opt("dimEmb"); + int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size + bool fixed = opt("fixed", false); // Embedding layer initialization should depend only on embedding size, hence fanIn=false auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); std::string queryFile = opt("ulrQueryFile"); - std::string keyFile = opt("ulrKeysFile"); - bool trainTrans = opt("ulrTrainTransform", false); + std::string keyFile = opt("ulrKeysFile"); + bool trainTrans = opt("ulrTrainTransform", false); if(!queryFile.empty() && !keyFile.empty()) { - initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); - name = "ulr_query"; - fixed = true; + initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); + name = "ulr_query"; + fixed = true; auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(query_embed); // keys embeds - initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); - name = "ulr_keys"; - fixed = true; + initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); + name = "ulr_keys"; + fixed = true; auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(key_embed); // actual trainable embedding initFunc = inits::glorotUniform(); - name = "ulr_embed"; - fixed = false; - auto ulr_embed - = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim + name = "ulr_embed"; + fixed = false; + auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim ulrEmbeddings_.push_back(ulr_embed); // init trainable src embedding - name = "ulr_src_embed"; + name = "ulr_src_embed"; auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulr_src_embed); // ulr transformation matrix @@ -76,20 +74,20 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { // we make this to the fixed case only if(trainTrans) { initFunc = inits::glorotUniform(); - fixed = false; + fixed = false; } else { initFunc = inits::eye(); // identity matrix - fixed = true; + fixed = true; } - name = "ulr_transform"; + name = "ulr_transform"; auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); ulrEmbeddings_.push_back(ulrTransform); initFunc = inits::fromValue( 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no // universal embeddings - should be zero for top freq only - fixed = true; - name = "ulr_shared"; + fixed = true; + name = "ulr_shared"; auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); ulrEmbeddings_.push_back(share_embed); } @@ -97,15 +95,15 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { std::tuple apply( Ptr subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = uniEmbed->shape()[-1]; - int dimWords = (int)subBatch->batchWidth(); + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = uniEmbed->shape()[-1]; + int dimWords = (int)subBatch->batchWidth(); // D = K.A.QT // dimm(K) = univ_tok_vocab*uni_embed_size // dim A = uni_embed_size*uni_embed_size @@ -114,18 +112,15 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { // note all above can be precombuted and serialized if A is not trainiable and during decoding // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this // minibatch from Q - auto embIdx = toWordIndexVector(subBatch->data()); + auto embIdx = toWordIndexVector(subBatch->data()); auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, - ulrTransform, - false, - false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, ulrTransform, false, false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in // magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity + auto z = dot(qt, keyEmbed, false, true); // query-key similarity float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout if(!inference_) z = dropout(z, dropProb); @@ -135,13 +130,11 @@ class ULREmbedding : public LayerBase, public IEmbeddingLayer { // temperature in softmax is to control randomness of predictions // high temperature Softmax outputs are more close to each other // low temperatures the softmax become more similar to "hardmax" - auto weights - = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix - = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); - auto graph = ulrEmbeddings_.front()->graph(); + auto graph = ulrEmbeddings_.front()->graph(); auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); if(!inference_) batchEmbeddings = dropout(batchEmbeddings, diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp index 772c57150..8c4d69bde 100644 --- a/src/layers/logits.cpp +++ b/src/layers/logits.cpp @@ -48,17 +48,14 @@ Expr Logits::applyLossFunction( for(size_t g = 0; g < numGroups; g++) { if(!logits_[g]) continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices( - maskedFactoredLabels - .indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask - = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with - // 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask - // it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + // clang-format off + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + // clang-format on if(loss) factorLoss = cast(factorLoss, loss->value_type()); factorLoss @@ -140,6 +137,7 @@ Expr Logits::getLogits() const { logProbs[g] = logsoftmax(logits_[g]->loss()); auto y = concatenate(logProbs, /*axis=*/-1); + // clang-format off // sum up the unit logits across factors for each target word auto graph = y->graph(); auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] @@ -147,13 +145,10 @@ Expr Logits::getLogits() const { y, // [B x U] factorMatrix.shape, graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, - inits::fromVector(factorMatrix.indices), - Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, - inits::fromVector(factorMatrix.offsets), - Type::uint32), + graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), /*transB=*/true); // -> [B x V] + // clang-format on // mask out gaps auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] @@ -247,4 +242,4 @@ Logits Logits::withCounts( newLogits.emplace_back(New(l->loss(), count)); return Logits(std::move(newLogits), factoredVocab_); } -} // namespace marian \ No newline at end of file +} // namespace marian From a1aaa32c6af15f9bc653b6ceeb81ac1767ef3f39 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Wed, 17 Mar 2021 17:34:09 +0000 Subject: [PATCH 12/14] Merged PR 18201: Install Boost in Azure pipelines Installing Boost manually in all workflows, because it has been recently removed from Azure/GitHub hosted runners. This should fix recent failures of Marian CI builds. --- CMakeLists.txt | 17 ++++++++++------- azure-pipelines.yml | 34 +++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 057e10a48..4ee339781 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -327,7 +327,7 @@ if(CUDA_FOUND) if(USE_STATIC_LIBS) set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusparse_LIBRARY}) set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusparse_LIBRARY}) - + find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) # The cuLIBOS library does not seem to exist in Windows CUDA toolkit installs if(CUDA_culibos_LIBRARY) @@ -504,8 +504,8 @@ if(USE_STATIC_LIBS) endif() # Find MPI -if(USE_MPI) - # 2.0 refers to MPI2 standard. OpenMPI is an implementation of that standard regardless of the specific OpenMPI version +if(USE_MPI) + # 2.0 refers to MPI2 standard. OpenMPI is an implementation of that standard regardless of the specific OpenMPI version # e.g. OpenMPI 1.10 implements MPI2 and will be found correctly. find_package(MPI 2.0 REQUIRED) if(MPI_FOUND) @@ -518,19 +518,22 @@ if(USE_MPI) endif(MPI_FOUND) endif(USE_MPI) -# TODO: move inside if(BOOST_COMPONENTS) -if(USE_STATIC_LIBS) - set(Boost_USE_STATIC_LIBS ON) -endif() ############################################################################### # Find Boost if required if(BOOST_COMPONENTS) + if(USE_STATIC_LIBS) + set(Boost_USE_STATIC_LIBS ON) + endif() + find_package(Boost COMPONENTS ${BOOST_COMPONENTS}) if(Boost_FOUND) include_directories(${Boost_INCLUDE_DIRS}) set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES}) set(EXT_LIBS ${EXT_LIBS} ${ZLIB_LIBRARIES}) # hack for static compilation + if(MSVC) + add_definitions(-DBOOST_ALL_NO_LIB=1) # hack for missing date-time stub + endif() else(Boost_FOUND) message(SEND_ERROR "Cannot find Boost libraries. Terminating.") endif(Boost_FOUND) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ceb6475d5..a32a82884 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -15,6 +15,7 @@ pool: name: Azure Pipelines variables: + BOOST_ROOT_WINDOWS: "C:/hostedtoolcache/windows/Boost/1.72.0/x86_64" CUDA_PATH_WINDOWS: "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA" MKL_DIR: "$(Build.SourcesDirectory)/mkl" MKL_URL: "https://romang.blob.core.windows.net/mariandev/ci/mkl-2020.1-windows-static.zip" @@ -69,6 +70,14 @@ stages: # key: 'v0 | "$(VCPKG_PACKAGES)" | vcpkg | "$(Agent.OS)"' # path: $(VCPKG_DIR) + # Boost is no longer pre-installed on Azure/GitHub-hosted Windows runners + - pwsh: | + Write-Host "Downloading Boost to $(BOOST_ROOT_WINDOWS)" + $Url = "https://sourceforge.net/projects/boost/files/boost-binaries/1.72.0/boost_1_72_0-msvc-14.2-64.exe" + C:\msys64\usr\bin\wget.exe -nv $Url -O "$(Pipeline.Workspace)/boost.exe" + Start-Process -Wait -FilePath "$(Pipeline.Workspace)/boost.exe" "/SILENT","/SP-","/SUPPRESSMSGBOXES","/DIR=$(BOOST_ROOT_WINDOWS)" + displayName: Download Boost + - pwsh: | git clone https://github.com/Microsoft/vcpkg.git $(VCPKG_DIR) cd $(VCPKG_DIR) @@ -121,9 +130,7 @@ stages: # Set envvars so that CMake can find the installed packages MKLROOT: $(MKL_DIR) CUDA_PATH: $(CUDA_PATH_WINDOWS)/v$(cuda_version) - # Boost is pre-installed on Azure/GitHub-hosted Windows runners - # https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#boost - BOOST_ROOT: $(BOOST_ROOT_1_72_0) + BOOST_ROOT: $(BOOST_ROOT_WINDOWS) - script: | call "$(VS_PATH)/VC/Auxiliary/Build/vcvarsall.bat" x64 @@ -226,12 +233,18 @@ stages: - checkout: self submodules: true - # The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev - # No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev + # The following packages are already installed on Azure-hosted runners: build-essential openssl libssl-dev + # No need to install libprotobuf{17,10,9v5} on Ubuntu {20,18,16}.04 because it is installed together with libprotobuf-dev - bash: sudo apt-get install -y libgoogle-perftools-dev libprotobuf-dev protobuf-compiler displayName: Install packages - # https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html + # Boost is no longer pre-installed on Azure/GitHub-hosted runners + # TODO: check which Boost components are really needed and update the list + - bash: sudo apt-get install -y libboost-system-dev + displayName: Install Boost + condition: eq(variables.boost, true) + + # https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html - bash: | wget -qO- "https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB" | sudo apt-key add - sudo sh -c "echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list" @@ -240,13 +253,11 @@ stages: displayName: Install MKL condition: eq(variables.cpu, true) - # The script simplifies installation of different versions of CUDA + # The script simplifies installation of different versions of CUDA - bash: ./scripts/ci/install_cuda_ubuntu.sh $(cuda) displayName: Install CUDA condition: eq(variables.gpu, true) - # Boost is already installed on Azure-hosted runners in a non-standard location - # https://github.com/actions/virtual-environments/issues/687#issuecomment-610471671 - bash: | mkdir -p build cd build @@ -260,9 +271,6 @@ stages: -DUSE_FBGEMM=$(cpu) \ -DUSE_SENTENCEPIECE=on \ -DUSE_STATIC_LIBS=$(static) \ - -DBOOST_ROOT=$BOOST_ROOT_1_72_0 \ - -DBOOST_INCLUDEDIR=$BOOST_ROOT_1_72_0/include \ - -DBOOST_LIBRARYDIR=$BOOST_ROOT_1_72_0/lib \ -DBoost_ARCHITECTURE=-x64 \ -DCUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-$(cuda) displayName: Configure CMake @@ -346,7 +354,7 @@ stages: - checkout: self submodules: true - - bash: brew install openblas protobuf + - bash: brew install boost openblas openssl protobuf displayName: Install packages # Openblas location is exported explicitly because openblas is keg-only, which means it was not symlinked into /usr/local/. From e08c52a8df4be5cf2d21cbb6e1eb3a5dcff0d0e2 Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Thu, 18 Mar 2021 03:33:13 +0000 Subject: [PATCH 13/14] Merged PR 18185: Support for Microsoft legacy binary shortlist Adds support for Microsoft-internal binary shortlist format. --- src/CMakeLists.txt | 8 +- src/data/factored_vocab.cpp | 1 - src/data/shortlist.cpp | 153 +++++++ src/data/shortlist.h | 47 ++ .../shortlist/logging/LoggerMacros.h | 25 ++ src/microsoft/shortlist/utils/Converter.cpp | 59 +++ src/microsoft/shortlist/utils/Converter.h | 83 ++++ .../shortlist/utils/ParameterTree.cpp | 417 ++++++++++++++++++ src/microsoft/shortlist/utils/ParameterTree.h | 185 ++++++++ src/microsoft/shortlist/utils/PrintTypes.h | 16 + src/microsoft/shortlist/utils/StringUtils.cpp | 338 ++++++++++++++ src/microsoft/shortlist/utils/StringUtils.h | 98 ++++ src/translator/translator.h | 3 +- 13 files changed, 1429 insertions(+), 4 deletions(-) create mode 100644 src/data/shortlist.cpp create mode 100644 src/microsoft/shortlist/logging/LoggerMacros.h create mode 100644 src/microsoft/shortlist/utils/Converter.cpp create mode 100644 src/microsoft/shortlist/utils/Converter.h create mode 100644 src/microsoft/shortlist/utils/ParameterTree.cpp create mode 100644 src/microsoft/shortlist/utils/ParameterTree.h create mode 100644 src/microsoft/shortlist/utils/PrintTypes.h create mode 100644 src/microsoft/shortlist/utils/StringUtils.cpp create mode 100644 src/microsoft/shortlist/utils/StringUtils.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c59d8bf61..64b86a695 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -40,6 +40,7 @@ set(MARIAN_SOURCES data/corpus_sqlite.cpp data/corpus_nbest.cpp data/text_input.cpp + data/shortlist.cpp 3rd_party/cnpy/cnpy.cpp 3rd_party/ExceptionWithCallStack.cpp @@ -107,10 +108,15 @@ set(MARIAN_SOURCES training/validator.cpp training/communicator.cpp - # this is only compiled to catch build errors, but not linked + # this is only compiled to catch build errors microsoft/quicksand.cpp microsoft/cosmos.cpp + # copied from quicksand to be able to read binary shortlist + microsoft/shortlist/utils/Converter.cpp + microsoft/shortlist/utils/StringUtils.cpp + microsoft/shortlist/utils/ParameterTree.cpp + $ $ $ diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp index 818f37888..17a5bfb74 100644 --- a/src/data/factored_vocab.cpp +++ b/src/data/factored_vocab.cpp @@ -546,7 +546,6 @@ void FactoredVocab::constructNormalizationInfoForVocab() { /*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { for (; num-- > 0; ptr++) { auto word = Word::fromWordIndex(*ptr); - auto wordString = word2string(word); auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first; *ptr = (WordIndex)lemmaIndex; } diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp new file mode 100644 index 000000000..6f551262d --- /dev/null +++ b/src/data/shortlist.cpp @@ -0,0 +1,153 @@ +#include "data/shortlist.h" +#include "microsoft/shortlist/utils/ParameterTree.h" + +namespace marian { +namespace data { + +// cast current void pointer to T pointer and move forward by num elements +template +const T* get(const void*& current, size_t num = 1) { + const T* ptr = (const T*)current; + current = (const T*)current + num; + return ptr; +} + +QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr options, + Ptr srcVocab, + Ptr trgVocab, + size_t srcIdx, + size_t /*trgIdx*/, + bool /*shared*/) + : options_(options), + srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx) { + std::vector vals = options_->get>("shortlist"); + + ABORT_IF(vals.empty(), "No path to filter path given"); + std::string fname = vals[0]; + + auto firstNum = vals.size() > 1 ? std::stoi(vals[1]) : 0; + auto bestNum = vals.size() > 2 ? std::stoi(vals[2]) : 0; + float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0; + + if(firstNum != 0 || bestNum != 0 || threshold != 0) { + LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored."); + } + + mmap_ = mio::mmap_source(fname); // memory-map the binary file once + const void* current = mmap_.data(); // pointer iterator over binary file + + // compare magic number in binary file to make sure we are reading the right thing + const int32_t MAGIC_NUMBER = 1234567890; + int32_t header_magic_number = *get(current); + ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number"); + + auto config = ::quicksand::ParameterTree::FromBinaryReader(current); + use16bit_ = config->GetBoolReq("use_16_bit"); + + LOG(info, "[data] Mapping Quicksand shortlist from {}", fname); + + idSize_ = sizeof(int32_t); + if (use16bit_) { + idSize_ = sizeof(uint16_t); + } + + // mmap the binary shortlist pieces + numDefaultIds_ = *get(current); + defaultIds_ = get(current, numDefaultIds_); + numSourceIds_ = *get(current); + sourceLengths_ = get(current, numSourceIds_); + sourceOffsets_ = get(current, numSourceIds_); + numShortlistIds_ = *get(current); + sourceToShortlistIds_ = get(current, idSize_ * numShortlistIds_); + + // display parameters + LOG(info, + "[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids", + numSourceIds_, + numDefaultIds_, + numShortlistIds_); +} + +Ptr QuicksandShortlistGenerator::generate(Ptr batch) const { + auto srcBatch = (*batch)[srcIdx_]; + auto maxShortlistSize = trgVocab_->size(); + + std::unordered_set indexSet; + for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) { + int32_t id = defaultIds_[i]; + indexSet.insert(id); + } + + // State + std::vector> curShortlists(maxShortlistSize); + auto curShortlistIt = curShortlists.begin(); + + // Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank. + // E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ... + int32_t maxLength = 0; + for (Word word : srcBatch->data()) { + int32_t sourceId = (int32_t)word.toWordIndex(); + srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1); + + if (sourceId < numSourceIds_) { // if it's a valid source id + const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping + int32_t length = sourceLengths_[sourceId]; // how many mappings are there + curShortlistIt->first = curShortlistIds; + curShortlistIt->second = length; + curShortlistIt++; + + if (length > maxLength) + maxLength = length; + } + } + + // collect the actual shortlist mappings + for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) { + for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) { + int32_t length = curShortlists[j].second; + if (i < length) { + const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first; + int32_t id = 0; + if (use16bit_) { + const uint16_t* source_shortlist_ids = reinterpret_cast(source_shortlist_ids_bytes); + id = (int32_t)source_shortlist_ids[i]; + } + else { + const int32_t* source_shortlist_ids = reinterpret_cast(source_shortlist_ids_bytes); + id = source_shortlist_ids[i]; + } + indexSet.insert(id); + } + } + } + + // turn into vector and sort (selected indices) + std::vector indices; + indices.reserve(indexSet.size()); + for(auto i : indexSet) + indices.push_back((WordIndex)i); + + std::sort(indices.begin(), indices.end()); + return New(indices); +} + +Ptr createShortlistGenerator(Ptr options, + Ptr srcVocab, + Ptr trgVocab, + size_t srcIdx, + size_t trgIdx, + bool shared) { + std::vector vals = options->get>("shortlist"); + ABORT_IF(vals.empty(), "No path to shortlist given"); + std::string fname = vals[0]; + if(filesystem::Path(fname).extension().string() == ".bin") { + return New(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } else { + return New(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } +} + +} // namespace data +} // namespace marian diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 395bcfee7..ab6a087b1 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -5,6 +5,7 @@ #include "common/file_stream.h" #include "data/corpus_base.h" #include "data/types.h" +#include "mio/mio.hpp" #include #include @@ -292,5 +293,51 @@ class FakeShortlistGenerator : public ShortlistGenerator { } }; +/* +Legacy binary shortlist for Microsoft-internal use. +*/ +class QuicksandShortlistGenerator : public ShortlistGenerator { +private: + Ptr options_; + Ptr srcVocab_; + Ptr trgVocab_; + + size_t srcIdx_; + + mio::mmap_source mmap_; + + // all the quicksand bits go here + bool use16bit_{false}; + int32_t numDefaultIds_; + int32_t idSize_; + const int32_t* defaultIds_{nullptr}; + int32_t numSourceIds_{0}; + const int32_t* sourceLengths_{nullptr}; + const int32_t* sourceOffsets_{nullptr}; + int32_t numShortlistIds_{0}; + const uint8_t* sourceToShortlistIds_{nullptr}; + +public: + QuicksandShortlistGenerator(Ptr options, + Ptr srcVocab, + Ptr trgVocab, + size_t srcIdx = 0, + size_t trgIdx = 1, + bool shared = false); + + virtual Ptr generate(Ptr batch) const override; +}; + +/* +Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist +unless the extension is *.bin for which the Microsoft legacy binary shortlist is used. +*/ +Ptr createShortlistGenerator(Ptr options, + Ptr srcVocab, + Ptr trgVocab, + size_t srcIdx = 0, + size_t trgIdx = 1, + bool shared = false); + } // namespace data } // namespace marian diff --git a/src/microsoft/shortlist/logging/LoggerMacros.h b/src/microsoft/shortlist/logging/LoggerMacros.h new file mode 100644 index 000000000..ca74e737e --- /dev/null +++ b/src/microsoft/shortlist/logging/LoggerMacros.h @@ -0,0 +1,25 @@ +#pragma once + +// Do NOT include this file directly except in special circumstances. +// (E.g., you want to define macros which call these but don't want to include Logger.h everywhere). +// Normally you should include logging/Logger.h + +#define LOG_WRITE(format, ...) do {\ + abort(); \ +} while (0) + +#define LOG_WRITE_STRING(str) do {\ + abort(); \ +} while (0) + +#define LOG_ERROR(format, ...) do {\ + abort(); \ +} while (0) + +#define LOG_ERROR_AND_THROW(format, ...) do {\ + abort(); \ +} while (0) + +#define DECODING_LOGIC_ERROR(format, ...) do {\ + abort(); \ +} while (0) diff --git a/src/microsoft/shortlist/utils/Converter.cpp b/src/microsoft/shortlist/utils/Converter.cpp new file mode 100644 index 000000000..c28178cd6 --- /dev/null +++ b/src/microsoft/shortlist/utils/Converter.cpp @@ -0,0 +1,59 @@ +#include "microsoft/shortlist/utils/Converter.h" + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + + +int64_t Converter::ToInt64(const std::string& str) { + return ConvertSingleInternal(str, "int64_t"); +} + +uint64_t Converter::ToUInt64(const std::string& str) { + return ConvertSingleInternal(str, "int64_t"); +} + +int32_t Converter::ToInt32(const std::string& str) { + return ConvertSingleInternal(str, "int32_t"); +} + +float Converter::ToFloat(const std::string& str) { + // In case the value is out of range of a 32-bit float, but in range of a 64-bit double, + // it's better to convert as a double and then do the conersion. + return (float)ConvertSingleInternal(str, "float"); +} + +double Converter::ToDouble(const std::string& str) { + return ConvertSingleInternal(str, "double"); +} + +bool Converter::ToBool(const std::string& str) { + bool value = false; + if (!TryConvert(str, /* out */ value)) { + LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type 'bool'", str.c_str()); + } + return value; +} + +std::vector Converter::ToInt32Vector(const std::vector& items) { + return ConvertVectorInternal::const_iterator>(items.begin(), items.end(), "int32_t"); +} + +std::vector Converter::ToInt64Vector(const std::vector& items) { + return ConvertVectorInternal::const_iterator>(items.begin(), items.end(), "int64_t"); +} + +std::vector Converter::ToFloatVector(const std::vector& items) { + return ConvertVectorInternal::const_iterator>(items.begin(), items.end(), "float"); +} + +std::vector Converter::ToDoubleVector(const std::vector& items) { + return ConvertVectorInternal::const_iterator>(items.begin(), items.end(), "double"); +} + +void Converter::HandleConversionError(const std::string& str, const char * type_name) { + str; type_name; // make compiler happy + LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type '%s'", str.c_str(), type_name); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/Converter.h b/src/microsoft/shortlist/utils/Converter.h new file mode 100644 index 000000000..9d9dd96d6 --- /dev/null +++ b/src/microsoft/shortlist/utils/Converter.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include + +namespace quicksand { + +class Converter { +public: + static int32_t ToInt32(const std::string& str); + + static int64_t ToInt64(const std::string& str); + + static uint64_t ToUInt64(const std::string& str); + + static float ToFloat(const std::string& str); + + static double ToDouble(const std::string& str); + + static bool ToBool(const std::string& str); + + static std::vector ToInt32Vector(const std::vector& items); + + static std::vector ToInt64Vector(const std::vector& items); + + static std::vector ToFloatVector(const std::vector& items); + + static std::vector ToDoubleVector(const std::vector& items); + + static bool TryConvert(const std::string& str, /* out*/ bool& obj) { + if (str == "True" || str == "true" || str == "TRUE" || str == "Yes" || str == "yes" || str == "1") { + obj = true; + return true; + } + else if (str == "False" || str == "false" || str == "FALSE" || str == "No" || str == "no" || str == "0") { + obj = false; + return true; + } + return false; + } + + template + static bool TryConvert(const std::string& str, /* out*/ T& value) { + std::istringstream ss(str); + value = T(); + if (!(ss >> value)) { + return false; + } + return true; + } + +private: + template + static T ConvertSingleInternal(const std::string& str, const char * type_name); + + template + static std::vector ConvertVectorInternal(I begin, I end, const char * type_name); + + static void HandleConversionError(const std::string& str, const char * type_name); +}; + +template +T Converter::ConvertSingleInternal(const std::string& str, const char * type_name) { + std::istringstream ss(str); + T value = T(); + if (!(ss >> value)) { + HandleConversionError(str, type_name); + } + return value; +} + +template +std::vector Converter::ConvertVectorInternal(I begin, I end, const char * type_name) { + std::vector items; + for (I it = begin; it != end; it++) { + items.push_back(ConvertSingleInternal(*it, type_name)); + } + return items; +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/ParameterTree.cpp b/src/microsoft/shortlist/utils/ParameterTree.cpp new file mode 100644 index 000000000..465d2e0db --- /dev/null +++ b/src/microsoft/shortlist/utils/ParameterTree.cpp @@ -0,0 +1,417 @@ +#include "microsoft/shortlist/utils/ParameterTree.h" + +#include + +#include "microsoft/shortlist/utils/StringUtils.h" +#include "microsoft/shortlist/utils/Converter.h" + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + +std::shared_ptr ParameterTree::m_empty_tree = std::make_shared("params"); + +ParameterTree::ParameterTree() { + m_name = "root"; +} + +ParameterTree::ParameterTree(const std::string& name) { + m_name = name; +} + +ParameterTree::~ParameterTree() { +} + +void ParameterTree::Clear() { + +} + +void ParameterTree::ReplaceVariables( + const std::unordered_map& vars, + bool error_on_unknown_vars) +{ + ReplaceVariablesInternal(vars, error_on_unknown_vars); +} + +void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) { + RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param); +} + +void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) { + RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param); +} + +void ParameterTree::RegisterFloat(const std::string& name, float * param) { + RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param); +} + +void ParameterTree::RegisterDouble(const std::string& name, double * param) { + RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param); +} + +void ParameterTree::RegisterBool(const std::string& name, bool * param) { + RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param); +} + +void ParameterTree::RegisterString(const std::string& name, std::string * param) { + RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param); +} + +std::shared_ptr ParameterTree::FromBinaryReader(const void*& current) { + std::shared_ptr root = std::make_shared(); + root->ReadBinary(current); + return root; +} + +void ParameterTree::SetRegisteredParams() { + for (std::size_t i = 0; i < m_registered_params.size(); i++) { + const RegisteredParam& rp = m_registered_params[i]; + switch (rp.Type()) { + case PARAM_TYPE_INT32: + (*(int32_t *)rp.Data()) = GetInt32Req(rp.Name()); + break; + case PARAM_TYPE_INT64: + (*(int64_t *)rp.Data()) = GetInt64Req(rp.Name()); + break; + default: + LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type()); + } + } +} + +int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToInt32(*value); +} + +int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToInt64(*value); +} + +uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToUInt64(*value); +} + +double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToDouble(*value); +} + +float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToFloat(*value); +} + +std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return (*value); +} + +bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToBool(*value); +} + +int32_t ParameterTree::GetInt32Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToInt32(value); +} + +uint64_t ParameterTree::GetUInt64Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToUInt64(value); +} + +int64_t ParameterTree::GetInt64Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToInt64(value); +} + +double ParameterTree::GetDoubleReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToDouble(value); +} + +float ParameterTree::GetFloatReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToFloat(value); +} + +bool ParameterTree::GetBoolReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToBool(value); +} + +std::string ParameterTree::GetStringReq(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str()); + } + return (*value); +} + +std::vector ParameterTree::GetFileListReq(const std::string& name) const { + std::vector output = GetFileListOptional(name); + if (output.size() == 0) { + LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str()); + } + return output; +} + +std::vector ParameterTree::GetFileListOptional(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr || (*value).size() == 0) { + return std::vector(); + } + std::vector all_files = StringUtils::Split(*value, ";"); + return all_files; +} + +std::vector ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const { + std::string value = GetStringReq(name); + std::vector output = StringUtils::Split(value, sep); + return output; +} + +std::vector ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const { + std::string value = GetStringOr(name, ""); + std::vector output = StringUtils::Split(value, sep); + return output; +} + +std::shared_ptr ParameterTree::GetChildReq(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return child; + } + } + LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str()); + return nullptr; // never happens +} + + +std::shared_ptr ParameterTree::GetChildOrEmpty(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return child; + } + } + return std::make_shared(); +} + +// cast current void pointer to T pointer and move forward by num elements +template +const T* get(const void*& current, size_t num = 1) { + const T* ptr = (const T*)current; + current = (const T*)current + num; + return ptr; +} + +void ParameterTree::ReadBinary(const void*& current) { + auto nameLength = *get(current); + auto nameBytes = get(current, nameLength); + m_name = std::string(nameBytes, nameBytes + nameLength); + + auto textLength = *get(current); + auto textBytes = get(current, textLength); + m_text = std::string(textBytes, textBytes + textLength); + + int32_t num_children = *get(current); + m_children.resize(num_children); + for (int32_t i = 0; i < num_children; i++) { + m_children[i].reset(new ParameterTree()); + m_children[i]->ReadBinary(current); + } +} + +std::vector< std::shared_ptr > ParameterTree::GetChildren(const std::string& name) const { + std::vector< std::shared_ptr > children; + for (std::shared_ptr child : m_children) { + if (child->Name() == name) { + children.push_back(child); + } + } + return children; +} + +void ParameterTree::AddParam(const std::string& name, const std::string& text) { + std::shared_ptr child = std::make_shared(name); + child->SetText(text); + m_children.push_back(child); +} + +void ParameterTree::SetParam(const std::string& name, const std::string& text) { + for (const auto& child : m_children) { + if (child->Name() == name) { + child->SetText(text); + return; + } + } + std::shared_ptr child = std::make_shared(name); + child->SetText(text); + m_children.push_back(child); +} + +void ParameterTree::AddChild(std::shared_ptr child) { + m_children.push_back(child); +} + +bool ParameterTree::HasParam(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return false; + } + return true; +} + +bool ParameterTree::HasChild(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return true; + } + } + return false; +} + +std::string ParameterTree::ToString() const { + std::ostringstream ss; + ToStringInternal(0, ss); + return ss.str(); +} + +const std::string * ParameterTree::GetParamInternal(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return &(child->Text()); + } + } + return nullptr; +} + + +void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) { + if (m_registered_param_names.find(name) != m_registered_param_names.end()) { + LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str()); + } + m_registered_params.push_back(RegisteredParam(name, type, param)); + m_registered_param_names.insert(name); +} + +void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const { + for (int32_t i = 0; i < 2*depth; i++) { + ss << " "; + } + ss << "<" << m_name << ">"; + if (m_children.size() > 0) { + ss << "\n"; + for (const std::shared_ptr& child : m_children) { + child->ToStringInternal(depth+1, ss); + } + for (int32_t i = 0; i < 2 * depth; i++) { + ss << " "; + } + ss << "\n"; + } + else { + ss << m_text << "\n"; + } +} + +std::shared_ptr ParameterTree::Clone() const { + std::shared_ptr node = std::make_shared(m_name); + node->m_text = m_text; + for (auto& child : m_children) { + node->m_children.push_back(child->Clone()); + } + return node; +} + +void ParameterTree::Merge(const ParameterTree& other) { + m_name = other.m_name; + m_text = other.m_text; + for (auto& other_child : other.m_children) { + if (HasChild(other_child->Name())) { + auto my_child = GetChildReq(other_child->Name()); + if (other_child->Text() != "" && my_child->Text() != "") { + my_child->SetText(other_child->Text()); + } + else { + my_child->Merge(*other_child); + } + } + else { + m_children.push_back(other_child->Clone()); + } + } +} + +void ParameterTree::ReplaceVariablesInternal( + const std::unordered_map& vars, + bool error_on_unknown_vars) +{ + std::size_t offset = 0; + std::ostringstream ss; + while (true) { + std::size_t s_pos = m_text.find("$$", offset); + if (s_pos == std::string::npos) { + break; + } + std::size_t e_pos = m_text.find("$$", s_pos + 2); + if (e_pos == std::string::npos) { + break; + } + + if (offset != s_pos) { + ss << m_text.substr(offset, s_pos-offset); + } + + std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2)); + auto it = vars.find(var_name); + if (it != vars.end()) { + std::string value = it->second; + ss << value; + } + else { + if (error_on_unknown_vars) { + LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str()); + } + else { + ss << "$$" << var_name << "$$"; + } + } + offset = e_pos + 2; + } + ss << m_text.substr(offset); + + m_text = ss.str(); + + for (auto& child : m_children) { + child->ReplaceVariablesInternal(vars, error_on_unknown_vars); + } +} + +} // namespace quicksand + diff --git a/src/microsoft/shortlist/utils/ParameterTree.h b/src/microsoft/shortlist/utils/ParameterTree.h new file mode 100644 index 000000000..1474ff645 --- /dev/null +++ b/src/microsoft/shortlist/utils/ParameterTree.h @@ -0,0 +1,185 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "microsoft/shortlist/utils/StringUtils.h" + +namespace quicksand { + +class ParameterTree { +private: + enum ParameterType { + PARAM_TYPE_INT32, + PARAM_TYPE_INT64, + PARAM_TYPE_UINT64, + PARAM_TYPE_FLOAT, + PARAM_TYPE_DOUBLE, + PARAM_TYPE_BOOL, + PARAM_TYPE_STRING + }; + + class RegisteredParam { + private: + std::string m_name; + ParameterType m_type; + void * m_data; + + public: + RegisteredParam() {} + + RegisteredParam(const std::string& name, + ParameterType type, + void * data) + { + m_name = name; + m_type = type; + m_data = data; + } + + const std::string& Name() const {return m_name;} + const ParameterType& Type() const {return m_type;} + void * Data() const {return m_data;} + }; + + static std::shared_ptr m_empty_tree; + + std::string m_name; + + std::string m_text; + + std::vector< std::shared_ptr > m_children; + + std::unordered_set m_registered_param_names; + + std::vector m_registered_params; + +public: + ParameterTree(); + + ParameterTree(const std::string& name); + + ~ParameterTree(); + + inline const std::string& Text() const { return m_text; } + inline void SetText(const std::string& text) { m_text = text; } + + inline const std::string& Name() const { return m_name; } + inline void SetName(const std::string& name) { m_name = name; } + + void Clear(); + + void ReplaceVariables( + const std::unordered_map& vars, + bool error_on_unknown_vars = true); + + void RegisterInt32(const std::string& name, int32_t * param); + + void RegisterInt64(const std::string& name, int64_t * param); + + void RegisterFloat(const std::string& name, float * param); + + void RegisterDouble(const std::string& name, double * param); + + void RegisterBool(const std::string& name, bool * param); + + void RegisterString(const std::string& name, std::string * param); + + static std::shared_ptr FromBinaryReader(const void*& current); + + void SetRegisteredParams(); + + int32_t GetInt32Req(const std::string& name) const; + + int64_t GetInt64Req(const std::string& name) const; + + uint64_t GetUInt64Req(const std::string& name) const; + + double GetDoubleReq(const std::string& name) const; + + float GetFloatReq(const std::string& name) const; + + std::string GetStringReq(const std::string& name) const; + + bool GetBoolReq(const std::string& name) const; + + int32_t GetInt32Or(const std::string& name, int32_t defaultValue) const; + + int64_t GetInt64Or(const std::string& name, int64_t defaultValue) const; + + uint64_t GetUInt64Or(const std::string& name, uint64_t defaultValue) const; + + std::string GetStringOr(const std::string& name, const std::string& defaultValue) const; + + double GetDoubleOr(const std::string& name, double defaultValue) const; + + float GetFloatOr(const std::string& name, float defaultValue) const; + + bool GetBoolOr(const std::string& name, bool defaultValue) const; + + std::vector GetFileListReq(const std::string& name) const; + + std::vector GetFileListOptional(const std::string& name) const; + + std::vector GetStringListReq(const std::string& name, const std::string& sep = " ") const; + + std::vector GetStringListOptional(const std::string& name, const std::string& sep = " ") const; + + std::shared_ptr GetChildReq(const std::string& name) const; + + std::shared_ptr GetChildOrEmpty(const std::string& name) const; + + std::vector< std::shared_ptr > GetChildren(const std::string& name) const; + + inline const std::vector< std::shared_ptr >& GetChildren() const { return m_children; } + + void ReadBinary(const void*& current); + + void AddParam(const std::string& name, const std::string& text); + + template + void AddParam(const std::string& name, const T& obj); + + void SetParam(const std::string& name, const std::string& text); + + template + void SetParam(const std::string& name, const T& obj); + + void AddChild(std::shared_ptr child); + + std::string ToString() const; + + bool HasChild(const std::string& name) const; + + bool HasParam(const std::string& name) const; + + std::shared_ptr Clone() const; + + void Merge(const ParameterTree& other); + +private: + void ReplaceVariablesInternal( + const std::unordered_map& vars, + bool error_on_unknown_vars); + + void RegisterItemInternal(const std::string& name, ParameterType type, void * param); + + const std::string * GetParamInternal(const std::string& name) const; + + void ToStringInternal(int32_t depth, std::ostream& ss) const; +}; + +template +void ParameterTree::AddParam(const std::string& name, const T& obj) { + AddParam(name, StringUtils::ToString(obj)); +} + +template +void ParameterTree::SetParam(const std::string& name, const T& obj) { + SetParam(name, StringUtils::ToString(obj)); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/PrintTypes.h b/src/microsoft/shortlist/utils/PrintTypes.h new file mode 100644 index 000000000..6bc1363d2 --- /dev/null +++ b/src/microsoft/shortlist/utils/PrintTypes.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#ifdef QUICKSAND_WINDOWS_BUILD +#define PI32 "d" +#define PI64 "lld" +#define PU32 "u" +#define PU64 "llu" +#else +#define PI32 PRId32 +#define PI64 PRId64 +#define PU32 PRIu32 +#define PU64 PRIu64 +#endif + diff --git a/src/microsoft/shortlist/utils/StringUtils.cpp b/src/microsoft/shortlist/utils/StringUtils.cpp new file mode 100644 index 000000000..7870b5422 --- /dev/null +++ b/src/microsoft/shortlist/utils/StringUtils.cpp @@ -0,0 +1,338 @@ +#include "microsoft/shortlist/utils/StringUtils.h" + +#include +#include +#include + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + +std::string StringUtils::VarArgsToString(const char * format, va_list args) { + if (format == nullptr) { + LOG_ERROR_AND_THROW("'format' cannot be null in StringUtils::VarArgsToString"); + } + + std::string output; + // Most of the time the stack buffer (5000 chars) will be sufficient. + // In cases where this is insufficient, dynamically allocate an appropriately sized buffer + char buffer[5000]; +#ifdef QUICKSAND_WINDOWS_BUILD + va_list copy; + va_copy(copy, args); + int ret = vsnprintf_s(buffer, sizeof(buffer), _TRUNCATE, format, copy); + va_end(copy); + if (ret >= 0) { + output = std::string(buffer, buffer + ret); + } + else { + va_list copy2; + va_copy(copy2, args); + int needed_size = _vscprintf(format, copy2); + va_end(copy2); + + if (needed_size < 0) { + LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen"); + } + char * dynamic_buffer = new char[needed_size+1]; + int ret2 = vsnprintf_s(dynamic_buffer, needed_size+1, _TRUNCATE, format, args); + if (ret2 >= 0) { + output = std::string(dynamic_buffer, dynamic_buffer + ret2); + delete[] dynamic_buffer; + } + else { + output = ""; + delete[] dynamic_buffer; + LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen, " + "since we made a call to _vscprintf() to check the dynamic buffer size. The call to _vscprintf() " + "returned %d bytes, but apparently that was not enough. This would imply a bug in MSVC's vsnprintf_s implementation.", needed_size); + } + } +#else + va_list copy; + va_copy(copy, args); + int needed_size = vsnprintf(buffer, sizeof(buffer), format, copy); + va_end(copy); + if (needed_size < (int)sizeof(buffer)) { + output = std::string(buffer, buffer + needed_size); + } + else { + char * dynamic_buffer = new char[needed_size+1]; + int ret = vsnprintf(dynamic_buffer, needed_size + 1, format, args); + if (ret >= 0 && ret < needed_size + 1) { + output = std::string(dynamic_buffer); + delete[] dynamic_buffer; + } + else { + output = ""; + delete[] dynamic_buffer; + LOG_ERROR_AND_THROW("A call to vsnprintf() failed. Return value: %d.", + ret); + } + } +#endif + return output; +} + +std::vector StringUtils::SplitIntoLines(const std::string& input) { + std::vector output; + if (input.size() == 0) { + return output; + } + std::size_t start = 0; + for (std::size_t i = 0; i < input.size(); i++) { + char c = input[i]; + if (c == '\r' || c == '\n') { + output.push_back(std::string(input.begin() + start, input.begin() + i)); + start = i+1; + } + if (c == '\r' && i + 1 < input.size() && input[i+1] == '\n') { + i++; + start = i+1; + } + } + // do NOT put an empty length trailing line (but empty length intermediate lines are fine) + if (input.begin() + start != input.end()) { + output.push_back(std::string(input.begin() + start, input.end())); + } + return output; +} + +bool StringUtils::StartsWith(const std::string& str, const std::string& prefix) { + if (str.length() < prefix.length()) + return false; + + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool StringUtils::EndsWith(const std::string& str, const std::string& suffix) { + if (str.length() < suffix.length()) + return false; + + return std::equal(suffix.begin(), suffix.end(), str.end() - suffix.length()); +} + +std::vector StringUtils::SplitFileList(const std::string& input) { + std::vector output; + for (const std::string& s : SplitIntoLines(input)) { + for (const std::string& t : Split(s, ";")) { + std::string f = CleanupWhitespace(t); + output.push_back(f); + } + } + return output; +} + +std::vector StringUtils::Split(const std::string& input, char splitter) { + std::vector output; + if (input.size() == 0) { + return output; + } + std::size_t start = 0; + for (std::size_t i = 0; i < input.size(); i++) { + if (input[i] == splitter) { + output.push_back(std::string(input.begin() + start, input.begin() + i)); + start = i+1; + } + } + output.push_back(std::string(input.begin() + start, input.end())); + return output; +} + +std::vector StringUtils::Split(const std::string& input, const std::string& splitter) { + std::vector output; + if (input.size() == 0) { + return output; + } + std::size_t pos = 0; + while (true) { + std::size_t next_pos = input.find(splitter, pos); + if (next_pos == std::string::npos) { + output.push_back(std::string(input.begin() + pos, input.end())); + break; + } + else { + output.push_back(std::string(input.begin() + pos, input.begin() + next_pos)); + } + pos = next_pos + splitter.size(); + } + return output; +} + +std::string StringUtils::Join(const std::string& joiner, const uint8_t * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << (int32_t)(items[i]); + } + return ss.str(); +} + +std::string StringUtils::Join(const std::string& joiner, const int8_t * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << (int32_t)(items[i]); + } + return ss.str(); +} + +std::string StringUtils::PrintString(const char * format, ...) { + va_list args; + va_start(args, format); + std::string output = StringUtils::VarArgsToString(format, args); + va_end(args); + + return output; +} + +std::vector StringUtils::WhitespaceTokenize(const std::string& input) { + std::vector output; + if (input.size() == 0) { + return output; + } + std::size_t size = input.size(); + std::size_t start = 0; + std::size_t end = size; + for (std::size_t i = 0; i < size; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + start++; + } + else { + break; + } + } + for (std::size_t i = 0; i < size; i++) { + char c = input[size-1-i]; + if (IsWhitespace(c)) { + end--; + } + else { + break; + } + } + if (end <= start) { + return output; + } + bool prev_is_whitespace = false; + std::size_t token_start = start; + for (std::size_t i = start; i < end; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + if (!prev_is_whitespace) { + output.push_back(std::string(input.begin() + token_start, input.begin() + i)); + } + prev_is_whitespace = true; + token_start = i+1; + } + else { + prev_is_whitespace = false; + } + } + output.push_back(std::string(input.begin() + token_start, input.begin() + end)); + return output; +} + +std::string StringUtils::CleanupWhitespace(const std::string& input) { + if (input.size() == 0) { + return std::string(""); + } + std::size_t size = input.size(); + std::size_t start = 0; + std::size_t end = size; + for (std::size_t i = 0; i < size; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + start++; + } + else { + break; + } + } + for (std::size_t i = 0; i < size; i++) { + char c = input[size-1-i]; + if (IsWhitespace(c)) { + end--; + } + else { + break; + } + } + if (end <= start) { + return std::string(""); + } + std::ostringstream ss; + bool prev_is_whitespace = false; + for (std::size_t i = start; i < end; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + if (!prev_is_whitespace) { + ss << ' '; + } + prev_is_whitespace = true; + } + else { + ss << c; + prev_is_whitespace = false; + } + } + return ss.str(); +} + +std::string StringUtils::XmlEscape(const std::string& str) { + std::ostringstream ss; + for (std::size_t i = 0; i < str.size(); i++) { + char c = str[i]; + if (c == '&') { + ss << "&"; + } + else if (c == '"') { + ss << """; + } + else if (c == '\'') { + ss << "'"; + } + else if (c == '<') { + ss << "<"; + } + else if (c == '>') { + ss << ">"; + } + else { + ss << c; + } + } + return ss.str(); +} + +std::string StringUtils::ToString(const std::string& str) { + return str; +} + +std::string StringUtils::ToString(bool obj) { + return (obj)?"true":"false"; +} + +std::string StringUtils::ToUpper(const std::string& str) { + std::vector output; + output.reserve(str.size()); + for (char c : str) { + output.push_back((char)toupper((int)c)); + } + return std::string(output.begin(), output.end()); +} + +std::string StringUtils::ToLower(const std::string& str) { + std::ostringstream ss; + for (char c : str) { + ss << c; + } + return ss.str(); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/StringUtils.h b/src/microsoft/shortlist/utils/StringUtils.h new file mode 100644 index 000000000..31bb1fcc0 --- /dev/null +++ b/src/microsoft/shortlist/utils/StringUtils.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "microsoft/shortlist/utils/PrintTypes.h" + +namespace quicksand { + +class StringUtils { +public: + template + static std::string Join(const std::string& joiner, const T& items); + + template + static std::string Join(const std::string& joiner, const T * items, int32_t length); + + static std::string Join(const std::string& joiner, const uint8_t * items, int32_t length); + + static std::string Join(const std::string& joiner, const int8_t * items, int32_t length); + + static std::vector Split(const std::string& input, char splitter); + + static std::vector Split(const std::string& input, const std::string& splitter); + + static std::vector SplitFileList(const std::string& input); + + static std::string PrintString(const char * format, ...); + + static std::string VarArgsToString(const char * format, va_list args); + + static std::vector WhitespaceTokenize(const std::string& input); + + static std::string CleanupWhitespace(const std::string& input); + + static std::string ToString(const std::string& str); + + static std::string ToString(bool obj); + + template + static std::string ToString(const T& obj); + + static std::string XmlEscape(const std::string& str); + + static std::vector SplitIntoLines(const std::string& input); + + static bool StartsWith(const std::string& str, const std::string& prefix); + + static bool EndsWith(const std::string& str, const std::string& suffix); + + inline static bool IsWhitespace(char c) { + return (c == ' ' || c == '\t' || c == '\n' || c == '\r'); + } + + // This should only be used for ASCII, e.g., filenames, NOT for language data + static std::string ToLower(const std::string& str); + + // This should only be used for ASCII, e.g., filenames, NOT for language data + static std::string ToUpper(const std::string& str); +}; + +template +std::string StringUtils::Join(const std::string& joiner, const T& items) { + std::ostringstream ss; + bool first = true; + for (auto it = items.begin(); it != items.end(); it++) { + if (!first) { + ss << joiner; + } + ss << (*it); + first = false; + } + return ss.str(); +} + +template +std::string StringUtils::Join(const std::string& joiner, const T * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << items[i]; + } + return ss.str(); +} + +template +std::string StringUtils::ToString(const T& obj) { + std::ostringstream ss; + ss << obj; + return ss.str(); +} + +} // namespace quicksand diff --git a/src/translator/translator.h b/src/translator/translator.h index 1ff19a4ae..82d9343d5 100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -60,8 +60,7 @@ class Translate : public ModelTask { auto srcVocab = corpus_->getVocabs()[0]; if(options_->hasAndNotEmpty("shortlist")) - shortlistGenerator_ = New( - options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back()); + shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back()); auto devices = Config::getDevices(options_); numDevices_ = devices.size(); From 8f73923d3134f4799497b7e880963336b8fe4d6b Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 18 Mar 2021 03:34:44 +0000 Subject: [PATCH 14/14] increase version and update changelog --- CHANGELOG.md | 1 + VERSION | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c0514b89..c3c429204 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Support for MS-internal binary shortlist - Local/global sharding with MPI training via `--sharding local` - fp16 support for factors. - Correct training with fp16 via `--fp16`. diff --git a/VERSION b/VERSION index 460c8b939..d9c8fd407 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.10.2 +v1.10.3