Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Multi-ahead Attention support on AMD ROCM #978

Merged
merged 540 commits into from
Mar 4, 2024

Conversation

qianfengz
Copy link
Contributor

This PR adds three flash-attention implementation for AMD ROCM

  1. Generic FMHA forward based on composable_kernel kernel components accelerated on AMD MI2xx/MI3xx
  2. decoder FMHA forward directly implemented in HIP kernel
  3. triton FMHA forward operation based on triton

In more details, the following codes are added in this PR

  1. Xformers Operator and its C++ implementation for Generic FMHA forward as well as the underlying composable_kernel_tiled submodule

    xformers/ops/fmha/ck.py
    xformers/csrc/attention/hip_fmha/
    thirty_party/composable_kernel_tiled/

  2. Xformers Operator and its C++ implementation for decoder FMHA forward

    xformers/ops/fmha/ck_decoder.py, ck_splitk.py
    xformers/csrc/attention/hip_fmha/

  3. Xformers Operator for triton FMHA forward

    xformers/ops/fmha/triton.py

The following scripts are used to verify the implementation

#> pytest tests/test_mem_eff_attention.py::test_forward
#> pytest tests/test_mem_eff_attention.py::test_mqa_forwrd
#> pytest tests/test_mem_eff_attention.py::test_decoder
#> pytest tests/test_mem_eff_attention.py::test_splitk_decoder
#> pytest  tests/test_mem_eff_attention.py::test_splitk_reference
#> pytest tests/test_mem_eff_attention.py::test_triton_splitk_decoder

The following scripts are used to benchmark the performance of the implementation

#> python xformers/benchmarks/benchmark_mem_eff_attention.py
#> python xformers/benchmarks/benchmark_mem_eff_attention_mqa.py
#> python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
#> python xformers/benchmarks/benchmark_attn_decoding.py

tenpercent and others added 30 commits December 7, 2023 20:47
…inner_product bhalf_t overloading in ck_attention_forward_decoder.h
@tenpercent
Copy link
Contributor

should I just do a pastebin? there are quite a few errors like part 1 edit: side note: my default GFX isnt 1030 but its required for my gpu to work on rocm to be exported override to that, before on original the GFX had 2 persisting errors on part 2

I think the related part is

/home/hina/xformers/third_party/composable_kernel_tiled/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl_hip.hpp:116:17: error: '__builtin_amdgcn_mfma_f32_32x32x8bf16_1k' needs target feature mai-insts
            c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
                    ^
    1 error generated when compiling for gfx1030.

But yeah, pastebin works better for such walls of text
And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

@HinaHyugaHime
Copy link

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

@tenpercent
Copy link
Contributor

tenpercent commented Feb 26, 2024

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device

@HinaHyugaHime
Copy link

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device

you have any recommendations to try?

@tenpercent
Copy link
Contributor

Now win8-build fails with

C:/Users/runneradmin/AppData/Local/Temp/pip-req-build-o3otijw7/third_party/flash-attention/csrc/cutlass/include\cute/int_tuple.hpp(85): error C2665: 'cute::get': no overloaded function could convert all the argument types

@tenpercent
Copy link
Contributor

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device

you have any recommendations to try?

Is there any specific model you're trying to run?

@HinaHyugaHime
Copy link

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device

you have any recommendations to try?

Is there any specific model you're trying to run?

Pytorch SD

@tenpercent
Copy link
Contributor

But yeah, pastebin works better for such walls of text And it looks like we are using a compiler intrinsic __builtin_amdgcn_mfma_f32_32x32x8bf16_1k, and it cannot compile on gfx1030 as this arch doesn't support it

meaning I wouldnt be able to use the end result?

Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device

you have any recommendations to try?

Is there any specific model you're trying to run?

Pytorch SD

You could try your luck with https://github.com/ROCm/AITemplate/tree/navi3_rel_ver_1.1/examples/05_stable_diffusion

@@ -180,11 +180,13 @@ def validate_inputs(self) -> None:
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
Hkv = self.key.shape[-2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current interface of MHA is that H has to always match between qkv. If you want to do GQA - e.g. one kv-head for every n q-heads, you have to send 5D inputs. (Thus we're forcing the user to be very explicit.) Do we really want to relax that rule in this PR?

xformers/ops/fmha/triton.py Show resolved Hide resolved
@@ -22,13 +22,25 @@ Available implementations
:member-order: bysource
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Looks like decoder and triton_splitk should have been added here months ago. 🫢)

Comment on lines +29 to +34
rocm_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM"
)
disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
Copy link
Contributor

@bottler bottler Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand these two decorators? Is not not torch.version.hip equivalent to torch.version.hip is not None? And what does the rocm_only condition really mean?

Also perhaps those new tests which only apply to AMD could be in a new file?

"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask]
)
@pytest.mark.parametrize("op", [fmha.ck.FwOp])
def test_mqa_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up on @bottler's question about @rocm_only - it'd be better for fmha.ck.FwOp to be covered by the generic test_forward and test_mqa_decoding. Then we don't need a separate test function (and eventually won't need @rocm_only after all such cases are refactored)

Not sure if this should be blocking the merge or can be done as a follow-up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that - let's try to factor as much code as possible :)

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good! A few more comments on the PR - mostly to simplify the test file, but otherwise we could merge.
Let's not worry about the windows build on CI which is already broken on master (and not even building the ROCm stuff anyway)

@@ -310,6 +318,185 @@ def T(t):
return out.permute((0, 2, 1, 3))


def ref_attention_splitk_bmhk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this is useful for debugging the kernel, but then maybe include it as a standalone python file in the repo with the C++ files? It does not need to be part of the test file (which is already way too long! :) )
(same goes for the function below ref_attention_mqa)

"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask]
)
@pytest.mark.parametrize("op", [fmha.ck.FwOp])
def test_mqa_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that - let's try to factor as much code as possible :)

Comment on lines +1484 to +1485
if op is fmha.triton.FwOp:
pytest.skip("Triton Flash Attention 2 doesn't support backward pass yet")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess it's not supported on AMD because we don't have any backward pass. Fine for me to exclude it for NVIDIA as well.

Comment on lines 1562 to 1565

if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, q, q)):
pytest.skip("; ".join(skip_reasons))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defeats the point of this test. None of the operators support these sort of strides.
What was the failure?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching! It got here as a refactor of skipping the test when op uses triton and python version is 3.8 or older, and I missed the context when refactoring. I think we can skip this check and the next one

Comment on lines 1581 to 1584

if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, q, q)):
pytest.skip("; ".join(skip_reasons))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1)])
@pytest.mark.parametrize("split_k", [1, 2, 4])
@pytest.mark.parametrize("device", ["cpu"])
def test_splitk_reference(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test should not be needed if we use always ref_attention instead of ref_attention_splitk

tests/test_mem_eff_attention.py Show resolved Hide resolved
Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM to merge in xFormers! Thanks a lot for the huge effort in making the entire codebase compatible with AMD GPUs - that was not an easy thing :o
Probably still a few things to improve, but we leave that to the future :)

> Triton does not support if expressions (ternary operators) with dynamic conditions, use if statements instead
OPS = [
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
# TODO: Triton is not stable: it can trigger Illegal Memory Accesses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this comment?



def _minimum_gemm_alignment(inp: Inputs) -> int:
return 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For cuda/NV GPU, we have gemm alignment like https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L57-L60 . Wonder here why it is set to 1?

@danthe3rd danthe3rd merged commit 44b0d07 into facebookresearch:main Mar 4, 2024
27 of 32 checks passed
@HinaHyugaHime
Copy link

dan I dont think this pull was ready yet but ok

@danthe3rd
Copy link
Contributor

@HinaHyugaHime this is working good enough for internal users, and we wanted to get this merged asap: given the size of the change, it would have been a nightmare to constantly rebase/merge new changes.
xFormers still supports NVIDIA mainly, and AMD support is experimental / "best-effort". Feel free to open new issues and tag the people from this PR tho.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants