Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Backward Performance and Navi31 Support #39

Merged
merged 66 commits into from
Aug 14, 2024
Merged

Conversation

xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Aug 5, 2024

What's Changed

  1. A whole new tuning system (referred as cpptune/cpp_tune/cpptuning) based on pre-compiling all GPU kernels with CMake option AOTRITON_BUILD_FOR_TUNING and kernel selection parameters provided by all AOTriton API
  2. GPU kernel compiling can timeout (default limit is 8 minutes), to avoid excessive long Navi31 kernel build
  3. Migrating the backward kernel away from block pointers
  4. Improved backward kernel performance by using better tuning database generated from cpptune.
  5. Add Navi31 to tuning database
  6. Enable Navi31 by default
  7. Default to AOTRITON_COMPRESS_KERNEL=ON and consequently requires zstd as runtime dependency
  8. Use pkg-config to search zstd since find_package(zstd) is not supported officially.

Known problems

  1. No official Navi32 support. Users may want to duplicate Navi31 tuning database entries to accomplish Navi 32 support in AOTriton.

This fixes #16

(The current kernel does not perform any bounday checks)
No performance loss for case fused-attention-batch512-head32-d64-bwd-causal=False
db part is still based on block pointers.
Also remove --out_name and make --out_path (-o) mandatory
1. Handle compiler errors (or technically it's Exception thrown by
   Triton) in compile.py, and record it as Exception. Meanwhile record
   non-zero exitcode of the compiler process as 'ExitWithError'
    * When error occurs, an zero sized hsaco file will be written as
      placeholder
2. cpp_autotune reports non-existing kernels (either it's due to timeout
   or compiler errors)
3. Add option AOTRITON_GPU_BUILD_TIMEOUT to cmake build system, a
   non-zero value will enable compiler timeout support.
4. TritonKernel class now detects zero sized image so autotune process
   can skip this kernel. In this case hipErrorInvalidImage is returned.
5. add gen_autotune_configs() to backward kernels. Notably enumerating
   num_warps=1 or 2.
@evshiron
Copy link

@Beinsezii

Only this patch are applied to my PyTorch v2.4.0:

And some changes to cmake/External/aotriton.cmake to use my repo of AOTriton.

I just recompiled several times and confirmed that my new tuning records are causing the noise issue in Stable DIffusion 1.5, with only SDPBackend.FLASH_ATTENTION and SDPBackend.Math enabled:

1

2

The image ends up with pure noise if any N_CTX > 4096 in the computation.

Test script:

import torch

for N_CTX in [1024, 4096, 4097]:
  query = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
  key = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)
  value = torch.randn(2, 8, N_CTX, 40, device='cuda:0', dtype=torch.float16)

  with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
      r1 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)

  with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]):
      r2 = torch.nn.functional.scaled_dot_product_attention(query=query, key=key, value=value)

  atol = 0.0125
  rtol = 0
  print(f'N_CTX={N_CTX} allclose={torch.allclose(r1, r2, atol=atol, rtol=rtol)}')
  print(f'r1={r1[0][0][0][:4]}')
  print(f'r2={r2[0][0][0][:4]}')
  print('')

Output:

N_CTX=1024 allclose=True
r1=tensor([-0.0812, -0.0159, -0.0928, -0.0147], device='cuda:0',
       dtype=torch.float16)
r2=tensor([-0.0812, -0.0159, -0.0929, -0.0147], device='cuda:0',
       dtype=torch.float16)

N_CTX=4096 allclose=True
r1=tensor([-0.0171, -0.0166, -0.0460,  0.0142], device='cuda:0',
       dtype=torch.float16)
r2=tensor([-0.0171, -0.0166, -0.0460,  0.0142], device='cuda:0',
       dtype=torch.float16)

N_CTX=4097 allclose=False
r1=tensor([ 0.0519,  0.0227,  0.0270, -0.0340], device='cuda:0',
       dtype=torch.float16)
r2=tensor([1.5318e-05, 1.5318e-05, 1.5318e-05, 1.5318e-05], device='cuda:0',
       dtype=torch.float16)

This is my tuning_database.json if you are interested:

tuning_database.zip

@Beinsezii
Copy link

Beinsezii commented Aug 23, 2024

I just recompiled several times and confirmed that my new tuning records are causing the noise issue in Stable DIffusion 1.5, with only SDPBackend.FLASH_ATTENTION and SDPBackend.Math enabled

Yes, I just ran SD15 up to 2048x2048 which peaks at seqlen 36864 on the outer layers and both FLASH and EFFICIENT attentions were completely stable. PixArt is still the only model I managed to break with EFFICIENT attention enabled on the stock tunes. One of my kernels timed out during compilation, I wonder if that's related?

Interestingly enabling either FLASH or EFFICIENT drops my SD15 speed by well over 30%, similar to the DiT models. Maybe it's more that SDXL is unusually fast with the new kernels?

attention speed @ 512
Math 19.1 it/s
Efficient or Flash 13.9 it/s
Math + howiejay/navi_support 23.8 it/s
All SDPA + howiejay/navi_support 21.6 it/s

Losses apply to all resolutions, though the all SDPA + howiejay combo might be faster for big upscale jobs from not needing vae tiling.

Update: The philox branch was merged. Do I dare rebuild..?

@evshiron
Copy link

The philox branch was merged. Do I dare rebuild..?

No idea if these changes optimize performance for Navi 31.

In fact, I am going to do some Frankenstein things by combining these to create a pip package:

@Beinsezii
Copy link

No idea if these changes optimize performance for Navi 31.

I'll let it build while I'm at the store tomorrow. I'll increase the kernel build timeout too since in theory the ones torch pulls are already validated if I'm understanding aotriton's MO correctly.

In fact, I am going to do some Frankenstein things by combining these to create a pip package:

A lot of the autotunes in aotriton/tritonsrc can page fault or even reset your gpu. I didn't find it worth tinkering with personally. If you do, build triton from master to hopefully have new compiler fixes.

main_perf is interesting. It does pretty well for diffusion and almost works for llama but sometimes it likes to barf out bad tokens. I don't think it's intended for end-use as its got lots of debug printouts lol. It's actually a lot faster than navi's torch sdpa as well. I think it's the closest thing to a truly cross-platform flash attention that exists right now. I haven't tried backwards yet but since it's triton it should work?

@evshiron
Copy link

@Beinsezii

Previously I made AOTriton's tritonsrc work in SD:Next by replacing existing Navi 31 hack with a custom adapter to employ Triton's auto tuning. The result is a bit slower than the Math impl, but AOTriton's impl is more complete than others written in Triton.

ROCm/flash-attention@main_perf can provide an widely used interface for kernels written in Triton, that is what I am interested in.

Regarding the performance, I think it's more of a Triton compiler work, but if ROCm/flash-attention@main_perf performs better, the AOTriton's might be able to achieve that too.

@evshiron
Copy link

evshiron commented Sep 4, 2024

@Beinsezii

May I ask how you use ROCm/flash-attention@main_perf?

It looks blazing fast at first glance, but I notice there is a hardcoded layout (input_metadata.layout = "bshd") and I have to transpose(1, 2) before replacing SDPA with it, then the performance becomes ordinary.

@Beinsezii
Copy link

Beinsezii commented Sep 4, 2024

@evshiron

I believe the transpose is negligible. At one point I made a custom Diffusers attention processor and it didn't seem any faster than just monkey patching sdpa so I removed it.

If you checkout my quickif repo to 3f832df6fccb7488ad7ed203d1dcadd820548965 before I removed the attention processors there's a file containing a diffusers attention processor for Flash Attention that works with both the howiejay and triton branches. Quickdif has the sdpa monkey patch on a flag so you can easily compare speeds.

Back when I first found main_perf I built triton from source and it was something like 3.4 it/s at SDXL 1024. Whether it's changes to upstream triton or the kvpacked branch being merged to main_perf it does seem slower now.

I also tried using main_perf on llama at some point but it must have an auto tune for sequence length which causes a kernel fetch for every new token effectively making it unusable.

@evshiron
Copy link

evshiron commented Sep 5, 2024

I ran a benchmark script:

#!/usr/bin/env python
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import pytest
import torch

import torch.nn.backends
import torch.nn.backends.thnn
import triton

TEST_TRITON = False
TEST_FLASH = True
TEST_FLASH_TRITON = True
TEST_TORCH = True
TEST_TORCH_MATH = True

# BATCH, N_HEADS, D_HEAD = 4, 32, 64
# BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX, D_HEAD = 2, 8, 4096, 128
# BATCH, N_HEADS, N_CTX, D_HEAD = 32, 32, 1024, 32
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
    # for causal in [False, True]:
    for causal in [False]:
        configs.append(triton.testing.Benchmark(
            x_names=['N_CTX'],
            # lower to allow torch sdpa to pass the benchmark
            # x_vals=[2**i for i in range(10, 15)] if not TEST_TORCH_MATH else [2**i for i in range(8, 15)],
            x_vals=[
                77, 256,
                1024,
                4096, 8192,
                9216,
            ],
            line_arg='provider',
            line_vals=(['triton'] if TEST_TRITON else []) + (['flash'] if TEST_FLASH else []) + (['flash-triton'] if TEST_FLASH_TRITON else []) + (['torch'] if TEST_TORCH else []) + (['torch-math'] if TEST_TORCH_MATH else []),
            line_names=(['Triton'] if TEST_TRITON else []) + ([f'Flash'] if TEST_FLASH else []) + ([f'Flash Triton'] if TEST_FLASH_TRITON else []) + (['Torch'] if TEST_TORCH else []) + (['Torch Math'] if TEST_TORCH_MATH else []),
            styles=[('red', '-'), ('orange', '-'), ('yellow', '-'), ('green', '-'), ('blue', '-'), ('indigo', '-'), ('violet', '-')],
            ylabel='flops',
            plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
            args={
                'H': N_HEADS,
                'BATCH': BATCH,
                'D_HEAD': D_HEAD,
                'dtype': torch.float16,
                'mode': mode,
                'causal': causal,
                })
        )


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
    print(f"{N_CTX=}")
    assert mode in ['fwd', 'bwd']
    warmup = 25
    rep = 100
    split_kernel = False
    requires_grad=True if mode == 'bwd' else False
    # Bwd pass only supports causal=True right now
    if mode == 'bwd':
        split_kernel = True if causal else split_kernel
    if provider == "triton":
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        b = None
        sm_scale = 1.3
        return_encoded_softmax = False
        autotune = True
        return_autotune = True
        fn = lambda: attention(q, k, v, b, causal, sm_scale, split_kernel, return_encoded_softmax, autotune, return_autotune)[0]
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "flash":
        from flash_attn import flash_attn_func
        # transpose is needed because metadata.layout is set to bshd
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "flash-triton":
        from flash_attn_rocm import flash_attn_func
        # transpose is needed because metadata.layout is set to bshd
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad).transpose(1, 2)
        fn = lambda: flash_attn_func(q, k, v, causal=causal).transpose(1, 2)
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "torch":
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
        b = None
        sm_scale = 1.3
        fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
        if mode == 'bwd':
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "torch-math":
        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
            q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=requires_grad)
            b = None
            sm_scale = 1.3
            fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=causal, scale=sm_scale)
            if mode == 'bwd':
                o = fn()
                do = torch.randn_like(o)
                fn = lambda: o.backward(do, retain_graph=True)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD
    total_flops = 2 * flops_per_matmul
    if causal:
        total_flops *= 0.5
    if mode == 'bwd':
        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
    return total_flops / ms * 1e-9


# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path='.', print_data=True)

RX 7900 XTX (WSL, Ubuntu 22.04, ROCm 6.1.3):

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash  Flash Triton      Torch  Torch Math
0    77.0   0.589959      0.760981   0.464005    0.488695
1   256.0   5.817600      3.338110   4.426412    2.501373
2  1024.0  23.633466      4.769343   8.805251    6.957538
3  4096.0  31.725495      5.340588  10.019228    9.271866
4  8192.0  33.280974      5.433053  10.138188    9.397304
5  9216.0  29.864885      5.443405  10.197588    9.261768
RTX 3090
fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash      Torch  Torch Math
0    77.0   3.351398   3.307981    2.140681
1   256.0  17.674617  20.356401   14.038236
2  1024.0  55.155258  53.944516   29.788766
3  4096.0  66.690431  65.878503   31.758032
4  8192.0  70.919531  71.069738   29.774254
5  9216.0  70.319972  69.594797   29.259435
RTX 4090 D
fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX       Flash       Torch  Torch Math
0    77.0    4.211498    4.174075    2.549500
1   256.0   23.667371   24.601858   17.446561
2  1024.0   73.363291   72.073351   44.259957
3  4096.0  129.738712  127.396387   53.866077
4  8192.0  146.490523  143.630641   54.301182
5  9216.0  136.064218  133.215636   53.720898

@Beinsezii
Copy link

I ran a benchmark script:
{snip}
RX 7900 XTX (WSL, Ubuntu 22.04, ROCm 6.1.3):

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash  Flash Triton      Torch  Torch Math
0    77.0   0.589959      0.760981   0.464005    0.488695
1   256.0   5.817600      3.338110   4.426412    2.501373
2  1024.0  23.633466      4.769343   8.805251    6.957538
3  4096.0  31.725495      5.340588  10.019228    9.271866
4  8192.0  33.280974      5.433053  10.138188    9.397304
5  9216.0  29.864885      5.443405  10.197588    9.261768

Is that using a 2.5 nightly wheel? Those have severe performance regressions for me. On my patched 2.4 I got

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash     Torch  Torch Math
0    77.0   1.048788  0.582226    1.585659
1   256.0   8.206998  4.299669   11.105464
2  1024.0  22.326422  6.281815   23.962240
3  4096.0  28.528275  8.420959   30.622303
4  8192.0  30.290357  8.520634   36.732278
5  9216.0  28.182249  8.543251   33.984900

Notice in particular the MATH backend.

@evshiron
Copy link

evshiron commented Sep 6, 2024

@Beinsezii

Yes. The branch to merge has already updated version.txt to 2.5.0a0. But it's strange that it performs better that the CK-based one in your case, isn't it?

@Beinsezii
Copy link

Beinsezii commented Sep 6, 2024

Yes. The branch to merge has already updated version.txt to 2.5.0a0. But it's strange that it performs better that the CK-based one in your case, isn't it?

Depending on the model that's actually true. I think it was PixArt or something that actually performs slightly better with math than with the ck flash ignoring the memory usage. SDXL is still 30-50% faster with CK flash though.

I'm pretty sure the navi flash was made with SDXL specifically in mind because that's what AMD's ck tune benchmarks were using at the time.

@evshiron
Copy link

evshiron commented Sep 6, 2024

@Beinsezii

Have you tried https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal?

The performance drops to the level of the CK-based one after transposing. However, it might the only one that implements a backward pass while maintaining good performance I guess?

@Beinsezii
Copy link

Beinsezii commented Sep 6, 2024

Have you tried https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal?

That one sketched me out because it's the first time I've ever gotten an extreme content warning on GitHub before accessing a repository, on GPU kernels of all things...

Right now I just use the howiejay CK flash with the aotriton 07 as a fallback for Diffusers. On the occasion I want to monkey with an LLM I just use llama.cpp's built-in FA which also supports ROCm. In theory the aotriton 07 kernels should work for exl2 but if the DiT performance is anything to go by it won't be fast by any measure.

Shame they dont' have discussions open anywhere so this info can be more accessible

@Beinsezii
Copy link

@evshiron I bet most of the massive torch 2.5 performance regression is from pytorch/pytorch#128922
So in theory it shouldn't affect the AOTriton kernels.

@Kademo15
Copy link

Hello @evshiron I am a novice in all of this and have seen that you seem to understand this a whole lot better than i am. I use mainly stable diffusion with my rx7900xtx and am wondering if you could tell me what version of flash attention would be the best for vram usage. I know that there is a ck, aotriton and wcmma version of flash attention but i dont understand the graphs so i dont know which would be best for my case. Could you please help me out that would be very kind.

@evshiron
Copy link

@Beinsezii

I bet most of the massive torch 2.5 performance regression is from pytorch/pytorch#128922

Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.

@Kademo15

I use mainly stable diffusion with my rx7900xtx

Generating images with Stable Diffusion doesn't require a backward pass implementation, so I will recommend CK-based ones:

@Beinsezii
Copy link

Beinsezii commented Sep 10, 2024

I bet most of the massive torch 2.5 performance regression is from pytorch/pytorch#128922

Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.

I think it does still affect the MiOpen tuning that happens automatically. First time using a resolution/batch size on torch 2.5/roc62 takes probably 3x as long compared to 2.4/roc61, even if the flash kernels are available. For large models or dimensions it's easily multiple minutes of overhead for me. You can see it yourself by clearing ~/.config/miopen/ or whatever the WSL equivalent is.

If it's not that PR I'm not sure what else it'd be. Nightly torch is a mess for ROCm right now.

...so I will recommend...

Wonder if it'd be worth consolidating the discussion everything relating to the different AOTriton, CK, and upstream Triton flash implementations somewhere with open threads like https://github.com/ROCm/ROCm/discussions/

I did a brief overview of CK flash a while ago at huggingface/diffusers#7172

@Beinsezii
Copy link

Glad to know! The fp32 performance is kind of poor, but thankfully AOTriton is not affected. I hope it's fixed when PyTorch 2.5.0 is released.

@evshiron It's made it into 2.5.0-rc1. Maybe one of us should open an issue because won't that effect literally everyone with hardware f16 that's not NVIDIA?

@evshiron
Copy link

@Beinsezii

I am currently using this branch:

Which is PyTorch 2.4.0 and has AOTriton updated. And I locally made a Flash Attention with different (currently two) backends.

For the same configuration we have tested, the results are:

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX  Triton Flash Attention  Repeerc's Flash Attention  Torch - (AOTriton?)  Torch - Math
0    77.0                0.075390                   0.696945             0.519582      1.089998
1   256.0                0.929930                  11.258369             4.649597      9.788326
2  1024.0                7.258171                  26.519371             8.881320     26.203581
3  4096.0               12.973244                  33.084101            10.138325     34.008212
4  8192.0               13.660235                  32.934404            10.160260     38.635985
5  9216.0               13.553775                  33.048324            10.265818     33.515431

And the numbers for ROCm/flash-attention@howiejay/navi_support, taken from previous comment:

fused-attention-batch2-head8-d128-fwd-causal=False:
    N_CTX      Flash
0    77.0   0.589959
1   256.0   5.817600
2  1024.0  23.633466
3  4096.0  31.725495
4  8192.0  33.280974
5  9216.0  29.864885

@Beinsezii
Copy link

@evshiron I wonder how they scale across different architectures? None of the implementations are super optimized so they swing wildly. Like for SDXL howiejay is the fastest by +20%, for SD15 the triton attention absolutely bombed by like -70%, for one of the DiT's (pixart or hunyuan?) I think Math was still the fastest as long as it didn't OOM. Eventually I just gave up and added CLI parameters to my app for runtime setting SDPA backend and adding flash attention monkey patches...

I really don't want to build torch from source a 12th time so I'll wait for Meta to figure out what they're doing with the 2.5 sdpa math casting before doing significant monkeying myself.

I don't know if you've seen but the main_perf branch was cut into an upstream PR too. Looks like they cut backward again because it was slow, but seems like with further development Navi will be able to depend on the flash-attn somewhat normally.

@evshiron
Copy link

evshiron commented Sep 15, 2024

@Beinsezii

The performance varies between FA implementations. If you are seeking for best performance, routing to different backends based on shapes and parameters might be a good approach, which is planned for my Flash Attention library, but I don't know if it's worth it.

The biggest problem for Triton (including AOTriton) is its performance for AMD GPUs. Currently the performance of Triton matmul is about 70% of hipBLAS on RX 7900 XTX1, and Flash Attention performance is even worse2. For CDNA GPUs like MI250/MI300, the performance of AOTriton is about 70% of the CK one too3.

A year and a half after purchasing the RX 7900 XTX, I begin to wonder whether the performance improvement of Flash Attention on RX 7900 XTX could be as significant as it was on RTX 30904 (or MI250 lol). Regardless, the VRAM usage does go down.

Footnotes

  1. You can run Triton's matmul tutorial locally

  2. As we have tested for both Triton and AOTriton ones

  3. https://github.com/ROCm/flash-attention/issues/82#issuecomment-2340064215

  4. At the end of https://github.com/ROCm/aotriton/pull/39#issuecomment-2330669978

@Beinsezii
Copy link

Beinsezii commented Sep 16, 2024

A year and a half after purchasing the RX 7900 XTX, I begin to wonder whether the performance improvement of Flash Attention on RX 7900 XTX could be as significant as it was on RTX 3090

When llama 3 came out one of my friends borked their venv and didn't have Flash Attention. On Exllama2 which compiles torch c extensions using the wheel bundled cuda/rocm compilers, my XTX actually outperformed his 3090. The GPUs themselves are completely capable it's just the software holding them back. It's a Sisyphean situation where all the cutting edge projects are built optimized to CUDA specifically so ROCm is perpetually in catch-up mode. I think that's why they're taking the approach of using Triton everywhere they can to reduce the amount of places they need to optimize kernels.

That said, yes it seems upstream Triton is particularly slow right now. I remember having a Triton flash running over 90% as fast as the CK branch in SDXL, but I can't appear to reproduce that right now. Maybe I was using the ROCm fork? It was definitely unstable though, I had to be really picky with the autotune configs to have it not cause GPU resets.

Edit: Navi configs might be worth looking at ROCm/triton#640

@evshiron
Copy link

evshiron commented Sep 16, 2024

@Beinsezii

The GPUs themselves are completely capable it's just the software holding them back.

Yeah. I believe the raw performance is comparable to RTX 3090 and RTX 4080, and we do have projects like MLC to achieve higher CPR than RTX 4090. In my opinion, LLM applications like llama.cpp (Ollama) is already satisfying on RX 7900 XTX. Diffusion models are more VRAM hungry and having a FA implementation with good performance would really benefit.

Maybe I was using the ROCm fork? It was definitely unstable though, I had to be really picky with the autotune configs to have it not cause GPU resets.

I am using the official repo of Triton and I believe most developments for AMD happen there now. I haven't experienced any auto-tuning crashes recently. Maybe it's a difference on WSL? Though the performance remains ordinary across different tuning configurations (including the mentioned ones).

@Beinsezii
Copy link

@evshiron

Yeah. I believe the raw performance is at the same level as RTX 3090 and RTX 4080

The XTX should by every metric be faster than a 3090. I think the vector/ai/whatever flops was a bit over 3090 ti.

I am using the official repo of Triton and I believe most developments for AMD happen there now.

Out of curiosity I just built git+https://github.com/ROCm/triton.git@micmelesse/cache_fix#subdirectory=python and gained +6% in for SDXL over upstream triton-lang using the flash branch from Dao-AILab/flash-attention#1203. It ain't much but it's honest work. I think that puts it a bit above aotriton sdpa now?

I haven't experienced any auto-tuning crashes recently. Maybe it's a difference on WSL? Though the performance remains ordinary across different tuning configurations (including the mentioned ones).

I haven't had any resets recently either. This was quite a few months ago. Though the original triton.ops.flash_attention.flash_attn_func still causes a page fault.

@evshiron
Copy link

@Beinsezii

I think that puts it a bit above aotriton sdpa now?

It's hard to tell.

AOTriton is using an old & custom version of Triton for kernel generation, which might include some specified optimizations, while the official repo of Triton are updated frequently, and bumping the LLVM version might also affect the performance.

As you can see from previous benchmarks, few of those implementations are able to constantly win when shapes of inputs change.

@xinyazhang
Copy link
Collaborator Author

AOTriton is using an old & custom version of Triton for kernel generation, which might include some specified optimizations, while the official repo of Triton are updated frequently, and bumping the LLVM version might also affect the performance.

To clarify, the customization of Triton is mainly about removing CUDA bits which only increases the download/build time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Memory Efficient Flash Attention for gfx1100 (7900xtx)
6 participants