Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Multi-ahead Attention support on AMD ROCM #977

Closed
wants to merge 480 commits into from

Conversation

qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Feb 7, 2024

This PR adds three flash-attention implementation for AMD ROCM

  1. Generic FMHA forward based on composable_kernel kernel components
  2. decoder FMHA forward directly implemented in HIP kernel
  3. triton FMHA forward operation based on triton

In more details, the following codes are added in this PR

  1. Xformers Operator and its C++ implementation for Generic FMHA forward as well as the underlying composable_kernel_tiled submodule

    xformers/ops/fmha/ck.py
    xformers/csrc/attention/hip_fmha/
    thirty_party/composable_kernel_tiled/

  2. Xformers Operator and its C++ implementation for decoder FMHA forward

    xformers/ops/fmha/ck_decoder.py, ck_splitk.py
    xformers/csrc/attention/hip_fmha/

  3. Xformers Operator for triton FMHA forward

    xformers/ops/fmha/triton.py

The following scripts are used to verify the implementation

#> pytest tests/test_mem_eff_attention.py::test_forward
#> pytest tests/test_mem_eff_attention.py::test_mqa_forwrd
#> pytest tests/test_mem_eff_attention.py::test_decoder
#> pytest tests/test_mem_eff_attention.py::test_splitk_decoder
#> pytest  tests/test_mem_eff_attention.py::test_splitk_reference
#> pytest tests/test_mem_eff_attention.py::test_triton_splitk_decoder

The following scripts are used to benchmark the performance of the implementation

#> python xformers/benchmarks/benchmark_mem_eff_attention.py
#> python xformers/benchmarks/benchmark_mem_eff_attention_mqa.py
#> python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
#> python xformers/benchmarks/benchmark_attn_decoding.py

qianfengz and others added 30 commits November 20, 2023 19:01
…mark-attn-decoding

add benchmark_attn_decoding from upstream xformers; run ck fw ops for decoding
@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm labels Feb 7, 2024
@@ -0,0 +1,12 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this file?

@@ -3,63 +3,430 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this file is to implement the flash attention here rather than importing from third_party/flash-attention/flash_attn/flash_attn_triton.py, is that right? Is there any amd specific optimization here (btw it's great to import more masks than just the lowerTriangular) or it's similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is from a patch PRed to http://github.com/ROCmSoftwarePlatform/xformers by [email protected].

But we did not keep the PR records because we just re-build the repo.

I will re-submit PR from http://github.com/ROCmSoftwarePlatform/xformers branch rather than from my personal repo.

Thanks

@qianfengz qianfengz closed this Feb 8, 2024
@qianfengz qianfengz deleted the dev_to_upstream branch February 8, 2024 05:19
@HinaHyugaHime
Copy link

was this scrapped or something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants