Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve checks on transformer cache #881

Merged
merged 5 commits into from
Jan 24, 2022

Conversation

graemenail
Copy link
Member

@graemenail graemenail commented Sep 15, 2021

Description

This PR strengthens the check on the cache in transformer attention, and improves cache access.

As background to this, @fiqas and I were encountering segfaults in some transformer models, on a custom branch, during decoding on both GPU and CPU. This emerged during a call tosuppressWords in which suppressedWordIndices had been corrupted, and pointed to an index beyond the vocabulary size. We traced this issue back to a bdot(q,k,...) operation in Attention where the batch dimension of q is smaller than that of k. In MultiHead, the former input is not cached, while the latter input is cached, and is currently only recomputed when the total number of elements of a relevant input change. We were unlucky to encounter a situation in which a cache value was kept.

The memory corruption seems to have occured because in the node operation for bdot bdot(a,b,...), it is implicitly assumed that a has the larger (or equal) batch dimension, and this is used when setting the resulting shape. The tensor operation ProdBatched takes the maximum of the batch sizes. Ideally, I would add some logic here to deal with this in some way, but I notice that these have been moved to _legacy in favour of a new implementation.

While I have not seen this arise on master, replacing the check on the number of elements to a check on the input shape from which the cache entry was computed is a stronger requirement, and should remove such calls to bdot.

List of changes:

  • Transformer attention cache is checked against the shape of the input
  • Improved access to the cache
  • Remove trailing whitespaces

Added dependencies: none

How to test

I tested the fix on the impacted model on our branch. I also ran the regression tests on this PR, after updating the expected outputs to match those obtained from current master on the same machine.

Checklist

  • I have tested the code manually
  • I have run regression tests
  • I have read and followed CONTRIBUTING.md
  • I have updated CHANGELOG.md

@kpu
Copy link
Member

kpu commented Oct 4, 2021

Brought up with @emjotde today, says he will take a look. Might want hash specialization in hash.h.

@kpu kpu requested a review from emjotde November 1, 2021 17:24
@snukky
Copy link
Member

snukky commented Nov 1, 2021

The automatic check with Ubuntu 16.04 is just an invalid artifact of previous runs. Ubuntu 16.04 has already been removed and this PR passes all required checks.

@emjotde emjotde self-assigned this Nov 1, 2021
@snukky snukky merged commit 894a07a into marian-nmt:master Jan 24, 2022
@emjotde
Copy link
Member

emjotde commented May 29, 2022

@graemenail Hi, I have to revert this PR internally. It's actually causing a ton of "memory leaks" in the memory allocator during decoding. I wonder how you guys never ran into that.

@emjotde
Copy link
Member

emjotde commented May 29, 2022

I can take a look later how to get that back, but for now this is causing many more bugs than it's solving.

@graemenail
Copy link
Member Author

Hi @emjotde; that's fine - I have no strong feelings about this code. It was only ever meant to be a stopgap until a memoized solution was implemented.

About the memory leaks, was it specifically the changes in this PR, or did the previous implementation also suffer? I think this PR will keep more objects in the cache as the tensor shape is now part of the cache key.

@emjotde
Copy link
Member

emjotde commented May 30, 2022

I did a git bisect, and it was this commit. During decoding with a large ensemble there was a growing memory allocation happening, not really a leak, but would result in OOM eventually.

@graemenail
Copy link
Member Author

That sounds like it's the cache key. This cache is internal to transformer, and persists until it leaves scope, which is seemingly too long.

Is the sync from internal coming soon? Otherwise I'll patch this today to be more like the old implementation, but retain the check on shape.

@emjotde
Copy link
Member

emjotde commented May 30, 2022

Yes, about to sync now. It was this issue that made me delay the sync since I thought I introduced that with something internally. I reverted already, was easy enough considering how local this PR is. We can then just re-open and see if we get that under control. I have a good testcase now (cannot share unfortunately but can run it).

emjotde added a commit that referenced this pull request May 30, 2022
This PR reverts changes to transformer caching (public PR #881)

It seems to cause catastrophic memory leaks or incorrect de-allocation during decoding.
@emjotde
Copy link
Member

emjotde commented May 30, 2022

Synced. I think I will do a release now too.

@graemenail
Copy link
Member Author

Thanks @emjotde - sorry for the headache! We can revisit the caching, I'll try to dig up the model we had the issue with.

@emjotde
Copy link
Member

emjotde commented May 30, 2022

No biggie. I will actually wait with the release until internal engineering confirms all the production test cases run smoothly, i.e. in a day or two.

graemenail added a commit to graemenail/marian-dev that referenced this pull request Jun 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants