forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] Use flash-attn for decoding (vllm-project#3648)
Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]>
- Loading branch information
1 parent
ce532ff
commit 1356df5
Showing
6 changed files
with
313 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import pytest | ||
import torch | ||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache | ||
|
||
NUM_HEADS = [(16, 16), (32, 8), (64, 8)] | ||
HEAD_SIZES = [128, 256] | ||
BLOCK_SIZES = [16, 32] | ||
DTYPES = [torch.float16, torch.bfloat16] | ||
|
||
|
||
def ref_paged_attn( | ||
query: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
query_lens: List[int], | ||
kv_lens: List[int], | ||
block_tables: torch.Tensor, | ||
scale: float, | ||
sliding_window: Optional[int] = None, | ||
) -> torch.Tensor: | ||
num_seqs = len(query_lens) | ||
block_tables = block_tables.cpu().numpy() | ||
_, block_size, num_kv_heads, head_size = key_cache.shape | ||
|
||
outputs = [] | ||
start_idx = 0 | ||
for i in range(num_seqs): | ||
query_len = query_lens[i] | ||
kv_len = kv_lens[i] | ||
q = query[start_idx:start_idx + query_len] | ||
q *= scale | ||
|
||
num_kv_blocks = (kv_len + block_size - 1) // block_size | ||
block_indices = block_tables[i, :num_kv_blocks] | ||
|
||
k = key_cache[block_indices].view(-1, num_kv_heads, head_size) | ||
k = k[:kv_len] | ||
v = value_cache[block_indices].view(-1, num_kv_heads, head_size) | ||
v = v[:kv_len] | ||
|
||
if q.shape[1] != k.shape[1]: | ||
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) | ||
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) | ||
attn = torch.einsum("qhd,khd->hqk", q, k).float() | ||
empty_mask = torch.ones(query_len, kv_len) | ||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() | ||
if sliding_window is not None: | ||
sliding_window_mask = torch.triu(empty_mask, | ||
diagonal=kv_len - | ||
(query_len + sliding_window) + | ||
1).bool().logical_not() | ||
mask |= sliding_window_mask | ||
attn.masked_fill_(mask, float("-inf")) | ||
attn = torch.softmax(attn, dim=-1).to(v.dtype) | ||
out = torch.einsum("hqk,khd->qhd", attn, v) | ||
|
||
outputs.append(out) | ||
start_idx += query_len | ||
|
||
return torch.cat(outputs, dim=0) | ||
|
||
|
||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("block_size", BLOCK_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode | ||
def test_flash_attn_with_paged_kv( | ||
kv_lens: List[Tuple[int, int]], | ||
num_heads: Tuple[int, int], | ||
head_size: int, | ||
dtype: torch.dtype, | ||
block_size: int, | ||
) -> None: | ||
torch.set_default_device("cuda") | ||
torch.cuda.manual_seed_all(0) | ||
num_blocks = 128 | ||
num_seqs = len(kv_lens) | ||
num_query_heads = num_heads[0] | ||
num_kv_heads = num_heads[1] | ||
assert num_query_heads % num_kv_heads == 0 | ||
max_kv_len = max(kv_lens) | ||
scale = head_size**-0.5 | ||
|
||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) | ||
key_cache = torch.randn(num_blocks, | ||
block_size, | ||
num_kv_heads, | ||
head_size, | ||
dtype=dtype) | ||
value_cache = torch.randn_like(key_cache) | ||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) | ||
|
||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | ||
block_tables = torch.randint(0, | ||
num_blocks, | ||
(num_seqs, max_num_blocks_per_seq), | ||
dtype=torch.int32) | ||
|
||
output = flash_attn_with_kvcache( | ||
q=query.unsqueeze(1), | ||
k_cache=key_cache, | ||
v_cache=value_cache, | ||
softmax_scale=scale, | ||
causal=True, | ||
block_table=block_tables, | ||
cache_seqlens=kv_lens_tensor, | ||
).squeeze(1) | ||
|
||
ref_output = ref_paged_attn( | ||
query=query, | ||
key_cache=key_cache, | ||
value_cache=value_cache, | ||
query_lens=[1] * num_seqs, | ||
kv_lens=kv_lens, | ||
block_tables=block_tables, | ||
scale=scale, | ||
) | ||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ | ||
f"{torch.max(torch.abs(output - ref_output))}" | ||
|
||
|
||
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("block_size", BLOCK_SIZES) | ||
@pytest.mark.parametrize("sliding_window", [None]) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode | ||
def test_varlen_with_paged_kv( | ||
seq_lens: List[Tuple[int, int]], | ||
num_heads: Tuple[int, int], | ||
head_size: int, | ||
sliding_window: Optional[int], | ||
dtype: torch.dtype, | ||
block_size: int, | ||
) -> None: | ||
torch.set_default_device("cuda") | ||
torch.cuda.manual_seed_all(0) | ||
num_blocks = 128 | ||
num_seqs = len(seq_lens) | ||
query_lens = [x[0] for x in seq_lens] | ||
kv_lens = [x[1] for x in seq_lens] | ||
num_query_heads = num_heads[0] | ||
num_kv_heads = num_heads[1] | ||
assert num_query_heads % num_kv_heads == 0 | ||
max_query_len = max(query_lens) | ||
max_kv_len = max(kv_lens) | ||
window_size = ((sliding_window, | ||
sliding_window) if sliding_window is not None else | ||
(-1, -1)) | ||
scale = head_size**-0.5 | ||
|
||
query = torch.randn(sum(query_lens), | ||
num_query_heads, | ||
head_size, | ||
dtype=dtype) | ||
key_cache = torch.randn(num_blocks, | ||
block_size, | ||
num_kv_heads, | ||
head_size, | ||
dtype=dtype) | ||
value_cache = torch.randn_like(key_cache) | ||
# Normalize the scale of the key and value caches to mitigate | ||
# numerical instability. | ||
key_cache /= head_size**0.5 | ||
value_cache /= head_size**0.5 | ||
cu_query_lens = torch.tensor([0] + query_lens, | ||
dtype=torch.int32).cumsum(dim=0, | ||
dtype=torch.int32) | ||
cu_kv_lens = torch.tensor([0] + kv_lens, | ||
dtype=torch.int32).cumsum(dim=0, | ||
dtype=torch.int32) | ||
|
||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | ||
block_tables = torch.randint(0, | ||
num_blocks, | ||
(num_seqs, max_num_blocks_per_seq), | ||
dtype=torch.int32) | ||
|
||
output = flash_attn_varlen_func( | ||
q=query, | ||
k=key_cache, | ||
v=value_cache, | ||
cu_seqlens_q=cu_query_lens, | ||
cu_seqlens_k=cu_kv_lens, | ||
max_seqlen_q=max_query_len, | ||
max_seqlen_k=max_kv_len, | ||
softmax_scale=scale, | ||
causal=True, | ||
window_size=window_size, | ||
block_table=block_tables, | ||
) | ||
|
||
ref_output = ref_paged_attn( | ||
query=query, | ||
key_cache=key_cache, | ||
value_cache=value_cache, | ||
query_lens=query_lens, | ||
kv_lens=kv_lens, | ||
block_tables=block_tables, | ||
scale=scale, | ||
sliding_window=sliding_window, | ||
) | ||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ | ||
f"{torch.max(torch.abs(output - ref_output))}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.