Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Use flash-attn for decoding #3648

Merged
merged 98 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
bbf023b
vendor flash-attention
skrider Mar 8, 2024
38d422d
update vendored flash-attention
skrider Mar 27, 2024
00dc8ad
add reshape_and_cache_flash
skrider Mar 8, 2024
4ffe256
refactor reshape_and_cache_flash
skrider Mar 8, 2024
6a2ddf4
refactor reshape_and_cache_flash
skrider Mar 8, 2024
0e45f5d
implement flash attention decode backend
skrider Mar 27, 2024
45c3662
FlashAttentionDecode -> FlashAttention
WoosukKwon Mar 27, 2024
18132d2
Remove gitmodule
WoosukKwon Mar 27, 2024
3cc5ebd
Minor
WoosukKwon Mar 27, 2024
70f6b16
Minor
WoosukKwon Mar 27, 2024
1ebf12e
Remove submodule
WoosukKwon Mar 27, 2024
8a209ff
Use prefix-enabled attention
WoosukKwon Mar 28, 2024
31f741d
Disable flash-attn backend
WoosukKwon Mar 28, 2024
70efebb
Minor
WoosukKwon Mar 28, 2024
56b78e6
Remove __ldg for AMD portability
WoosukKwon Mar 28, 2024
f119396
Remove assert
WoosukKwon Mar 28, 2024
b6a1833
Add causal=True
WoosukKwon Mar 28, 2024
da50678
Enable when vllm_flash_attn
WoosukKwon Mar 28, 2024
6d5b4ec
Merge branch 'main' into flash-attention-decode
WoosukKwon Mar 28, 2024
37cb5a9
Add vllm-flash-attn as dependency
WoosukKwon Mar 28, 2024
df13824
Fix prefix attention
WoosukKwon Mar 28, 2024
9a02294
Fix prefix attention
WoosukKwon Mar 28, 2024
4553846
add test
LiuXiaoxuanPKU Mar 28, 2024
2669902
Merge branch 'flash-attention-decode' of github.com:skrider/vllm into…
LiuXiaoxuanPKU Mar 28, 2024
ff19304
fix
LiuXiaoxuanPKU Mar 28, 2024
dbeeb8a
fix test
skrider Apr 2, 2024
f489dee
Merge branch 'main' into flash-attention-decode
WoosukKwon Apr 22, 2024
358886f
Add vllm-flash-attn to requirements-cuda
WoosukKwon Apr 22, 2024
0013859
Minor
WoosukKwon Apr 22, 2024
45cb2d6
Fix
WoosukKwon Apr 22, 2024
2800f2e
yapf
WoosukKwon Apr 22, 2024
0fe3ec5
isort
WoosukKwon Apr 22, 2024
977afb6
Fix
WoosukKwon Apr 22, 2024
c55627a
Fix
WoosukKwon Apr 22, 2024
d7767ab
Fix
WoosukKwon May 6, 2024
f9d200b
Add test
WoosukKwon May 6, 2024
5b5cfae
Fix test
WoosukKwon May 6, 2024
738f7fc
Upgrade vllm-flash-attn
WoosukKwon May 6, 2024
2c0984f
Merge branch 'main' into flash-attention-decode
WoosukKwon May 6, 2024
b45cfb6
Minor
WoosukKwon May 6, 2024
60648fd
Fix
WoosukKwon May 6, 2024
eeb8050
Fix
WoosukKwon May 7, 2024
5bbd2d3
Remove flash-attn from dockerfile
WoosukKwon May 7, 2024
b72cd13
Add test for flash_attn_with_kv_cache
WoosukKwon May 7, 2024
4230040
Bump up vllm-flash-attn
WoosukKwon May 7, 2024
9cb226f
Merge branch 'main' into flash-attention-decode
WoosukKwon May 7, 2024
7cd9b73
Handle FP8 KV cache
WoosukKwon May 7, 2024
5370f86
Add docstring
WoosukKwon May 7, 2024
2f5b9b7
Fix
WoosukKwon May 7, 2024
4b05153
Fix
WoosukKwon May 7, 2024
848a1d7
Fix
WoosukKwon May 7, 2024
9c11e15
Merge branch 'main' into flash-attention-decode
WoosukKwon May 8, 2024
1f5baf4
Merge branch 'main' into flash-attention-decode
WoosukKwon May 8, 2024
cf86a48
Merge branch 'main' into flash-attention-decode
WoosukKwon May 9, 2024
d6996c1
Set block size from beginning
WoosukKwon May 9, 2024
6b45dfb
Remove model runner from test_sampler
WoosukKwon May 9, 2024
6137ad4
[Misc] Remove unnecessary ModelRunner import
WoosukKwon May 9, 2024
7569137
Fix test_logits_processor
WoosukKwon May 9, 2024
81be0af
Merge branch 'remove-model-runner' into set-block-size
WoosukKwon May 9, 2024
6bcf10f
Fix test_model_runner
WoosukKwon May 9, 2024
9c5a51d
Merge branch 'main' into set-block-size
WoosukKwon May 9, 2024
c5aaee2
Merge branch 'set-block-size' into flash-attention-decode
WoosukKwon May 9, 2024
6a70f87
Merge branch 'main' into attn-layer
WoosukKwon May 10, 2024
9092bb4
Enhance attention backend selector
WoosukKwon May 10, 2024
7a10755
Rever flash-attn
WoosukKwon May 10, 2024
0359113
Remove test
WoosukKwon May 10, 2024
21945e3
Enhance attention selector
WoosukKwon May 10, 2024
72d5155
Fix
WoosukKwon May 10, 2024
c49d015
Revert
WoosukKwon May 10, 2024
8a8bb1c
Fix CPU
WoosukKwon May 10, 2024
1ff4fbd
Fix
WoosukKwon May 10, 2024
49938d8
Merge branch 'attn-selector' into flash-attention-decode
WoosukKwon May 10, 2024
f1ebae5
Merge branch 'attn-layer' into flash-attention-decode
WoosukKwon May 10, 2024
950bc82
Fix
WoosukKwon May 10, 2024
adf545a
Fix
WoosukKwon May 10, 2024
250eac4
Fix CPU
WoosukKwon May 11, 2024
974a4f8
Fix CPU
WoosukKwon May 11, 2024
8a629e5
Fix
WoosukKwon May 11, 2024
d622b3e
Fix
WoosukKwon May 11, 2024
d27c139
Fix Llama
WoosukKwon May 11, 2024
e4fa494
Fix
WoosukKwon May 11, 2024
e2a4ba0
yapf
WoosukKwon May 11, 2024
ee71445
Update models
WoosukKwon May 11, 2024
974ed4d
Fix
WoosukKwon May 11, 2024
ec72063
Fix
WoosukKwon May 11, 2024
180acaa
Fix
WoosukKwon May 11, 2024
4a19d96
Merge branch 'main' into attn-selector
WoosukKwon May 13, 2024
1c2ad0a
Add comment
WoosukKwon May 13, 2024
8cfb402
Remove kv_cache_dtype
WoosukKwon May 13, 2024
7c23fe3
Merge branch 'attn-selector' into flash-attention-decode
WoosukKwon May 13, 2024
9a1b8f3
yapf
WoosukKwon May 13, 2024
304c9e3
Sliding window
WoosukKwon May 13, 2024
5789be5
Merge branch 'main' into flash-attention-decode
WoosukKwon May 13, 2024
1be2eb3
yapf
WoosukKwon May 13, 2024
d544611
Use fp32 in ref attn softmax
WoosukKwon May 13, 2024
ddd9e35
Fix broken tests
WoosukKwon May 13, 2024
7e0da78
Address comments
WoosukKwon May 13, 2024
cd22037
Fix CI
WoosukKwon May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import List, Optional, Tuple

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]


def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]

k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]

if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
mask = torch.triu(torch.ones(query_len, kv_len),
diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(torch.ones(query_len, kv_len),
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)

outputs.append(out)
start_idx += query_len

return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
).squeeze(1)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
Loading
Loading