-
Notifications
You must be signed in to change notification settings - Fork 630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast Multi-ahead Attention support on AMD ROCM #978
Conversation
… numerics in reduction
…erformance on ck-tiled fmha
…a-pad-support branch
…l unit_tests passed
…inner_product bhalf_t overloading in ck_attention_forward_decoder.h
…bugging must go on
I think the related part is
But yeah, pastebin works better for such walls of text |
meaning I wouldnt be able to use the end result? |
Yes, meaning xformers is not yet supported on your device. Depending on your application, there may exist some other solution for using GPU acceleration on your device |
you have any recommendations to try? |
Now win8-build fails with
|
Is there any specific model you're trying to run? |
Pytorch SD |
You could try your luck with https://github.com/ROCm/AITemplate/tree/navi3_rel_ver_1.1/examples/05_stable_diffusion |
xformers/ops/fmha/common.py
Outdated
@@ -180,11 +180,13 @@ def validate_inputs(self) -> None: | |||
and self.value.shape == (B, Mkv, Kv) | |||
) | |||
H = self.query.shape[-2] | |||
Hkv = self.key.shape[-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current interface of MHA is that H has to always match between qkv. If you want to do GQA - e.g. one kv-head for every n q-heads, you have to send 5D inputs. (Thus we're forcing the user to be very explicit.) Do we really want to relax that rule in this PR?
@@ -22,13 +22,25 @@ Available implementations | |||
:member-order: bysource |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Looks like decoder and triton_splitk should have been added here months ago. 🫢)
rocm_only = pytest.mark.skipif( | ||
not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM" | ||
) | ||
disable_on_rocm = pytest.mark.skipif( | ||
not not torch.version.hip, reason="could not be done on ROCM" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand these two decorators? Is not not torch.version.hip
equivalent to torch.version.hip is not None
? And what does the rocm_only
condition really mean?
Also perhaps those new tests which only apply to AMD could be in a new file?
tests/test_mem_eff_attention.py
Outdated
"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask] | ||
) | ||
@pytest.mark.parametrize("op", [fmha.ck.FwOp]) | ||
def test_mqa_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To follow up on @bottler's question about @rocm_only
- it'd be better for fmha.ck.FwOp
to be covered by the generic test_forward
and test_mqa_decoding
. Then we don't need a separate test function (and eventually won't need @rocm_only
after all such cases are refactored)
Not sure if this should be blocking the merge or can be done as a follow-up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that - let's try to factor as much code as possible :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good! A few more comments on the PR - mostly to simplify the test file, but otherwise we could merge.
Let's not worry about the windows build on CI which is already broken on master (and not even building the ROCm stuff anyway)
tests/test_mem_eff_attention.py
Outdated
@@ -310,6 +318,185 @@ def T(t): | |||
return out.permute((0, 2, 1, 3)) | |||
|
|||
|
|||
def ref_attention_splitk_bmhk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this is useful for debugging the kernel, but then maybe include it as a standalone python file in the repo with the C++ files? It does not need to be part of the test file (which is already way too long! :) )
(same goes for the function below ref_attention_mqa
)
tests/test_mem_eff_attention.py
Outdated
"attn_bias_type", [type(None), torch.Tensor, fmha.attn_bias.LowerTriangularMask] | ||
) | ||
@pytest.mark.parametrize("op", [fmha.ck.FwOp]) | ||
def test_mqa_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that - let's try to factor as much code as possible :)
if op is fmha.triton.FwOp: | ||
pytest.skip("Triton Flash Attention 2 doesn't support backward pass yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I guess it's not supported on AMD because we don't have any backward pass. Fine for me to exclude it for NVIDIA as well.
tests/test_mem_eff_attention.py
Outdated
|
||
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, q, q)): | ||
pytest.skip("; ".join(skip_reasons)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This defeats the point of this test. None of the operators support these sort of strides.
What was the failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching! It got here as a refactor of skipping the test when op uses triton and python version is 3.8 or older, and I missed the context when refactoring. I think we can skip this check and the next one
tests/test_mem_eff_attention.py
Outdated
|
||
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, q, q)): | ||
pytest.skip("; ".join(skip_reasons)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
tests/test_mem_eff_attention.py
Outdated
@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1)]) | ||
@pytest.mark.parametrize("split_k", [1, 2, 4]) | ||
@pytest.mark.parametrize("device", ["cpu"]) | ||
def test_splitk_reference( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test should not be needed if we use always ref_attention
instead of ref_attention_splitk
... so users are forced to provide rank-5 inputs for mqa/gqa
roll back fmha/common.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM to merge in xFormers! Thanks a lot for the huge effort in making the entire codebase compatible with AMD GPUs - that was not an easy thing :o
Probably still a few things to improve, but we leave that to the future :)
> Triton does not support if expressions (ternary operators) with dynamic conditions, use if statements instead
…nto dev_upstream
OPS = [ | ||
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp), | ||
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp), | ||
# TODO: Triton is not stable: it can trigger Illegal Memory Accesses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this comment?
|
||
|
||
def _minimum_gemm_alignment(inp: Inputs) -> int: | ||
return 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For cuda/NV GPU, we have gemm alignment like https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L57-L60 . Wonder here why it is set to 1?
dan I dont think this pull was ready yet but ok |
@HinaHyugaHime this is working good enough for internal users, and we wanted to get this merged asap: given the size of the change, it would have been a nightmare to constantly rebase/merge new changes. |
This PR adds three flash-attention implementation for AMD ROCM
In more details, the following codes are added in this PR
The following scripts are used to verify the implementation
The following scripts are used to benchmark the performance of the implementation