Skip to content

Commit

Permalink
Check shapes on transformer cache
Browse files Browse the repository at this point in the history
  • Loading branch information
graemenail committed Aug 24, 2022
1 parent e27da62 commit d186d94
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase {

protected:
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
std::unordered_map<std::string, std::pair<Shape, Expr>> cache_; // caching transformation of the encoder that should not be created again
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings

bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth);
Expand Down Expand Up @@ -288,10 +288,10 @@ class Transformer : public EncoderOrDecoderBase {
// Caching transformation of the encoder that should not be created again.
// @TODO: set this automatically by memoizing encoder context and
// memoization propagation (short-term)
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
kh = cache_[prefix + "_keys"]; // then return cached tensor
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"].first == keys->shape()) { // and the underlying element size did not change
kh = cache_[prefix + "_keys"].second; // then return cached tensor
}
else {
int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
Expand All @@ -300,22 +300,22 @@ class Transformer : public EncoderOrDecoderBase {

kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
cache_[prefix + "_keys"] = kh;
cache_[prefix + "_keys"] = std::make_pair(keys->shape(), kh);
}

Expr vh;
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
vh = cache_[prefix + "_values"];
&& cache_[prefix + "_values"].first == values->shape()) {
vh = cache_[prefix + "_values"].second;
} else {
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros());

vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
vh = SplitHeads(vh, dimHeads);
cache_[prefix + "_values"] = vh;
cache_[prefix + "_values"] = std::make_pair(values->shape(), vh);
}

int dimBeam = q->shape()[-4];
Expand Down

0 comments on commit d186d94

Please sign in to comment.