Skip to content

Commit

Permalink
[Kernel] Use flash-attn for decoding (vllm-project#3648)
Browse files Browse the repository at this point in the history
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: LiuXiaoxuanPKU <[email protected]>
  • Loading branch information
3 people authored May 13, 2024
1 parent ce532ff commit 1356df5
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 65 deletions.
209 changes: 209 additions & 0 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from typing import List, Optional, Tuple

import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]


def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape

outputs = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale

num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]

k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]

if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)

outputs.append(out)
start_idx += query_len

return torch.cat(outputs, dim=0)


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
).squeeze(1)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
2 changes: 1 addition & 1 deletion tests/models/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# "Deci/DeciLM-7b", # Broken
# "tiiuae/falcon-7b", # Broken
"EleutherAI/gpt-j-6b",
"mosaicml/mpt-7b",
# "mosaicml/mpt-7b", # Broken
# "Qwen/Qwen1.5-0.5B" # Broken,
]

Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**'
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
],
"meta-llama/Meta-Llama-3-8B-Instruct": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
Expand Down
Loading

0 comments on commit 1356df5

Please sign in to comment.