-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change workspace size in fmha backward #1028
Change workspace size in fmha backward #1028
Conversation
FMHA backward's scratch space for gK and gV is set up to be: num_k_splits * align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI) repeated over batch and heads. Given that each CTA computes unique tiles of gK and gV, this means for every gK/gV tile, align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI) accum elements are reserved. I might be totally off, but my understanding is that the gK and gV accumulator pointers aren't even offset by something relating to key_start, which means the same kBlockSizeJ rows will be reused over an over. This means that `align_up(num_keys, kBlockSizeJ)` can be replaced with just kBlockSizeJ. My own use case works fine and passes a memcheck with this change. All unit tests pass (excluding the recent torch.compile test in test_mem_eff_attention; I think I need to be on torch nightly?).
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1028 +/- ##
=======================================
Coverage 59.92% 59.92%
=======================================
Files 113 113
Lines 10007 10007
=======================================
Hits 5997 5997
Misses 4010 4010
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! That's a nice catch :o
Indeed this memory over-allocated is never used in the kernel. Let me triple check by running some internal tests and I'll merge this :)
cc @drisspg
All tests pass - merging |
Backporting a few fixes from xFormers: * Bug fixes for local attention (which is not exposed in PT at the moment) * Massively reduced memory usage on the BW pass (see also facebookresearch/xformers#1028) Essentially this will also make xFormers build process much easier, as we will be able to use mem-eff from PyTorch (if the user has a recent enough version) rather than building it at xFormers install time The goal is to have the source of truth for these files in PT moving forward, and remove them from xFormers eventually once our users have a recent-enough version of PT. Pull Request resolved: #127090 Approved by: https://github.com/drisspg
What does this PR do?
TLDR; I think the FMHA backward kernel uses more scratch memory than it needs, aside from padding due to 128-bit alignment and tile sizes.
FMHA backward's scratch space for gK and gV is set up to be: num_k_splits * align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI) repeated over batch and heads.
Given that each CTA computes unique tiles of gK and gV, this means for every gK/gV tile,
align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI)
accum elements are reserved.I might be totally off, but my understanding is that the gK and gV accumulator pointers aren't even offset by something relating to key_start, which means the same
kBlockSizeJ
rows will be reused over an over.This means that
align_up(num_keys, kBlockSizeJ)
can be replaced with justkBlockSizeJ
.My own use case works fine and passes a memcheck with this change. All unit tests passed for me locally (excluding the recent torch.compile test in test_mem_eff_attention; I think I need to be on torch nightly?).
Kernel archtags tested:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
@danthe3rd