Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Lean Attention #22352

Merged
merged 10 commits into from
Oct 14, 2024
Merged

[CUDA] Lean Attention #22352

merged 10 commits into from
Oct 14, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Oct 8, 2024

Description

Add Lean Attention and the integration with MultiHeadAttention operator for LLM in GPU.

LeanAttention speeds up self-attention for the token-generation phase (decode-phase) of decoder-only transformer models, especially on long context lengths.

  • Initial implementation of Lean Attention (by Srikant Bharadwaj)
  • Integration with MultiHeadAttention operator
  • Add parity tests
  • Add benchmark

Implementation Details

(1) Lean Attention is enabled in build for Linux, and disabled for Windows
(2) Lean Attention is disabled by default. Need enable it through cuda provider option sdpa_kernel, or use environment variable ORT_ENABLE_LEAN_ATTENTION=1
(3) It only works for token-generation (sequence_length==1, past_sequence_length > 0).
(4) Like flash attention, it only works in Ampere or newer GPU.

We can revisit #1 and #2 after comparing with DecoderMaskedMultiHeadAttention and XQA kernels.

Benchmark

cd onnxruntime/test/python/transformers 
/bin/bash benchmark_mha.sh lean

Example outputs in H100:

Note that past and present does not share buffer for MHA for now, so we can see low tflops. The relative ratio will change after buffer sharing is enabled. But we expect that the order (kernel A is faster than B) will remain the same after buffer sharing is enabled.

Note that common settings sequence_length=1; causal=True;attn_bias=None;cuda_graph=False are not shown in the below table.

batch_size past_sequence_length num_heads head_size average_latency tflops kernel
1 512 16 64 0.000059 0.0178 ort:flash
1 512 16 64 0.000068 0.0155 ort:efficient
1 512 16 64 0.000065 0.0161 ort:math
1 512 16 64 0.000060 0.0176 ort:lean
1 512 32 128 0.000062 0.0674 ort:flash
1 512 32 128 0.000064 0.0661 ort:efficient
1 512 32 128 0.000067 0.0625 ort:math
1 512 32 128 0.000062 0.0678 ort:lean
1 1024 16 64 0.000061 0.0345 ort:flash
1 1024 16 64 0.000086 0.0244 ort:efficient
1 1024 16 64 0.000065 0.0322 ort:math
1 1024 16 64 0.000063 0.0332 ort:lean
1 1024 32 128 0.000075 0.1125 ort:flash
1 1024 32 128 0.000088 0.0951 ort:efficient
1 1024 32 128 0.000079 0.1068 ort:math
1 1024 32 128 0.000072 0.1171 ort:lean
1 2048 16 64 0.000069 0.0606 ort:flash
1 2048 16 64 0.000125 0.0336 ort:efficient
1 2048 16 64 0.000064 0.0655 ort:lean
1 2048 32 128 0.000098 0.1720 ort:flash
1 2048 32 128 0.000132 0.1270 ort:efficient
1 2048 32 128 0.000092 0.1828 ort:lean
1 4096 16 64 0.000076 0.1097 ort:flash
1 4096 16 64 0.000207 0.0406 ort:efficient
1 4096 16 64 0.000069 0.1209 ort:lean
1 4096 32 128 0.000140 0.2394 ort:flash
1 4096 32 128 0.000213 0.1575 ort:efficient
1 4096 32 128 0.000139 0.2419 ort:lean
1 8192 16 64 0.000104 0.1609 ort:flash
1 8192 16 64 0.000392 0.0428 ort:efficient
1 8192 16 64 0.000093 0.1809 ort:lean
1 8192 32 128 0.000212 0.3160 ort:flash
1 8192 32 128 0.000360 0.1866 ort:efficient
1 8192 32 128 0.000212 0.3162 ort:lean
1 16384 16 64 0.000139 0.2410 ort:flash
1 16384 16 64 0.000731 0.0459 ort:efficient
1 16384 16 64 0.000136 0.2465 ort:lean
1 16384 32 128 0.000361 0.3722 ort:flash
1 16384 32 128 0.000667 0.2014 ort:efficient
1 16384 32 128 0.000357 0.3765 ort:lean
1 32768 16 64 0.000210 0.3194 ort:flash
1 32768 16 64 0.001428 0.0470 ort:efficient
1 32768 16 64 0.000209 0.3211 ort:lean
1 32768 32 128 0.000659 0.4074 ort:flash
1 32768 32 128 0.001270 0.2114 ort:efficient
1 32768 32 128 0.000651 0.4123 ort:lean
1 65536 16 64 0.000355 0.3785 ort:flash
1 65536 16 64 0.002736 0.0491 ort:efficient
1 65536 16 64 0.000349 0.3845 ort:lean
1 65536 32 128 0.001251 0.4290 ort:flash
1 65536 32 128 0.002480 0.2165 ort:efficient
1 65536 32 128 0.001239 0.4333 ort:lean
4 512 16 64 0.000063 0.0665 ort:flash
4 512 16 64 0.000069 0.0607 ort:efficient
4 512 16 64 0.000066 0.0634 ort:math
4 512 16 64 0.000062 0.0674 ort:lean
4 512 32 128 0.000100 0.1677 ort:flash
4 512 32 128 0.000099 0.1703 ort:efficient
4 512 32 128 0.000108 0.1557 ort:math
4 512 32 128 0.000092 0.1818 ort:lean
4 1024 16 64 0.000077 0.1094 ort:flash
4 1024 16 64 0.000099 0.0850 ort:efficient
4 1024 16 64 0.000081 0.1038 ort:math
4 1024 16 64 0.000072 0.1161 ort:lean
4 1024 32 128 0.000143 0.2343 ort:flash
4 1024 32 128 0.000137 0.2447 ort:efficient
4 1024 32 128 0.000150 0.2245 ort:math
4 1024 32 128 0.000135 0.2496 ort:lean
4 2048 16 64 0.000096 0.1757 ort:flash
4 2048 16 64 0.000156 0.1078 ort:efficient
4 2048 16 64 0.000089 0.1892 ort:lean
4 2048 32 128 0.000223 0.3010 ort:flash
4 2048 32 128 0.000217 0.3101 ort:efficient
4 2048 32 128 0.000209 0.3209 ort:lean
4 4096 16 64 0.000137 0.2448 ort:flash
4 4096 16 64 0.000256 0.1312 ort:efficient
4 4096 16 64 0.000133 0.2530 ort:lean
4 4096 32 128 0.000389 0.3450 ort:flash
4 4096 32 128 0.000376 0.3574 ort:efficient
4 4096 32 128 0.000354 0.3794 ort:lean
4 8192 16 64 0.000210 0.3198 ort:flash
4 8192 16 64 0.000453 0.1480 ort:efficient
4 8192 16 64 0.000206 0.3260 ort:lean
4 8192 32 128 0.000725 0.3705 ort:flash
4 8192 32 128 0.000693 0.3874 ort:efficient
4 8192 32 128 0.000653 0.4114 ort:lean
4 16384 16 64 0.000355 0.3782 ort:flash
4 16384 16 64 0.000849 0.1581 ort:efficient
4 16384 16 64 0.000346 0.3874 ort:lean
4 16384 32 128 0.001395 0.3848 ort:flash
4 16384 32 128 0.001337 0.4017 ort:efficient
4 16384 32 128 0.001252 0.4288 ort:lean
4 32768 16 64 0.000647 0.4146 ort:flash
4 32768 16 64 0.001649 0.1628 ort:efficient
4 32768 16 64 0.000639 0.4204 ort:lean
4 32768 32 128 0.002721 0.3947 ort:flash
4 32768 32 128 0.002601 0.4128 ort:efficient
4 32768 32 128 0.002434 0.4411 ort:lean
4 65536 16 64 0.001231 0.4361 ort:flash
4 65536 16 64 0.003238 0.1658 ort:efficient
4 65536 16 64 0.001217 0.4412 ort:lean
4 65536 32 128 0.005357 0.4009 ort:flash
4 65536 32 128 0.005118 0.4196 ort:efficient
4 65536 32 128 0.004781 0.4492 ort:lean
16 512 16 64 0.000098 0.1724 ort:flash
16 512 16 64 0.000104 0.1616 ort:efficient
16 512 16 64 0.000118 0.1420 ort:math
16 512 16 64 0.000087 0.1926 ort:lean
16 512 32 128 0.000220 0.3062 ort:flash
16 512 32 128 0.000208 0.3237 ort:efficient
16 512 32 128 0.000237 0.2838 ort:math
16 512 32 128 0.000209 0.3216 ort:lean
16 1024 16 64 0.000136 0.2465 ort:flash
16 1024 16 64 0.000150 0.2235 ort:efficient
16 1024 16 64 0.000148 0.2266 ort:math
16 1024 16 64 0.000129 0.2611 ort:lean
16 1024 32 128 0.000367 0.3663 ort:flash
16 1024 32 128 0.000351 0.3829 ort:efficient
16 1024 32 128 0.000400 0.3357 ort:math
16 1024 32 128 0.000349 0.3853 ort:lean
16 2048 16 64 0.000209 0.3206 ort:flash
16 2048 16 64 0.000243 0.2762 ort:efficient
16 2048 16 64 0.000201 0.3338 ort:lean
16 2048 32 128 0.000671 0.4002 ort:flash
16 2048 32 128 0.000645 0.4163 ort:efficient
16 2048 32 128 0.000642 0.4185 ort:lean
16 4096 16 64 0.000360 0.3732 ort:flash
16 4096 16 64 0.000425 0.3162 ort:efficient
16 4096 16 64 0.000341 0.3933 ort:lean
16 4096 32 128 0.001292 0.4156 ort:flash
16 4096 32 128 0.001251 0.4291 ort:efficient
16 4096 32 128 0.001241 0.4327 ort:lean
16 8192 16 64 0.000666 0.4030 ort:flash
16 8192 16 64 0.000804 0.3339 ort:efficient
16 8192 16 64 0.000627 0.4283 ort:lean
16 8192 32 128 0.002541 0.4226 ort:flash
16 8192 32 128 0.002454 0.4376 ort:efficient
16 8192 32 128 0.002438 0.4405 ort:lean
16 16384 16 64 0.001292 0.4156 ort:flash
16 16384 16 64 0.001571 0.3417 ort:efficient
16 16384 16 64 0.001217 0.4411 ort:lean
16 16384 32 128 0.005042 0.4260 ort:flash
16 16384 32 128 0.004859 0.4420 ort:efficient
16 16384 32 128 0.004827 0.4449 ort:lean
16 32768 16 64 0.002537 0.4233 ort:flash
16 32768 16 64 0.003103 0.3461 ort:efficient
16 32768 16 64 0.002385 0.4501 ort:lean
16 32768 32 128 0.009961 0.4312 ort:flash
16 32768 32 128 0.009605 0.4472 ort:efficient
16 32768 32 128 0.009524 0.4510 ort:lean
16 65536 16 64 0.005019 0.4279 ort:flash
16 65536 16 64 0.006133 0.3502 ort:efficient
16 65536 16 64 0.004703 0.4566 ort:lean
16 65536 32 128 0.019746 0.4350 ort:flash
16 65536 32 128 0.019027 0.4515 ort:efficient
16 65536 32 128 0.018864 0.4554 ort:lean

Motivation and Context

@tianleiwu tianleiwu marked this pull request as draft October 8, 2024 21:08
@tianleiwu tianleiwu marked this pull request as ready for review October 13, 2024 07:14
@tianleiwu tianleiwu merged commit de93f40 into main Oct 14, 2024
91 of 92 checks passed
@tianleiwu tianleiwu deleted the tlwu/lean_att branch October 14, 2024 21:49
guschmue pushed a commit that referenced this pull request Oct 18, 2024
### Description
Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the
integration with MultiHeadAttention operator for LLM in GPU.

LeanAttention speeds up self-attention for the token-generation phase
(decode-phase) of decoder-only transformer models, especially on long
context lengths.

- [x] Initial implementation of Lean Attention (by Srikant Bharadwaj)
- [x] Integration with MultiHeadAttention operator
- [x] Add parity tests
- [x] Add benchmark

#### Implementation Details

(1) Lean Attention is enabled in build for Linux, and disabled for
Windows
(2) Lean Attention is disabled by default. Need enable it through cuda
provider option sdpa_kernel, or use environment variable
`ORT_ENABLE_LEAN_ATTENTION=1`
(3) It only works for token-generation (sequence_length==1,
past_sequence_length > 0).
(4) Like flash attention, it only works in Ampere or newer GPU.

We can revisit #1 and #2 after comparing with
DecoderMaskedMultiHeadAttention and XQA kernels.

#### Benchmark

```
cd onnxruntime/test/python/transformers 
/bin/bash benchmark_mha.sh lean
```

Example outputs in H100:

Note that past and present does not share buffer for MHA for now, so we
can see low tflops. The relative ratio will change after buffer sharing
is enabled. But we expect that the order (kernel A is faster than B)
will remain the same after buffer sharing is enabled.

Note that common settings `sequence_length=1;
causal=True;attn_bias=None;cuda_graph=False` are not shown in the below
table.

batch_size | past_sequence_length | num_heads | head_size |
average_latency | tflops | kernel
-- | -- | -- | -- | -- | -- | --
1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash
1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient
1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math
1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean
1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash
1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient
1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math
1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean
1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash
1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient
1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math
1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean
1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash
1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient
1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math
1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean
1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash
1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient
1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean
1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash
1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient
1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean
1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash
1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient
1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean
1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash
1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient
1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean
1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash
1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient
1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean
1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash
1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient
1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean
1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash
1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient
1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean
1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash
1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient
1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean
1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash
1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient
1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean
1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash
1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient
1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean
1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash
1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient
1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean
1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash
1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient
1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean
4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash
4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient
4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math
4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean
4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash
4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient
4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math
4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean
4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash
4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient
4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math
4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean
4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash
4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient
4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math
4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean
4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash
4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient
4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean
4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash
4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient
4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean
4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash
4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient
4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean
4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash
4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient
4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean
4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash
4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient
4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean
4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash
4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient
4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean
4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash
4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient
4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean
4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash
4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient
4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean
4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash
4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient
4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean
4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash
4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient
4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean
4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash
4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient
4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean
4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash
4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient
4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean
16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash
16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient
16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math
16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean
16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash
16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient
16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math
16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean
16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash
16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient
16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math
16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean
16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash
16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient
16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math
16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean
16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash
16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient
16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean
16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash
16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient
16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean
16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash
16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient
16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean
16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash
16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient
16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean
16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash
16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient
16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean
16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash
16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient
16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean
16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash
16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient
16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean
16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash
16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient
16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean
16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash
16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient
16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean
16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash
16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient
16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean
16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash
16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient
16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean
16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash
16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient
16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
ishwar-raut1 pushed a commit to ishwar-raut1/onnxruntime that referenced this pull request Nov 19, 2024
### Description
Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the
integration with MultiHeadAttention operator for LLM in GPU.

LeanAttention speeds up self-attention for the token-generation phase
(decode-phase) of decoder-only transformer models, especially on long
context lengths.

- [x] Initial implementation of Lean Attention (by Srikant Bharadwaj)
- [x] Integration with MultiHeadAttention operator
- [x] Add parity tests
- [x] Add benchmark

#### Implementation Details

(1) Lean Attention is enabled in build for Linux, and disabled for
Windows
(2) Lean Attention is disabled by default. Need enable it through cuda
provider option sdpa_kernel, or use environment variable
`ORT_ENABLE_LEAN_ATTENTION=1`
(3) It only works for token-generation (sequence_length==1,
past_sequence_length > 0).
(4) Like flash attention, it only works in Ampere or newer GPU.

We can revisit microsoft#1 and microsoft#2 after comparing with
DecoderMaskedMultiHeadAttention and XQA kernels.

#### Benchmark

```
cd onnxruntime/test/python/transformers 
/bin/bash benchmark_mha.sh lean
```

Example outputs in H100:

Note that past and present does not share buffer for MHA for now, so we
can see low tflops. The relative ratio will change after buffer sharing
is enabled. But we expect that the order (kernel A is faster than B)
will remain the same after buffer sharing is enabled.

Note that common settings `sequence_length=1;
causal=True;attn_bias=None;cuda_graph=False` are not shown in the below
table.

batch_size | past_sequence_length | num_heads | head_size |
average_latency | tflops | kernel
-- | -- | -- | -- | -- | -- | --
1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash
1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient
1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math
1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean
1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash
1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient
1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math
1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean
1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash
1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient
1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math
1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean
1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash
1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient
1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math
1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean
1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash
1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient
1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean
1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash
1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient
1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean
1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash
1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient
1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean
1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash
1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient
1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean
1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash
1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient
1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean
1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash
1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient
1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean
1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash
1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient
1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean
1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash
1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient
1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean
1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash
1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient
1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean
1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash
1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient
1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean
1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash
1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient
1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean
1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash
1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient
1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean
4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash
4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient
4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math
4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean
4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash
4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient
4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math
4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean
4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash
4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient
4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math
4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean
4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash
4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient
4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math
4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean
4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash
4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient
4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean
4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash
4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient
4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean
4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash
4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient
4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean
4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash
4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient
4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean
4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash
4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient
4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean
4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash
4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient
4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean
4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash
4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient
4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean
4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash
4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient
4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean
4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash
4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient
4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean
4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash
4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient
4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean
4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash
4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient
4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean
4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash
4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient
4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean
16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash
16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient
16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math
16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean
16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash
16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient
16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math
16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean
16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash
16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient
16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math
16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean
16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash
16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient
16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math
16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean
16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash
16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient
16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean
16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash
16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient
16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean
16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash
16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient
16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean
16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash
16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient
16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean
16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash
16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient
16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean
16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash
16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient
16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean
16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash
16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient
16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean
16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash
16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient
16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean
16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash
16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient
16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean
16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash
16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient
16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean
16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash
16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient
16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean
16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash
16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient
16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants