-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of flash attention for native webgpu ep #22932
base: main
Are you sure you want to change the base?
Implementation of flash attention for native webgpu ep #22932
Conversation
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
Azure Pipelines successfully started running 2 pipeline(s). |
/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline |
Azure Pipelines successfully started running 4 pipeline(s). |
Azure Pipelines successfully started running 3 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
wgpu::DawnExperimentalSubgroupLimits subgroup_limits; | ||
device_supported_limits.nextInChain = &subgroup_limits; | ||
ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); | ||
device_limits_ = device_supported_limits.limits; | ||
min_subgroup_size_ = subgroup_limits.minSubgroupSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected behavior when subgroup is not supported?
I didn't test this code but it looks like it will abort. Would be better to disable features that use subgroup if it's not available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, on devices where subgroup is not supported, the minSubgroupSize will be 0 - AFAIK. This is because
https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/server/ServerAdapter.cpp#L129 - Inits the struct to 0.
https://github.com/google/dawn/blob/5d28e25927778b028473c4aa7af11fd5a5c9f76b/src/dawn/wire/client/LimitsAndFeatures.cpp#L53 - reads from the struct when we call GetLimits. If no one set the Limits for this feature it will remain as 0.
Later in CanApplyFlashAttention in flash_attention.cc I only support FlashAttention if we have >= 16 subgroups. So I expect no crashes, just that FlashAttention will not be activated on those machines without subgroup support.
I dont have a machine without subgroups to test this though.
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
Azure Pipelines successfully started running 2 pipeline(s). |
/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline |
Azure Pipelines successfully started running 4 pipeline(s). |
Azure Pipelines successfully started running 3 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
Very cool, I can give it a test drive on some other gpu's and macos. |
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
Azure Pipelines successfully started running 2 pipeline(s). |
/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline |
Azure Pipelines successfully started running 3 pipeline(s). |
Azure Pipelines successfully started running 4 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
d1e442e
to
0549374
Compare
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
Azure Pipelines successfully started running 2 pipeline(s). |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline |
Azure Pipelines successfully started running 4 pipeline(s). |
Azure Pipelines successfully started running 9 pipeline(s). |
Azure Pipelines successfully started running 3 pipeline(s). |
std::to_string(tile_size) + | ||
std::to_string(parameters.head_size_) + | ||
std::to_string(parameters.num_heads_); | ||
program.SetDispatchGroupSize(parameters.num_heads_, (parameters.sequence_length_ + tile_size - 1) / tile_size, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,@sushraja-msft,I have two questions about apply the algorithm in Phi3 model,
- From generating the second tokens,
sequence_length_
is always 1, so the DispatchGroupSize is always 32 (num_heads_
), Because the kernels in the same block must be executed in the same EU(intel gpu) | SM (NV gpu), there are only 32 EUs | SMs used in this algorithm. If the driver arranges some blocks in one EU | SM, ornum_heads_
is smaller in other models, there are fewer EUs | SMs used. Does it mean that the algorithm performance is the same running on a 32 EUs | SMs gpu and a 128 EUs | SMs gpu? - The algorithm is mainly a loop to tile the KVs and compute the results. When generating the 256th token, the count of loop is 8(256/32(tile_size)). The count of loop is 16 for the 512th token, and 32 for the 1024th token. Does it means that generating the 1024th token costs 4 times time than generating the 256th token?
I also drafted a FlashAttention-V2 in JS EP, #22915, and want to apply it to SD model. I think FlashAttention[V1|V2] algorithm is used when sequence_length_
and kv_sequence_length_
are all big size. Because the Prompt token
is 501 in the test, so we could see a good prefill speed. If only kv_sequence_length_
is large size, FlashAttention decoding
may be better, some details is here https://crfm.stanford.edu/2023/10/12/flashdecoding.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xhcao - Thank you for contributing a FA2 implementation in WebGPU. I'll take a look at your implementation this weekend but getting FA to work in webgpu without subgroup support is a remarkable feat.
I think the point you are making in 1/2 are the same in that dispatch is not parallelizing beyond batch/head size. Yes, this is true and is a limitation in Flash Attention 1. This specific issue was addressed in FA 2- https://arxiv.org/pdf/[2307.08691](https://arxiv.org/pdf/2307.08691), Paragraph 3.2.
My goal with this change was to land changes incrementally, first land a basic implementation for MHA, then GQA then upgrade to FA2.
Finally, though I did observe the speed up earlier with this change - this week after a rebase baseline MHA got faster (MHA was refactored into attention.cc recently) - I now see FA being just 1tk/s faster than regular MHA for a sequence length of 501 tokens. This needs further investigation, that I am looking into.
Per operator performance numbers show that during prefill Attention is just 8% of the time spend - at 22 tk's each token takes 45ms on an Intel Alderlake device. Even if I have 0 cost attention 8% of 45ms will save 4ms and result in just 2tk/s speed up. Net, I am not sure what the authors meant by " This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving" in https://arxiv.org/pdf/[2307.08691](https://arxiv.org/pdf/2307.08691). Is it 2-4x total inference time speedup or attention wall clock time speedup. I am investigating to see if combined with static KV cache this change will show more performance wins before I pursue this change further.
Do let me know what perf improvements you are seeing with FA2 for stable diffusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In FA2 Paper, the author said in Paragraph 3.2,
We also parallelize over the batch dimension and number of heads dimension, as done in FlashAttention. The increased parallelism over sequence length helps improve occupancy (fraction of GPU resources being used) when the batch size and number of heads are small, leading to speedup in this case.
But for phi3 model, in decoding stage, the sequence length is 1, so the parallelizing issue is still not fixed. And in prefill stage, in your test, sequence length is 501, so there is no parallelizing issue here.
But if the Prompt tokens are 20, parallelizing issue also exists, if you test this situation, you may get an unexpected result.
you referred the conclusion in the paper " This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving", I think it was related with FA1, comparing with standard attention implementing by naive algorithm.
Not like language models, the sequence length in SD is always very large. Currently, upstream code cannot run SD 2.1 with attention operators, chrome will be out of memory.
Description
This change implements flash attention in webgpu, to improve prefill speed.
Perf numbers from Intel Alderlake device
Baseline MHA
With FA
Motivation and Context
On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 30% speed up measured here.
This implementation also uses new webgpu feature subgroups to further accelerate attention computation.
Remaining work
How to enable
Currently flash attention is off by default. To enable use
"provider_options": [
{
"webgpu": { "enableFlashAttention" : "1" }
}
]
in genai_config.json.