-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Hieu Hoang
committed
Mar 4, 2021
1 parent
96ed0ba
commit 0d8372c
Showing
6 changed files
with
325 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
#include "embedding.h" | ||
#include "data/factored_vocab.h" | ||
|
||
namespace marian { | ||
|
||
Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) | ||
: LayerBase(graph, options), inference_(opt<bool>("inference")) { | ||
std::string name = opt<std::string>("prefix"); | ||
int dimVoc = opt<int>("dimVocab"); | ||
int dimEmb = opt<int>("dimEmb"); | ||
|
||
bool fixed = opt<bool>("fixed", false); | ||
|
||
factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); | ||
if (factoredVocab_) { | ||
dimVoc = (int)factoredVocab_->factorVocabSize(); | ||
LOG_ONCE(info, "[embedding] Factored embeddings enabled"); | ||
} | ||
|
||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false | ||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length | ||
|
||
if (options_->has("embFile")) { | ||
std::string file = opt<std::string>("embFile"); | ||
if (!file.empty()) { | ||
bool norm = opt<bool>("normalization", false); | ||
initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); | ||
} | ||
} | ||
|
||
E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); | ||
} | ||
|
||
// helper to embed a sequence of words (given as indices) via factored embeddings | ||
Expr Embedding::multiRows(const Words& data, float dropProb) const { | ||
auto graph = E_->graph(); | ||
auto factoredData = factoredVocab_->csr_rows(data); | ||
// multi-hot factor vectors are represented as a sparse CSR matrix | ||
// [row index = word position index] -> set of factor indices for word at this position | ||
ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); | ||
// the CSR matrix is passed in pieces | ||
auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); | ||
auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); | ||
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); | ||
// apply dropout | ||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. | ||
if(!inference_) | ||
weights = dropout(weights, dropProb); | ||
// perform the product | ||
return csr_dot(factoredData.shape, weights, indices, offsets, E_); | ||
} | ||
|
||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ { | ||
auto graph = E_->graph(); | ||
int dimBatch = (int)subBatch->batchSize(); | ||
int dimEmb = E_->shape()[-1]; | ||
int dimWidth = (int)subBatch->batchWidth(); | ||
|
||
// factored embeddings: | ||
// - regular: | ||
// - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] | ||
// - factored: | ||
// - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) | ||
// - each row of M contains the set of factors for one word => we want a CSR matrix | ||
// - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] | ||
// - first compute x @ M on the CPU | ||
// - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): | ||
// - shape (U, specifically) not actually needed here | ||
// - foreach input x[i] | ||
// - locate row M[i,*] | ||
// - copy through its index values (std::vector<push_back>) | ||
// - create a matching ones vector (we can keep growing) | ||
// - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) | ||
// - CSR matrix product with E | ||
// - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) | ||
// - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). | ||
// - weighting: | ||
// - core factors' gradients are sums over all words that use the factors; | ||
// - core factors' embeddings move very fast | ||
// - words will need to make up for the move; rare words cannot | ||
// - so, we multiply each factor with 1/refCount | ||
// - core factors get weighed down a lot | ||
// - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before | ||
// - but forward pass weighs them down, so that all factors are in a similar numeric range | ||
// - if it is required to be in a different range, the embeddings can still learn that, but more slowly | ||
|
||
auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); | ||
#if 1 | ||
auto batchMask = graph->constant({dimWidth, dimBatch, 1}, | ||
inits::fromVector(subBatch->mask())); | ||
#else // @TODO: this is dead code now, get rid of it | ||
// experimental: hide inline-fix source tokens from cross attention | ||
auto batchMask = graph->constant({dimWidth, dimBatch, 1}, | ||
inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); | ||
#endif | ||
// give the graph inputs readable names for debugging and ONNX | ||
batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); | ||
|
||
return std::make_tuple(batchEmbeddings, batchMask); | ||
} | ||
|
||
Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { | ||
if (factoredVocab_) { | ||
Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] | ||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] | ||
//selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout | ||
return selectedEmbs; | ||
} | ||
else | ||
return applyIndices(toWordIndexVector(words), shape); | ||
} | ||
|
||
Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ { | ||
ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); | ||
auto embIdxExpr = E_->graph()->indices(embIdx); | ||
embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? | ||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] | ||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] | ||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) | ||
if(!inference_) | ||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); | ||
return selectedEmbs; | ||
} | ||
|
||
// standard encoder word embeddings | ||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { | ||
auto options = New<Options>( | ||
"dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], | ||
"dimEmb", opt<int>("dim-emb"), | ||
"dropout", dropoutEmbeddings_, | ||
"inference", inference_, | ||
"prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", | ||
"fixed", embeddingFix_, | ||
"vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings | ||
if(options_->hasAndNotEmpty("embedding-vectors")) { | ||
auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); | ||
options->set( | ||
"embFile", embFiles[batchIndex_], | ||
"normalization", opt<bool>("embedding-normalization")); | ||
} | ||
return New<Embedding>(graph_, options); | ||
} | ||
|
||
// ULR word embeddings | ||
/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const { | ||
return New<ULREmbedding>(graph_, New<Options>( | ||
"dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src | ||
"dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt | ||
"dimUlrEmb", opt<int>("ulr-dim-emb"), | ||
"dimEmb", opt<int>("dim-emb"), | ||
"ulr-dropout", opt<float>("ulr-dropout"), | ||
"dropout", dropoutEmbeddings_, | ||
"inference", inference_, | ||
"ulrTrainTransform", opt<bool>("ulr-trainable-transformation"), | ||
"ulrQueryFile", opt<std::string>("ulr-query-vectors"), | ||
"ulrKeysFile", opt<std::string>("ulr-keys-vectors"))); | ||
} | ||
|
||
// get embedding layer for this encoder or decoder | ||
// This is lazy mostly because the constructors of the consuming objects are not | ||
// guaranteed presently to have access to their graph. | ||
Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { | ||
if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy | ||
if (embeddingLayers_.size() <= batchIndex_) | ||
embeddingLayers_.resize(batchIndex_ + 1); | ||
if (ulr) | ||
embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR | ||
else | ||
embeddingLayers_[batchIndex_] = createEmbeddingLayer(); | ||
} | ||
return embeddingLayers_[batchIndex_]; | ||
} | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#pragma once | ||
#include "marian.h" | ||
#include "generic.h" | ||
|
||
namespace marian { | ||
|
||
// A regular embedding layer. | ||
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). | ||
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in | ||
// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. | ||
class Embedding : public LayerBase, public IEmbeddingLayer { | ||
Expr E_; | ||
Ptr<FactoredVocab> factoredVocab_; | ||
Expr multiRows(const Words& data, float dropProb) const; | ||
bool inference_{false}; | ||
|
||
public: | ||
Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options); | ||
|
||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final; | ||
|
||
Expr apply(const Words& words, const Shape& shape) const override final; | ||
|
||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final; | ||
}; | ||
|
||
class ULREmbedding : public LayerBase, public IEmbeddingLayer { | ||
std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members | ||
bool inference_{false}; | ||
|
||
public: | ||
ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) | ||
: LayerBase(graph, options), inference_(opt<bool>("inference")) { | ||
std::string name = "url_embed"; //opt<std::string>("prefix"); | ||
int dimKeys = opt<int>("dimTgtVoc"); | ||
int dimQueries = opt<int>("dimSrcVoc"); | ||
int dimEmb = opt<int>("dimEmb"); | ||
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size | ||
bool fixed = opt<bool>("fixed", false); | ||
|
||
// Embedding layer initialization should depend only on embedding size, hence fanIn=false | ||
auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); | ||
|
||
std::string queryFile = opt<std::string>("ulrQueryFile"); | ||
std::string keyFile = opt<std::string>("ulrKeysFile"); | ||
bool trainTrans = opt<bool>("ulrTrainTransform", false); | ||
if (!queryFile.empty() && !keyFile.empty()) { | ||
initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); | ||
name = "ulr_query"; | ||
fixed = true; | ||
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); | ||
ulrEmbeddings_.push_back(query_embed); | ||
// keys embeds | ||
initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); | ||
name = "ulr_keys"; | ||
fixed = true; | ||
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); | ||
ulrEmbeddings_.push_back(key_embed); | ||
// actual trainable embedding | ||
initFunc = inits::glorotUniform(); | ||
name = "ulr_embed"; | ||
fixed = false; | ||
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim | ||
ulrEmbeddings_.push_back(ulr_embed); | ||
// init trainable src embedding | ||
name = "ulr_src_embed"; | ||
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); | ||
ulrEmbeddings_.push_back(ulr_src_embed); | ||
// ulr transformation matrix | ||
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only | ||
if (trainTrans) { | ||
initFunc = inits::glorotUniform(); | ||
fixed = false; | ||
} | ||
else | ||
{ | ||
initFunc = inits::eye(); // identity matrix | ||
fixed = true; | ||
} | ||
name = "ulr_transform"; | ||
auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); | ||
ulrEmbeddings_.push_back(ulrTransform); | ||
|
||
initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only | ||
fixed = true; | ||
name = "ulr_shared"; | ||
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); | ||
ulrEmbeddings_.push_back(share_embed); | ||
} | ||
} | ||
|
||
std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final { | ||
auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb | ||
auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb | ||
auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb | ||
auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb | ||
auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb | ||
auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 | ||
int dimBatch = (int)subBatch->batchSize(); | ||
int dimEmb = uniEmbed->shape()[-1]; | ||
int dimWords = (int)subBatch->batchWidth(); | ||
// D = K.A.QT | ||
// dimm(K) = univ_tok_vocab*uni_embed_size | ||
// dim A = uni_embed_size*uni_embed_size | ||
// dim Q: uni_embed_size * total_merged_vocab_size | ||
// dim D = univ_tok_vocab * total_merged_vocab_size | ||
// note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) | ||
// here we need to handle the mini-batch | ||
// extract raws corresponding to Xs in this minibatch from Q | ||
auto embIdx = toWordIndexVector(subBatch->data()); | ||
auto queryEmbeddings = rows(queryEmbed, embIdx); | ||
auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings | ||
auto alpha = rows(ulrSharable, embIdx); // extract sharable flags | ||
auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb | ||
auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); | ||
qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes | ||
auto z = dot(qt, keyEmbed, false, true); // query-key similarity | ||
float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout | ||
if(!inference_) | ||
z = dropout(z, dropProb); | ||
|
||
float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature | ||
// temperature in softmax is to control randomness of predictions | ||
// high temperature Softmax outputs are more close to each other | ||
// low temperatures the softmax become more similar to "hardmax" | ||
auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? | ||
auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE | ||
auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast | ||
auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); | ||
auto graph = ulrEmbeddings_.front()->graph(); | ||
auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, | ||
inits::fromVector(subBatch->mask())); | ||
if(!inference_) | ||
batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); | ||
return std::make_tuple(batchEmbeddings, batchMask); | ||
} | ||
|
||
Expr apply(const Words& words, const Shape& shape) const override final { | ||
return applyIndices(toWordIndexVector(words), shape); | ||
} | ||
|
||
Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final { | ||
embIdx; shape; | ||
ABORT("not implemented"); // @TODO: implement me | ||
} | ||
}; | ||
|
||
} |
Oops, something went wrong.