diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b47663b4e..e71a7b711 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -72,6 +72,7 @@ set(MARIAN_SOURCES layers/loss.cpp layers/weight.cpp layers/lsh.cpp + layers/embedding.cpp rnn/cells.cpp rnn/attention.cpp diff --git a/src/layers/constructors.h b/src/layers/constructors.h index a2c38197f..7771919d0 100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -2,6 +2,7 @@ #include "layers/factory.h" #include "layers/generic.h" +#include "layers/embedding.h" namespace marian { namespace mlp { diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp new file mode 100644 index 000000000..488fbb8be --- /dev/null +++ b/src/layers/embedding.cpp @@ -0,0 +1,175 @@ +#include "embedding.h" +#include "data/factored_vocab.h" + +namespace marian { + +Embedding::Embedding(Ptr graph, Ptr options) +: LayerBase(graph, options), inference_(opt("inference")) { +std::string name = opt("prefix"); +int dimVoc = opt("dimVocab"); +int dimEmb = opt("dimEmb"); + +bool fixed = opt("fixed", false); + +factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); +if (factoredVocab_) { + dimVoc = (int)factoredVocab_->factorVocabSize(); + LOG_ONCE(info, "[embedding] Factored embeddings enabled"); +} + +// Embedding layer initialization should depend only on embedding size, hence fanIn=false +auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + +if (options_->has("embFile")) { + std::string file = opt("embFile"); + if (!file.empty()) { + bool norm = opt("normalization", false); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + } +} + +E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); +} + +// helper to embed a sequence of words (given as indices) via factored embeddings +Expr Embedding::multiRows(const Words& data, float dropProb) const { +auto graph = E_->graph(); +auto factoredData = factoredVocab_->csr_rows(data); +// multi-hot factor vectors are represented as a sparse CSR matrix +// [row index = word position index] -> set of factor indices for word at this position +ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); +// the CSR matrix is passed in pieces +auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); +auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); +auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); +// apply dropout +// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. +if(!inference_) + weights = dropout(weights, dropProb); +// perform the product +return csr_dot(factoredData.shape, weights, indices, offsets, E_); +} + +std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { +auto graph = E_->graph(); +int dimBatch = (int)subBatch->batchSize(); +int dimEmb = E_->shape()[-1]; +int dimWidth = (int)subBatch->batchWidth(); + +// factored embeddings: +// - regular: +// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] +// - factored: +// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) +// - each row of M contains the set of factors for one word => we want a CSR matrix +// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] +// - first compute x @ M on the CPU +// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): +// - shape (U, specifically) not actually needed here +// - foreach input x[i] +// - locate row M[i,*] +// - copy through its index values (std::vector) +// - create a matching ones vector (we can keep growing) +// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) +// - CSR matrix product with E +// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) +// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). +// - weighting: +// - core factors' gradients are sums over all words that use the factors; +// - core factors' embeddings move very fast +// - words will need to make up for the move; rare words cannot +// - so, we multiply each factor with 1/refCount +// - core factors get weighed down a lot +// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before +// - but forward pass weighs them down, so that all factors are in a similar numeric range +// - if it is required to be in a different range, the embeddings can still learn that, but more slowly + +auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); +#if 1 +auto batchMask = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->mask())); +#else // @TODO: this is dead code now, get rid of it +// experimental: hide inline-fix source tokens from cross attention +auto batchMask = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); +#endif +// give the graph inputs readable names for debugging and ONNX +batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); + +return std::make_tuple(batchEmbeddings, batchMask); +} + +Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { +if (factoredVocab_) { + Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + //selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout + return selectedEmbs; +} +else + return applyIndices(toWordIndexVector(words), shape); +} + +Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { +ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); +auto embIdxExpr = E_->graph()->indices(embIdx); +embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? +auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] +selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] +// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) +if(!inference_) + selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); +return selectedEmbs; +} + +// standard encoder word embeddings +/*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { +auto options = New( + "dimVocab", opt>("dim-vocabs")[batchIndex_], + "dimEmb", opt("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", + "fixed", embeddingFix_, + "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings +if(options_->hasAndNotEmpty("embedding-vectors")) { + auto embFiles = opt>("embedding-vectors"); + options->set( + "embFile", embFiles[batchIndex_], + "normalization", opt("embedding-normalization")); +} +return New(graph_, options); +} + +// ULR word embeddings +/*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { +return New(graph_, New( + "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", opt("ulr-dim-emb"), + "dimEmb", opt("dim-emb"), + "ulr-dropout", opt("ulr-dropout"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "ulrTrainTransform", opt("ulr-trainable-transformation"), + "ulrQueryFile", opt("ulr-query-vectors"), + "ulrKeysFile", opt("ulr-keys-vectors"))); +} + +// get embedding layer for this encoder or decoder +// This is lazy mostly because the constructors of the consuming objects are not +// guaranteed presently to have access to their graph. +Ptr EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { +if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy + if (embeddingLayers_.size() <= batchIndex_) + embeddingLayers_.resize(batchIndex_ + 1); + if (ulr) + embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR + else + embeddingLayers_[batchIndex_] = createEmbeddingLayer(); +} +return embeddingLayers_[batchIndex_]; +} + +} + diff --git a/src/layers/embedding.h b/src/layers/embedding.h new file mode 100644 index 000000000..91fd0b9d7 --- /dev/null +++ b/src/layers/embedding.h @@ -0,0 +1,148 @@ +#pragma once +#include "marian.h" +#include "generic.h" + +namespace marian { + +// A regular embedding layer. +// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). +// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in +// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. +class Embedding : public LayerBase, public IEmbeddingLayer { + Expr E_; + Ptr factoredVocab_; + Expr multiRows(const Words& data, float dropProb) const; + bool inference_{false}; + +public: + Embedding(Ptr graph, Ptr options); + + std::tuple apply(Ptr subBatch) const override final; + + Expr apply(const Words& words, const Shape& shape) const override final; + + Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final; +}; + +class ULREmbedding : public LayerBase, public IEmbeddingLayer { + std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + bool inference_{false}; + +public: + ULREmbedding(Ptr graph, Ptr options) + : LayerBase(graph, options), inference_(opt("inference")) { + std::string name = "url_embed"; //opt("prefix"); + int dimKeys = opt("dimTgtVoc"); + int dimQueries = opt("dimSrcVoc"); + int dimEmb = opt("dimEmb"); + int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size + bool fixed = opt("fixed", false); + + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); + + std::string queryFile = opt("ulrQueryFile"); + std::string keyFile = opt("ulrKeysFile"); + bool trainTrans = opt("ulrTrainTransform", false); + if (!queryFile.empty() && !keyFile.empty()) { + initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); + name = "ulr_query"; + fixed = true; + auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(query_embed); + // keys embeds + initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); + name = "ulr_keys"; + fixed = true; + auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(key_embed); + // actual trainable embedding + initFunc = inits::glorotUniform(); + name = "ulr_embed"; + fixed = false; + auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim + ulrEmbeddings_.push_back(ulr_embed); + // init trainable src embedding + name = "ulr_src_embed"; + auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(ulr_src_embed); + // ulr transformation matrix + //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only + if (trainTrans) { + initFunc = inits::glorotUniform(); + fixed = false; + } + else + { + initFunc = inits::eye(); // identity matrix + fixed = true; + } + name = "ulr_transform"; + auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); + ulrEmbeddings_.push_back(ulrTransform); + + initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only + fixed = true; + name = "ulr_shared"; + auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); + ulrEmbeddings_.push_back(share_embed); + } + } + + std::tuple apply(Ptr subBatch) const override final { + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = uniEmbed->shape()[-1]; + int dimWords = (int)subBatch->batchWidth(); + // D = K.A.QT + // dimm(K) = univ_tok_vocab*uni_embed_size + // dim A = uni_embed_size*uni_embed_size + // dim Q: uni_embed_size * total_merged_vocab_size + // dim D = univ_tok_vocab * total_merged_vocab_size + // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) + // here we need to handle the mini-batch + // extract raws corresponding to Xs in this minibatch from Q + auto embIdx = toWordIndexVector(subBatch->data()); + auto queryEmbeddings = rows(queryEmbed, embIdx); + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); + qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes + auto z = dot(qt, keyEmbed, false, true); // query-key similarity + float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout + if(!inference_) + z = dropout(z, dropProb); + + float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature + // temperature in softmax is to control randomness of predictions + // high temperature Softmax outputs are more close to each other + // low temperatures the softmax become more similar to "hardmax" + auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE + auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); + auto graph = ulrEmbeddings_.front()->graph(); + auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, + inits::fromVector(subBatch->mask())); + if(!inference_) + batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); + return std::make_tuple(batchEmbeddings, batchMask); + } + + Expr apply(const Words& words, const Shape& shape) const override final { + return applyIndices(toWordIndexVector(words), shape); + } + + Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { + embIdx; shape; + ABORT("not implemented"); // @TODO: implement me + } +}; + +} diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index d44f40206..36612577c 100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -438,172 +438,4 @@ namespace marian { } } } - - Embedding::Embedding(Ptr graph, Ptr options) - : LayerBase(graph, options), inference_(opt("inference")) { - std::string name = opt("prefix"); - int dimVoc = opt("dimVocab"); - int dimEmb = opt("dimEmb"); - - bool fixed = opt("fixed", false); - - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get("vocab", "")); - if (factoredVocab_) { - dimVoc = (int)factoredVocab_->factorVocabSize(); - LOG_ONCE(info, "[embedding] Factored embeddings enabled"); - } - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length - - if (options_->has("embFile")) { - std::string file = opt("embFile"); - if (!file.empty()) { - bool norm = opt("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); - } - } - - E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); - } - - // helper to embed a sequence of words (given as indices) via factored embeddings - Expr Embedding::multiRows(const Words& data, float dropProb) const { - auto graph = E_->graph(); - auto factoredData = factoredVocab_->csr_rows(data); - // multi-hot factor vectors are represented as a sparse CSR matrix - // [row index = word position index] -> set of factor indices for word at this position - ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); - // the CSR matrix is passed in pieces - auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); - auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); - auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); - // apply dropout - // We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. - if(!inference_) - weights = dropout(weights, dropProb); - // perform the product - return csr_dot(factoredData.shape, weights, indices, offsets, E_); - } - - std::tuple Embedding::apply(Ptr subBatch) const /*override final*/ { - auto graph = E_->graph(); - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; - int dimWidth = (int)subBatch->batchWidth(); - - // factored embeddings: - // - regular: - // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] - // - factored: - // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) - // - each row of M contains the set of factors for one word => we want a CSR matrix - // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] - // - first compute x @ M on the CPU - // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): - // - shape (U, specifically) not actually needed here - // - foreach input x[i] - // - locate row M[i,*] - // - copy through its index values (std::vector) - // - create a matching ones vector (we can keep growing) - // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) - // - CSR matrix product with E - // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) - // - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). - // - weighting: - // - core factors' gradients are sums over all words that use the factors; - // - core factors' embeddings move very fast - // - words will need to make up for the move; rare words cannot - // - so, we multiply each factor with 1/refCount - // - core factors get weighed down a lot - // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before - // - but forward pass weighs them down, so that all factors are in a similar numeric range - // - if it is required to be in a different range, the embeddings can still learn that, but more slowly - - auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); -#if 1 - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it - // experimental: hide inline-fix source tokens from cross attention - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); -#endif - // give the graph inputs readable names for debugging and ONNX - batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); - - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { - if (factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - //selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout - return selectedEmbs; - } - else - return applyIndices(toWordIndexVector(words), shape); - } - - Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { - ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); - auto embIdxExpr = E_->graph()->indices(embIdx); - embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? - auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) - if(!inference_) - selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); - return selectedEmbs; - } - - // standard encoder word embeddings - /*private*/ Ptr EncoderDecoderLayerBase::createEmbeddingLayer() const { - auto options = New( - "dimVocab", opt>("dim-vocabs")[batchIndex_], - "dimEmb", opt("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt("tied-embeddings-src") || opt("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, - "vocab", opt>("vocabs")[batchIndex_]); // for factored embeddings - if(options_->hasAndNotEmpty("embedding-vectors")) { - auto embFiles = opt>("embedding-vectors"); - options->set( - "embFile", embFiles[batchIndex_], - "normalization", opt("embedding-normalization")); - } - return New(graph_, options); - } - - // ULR word embeddings - /*private*/ Ptr EncoderDecoderLayerBase::createULREmbeddingLayer() const { - return New(graph_, New( - "dimSrcVoc", opt>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", opt>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", opt("ulr-dim-emb"), - "dimEmb", opt("dim-emb"), - "ulr-dropout", opt("ulr-dropout"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "ulrTrainTransform", opt("ulr-trainable-transformation"), - "ulrQueryFile", opt("ulr-query-vectors"), - "ulrKeysFile", opt("ulr-keys-vectors"))); - } - - // get embedding layer for this encoder or decoder - // This is lazy mostly because the constructors of the consuming objects are not - // guaranteed presently to have access to their graph. - Ptr EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { - if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy - if (embeddingLayers_.size() <= batchIndex_) - embeddingLayers_.resize(batchIndex_ + 1); - if (ulr) - embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR - else - embeddingLayers_[batchIndex_] = createEmbeddingLayer(); - } - return embeddingLayers_[batchIndex_]; - } } // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index f47bb45e2..9fbaea7c7 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -295,146 +295,6 @@ class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { } // namespace mlp -// A regular embedding layer. -// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). -// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in -// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. -class Embedding : public LayerBase, public IEmbeddingLayer { - Expr E_; - Ptr factoredVocab_; - Expr multiRows(const Words& data, float dropProb) const; - bool inference_{false}; - -public: - Embedding(Ptr graph, Ptr options); - - std::tuple apply(Ptr subBatch) const override final; - - Expr apply(const Words& words, const Shape& shape) const override final; - - Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final; -}; - -class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members - bool inference_{false}; - -public: - ULREmbedding(Ptr graph, Ptr options) - : LayerBase(graph, options), inference_(opt("inference")) { - std::string name = "url_embed"; //opt("prefix"); - int dimKeys = opt("dimTgtVoc"); - int dimQueries = opt("dimSrcVoc"); - int dimEmb = opt("dimEmb"); - int dimUlrEmb = opt("dimUlrEmb"); // ULR mono embed size - bool fixed = opt("fixed", false); - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); - - std::string queryFile = opt("ulrQueryFile"); - std::string keyFile = opt("ulrKeysFile"); - bool trainTrans = opt("ulrTrainTransform", false); - if (!queryFile.empty() && !keyFile.empty()) { - initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); - name = "ulr_query"; - fixed = true; - auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(query_embed); - // keys embeds - initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); - name = "ulr_keys"; - fixed = true; - auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(key_embed); - // actual trainable embedding - initFunc = inits::glorotUniform(); - name = "ulr_embed"; - fixed = false; - auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim - ulrEmbeddings_.push_back(ulr_embed); - // init trainable src embedding - name = "ulr_src_embed"; - auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulr_src_embed); - // ulr transformation matrix - //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only - if (trainTrans) { - initFunc = inits::glorotUniform(); - fixed = false; - } - else - { - initFunc = inits::eye(); // identity matrix - fixed = true; - } - name = "ulr_transform"; - auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulrTransform); - - initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only - fixed = true; - name = "ulr_shared"; - auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); - ulrEmbeddings_.push_back(share_embed); - } - } - - std::tuple apply(Ptr subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb - auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = uniEmbed->shape()[-1]; - int dimWords = (int)subBatch->batchWidth(); - // D = K.A.QT - // dimm(K) = univ_tok_vocab*uni_embed_size - // dim A = uni_embed_size*uni_embed_size - // dim Q: uni_embed_size * total_merged_vocab_size - // dim D = univ_tok_vocab * total_merged_vocab_size - // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) - // here we need to handle the mini-batch - // extract raws corresponding to Xs in this minibatch from Q - auto embIdx = toWordIndexVector(subBatch->data()); - auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); - qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity - float dropProb = this->options_->get("ulr-dropout", 0.0f); // default no dropout - if(!inference_) - z = dropout(z, dropProb); - - float tau = this->options_->get("ulr-softmax-temperature", 1.0f); // default no temperature - // temperature in softmax is to control randomness of predictions - // high temperature Softmax outputs are more close to each other - // low temperatures the softmax become more similar to "hardmax" - auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? - auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast - auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); - auto graph = ulrEmbeddings_.front()->graph(); - auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, - inits::fromVector(subBatch->mask())); - if(!inference_) - batchEmbeddings = dropout(batchEmbeddings, options_->get("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr apply(const Words& words, const Shape& shape) const override final { - return applyIndices(toWordIndexVector(words), shape); - } - - Expr applyIndices(const std::vector& embIdx, const Shape& shape) const override final { - embIdx; shape; - ABORT("not implemented"); // @TODO: implement me - } -}; // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object