Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change workspace size in fmha backward #1028

Conversation

alihassanijr
Copy link
Contributor

@alihassanijr alihassanijr commented Apr 16, 2024

What does this PR do?

TLDR; I think the FMHA backward kernel uses more scratch memory than it needs, aside from padding due to 128-bit alignment and tile sizes.

FMHA backward's scratch space for gK and gV is set up to be: num_k_splits * align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI) repeated over batch and heads.

Given that each CTA computes unique tiles of gK and gV, this means for every gK/gV tile, align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI) accum elements are reserved.

I might be totally off, but my understanding is that the gK and gV accumulator pointers aren't even offset by something relating to key_start, which means the same kBlockSizeJ rows will be reused over an over.

This means that align_up(num_keys, kBlockSizeJ) can be replaced with just kBlockSizeJ.

My own use case works fine and passes a memcheck with this change. All unit tests passed for me locally (excluding the recent torch.compile test in test_mem_eff_attention; I think I need to be on torch nightly?).

Kernel archtags tested:

  • SM50
    • SM61
  • SM70
  • SM75
  • SM80:
    • SM80
    • SM86

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@danthe3rd

FMHA backward's scratch space for gK and gV is set up to be:
num_k_splits * align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI)

repeated over batch and heads.

Given that each CTA computes unique tiles of gK and gV, this means for
every gK/gV tile, align_up(num_keys, kBlockSizeJ) * align_up(dim, kBlockSizeI)
accum elements are reserved.

I might be totally off, but my understanding is that the gK and gV
accumulator pointers aren't even offset by something relating to
key_start, which means the same kBlockSizeJ rows will be reused over an
over.

This means that `align_up(num_keys, kBlockSizeJ)` can be replaced with
just kBlockSizeJ.

My own use case works fine and passes a memcheck with this change.
All unit tests pass (excluding the recent torch.compile
test in test_mem_eff_attention; I think I need to be on torch nightly?).
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2024
@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 59.92%. Comparing base (5d59023) to head (071b1f0).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1028   +/-   ##
=======================================
  Coverage   59.92%   59.92%           
=======================================
  Files         113      113           
  Lines       10007    10007           
=======================================
  Hits         5997     5997           
  Misses       4010     4010           
Flag Coverage Δ
Python 59.92% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That's a nice catch :o
Indeed this memory over-allocated is never used in the kernel. Let me triple check by running some internal tests and I'll merge this :)
cc @drisspg

@danthe3rd
Copy link
Contributor

All tests pass - merging
Thanks a lot for spotting and submitting a fix!

@danthe3rd danthe3rd merged commit f663712 into facebookresearch:main Apr 16, 2024
8 of 9 checks passed
@alihassanijr alihassanijr deleted the fmha-backward-gK-gV-workspace-size branch April 16, 2024 15:30
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 5, 2024
Backporting a few fixes from xFormers:
* Bug fixes for local attention (which is not exposed in PT at the moment)
* Massively reduced memory usage on the BW pass (see also facebookresearch/xformers#1028)

Essentially this will also make xFormers build process much easier, as we will be able to use mem-eff from PyTorch (if the user has a recent enough version) rather than building it at xFormers install time
The goal is to have the source of truth for these files in PT moving forward, and remove them from xFormers eventually once our users have a recent-enough version of PT.
Pull Request resolved: #127090
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants