Skip to content

Commit

Permalink
Check size on transformer cache
Browse files Browse the repository at this point in the history
  • Loading branch information
graemenail committed Jun 8, 2022
1 parent e27da62 commit a817edd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class Transformer : public EncoderOrDecoderBase {
// memoization propagation (short-term)
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
&& cache_[prefix + "_keys"]->shape() == keys->shape()) { // and the underlying shape did not change
kh = cache_[prefix + "_keys"]; // then return cached tensor
}
else {
Expand All @@ -306,7 +306,7 @@ class Transformer : public EncoderOrDecoderBase {
Expr vh;
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
&& cache_[prefix + "_values"]->shape() == values->shape()) {
vh = cache_[prefix + "_values"];
} else {
int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation
Expand Down

0 comments on commit a817edd

Please sign in to comment.