Skip to content

Commit

Permalink
This commit slightly changes the output for certain models. It introd…
Browse files Browse the repository at this point in the history
…uces a new operator to compute the sinusoidal embeddings on the GPU instead of doing CPU work. This is the reason for the differences since there are small differences in the float results for the positional embeddings. It also caches the triangle count mask for the transformer when doing inference to reduce device-hsot communication. This change should not affect the transofrmer output.
  • Loading branch information
rhenry-nv committed Oct 10, 2020
1 parent 201dc0a commit 5b889da
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 14 deletions.
14 changes: 14 additions & 0 deletions src/graph/expression_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ Expr atleast_4d(Expr a) {
return atleast_nd(a, 4);
}

Expr addPosEmbedding(Expr embeddings, float scaleFactor, int startPos) {
int dimEmb = embeddings->shape()[-1];
int dimWords = embeddings->shape()[-3];
auto graph = embeddings->graph();
if (!graph->isInference() || graph->getBackend()->getDeviceId().type == DeviceType::cpu) {
auto signal = graph->constant({dimWords, 1, dimEmb},
inits::sinusoidalPositionEmbeddings(startPos));
return scaleFactor * embeddings + signal;
}

// Mode is GPU inference
return Expression<PosEmbeddingNodeOp>(embeddings, scaleFactor, startPos);
}

Expr atleast_nd(Expr a, size_t dims) {
if(a->shape().size() >= dims)
return a;
Expand Down
2 changes: 2 additions & 0 deletions src/graph/expression_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ Expr atleast_3d(Expr a);
Expr atleast_4d(Expr a);
Expr atleast_nd(Expr a, size_t dims);

Expr addPosEmbedding(Expr embeddings, float scaleFactor, int startPos);

// create a constant of shape a->shape() and initialize with init
// @TODO: add a && version, to avoid a ref count. NodeInitializers are typically temps.
// @TODO: and/or make this a template on init
Expand Down
48 changes: 48 additions & 0 deletions src/graph/node_operators_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,54 @@ struct ScalarProductNodeOp : public NaryNodeOp {
int axis_;
};

struct PosEmbeddingNodeOp : public NaryNodeOp {
PosEmbeddingNodeOp(Expr embeddings, float scaleFactor, int startPos)
: NaryNodeOp({embeddings}, newShape(embeddings)),
scaleFactor_(scaleFactor),
startPos_(startPos) {}

Shape newShape(Expr a) {
return a->shape();
}

NodeOps forwardOps() override {
using namespace functional;

return {NodeOp(AddPosEmbeddings(val_, child(0)->val(), scaleFactor_, startPos_))};
}

NodeOps backwardOps() override {
ABORT("Not Implemented. Inference Optimization");
}

const std::string type() override { return "Add Positional Embedding"; }

const std::string color() override { return "blue"; }

virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, scaleFactor_);
util::hash_combine(seed, startPos_);
return seed;
}

virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<PosEmbeddingNodeOp>(node);
if(!cnode)
return false;
if(scaleFactor_ != cnode->scaleFactor_)
return false;
if(startPos_ != cnode->startPos_)
return false;
return true;
}

float scaleFactor_;
int startPos_;
};

struct RowsNodeOp : public NaryNodeOp {
RowsNodeOp(Expr a, Expr indices)
: NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
Expand Down
41 changes: 27 additions & 14 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Transformer : public EncoderOrDecoderBase {
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings

std::unordered_map<std::string, Expr> constantCache_;
// attention weights produced by step()
// If enabled, it is set once per batch during training, and once per step during translation.
// It can be accessed by getAlignments(). @TODO: move into a state or return-value object
Expand Down Expand Up @@ -85,8 +86,8 @@ class Transformer : public EncoderOrDecoderBase {
} else {
// @TODO : test if embeddings should be scaled when trainable
// according to paper embeddings are scaled up by \sqrt(d_m)
embeddings = std::sqrt((float)dimEmb) * embeddings; // embeddings were initialized to unit length; so norms will be in order of sqrt(dimEmb)

float scaleFactor = std::sqrt((float)dimEmb);
#ifdef USE_ONNX // TODO 'Sin' op and constant sine generate different result. So, use constant when 'USE_ONNX' is not defined for now.
// precompute the arguments to sin() (the cos(x) are expressed as sin(x+pi/2))
if (sinusoidalEmbeddingsFreq_.empty()) {
Expand All @@ -101,12 +102,10 @@ class Transformer : public EncoderOrDecoderBase {
auto positionRange = graph_->constant({ dimWords, 1, 1 }, inits::range((float)start, (float)start + (float)dimWords));
positionRange->set_name("data_" + std::to_string(batchIndex_) + "_posrange");
auto signal = sin(positionRange * frequencies + cosOffsets);
embeddings = scaleFactor * embeddings + signal;
#else // USE_ONNX
auto signal = graph_->constant({dimWords, 1, dimEmb},
inits::sinusoidalPositionEmbeddings(start));
embeddings = addPosEmbedding(embeddings, scaleFactor, start);
#endif // USE_ONNX

embeddings = embeddings + signal;
}

return embeddings;
Expand All @@ -117,13 +116,27 @@ class Transformer : public EncoderOrDecoderBase {
return addPositionalEmbeddings(input, start, trainPosEmbeddings);
}

Expr triangleMask(int length) const {
// fill triangle mask
std::vector<float> vMask(length * length, 0);
for(int i = 0; i < length; ++i)
for(int j = 0; j <= i; ++j)
vMask[i * length + j] = 1.f;
return graph_->constant({1, length, length}, inits::fromVector(vMask));
Expr triangleMask(int length) {
if (inference_ && constantCache_.count("triangleMask") > 0 &&
constantCache_["triangleMask"]->shape().elements() == length * length) {
return constantCache_["triangleMask"];
} else {
// fill triangle mask
std::vector<float> vMask(length * length, 0);
for(int i = 0; i < length; ++i)
for(int j = 0; j <= i; ++j)
vMask[i * length + j] = 1.f;

Expr e = graph_->constant({1, length, length}, inits::fromVector(vMask));
if(inference_) {
e->setMemoize(true);
if (constantCache_.count("triangleMask") > 0) {
constantCache_["triangleMask"]->setMemoize(false);
}
constantCache_["triangleMask"] = e;
}
return e;
}
}

// convert multiplicative 1/0 mask to additive 0/-inf log mask, and transpose to match result of bdot() op in Attention()
Expand Down Expand Up @@ -714,11 +727,11 @@ class DecoderTransformer : public Transformer<DecoderBase> {
// Used for position embeddings and creating new decoder states.
int startPos = (int)state->getPosition();

auto scaledEmbeddings = addSpecialEmbeddings(embeddings, startPos);
auto scaledEmbeddings = addSpecialEmbeddings(embeddings, startPos); // TODO1
scaledEmbeddings = atleast_nd(scaledEmbeddings, 4);

// reorganize batch and timestep
auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] TODO1

auto prevQuery = query; // keep handle to untransformed embeddings, potentially used for a final skip connection

Expand Down
4 changes: 4 additions & 0 deletions src/tensors/cpu/tensor_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ namespace cpu {
ABORT("Not implemented");
}

void AddPosEmbeddings(marian::Tensor result, const marian::Tensor& embeddings, float scaleFactor, int startPos) {
ABORT("Not implemented");
}

template <typename To, typename From>
void CopyCastTo(To* out, const From* in, int length) {
for(int i = 0; i < length; ++i)
Expand Down
50 changes: 50 additions & 0 deletions src/tensors/gpu/tensor_operators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1958,6 +1958,7 @@ __global__ void gLNormalization(T* out,
}
}


void LayerNormalization(Tensor out,
Tensor in,
Tensor gamma,
Expand Down Expand Up @@ -2917,5 +2918,54 @@ void PoolingWithMaskingBackward(Tensor adj,
width,
lastWidth);
}

template<typename T>
__global__ void gComputeSinusoidalPosEmb(T* result, const T* const embedding, float scaleFactor, int startPos,
int dimWords, int dim2, int dimEmb, int elts) {

const float numTimescales = (float)dimEmb / 2;
const float logTimescaleIncrement = log(10000.f) / (numTimescales - 1.f);

const size_t offset = blockDim.x * blockIdx.x + threadIdx.x;

if (offset < elts) {
// Dim2 of signal is 1 so we ignore it
const int colInSignal = offset % dimEmb;
const int rowInSignal = (offset / (dimEmb * dim2)) % dimWords;

const int p = rowInSignal + startPos;
const int i = colInSignal % (int) numTimescales;
const float v = p * exp(i * -logTimescaleIncrement);
const T posSignal = colInSignal < (int) numTimescales? sin(v) : cos(v);
const T scaledEmbedding = (T)scaleFactor * embedding[offset];
result[offset] = scaledEmbedding + posSignal;
}
}

void AddPosEmbeddings(marian::Tensor result, const marian::Tensor& embeddings, float scaleFactor, int startPos) {

int dimEmb = embeddings->shape()[-1];
int broadcastDim = embeddings->shape()[-2];
int dimWords = embeddings->shape()[-3];
int elts = embeddings->shape().elements();

int threads = std::min(elts, MAX_THREADS);
int blocks = (elts + threads - 1) / threads;

cudaSetDevice(result->getDeviceId().no);

if(result->type() == Type::float32) {
gComputeSinusoidalPosEmb<<<blocks, threads>>>(
result->data<float>(), embeddings->data<float>(), scaleFactor, startPos, dimWords, broadcastDim, dimEmb, elts);
#if COMPILE_FP16
} else if(result->type() == Type::float16) {
gComputeSinusoidalPosEmb<<<blocks, threads>>>(
result->data<half>(), embeddings->data<half>(), scaleFactor, startPos, dimWords, broadcastDim, dimEmb, elts);
#endif
} else {
ABORT("gComputeSinusoidalPosEmb not implemented for type {}", result->type());
}
}

} // namespace gpu
} // namespace marian
1 change: 1 addition & 0 deletions src/tensors/tensor_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ void Reduce(Functor functor, AggFunctor aggFunctor, float aggInit,
}

// clang-format off
DISPATCH4(AddPosEmbeddings, marian::Tensor, const marian::Tensor&, float, int);
DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
Expand Down

0 comments on commit 5b889da

Please sign in to comment.