Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of flash attention for native webgpu ep #22932

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sushraja-msft
Copy link
Contributor

@sushraja-msft sushraja-msft commented Nov 24, 2024

Description

This change implements flash attention in webgpu, to improve prefill speed.
Perf numbers from Intel Alderlake device

Baseline MHA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       2.26746e+07
        avg (tokens/s): 22.0952              <<<
        p50 (us):       2.34637e+07
        stddev (us):    3.92912e+06
        n:              5 * 501 token(s)
Token generation:
        avg (us):       96519.8
        avg (tokens/s): 10.3606              <<<
        p50 (us):       98061.5
        stddev (us):    9220.87
        n:              635 * 1 token(s)

With FA

Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.69236e+07
        avg (tokens/s): 29.6036             <<<
        p50 (us):       1.63162e+07
        stddev (us):    960417
        n:              5 * 501 token(s)
Token generation:
        avg (us):       91436.7
        avg (tokens/s): 10.9365             <<<
        p50 (us):       90397.1
        stddev (us):    5349.19
        n:              635 * 1 token(s)

Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 30% speed up measured here.

This implementation also uses new webgpu feature subgroups to further accelerate attention computation.

  • Tested on Intel Alderlake (Subgroup Size 16) with Phi 3.5 mini.
  • Tested on Nvidia 2060 (Subgroup Size 32) with Phi 3.5 mini.
  • Tested with Lama 3.2 1B parameters, FlashAttention does not activate because past/present keys are always null. Needs investigation into the model to understand why this is the case.

Remaining work

  • Algorithm specialization for generation phase, here memory tiles for K/V can be removed because each K/V values are used just once creating more Shared memory space for larger tile size.
  • Algorithm specialization for no past KV case (prefill case). The CopyKVCache operation can likely be eliminated in this case, as there is no past KV values to copy over, new KV values can be copied to present KV as part of flash attention. PIX profiling shows CopyKVCache is almost as expensive as FlashAttention implementation. StaticKV cache will also eliminate this and result in more performance wins.

How to enable

Currently flash attention is off by default. To enable use
"provider_options": [
{
"webgpu": { "enableFlashAttention" : "1" }
}
]
in genai_config.json.

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

Comment on lines +121 to +125
wgpu::DawnExperimentalSubgroupLimits subgroup_limits;
device_supported_limits.nextInChain = &subgroup_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
device_limits_ = device_supported_limits.limits;
min_subgroup_size_ = subgroup_limits.minSubgroupSize;
Copy link
Contributor

@fs-eire fs-eire Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected behavior when subgroup is not supported?

I didn't test this code but it looks like it will abort. Would be better to disable features that use subgroup if it's not available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, on devices where subgroup is not supported, the minSubgroupSize will be 0 - AFAIK. This is because
https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/server/ServerAdapter.cpp#L129 - Inits the struct to 0.

https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/client/LimitsAndFeatures.cpp#L53 - reads from the struct when we call GetLimits. If no one set the Limits for this feature it will remain as 0.

Later in CanApplyFlashAttention in flash_attention.cc I only support FlashAttention if we have >= 16 subgroups. So I expect no crashes, just that FlashAttention will not be activated on those machines without subgroup support.

I dont have a machine without subgroups to test this though.

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@guschmue
Copy link
Contributor

Very cool, I can give it a test drive on some other gpu's and macos.

@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@sushraja-msft sushraja-msft force-pushed the user/sushraja/fa_attempt_3 branch from d1e442e to 0549374 Compare December 3, 2024 00:14
@guschmue
Copy link
Contributor

guschmue commented Dec 3, 2024

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

guschmue commented Dec 3, 2024

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

guschmue commented Dec 3, 2024

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@guschmue
Copy link
Contributor

guschmue commented Dec 3, 2024

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline

Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link

Azure Pipelines successfully started running 9 pipeline(s).

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

std::to_string(tile_size) +
std::to_string(parameters.head_size_) +
std::to_string(parameters.num_heads_);
program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,@sushraja-msft,I have two questions about apply the algorithm in Phi3 model,

  1. From generating the second tokens, sequence_length_ is always 1, so the DispatchGroupSize is always 32 (num_heads_), Because the kernels in the same block must be executed in the same EU(intel gpu) | SM (NV gpu), there are only 32 EUs | SMs used in this algorithm. If the driver arranges some blocks in one EU | SM, or num_heads_ is smaller in other models, there are fewer EUs | SMs used. Does it mean that the algorithm performance is the same running on a 32 EUs | SMs gpu and a 128 EUs | SMs gpu?
  2. The algorithm is mainly a loop to tile the KVs and compute the results. When generating the 256th token, the count of loop is 8(256/32(tile_size)). The count of loop is 16 for the 512th token, and 32 for the 1024th token. Does it means that generating the 1024th token costs 4 times time than generating the 256th token?

I also drafted a FlashAttention-V2 in JS EP, #22915, and want to apply it to SD model. I think FlashAttention[V1|V2] algorithm is used when sequence_length_ and kv_sequence_length_ are all big size. Because the Prompt token is 501 in the test, so we could see a good prefill speed. If only kv_sequence_length_ is large size, FlashAttention decoding may be better, some details is here https://crfm.stanford.edu/2023/10/12/flashdecoding.html

Copy link
Contributor Author

@sushraja-msft sushraja-msft Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xhcao - Thank you for contributing a FA2 implementation in WebGPU. I'll take a look at your implementation this weekend but getting FA to work in webgpu without subgroup support is a remarkable feat.

I think the point you are making in 1/2 are the same in that dispatch is not parallelizing beyond batch/head size. Yes, this is true and is a limitation in Flash Attention 1. This specific issue was addressed in FA 2- https://arxiv.org/pdf/[2307.08691](https://arxiv.org/pdf/2307.08691), Paragraph 3.2.

My goal with this change was to land changes incrementally, first land a basic implementation for MHA, then GQA then upgrade to FA2.

Finally, though I did observe the speed up earlier with this change - this week after a rebase baseline MHA got faster (MHA was refactored into attention.cc recently) - I now see FA being just 1tk/s faster than regular MHA for a sequence length of 501 tokens. This needs further investigation, that I am looking into.

Per operator performance numbers show that during prefill Attention is just 8% of the time spend - at 22 tk's each token takes 45ms on an Intel Alderlake device. Even if I have 0 cost attention 8% of 45ms will save 4ms and result in just 2tk/s speed up. Net, I am not sure what the authors meant by " This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving" in https://arxiv.org/pdf/[2307.08691](https://arxiv.org/pdf/2307.08691). Is it 2-4x total inference time speedup or attention wall clock time speedup. I am investigating to see if combined with static KV cache this change will show more performance wins before I pursue this change further.

Do let me know what perf improvements you are seeing with FA2 for stable diffusion.

Copy link
Contributor

@xhcao xhcao Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FA2 Paper, the author said in Paragraph 3.2,
We also parallelize over the batch dimension and number of heads dimension, as done in FlashAttention. The increased parallelism over sequence length helps improve occupancy (fraction of GPU resources being used) when the batch size and number of heads are small, leading to speedup in this case.
But for phi3 model, in decoding stage, the sequence length is 1, so the parallelizing issue is still not fixed. And in prefill stage, in your test, sequence length is 501, so there is no parallelizing issue here.
But if the Prompt tokens are 20, parallelizing issue also exists, if you test this situation, you may get an unexpected result.
you referred the conclusion in the paper " This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving", I think it was related with FA1, comparing with standard attention implementing by naive algorithm.
Not like language models, the sequence length in SD is always very large. Currently, upstream code cannot run SD 2.1 with attention operators, chrome will be out of memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants