Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Flash Attention for AMD/ROCm #112997

Closed
chauhang opened this issue Nov 5, 2023 · 10 comments · Fixed by #114309
Closed

Add support for Flash Attention for AMD/ROCm #112997

chauhang opened this issue Nov 5, 2023 · 10 comments · Fixed by #114309
Labels
ciflow/rocm module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chauhang
Copy link
Contributor

chauhang commented Nov 5, 2023

🚀 The feature, motivation and pitch

Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.

At present using these gives below warning with latest nightlies (torch==2.2.0.dev20231105+rocm5.6, pytorch-triton-rocm==2.1.0+34f8189eae):

model.py:187: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:253.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
model.py:187: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:291.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

ROCm already has an implementation of Tri's FA here: https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support

Alternatives

User's have to manually install the ROCm version of FA and use that in their code, vs using the native PyTorch APIs.

Additional context

The ROCM build currently has the FA related flags turned off by default: https://github.com/pytorch/pytorch/blob/main/CMakeLists.txt#L741-L750

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

@pytorch-bot pytorch-bot bot added ciflow/rocm module: rocm AMD GPU support for Pytorch labels Nov 5, 2023
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 6, 2023
@hongxiayang
Copy link
Collaborator

We are working on it internally. We will update upstream after that.

@hongxiayang hongxiayang moved this to Todo in PyTorch on ROCm Nov 7, 2023
@github-project-automation github-project-automation bot moved this from Todo to Done in PyTorch on ROCm Dec 14, 2023
malfet pushed a commit that referenced this issue Dec 14, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes #112997

Pull Request resolved: #114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <[email protected]>
guilhermeleobas pushed a commit to guilhermeleobas/pytorch that referenced this issue Dec 18, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <[email protected]>
dmenig pushed a commit to dmenig/pytorch that referenced this issue Dec 21, 2023
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <[email protected]>
pytorchmergebot pushed a commit that referenced this issue Jan 4, 2024
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (#114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: #115981
Approved by: https://github.com/malfet
pytorchmergebot pushed a commit that referenced this issue Mar 12, 2024
This patch addresses the major limitations in our previous [PR #115981](#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: #121561
Approved by: https://github.com/malfet, https://github.com/atalman
@chauhang
Copy link
Contributor Author

chauhang commented Mar 24, 2024

@jeffdaily @malfet Are the torch nightlies for rocm getting built with the correct flags for SDPA support? I am still getting this error with the latest torch==2.4.0.dev20240323+rocm6.0

UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:515.)
  output = nn.functional.scaled_dot_product_attention(
[pip3] pytorch-triton-rocm==3.0.0+0a22a91d04
[pip3] torch==2.4.0.dev20240323+rocm6.0
[conda] pytorch-triton-rocm       3.0.0+0a22a91d04          pypi_0    pypi
[conda] torch                     2.4.0.dev20240323+rocm6.0          pypi_0    pypi

@xinyazhang
Copy link
Contributor

Are the torch nightlies for rocm getting built with the correct flags for SDPA support? I am still getting this error with the latest torch==2.4.0.dev20240323+rocm6.0

Hi @chauhang , If you're referring to memory efficient attention this is expected. This PR only adds Flash Attention to PyTorch and we are working on Memory efficient attention support now.

@jithunnair-amd
Copy link
Collaborator

@chauhang Since this issue is regarding FA support for ROCm in nightlies, and PR #121561 has been reverted, I'm reopening this issue.

pytorchmergebot pushed a commit that referenced this issue Mar 28, 2024
This patch addresses the major limitations in our previous [PR #115981](#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: #121561
Approved by: https://github.com/huydhn
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this issue Apr 22, 2024
This patch addresses the major limitations in our previous [PR pytorch#115981](pytorch#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to pytorch#112997

Pull Request resolved: pytorch#121561
Approved by: https://github.com/huydhn
@jinsong-mao
Copy link

@xinyazhang @jithunnair-amd Hi the default flash_attention backward is using the kernel from aotriton bwd_kernel_dk_dv now, https://github.com/ROCm/aotriton/blob/6035fb4aa021d39a18adce883e059d5b99a8192c/v2python/rules/flash/bwd_kernel_dk_dv.py#L7. it's very low efficiency, any optimization configuration or usage should us do to improve the performance?

@jeffdaily
Copy link
Collaborator

CC @groenenboomj

@maxall41
Copy link

@xinyazhang Any updates on implementing memory efficient attention?

@xinyazhang
Copy link
Contributor

@xinyazhang Any updates on implementing memory efficient attention?

ME and FA are fundamentally the same algorithm but providing different feature sets and on ROCm they are both implemented in AOTrition. The support is already included in PyTorch 2.4 #124885

@xinyazhang
Copy link
Contributor

it's very low efficiency, any optimization configuration or usage should us do to improve the performance?

AOTriton 0.7.1 https://github.com/ROCm/aotriton/releases/tag/0.7.1b made some improvements to the bwd kernel, which is ABI compatible (and only compatible) with AOTriton 0.7b and can be used as drop-in replacement for official PyTorch 2.5 wheel.

@hongxiayang
Copy link
Collaborator

close it as this is supported now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
8 participants