diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 117b5d878..08964ec08 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -310,6 +310,20 @@ Expr atleast_4d(Expr a) { return atleast_nd(a, 4); } +Expr addPosEmbedding(Expr embeddings, float scaleFactor, int startPos) { + int dimEmb = embeddings->shape()[-1]; + int dimWords = embeddings->shape()[-3]; + auto graph = embeddings->graph(); + if (!graph->isInference() || graph->getBackend()->getDeviceId().type == DeviceType::cpu) { + auto signal = graph->constant({dimWords, 1, dimEmb}, + inits::sinusoidalPositionEmbeddings(startPos)); + return scaleFactor * embeddings + signal; + } + + // Mode is GPU inference + return Expression(embeddings, scaleFactor, startPos); +} + Expr atleast_nd(Expr a, size_t dims) { if(a->shape().size() >= dims) return a; diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 2d8283348..19f6eefcc 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -173,6 +173,8 @@ Expr atleast_3d(Expr a); Expr atleast_4d(Expr a); Expr atleast_nd(Expr a, size_t dims); +Expr addPosEmbedding(Expr embeddings, float scaleFactor, int startPos); + // create a constant of shape a->shape() and initialize with init // @TODO: add a && version, to avoid a ref count. NodeInitializers are typically temps. // @TODO: and/or make this a template on init diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 43f38f6fe..9b2f1fafe 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -754,6 +754,54 @@ struct ScalarProductNodeOp : public NaryNodeOp { int axis_; }; +struct PosEmbeddingNodeOp : public NaryNodeOp { + PosEmbeddingNodeOp(Expr embeddings, float scaleFactor, int startPos) + : NaryNodeOp({embeddings}, newShape(embeddings)), + scaleFactor_(scaleFactor), + startPos_(startPos) {} + + Shape newShape(Expr a) { + return a->shape(); + } + + NodeOps forwardOps() override { + using namespace functional; + + return {NodeOp(AddPosEmbeddings(val_, child(0)->val(), scaleFactor_, startPos_))}; + } + + NodeOps backwardOps() override { + ABORT("Not Implemented. Inference Optimization"); + } + + const std::string type() override { return "Add Positional Embedding"; } + + const std::string color() override { return "blue"; } + + virtual size_t hash() override { + size_t seed = NaryNodeOp::hash(); + util::hash_combine(seed, scaleFactor_); + util::hash_combine(seed, startPos_); + return seed; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(scaleFactor_ != cnode->scaleFactor_) + return false; + if(startPos_ != cnode->startPos_) + return false; + return true; + } + + float scaleFactor_; + int startPos_; +}; + struct RowsNodeOp : public NaryNodeOp { RowsNodeOp(Expr a, Expr indices) : NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) { diff --git a/src/models/transformer.h b/src/models/transformer.h index de6304002..59b4366ec 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -31,6 +31,7 @@ class Transformer : public EncoderOrDecoderBase { std::unordered_map cache_; // caching transformation of the encoder that should not be created again mutable/*lazy*/ std::vector sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings + std::unordered_map constantCache_; // attention weights produced by step() // If enabled, it is set once per batch during training, and once per step during translation. // It can be accessed by getAlignments(). @TODO: move into a state or return-value object @@ -85,8 +86,8 @@ class Transformer : public EncoderOrDecoderBase { } else { // @TODO : test if embeddings should be scaled when trainable // according to paper embeddings are scaled up by \sqrt(d_m) - embeddings = std::sqrt((float)dimEmb) * embeddings; // embeddings were initialized to unit length; so norms will be in order of sqrt(dimEmb) + float scaleFactor = std::sqrt((float)dimEmb); #ifdef USE_ONNX // TODO 'Sin' op and constant sine generate different result. So, use constant when 'USE_ONNX' is not defined for now. // precompute the arguments to sin() (the cos(x) are expressed as sin(x+pi/2)) if (sinusoidalEmbeddingsFreq_.empty()) { @@ -101,12 +102,10 @@ class Transformer : public EncoderOrDecoderBase { auto positionRange = graph_->constant({ dimWords, 1, 1 }, inits::range((float)start, (float)start + (float)dimWords)); positionRange->set_name("data_" + std::to_string(batchIndex_) + "_posrange"); auto signal = sin(positionRange * frequencies + cosOffsets); + embeddings = scaleFactor * embeddings + signal; #else // USE_ONNX - auto signal = graph_->constant({dimWords, 1, dimEmb}, - inits::sinusoidalPositionEmbeddings(start)); + embeddings = addPosEmbedding(embeddings, scaleFactor, start); #endif // USE_ONNX - - embeddings = embeddings + signal; } return embeddings; @@ -117,13 +116,27 @@ class Transformer : public EncoderOrDecoderBase { return addPositionalEmbeddings(input, start, trainPosEmbeddings); } - Expr triangleMask(int length) const { - // fill triangle mask - std::vector vMask(length * length, 0); - for(int i = 0; i < length; ++i) - for(int j = 0; j <= i; ++j) - vMask[i * length + j] = 1.f; - return graph_->constant({1, length, length}, inits::fromVector(vMask)); + Expr triangleMask(int length) { + if (inference_ && constantCache_.count("triangleMask") > 0 && + constantCache_["triangleMask"]->shape().elements() == length * length) { + return constantCache_["triangleMask"]; + } else { + // fill triangle mask + std::vector vMask(length * length, 0); + for(int i = 0; i < length; ++i) + for(int j = 0; j <= i; ++j) + vMask[i * length + j] = 1.f; + + Expr e = graph_->constant({1, length, length}, inits::fromVector(vMask)); + if(inference_) { + e->setMemoize(true); + if (constantCache_.count("triangleMask") > 0) { + constantCache_["triangleMask"]->setMemoize(false); + } + constantCache_["triangleMask"] = e; + } + return e; + } } // convert multiplicative 1/0 mask to additive 0/-inf log mask, and transpose to match result of bdot() op in Attention() @@ -714,11 +727,11 @@ class DecoderTransformer : public Transformer { // Used for position embeddings and creating new decoder states. int startPos = (int)state->getPosition(); - auto scaledEmbeddings = addSpecialEmbeddings(embeddings, startPos); + auto scaledEmbeddings = addSpecialEmbeddings(embeddings, startPos); // TODO1 scaledEmbeddings = atleast_nd(scaledEmbeddings, 4); // reorganize batch and timestep - auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] TODO1 auto prevQuery = query; // keep handle to untransformed embeddings, potentially used for a final skip connection diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 4236a1de8..fabdc948e 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -24,6 +24,10 @@ namespace cpu { ABORT("Not implemented"); } +void AddPosEmbeddings(marian::Tensor result, const marian::Tensor& embeddings, float scaleFactor, int startPos) { + ABORT("Not implemented"); +} + template void CopyCastTo(To* out, const From* in, int length) { for(int i = 0; i < length; ++i) diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index 90a6bc667..b9abcaafe 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -1958,6 +1958,7 @@ __global__ void gLNormalization(T* out, } } + void LayerNormalization(Tensor out, Tensor in, Tensor gamma, @@ -2917,5 +2918,54 @@ void PoolingWithMaskingBackward(Tensor adj, width, lastWidth); } + +template +__global__ void gComputeSinusoidalPosEmb(T* result, const T* const embedding, float scaleFactor, int startPos, + int dimWords, int dim2, int dimEmb, int elts) { + + const float numTimescales = (float)dimEmb / 2; + const float logTimescaleIncrement = log(10000.f) / (numTimescales - 1.f); + + const size_t offset = blockDim.x * blockIdx.x + threadIdx.x; + + if (offset < elts) { + // Dim2 of signal is 1 so we ignore it + const int colInSignal = offset % dimEmb; + const int rowInSignal = (offset / (dimEmb * dim2)) % dimWords; + + const int p = rowInSignal + startPos; + const int i = colInSignal % (int) numTimescales; + const float v = p * exp(i * -logTimescaleIncrement); + const T posSignal = colInSignal < (int) numTimescales? sin(v) : cos(v); + const T scaledEmbedding = (T)scaleFactor * embedding[offset]; + result[offset] = scaledEmbedding + posSignal; + } +} + +void AddPosEmbeddings(marian::Tensor result, const marian::Tensor& embeddings, float scaleFactor, int startPos) { + + int dimEmb = embeddings->shape()[-1]; + int broadcastDim = embeddings->shape()[-2]; + int dimWords = embeddings->shape()[-3]; + int elts = embeddings->shape().elements(); + + int threads = std::min(elts, MAX_THREADS); + int blocks = (elts + threads - 1) / threads; + + cudaSetDevice(result->getDeviceId().no); + + if(result->type() == Type::float32) { + gComputeSinusoidalPosEmb<<>>( + result->data(), embeddings->data(), scaleFactor, startPos, dimWords, broadcastDim, dimEmb, elts); +#if COMPILE_FP16 + } else if(result->type() == Type::float16) { + gComputeSinusoidalPosEmb<<>>( + result->data(), embeddings->data(), scaleFactor, startPos, dimWords, broadcastDim, dimEmb, elts); +#endif + } else { + ABORT("gComputeSinusoidalPosEmb not implemented for type {}", result->type()); + } +} + } // namespace gpu } // namespace marian diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index d278803f5..2c7353abf 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -99,6 +99,7 @@ void Reduce(Functor functor, AggFunctor aggFunctor, float aggInit, } // clang-format off +DISPATCH4(AddPosEmbeddings, marian::Tensor, const marian::Tensor&, float, int); DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float) DISPATCH8(ProdBatched, marian::Tensor, Ptr, const marian::Tensor, const marian::Tensor, bool, bool, float, float) DISPATCH9(CSRProd, marian::Tensor, Ptr, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)