-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Flash Attention for AMD/ROCm #112997
Comments
We are working on it internally. We will update upstream after that. |
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes #112997 Pull Request resolved: #114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <[email protected]>
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <[email protected]>
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes pytorch#112997 Pull Request resolved: pytorch#114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <[email protected]>
Note about the Updates: This PR: 1. skips more flash attention related UTs on MI200 2. Fix additional ATen compiling errors after hipification 3. Fix the author "root" of a specific commit 4. Includes the patch from Nikita in favor of block level static initialization. CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge. Original PR (#114309) Note: This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - Only supports power of two sequence lengths. - No support for varlen APIs. - Only support head dimension 16,32,64,128. - Performance is still being optimized. Fixes #112997 Pull Request resolved: #115981 Approved by: https://github.com/malfet
This patch addresses the major limitations in our previous [PR #115981](#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in the next release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to #112997 Pull Request resolved: #121561 Approved by: https://github.com/malfet, https://github.com/atalman
@jeffdaily @malfet Are the torch nightlies for rocm getting built with the correct flags for SDPA support? I am still getting this error with the latest
|
Hi @chauhang , If you're referring to memory efficient attention this is expected. This PR only adds Flash Attention to PyTorch and we are working on Memory efficient attention support now. |
This patch addresses the major limitations in our previous [PR #115981](#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in future release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to #112997 Pull Request resolved: #121561 Approved by: https://github.com/huydhn
This patch addresses the major limitations in our previous [PR pytorch#115981](pytorch#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in future release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to pytorch#112997 Pull Request resolved: pytorch#121561 Approved by: https://github.com/huydhn
@xinyazhang @jithunnair-amd Hi the default flash_attention backward is using the kernel from aotriton bwd_kernel_dk_dv now, https://github.com/ROCm/aotriton/blob/6035fb4aa021d39a18adce883e059d5b99a8192c/v2python/rules/flash/bwd_kernel_dk_dv.py#L7. it's very low efficiency, any optimization configuration or usage should us do to improve the performance? |
@xinyazhang Any updates on implementing memory efficient attention? |
ME and FA are fundamentally the same algorithm but providing different feature sets and on ROCm they are both implemented in AOTrition. The support is already included in PyTorch 2.4 #124885 |
AOTriton 0.7.1 https://github.com/ROCm/aotriton/releases/tag/0.7.1b made some improvements to the bwd kernel, which is ABI compatible (and only compatible) with AOTriton 0.7b and can be used as drop-in replacement for official PyTorch 2.5 wheel. |
close it as this is supported now |
🚀 The feature, motivation and pitch
Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.
At present using these gives below warning with latest nightlies (torch==2.2.0.dev20231105+rocm5.6, pytorch-triton-rocm==2.1.0+34f8189eae):
ROCm already has an implementation of Tri's FA here: https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support
Alternatives
User's have to manually install the ROCm version of FA and use that in their code, vs using the native PyTorch APIs.
Additional context
The ROCM build currently has the FA related flags turned off by default: https://github.com/pytorch/pytorch/blob/main/CMakeLists.txt#L741-L750
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang
The text was updated successfully, but these errors were encountered: