-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Backward Performance and Navi31 Support #39
Conversation
…rmance is testing the right code.
(The current kernel does not perform any bounday checks)
No performance loss for case fused-attention-batch512-head32-d64-bwd-causal=False
db part is still based on block pointers.
Also remove --out_name and make --out_path (-o) mandatory
1. Handle compiler errors (or technically it's Exception thrown by Triton) in compile.py, and record it as Exception. Meanwhile record non-zero exitcode of the compiler process as 'ExitWithError' * When error occurs, an zero sized hsaco file will be written as placeholder 2. cpp_autotune reports non-existing kernels (either it's due to timeout or compiler errors) 3. Add option AOTRITON_GPU_BUILD_TIMEOUT to cmake build system, a non-zero value will enable compiler timeout support. 4. TritonKernel class now detects zero sized image so autotune process can skip this kernel. In this case hipErrorInvalidImage is returned. 5. add gen_autotune_configs() to backward kernels. Notably enumerating num_warps=1 or 2.
Only this patch are applied to my PyTorch v2.4.0: And some changes to I just recompiled several times and confirmed that my new tuning records are causing the noise issue in Stable DIffusion 1.5, with only The image ends up with pure noise if any Test script: import torch
for N_CTX in [1024, 4096, 4097]:
query = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
key = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
value = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
r1 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
r2 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)
atol = 0.0125
rtol = 0
print(f'N_CTX={N_CTX} allclose={torch.allclose(r1, r2, atol=atol, rtol=rtol)}')
print(f'r1={r1[0][0][0][:4]}')
print(f'r2={r2[0][0][0][:4]}')
print('') Output:
This is my |
Yes, I just ran SD15 up to 2048x2048 which peaks at seqlen 36864 on the outer layers and both FLASH and EFFICIENT attentions were completely stable. PixArt is still the only model I managed to break with EFFICIENT attention enabled on the stock tunes. One of my kernels timed out during compilation, I wonder if that's related? Interestingly enabling either FLASH or EFFICIENT drops my SD15 speed by well over 30%, similar to the DiT models. Maybe it's more that SDXL is unusually fast with the new kernels?
Losses apply to all resolutions, though the all SDPA + howiejay combo might be faster for big upscale jobs from not needing vae tiling. Update: The philox branch was merged. Do I dare rebuild..? |
No idea if these changes optimize performance for Navi 31. In fact, I am going to do some Frankenstein things by combining these to create a pip package: |
I'll let it build while I'm at the store tomorrow. I'll increase the kernel build timeout too since in theory the ones torch pulls are already validated if I'm understanding aotriton's MO correctly.
A lot of the autotunes in aotriton/tritonsrc can page fault or even reset your gpu. I didn't find it worth tinkering with personally. If you do, build triton from master to hopefully have new compiler fixes.
|
Previously I made AOTriton's ROCm/flash-attention@main_perf can provide an widely used interface for kernels written in Triton, that is what I am interested in. Regarding the performance, I think it's more of a Triton compiler work, but if ROCm/flash-attention@main_perf performs better, the AOTriton's might be able to achieve that too. |
May I ask how you use ROCm/flash-attention@main_perf? It looks blazing fast at first glance, but I notice there is a hardcoded layout ( |
I believe the transpose is negligible. At one point I made a custom Diffusers attention processor and it didn't seem any faster than just monkey patching sdpa so I removed it. If you checkout my quickif repo to 3f832df6fccb7488ad7ed203d1dcadd820548965 before I removed the attention processors there's a file containing a diffusers attention processor for Flash Attention that works with both the howiejay and triton branches. Quickdif has the sdpa monkey patch on a flag so you can easily compare speeds. Back when I first found main_perf I built triton from source and it was something like 3.4 it/s at SDXL 1024. Whether it's changes to upstream triton or the kvpacked branch being merged to main_perf it does seem slower now. I also tried using main_perf on llama at some point but it must have an auto tune for sequence length which causes a kernel fetch for every new token effectively making it unusable. |
I ran a benchmark script: #!/usr/bin/env python
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
import pytest
import torch
import torch.nn.backends
import torch.nn.backends.thnn
import triton
TEST_TRITON = False
TEST_FLASH = True
TEST_FLASH_TRITON = True
TEST_TORCH = True
TEST_TORCH_MATH = True
# BATCH, N_HEADS, D_HEAD = 4, 32, 64
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX, D_HEAD = 2, 8, 4096, 128
# BATCH, N_HEADS, N_CTX, D_HEAD = 32, 32, 1024, 32
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
# for causal in [False, True]:
for causal in [False]:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
# lower to allow torch sdpa to pass the benchmark
# x_vals=[2**i for i in range(10, 15)] if not TEST_TORCH_MATH else [2**i for i in range(8, 15)],
x_vals=[
77, 256,
1024,
4096, 8192,
9216,
],
line_arg='provider',
line_vals=(['triton'] if TEST_TRITON else []) + (['flash'] if TEST_FLASH else []) + (['flash-triton'] if TEST_FLASH_TRITON else []) + (['torch'] if TEST_TORCH else []) + (['torch-math'] if TEST_TORCH_MATH else []),
line_names=(['Triton'] if TEST_TRITON else []) + ([f'Flash'] if TEST_FLASH else []) + ([f'Flash Triton'] if TEST_FLASH_TRITON else []) + (['Torch'] if TEST_TORCH else []) + (['Torch Math'] if TEST_TORCH_MATH else []),
styles=[('red', '-'), ('orange', '-'), ('yellow', '-'), ('green', '-'), ('blue', '-'), ('indigo', '-'), ('violet', '-')],
ylabel='flops',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal,
})
)
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
print(f"{N_CTX=}")
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
split_kernel = False
requires_grad=True if mode == 'bwd' else False
# Bwd pass only supports causal=True right now
if mode == 'bwd':
split_kernel = True if causal else split_kernel
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
return_encoded_softmax = False
autotune = True
return_autotune = True
fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune, return_autotune)[0]
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
from flash_attn import flash_attn_func
# transpose is needed because metadata.layout is set to bshd
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash-triton":
from flash_attn_rocm import flash_attn_func
# transpose is needed because metadata.layout is set to bshd
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "torch":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "torch-math":
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
b = None
sm_scale = 1.3
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path='.', print_data=True) RX 7900 XTX (WSL, Ubuntu 22.04, ROCm 6.1.3):
RTX 3090fused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Torch Torch Math 0 77.0 3.351398 3.307981 2.140681 1 256.0 17.674617 20.356401 14.038236 2 1024.0 55.155258 53.944516 29.788766 3 4096.0 66.690431 65.878503 31.758032 4 8192.0 70.919531 71.069738 29.774254 5 9216.0 70.319972 69.594797 29.259435 RTX 4090 Dfused-attention-batch2-head8-d128-fwd-causal=False: N_CTX Flash Torch Torch Math 0 77.0 4.211498 4.174075 2.549500 1 256.0 23.667371 24.601858 17.446561 2 1024.0 73.363291 72.073351 44.259957 3 4096.0 129.738712 127.396387 53.866077 4 8192.0 146.490523 143.630641 54.301182 5 9216.0 136.064218 133.215636 53.720898 |
Is that using a 2.5 nightly wheel? Those have severe performance regressions for me. On my patched 2.4 I got
Notice in particular the MATH backend. |
Yes. The branch to merge has already updated |
Depending on the model that's actually true. I think it was PixArt or something that actually performs slightly better with math than with the ck flash ignoring the memory usage. SDXL is still 30-50% faster with CK flash though. I'm pretty sure the navi flash was made with SDXL specifically in mind because that's what AMD's ck tune benchmarks were using at the time. |
Have you tried https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal? The performance drops to the level of the CK-based one after transposing. However, it might the only one that implements a backward pass while maintaining good performance I guess? |
That one sketched me out because it's the first time I've ever gotten an extreme content warning on GitHub before accessing a repository, on GPU kernels of all things... Right now I just use the howiejay CK flash with the aotriton 07 as a fallback for Diffusers. On the occasion I want to monkey with an LLM I just use llama.cpp's built-in FA which also supports ROCm. In theory the aotriton 07 kernels should work for exl2 but if the DiT performance is anything to go by it won't be fast by any measure. Shame they dont' have discussions open anywhere so this info can be more accessible |
@evshiron I bet most of the massive torch 2.5 performance regression is from pytorch/pytorch#128922 |
Hello @evshiron I am a novice in all of this and have seen that you seem to understand this a whole lot better than i am. I use mainly stable diffusion with my rx7900xtx and am wondering if you could tell me what version of flash attention would be the best for vram usage. I know that there is a ck, aotriton and wcmma version of flash attention but i dont understand the graphs so i dont know which would be best for my case. Could you please help me out that would be very kind. |
Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.
Generating images with Stable Diffusion doesn't require a backward pass implementation, so I will recommend CK-based ones:
|
I think it does still affect the MiOpen tuning that happens automatically. First time using a resolution/batch size on torch 2.5/roc62 takes probably 3x as long compared to 2.4/roc61, even if the flash kernels are available. For large models or dimensions it's easily multiple minutes of overhead for me. You can see it yourself by clearing If it's not that PR I'm not sure what else it'd be. Nightly torch is a mess for ROCm right now.
Wonder if it'd be worth consolidating the discussion everything relating to the different AOTriton, CK, and upstream Triton flash implementations somewhere with open threads like https://github.com/ROCm/ROCm/discussions/ I did a brief overview of CK flash a while ago at huggingface/diffusers#7172 |
@evshiron It's made it into 2.5.0-rc1. Maybe one of us should open an issue because won't that effect literally everyone with hardware f16 that's not NVIDIA? |
I am currently using this branch: Which is PyTorch 2.4.0 and has AOTriton updated. And I locally made a Flash Attention with different (currently two) backends. For the same configuration we have tested, the results are:
And the numbers for ROCm/flash-attention@howiejay/navi_support, taken from previous comment:
|
@evshiron I wonder how they scale across different architectures? None of the implementations are super optimized so they swing wildly. Like for SDXL howiejay is the fastest by +20%, for SD15 the triton attention absolutely bombed by like -70%, for one of the DiT's (pixart or hunyuan?) I think Math was still the fastest as long as it didn't OOM. Eventually I just gave up and added CLI parameters to my app for runtime setting SDPA backend and adding flash attention monkey patches... I really don't want to build torch from source a 12th time so I'll wait for Meta to figure out what they're doing with the 2.5 sdpa math casting before doing significant monkeying myself. I don't know if you've seen but the |
The performance varies between FA implementations. If you are seeking for best performance, routing to different backends based on shapes and parameters might be a good approach, which is planned for my Flash Attention library, but I don't know if it's worth it. The biggest problem for Triton (including AOTriton) is its performance for AMD GPUs. Currently the performance of Triton matmul is about 70% of hipBLAS on RX 7900 XTX1, and Flash Attention performance is even worse2. For CDNA GPUs like MI250/MI300, the performance of AOTriton is about 70% of the CK one too3. A year and a half after purchasing the RX 7900 XTX, I begin to wonder whether the performance improvement of Flash Attention on RX 7900 XTX could be as significant as it was on RTX 30904 (or MI250 lol). Regardless, the VRAM usage does go down. Footnotes
|
When llama 3 came out one of my friends borked their venv and didn't have Flash Attention. On Exllama2 which compiles torch c extensions using the wheel bundled cuda/rocm compilers, my XTX actually outperformed his 3090. The GPUs themselves are completely capable it's just the software holding them back. It's a Sisyphean situation where all the cutting edge projects are built optimized to CUDA specifically so ROCm is perpetually in catch-up mode. I think that's why they're taking the approach of using Triton everywhere they can to reduce the amount of places they need to optimize kernels. That said, yes it seems upstream Triton is particularly slow right now. I remember having a Triton flash running over 90% as fast as the CK branch in SDXL, but I can't appear to reproduce that right now. Maybe I was using the ROCm fork? It was definitely unstable though, I had to be really picky with the autotune configs to have it not cause GPU resets. Edit: Navi configs might be worth looking at ROCm/triton#640 |
Yeah. I believe the raw performance is comparable to RTX 3090 and RTX 4080, and we do have projects like MLC to achieve higher CPR than RTX 4090. In my opinion, LLM applications like llama.cpp (Ollama) is already satisfying on RX 7900 XTX. Diffusion models are more VRAM hungry and having a FA implementation with good performance would really benefit.
I am using the official repo of Triton and I believe most developments for AMD happen there now. I haven't experienced any auto-tuning crashes recently. Maybe it's a difference on WSL? Though the performance remains ordinary across different tuning configurations (including the mentioned ones). |
The XTX should by every metric be faster than a 3090. I think the vector/ai/whatever flops was a bit over 3090 ti.
Out of curiosity I just built
I haven't had any resets recently either. This was quite a few months ago. Though the original |
It's hard to tell. AOTriton is using an old & custom version of Triton for kernel generation, which might include some specified optimizations, while the official repo of Triton are updated frequently, and bumping the LLVM version might also affect the performance. As you can see from previous benchmarks, few of those implementations are able to constantly win when shapes of inputs change. |
To clarify, the customization of Triton is mainly about removing CUDA bits which only increases the download/build time. |
What's Changed
cpptune
/cpp_tune
/cpptuning
) based on pre-compiling all GPU kernels with CMake optionAOTRITON_BUILD_FOR_TUNING
and kernel selection parameters provided by all AOTriton APIpkg-config
to search zstd sincefind_package(zstd)
is not supported officially.Known problems
This fixes #16