Skip to content

Commit

Permalink
Refactor the test code for attention kernels (vllm-project#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 30, 2023
1 parent 0b05fa5 commit 54bcc0f
Showing 1 changed file with 53 additions and 19 deletions.
72 changes: 53 additions & 19 deletions tests/kernels/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from typing import Optional
from typing import List, Optional

from flash_attn.flash_attention import FlashAttention
import torch
Expand Down Expand Up @@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)


def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)

num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx

# Create attention mask
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')

ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output


def test_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
Expand Down Expand Up @@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
causal=True,
)[0]

ref_outputs = []
for i, seq_len in enumerate(seq_lens):
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)

cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


@torch.inference_mode()
def test_attention() -> None:
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
Expand All @@ -193,6 +225,8 @@ def test_attention() -> None:
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
num_seqs=11,
num_heads=3,
Expand All @@ -202,4 +236,4 @@ def test_attention() -> None:


if __name__ == '__main__':
test_attention()
test_attention(seed=0)

0 comments on commit 54bcc0f

Please sign in to comment.