-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve checks on transformer cache #881
Conversation
Brought up with @emjotde today, says he will take a look. Might want hash specialization in hash.h. |
f41fb00
to
a2941ff
Compare
The automatic check with Ubuntu 16.04 is just an invalid artifact of previous runs. Ubuntu 16.04 has already been removed and this PR passes all required checks. |
@graemenail Hi, I have to revert this PR internally. It's actually causing a ton of "memory leaks" in the memory allocator during decoding. I wonder how you guys never ran into that. |
I can take a look later how to get that back, but for now this is causing many more bugs than it's solving. |
Hi @emjotde; that's fine - I have no strong feelings about this code. It was only ever meant to be a stopgap until a memoized solution was implemented. About the memory leaks, was it specifically the changes in this PR, or did the previous implementation also suffer? I think this PR will keep more objects in the cache as the tensor shape is now part of the cache key. |
I did a git bisect, and it was this commit. During decoding with a large ensemble there was a growing memory allocation happening, not really a leak, but would result in OOM eventually. |
That sounds like it's the cache key. This cache is internal to transformer, and persists until it leaves scope, which is seemingly too long. Is the sync from internal coming soon? Otherwise I'll patch this today to be more like the old implementation, but retain the check on shape. |
Yes, about to sync now. It was this issue that made me delay the sync since I thought I introduced that with something internally. I reverted already, was easy enough considering how local this PR is. We can then just re-open and see if we get that under control. I have a good testcase now (cannot share unfortunately but can run it). |
This PR reverts changes to transformer caching (public PR #881) It seems to cause catastrophic memory leaks or incorrect de-allocation during decoding.
Synced. I think I will do a release now too. |
Thanks @emjotde - sorry for the headache! We can revisit the caching, I'll try to dig up the model we had the issue with. |
No biggie. I will actually wait with the release until internal engineering confirms all the production test cases run smoothly, i.e. in a day or two. |
Description
This PR strengthens the check on the cache in transformer attention, and improves cache access.
As background to this, @fiqas and I were encountering segfaults in some transformer models, on a custom branch, during decoding on both GPU and CPU. This emerged during a call to
suppressWords
in whichsuppressedWordIndices
had been corrupted, and pointed to an index beyond the vocabulary size. We traced this issue back to abdot(q,k,...)
operation in Attention where the batch dimension ofq
is smaller than that ofk
. In MultiHead, the former input is not cached, while the latter input is cached, and is currently only recomputed when the total number of elements of a relevant input change. We were unlucky to encounter a situation in which a cache value was kept.The memory corruption seems to have occured because in the node operation for bdot
bdot(a,b,...)
, it is implicitly assumed thata
has the larger (or equal) batch dimension, and this is used when setting the resulting shape. The tensor operation ProdBatched takes the maximum of the batch sizes. Ideally, I would add some logic here to deal with this in some way, but I notice that these have been moved to_legacy
in favour of a new implementation.While I have not seen this arise on master, replacing the check on the number of elements to a check on the input shape from which the cache entry was computed is a stronger requirement, and should remove such calls to bdot.
List of changes:
Added dependencies: none
How to test
I tested the fix on the impacted model on our branch. I also ran the regression tests on this PR, after updating the expected outputs to match those obtained from current master on the same machine.
Checklist