diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py new file mode 100644 index 0000000000000..97516bd3052cf --- /dev/null +++ b/tests/kernels/conftest.py @@ -0,0 +1,43 @@ +from typing import List, Tuple + +import pytest +import torch + + +def create_kv_caches( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + seed: int, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device='cuda') + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device='cuda') + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches + + +@pytest.fixture() +def kv_cache_factory(): + return create_kv_caches diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index b4ddd3e5588aa..8aa35d2b2340f 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,20 +1,34 @@ +import pytest import torch import torch.nn.functional as F from transformers.activations import get_activation + from vllm import activation_ops +DTYPES = [torch.half, torch.bfloat16, torch.float] +NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing +D = [512, 4096, 5120, 13824] # Arbitrary values for testing +SEEDS = [0] + def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(chunks=2, dim=1) return F.silu(x1) * x2 +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_silu_and_mul( +def test_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, + seed: int, ) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') activation_ops.silu_and_mul(out, x) @@ -22,20 +36,19 @@ def run_silu_and_mul( assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) -def test_silu_and_mul() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - for num_tokens in [7, 83, 2048]: - for d in [512, 4096, 5120, 13824]: - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') - run_silu_and_mul(num_tokens, d, dtype) - - +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_gelu_new( +def test_gelu_new( num_tokens: int, d: int, dtype: torch.dtype, + seed: int, ) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') activation_ops.gelu_new(out, x) @@ -43,30 +56,20 @@ def run_gelu_new( assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) -def test_gelu_new() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - for num_tokens in [7, 83, 2048]: - for d in [512, 4096, 5120, 13824]: - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') - run_gelu_new(num_tokens, d, dtype) - - -@torch.inference_mode() -def run_gelu_fast( +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +def test_gelu_fast( num_tokens: int, d: int, dtype: torch.dtype, + seed: int, ) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') activation_ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) - - -def test_gelu_fast() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - for num_tokens in [7, 83, 2048]: - for d in [512, 4096, 5120, 13824]: - print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') - run_gelu_fast(num_tokens, d, dtype) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 452ac4c61853e..6a694f753f259 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -1,14 +1,24 @@ import random -from typing import List, Optional +from typing import List, Optional, Tuple +import pytest import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import attention_ops -MAX_SEQ_LEN = 4096 -TEST_SEED = 0 +MAX_SEQ_LEN = 8192 +NUM_BLOCKS = 128 # Arbitrary values for testing + +DTYPES = [torch.half, torch.bfloat16, torch.float] +NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +BLOCK_SIZES = [8, 16, 32] +USE_ALIBI = [False] # TODO(woosuk): Add USE_ALIBI=True +SEEDS = [0] def ref_masked_attention( @@ -18,29 +28,34 @@ def ref_masked_attention( scale: float, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - query = query * scale - attn = torch.einsum('qhd,khd->hqk', query, key) + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: - attn = attn + attn_mask - attn = torch.softmax(attn, dim=-1) - out = torch.einsum('hqk,khd->qhd', attn, value) + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) return out def ref_single_query_cached_kv_attention( output: torch.Tensor, query: torch.Tensor, + num_queries_per_kv: int, key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], ) -> None: - num_heads = value_cache.shape[1] + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] block_size = value_cache.shape[3] + num_seqs = query.shape[0] - num_input_tokens = query.shape[0] - for i in range(num_input_tokens): + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] context_len = int(context_lens[i]) @@ -52,170 +67,96 @@ def ref_single_query_cached_kv_attention( block_offset = j % block_size k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_heads, head_size) + k = k.reshape(num_kv_heads, head_size) keys.append(k) v = value_cache[block_number, :, :, block_offset] values.append(v) keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) - - scale = 1.0 / (head_size**0.5) - out = ref_masked_attention(q, keys, values, scale) - out = out.view(num_heads, head_size) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (context_len - position_ids).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) -def ref_multi_query_kv_attention( - cu_seq_lens: List[int], - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dtype: torch.dtype, -) -> torch.Tensor: - head_size = query.shape[-1] - scale = 1.0 / (head_size**0.5) - - num_seqs = len(cu_seq_lens) - 1 - ref_outputs = [] - for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx - - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device='cuda') - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -def ref_multi_query_cached_kv_attention( - cu_query_lens: List[int], - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - dtype: torch.dtype, -) -> torch.Tensor: - num_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - scale = 1.0 / (head_size**0.5) - - num_queries = len(cu_query_lens) - 1 - ref_outputs = [] - for i in range(num_queries): - start_idx = cu_query_lens[i] - end_idx = cu_query_lens[i + 1] - query_len = end_idx - start_idx - context_len = int(context_lens[i]) - block_table = block_tables[i] - - # Create attention mask - attn_mask = torch.triu(torch.ones(query_len, context_len), - diagonal=context_len - query_len + 1) * -1e5 - attn_mask = attn_mask.to(dtype=dtype, device='cuda') - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_heads, head_size) - keys.append(k) - - v = value_cache[block_number, :, :, block_offset] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - keys, - values, - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_single_query_cached_kv_attention( - num_tokens: int, - num_heads: int, +def test_single_query_cached_kv_attention( + kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], head_size: int, + use_alibi: bool, block_size: int, - num_blocks: int, dtype: torch.dtype, - num_kv_heads: int = None, + seed: int, ) -> None: - qkv = torch.empty(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - qkv.uniform_(-1e-3, 1e-3) - query, _, _ = qkv.unbind(dim=1) - - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_block_shape = (num_heads, head_size // x, block_size, x) - key_cache = torch.empty(size=(num_blocks, *key_block_shape), - dtype=dtype, - device='cuda') - key_cache.uniform_(-1e-3, 1e-3) - value_block_shape = (num_heads, head_size, block_size) - value_cache = torch.empty(size=(num_blocks, *value_block_shape), - dtype=dtype, - device='cuda') - value_cache.uniform_(-1e-3, 1e-3) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] - for _ in range(num_tokens): + for _ in range(num_seqs): block_table = [ - random.randint(0, num_blocks - 1) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') - head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") - scale = float(1.0 / (head_size**0.5)) - - num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert num_heads % num_kv_heads == 0 - num_queries_per_kv = num_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, dtype, + seed) + key_cache, value_cache = key_caches[0], value_caches[0] - output = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') + # Call the paged attention kernel. + output = torch.empty_like(query) attention_ops.single_query_cached_kv_attention( output, query, @@ -227,45 +168,98 @@ def run_single_query_cached_kv_attention( context_lens, block_size, max_context_len, - None, # ALiBi slopes. + alibi_slopes, ) + # Run the reference implementation. ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( ref_output, query, + num_queries_per_kv, key_cache, value_cache, block_tables, context_lens, + scale, + alibi_slopes, ) - # NOTE(woosuk): Due to the difference in the data types the two - # implementations use for attention softmax logits and accumulation, - # there is a small difference in the final outputs. - # We should use a relaxed tolerance for the test. + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_multi_query_kv_attention( +def test_multi_query_kv_attention( num_seqs: int, - num_heads: int, + num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, + seed: int, ) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) num_tokens = sum(seq_lens) scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads qkv = torch.empty(num_tokens, - 3, - num_heads, + num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype, - device='cuda') - qkv.uniform_(-1e-3, 1e-3) - query, key, value = qkv.unbind(dim=1) - + device="cuda") + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + num_queries_per_kv = num_query_heads // num_kv_heads + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) output = xops.memory_efficient_attention_forward( query.unsqueeze(0), @@ -285,40 +279,7 @@ def run_multi_query_kv_attention( query, key, value, + scale, dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - -def test_single_query_cached_kv_attention() -> None: - torch.random.manual_seed(TEST_SEED) - torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16, torch.float]: - for block_size in [8, 16, 32]: - for head_size in [64, 80, 96, 112, 128, 256]: - print(f'Testing single_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - run_single_query_cached_kv_attention( - num_tokens=37, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) - - -def test_multi_query_kv_attention() -> None: - torch.random.manual_seed(TEST_SEED) - torch.cuda.manual_seed(TEST_SEED) - for dtype in [torch.half, torch.bfloat16, torch.float]: - for head_size in [64, 80, 96, 112, 128, 256]: - print(f'Testing multi_query_kv_attention with dtype={dtype}, ' - f'head_size={head_size}') - run_multi_query_kv_attention( - num_seqs=5, - num_heads=3, - head_size=head_size, - dtype=dtype, - ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 6b309ce529ecc..cca037df235dc 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -1,12 +1,32 @@ import random +import pytest import torch from vllm import cache_ops - +DTYPES = [torch.half, torch.bfloat16, torch.float] +NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing +NUM_LAYERS = [5] # Arbitrary values for testing +NUM_HEADS = [8] # Arbitrary values for testing +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +BLOCK_SIZES = [8, 16, 32] +NUM_BLOCKS = [1024] # Arbitrary values for testing +NUM_MAPPINGS = [32, 256] # Arbitrary values for testing +SEEDS = [0] + + +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_copy_blocks( +def test_copy_blocks( + kv_cache_factory, num_mappings: int, num_layers: int, num_heads: int, @@ -14,48 +34,43 @@ def run_copy_blocks( block_size: int, num_blocks: int, dtype: torch.dtype, + seed: int, ) -> None: - # Generate random block mappings. + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Generate random block mappings where each source block is mapped to two + # destination blocks. + assert 2 * num_mappings <= num_blocks src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) - dst_blocks = random.sample(remainig_blocks, num_mappings) - block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)} - - # Create the KV cache. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches = [] - for _ in range(num_layers): - key_cache = torch.randn(size=key_cache_shape, - dtype=dtype, - device='cuda') - key_caches.append(key_cache) - cloned_key_caches = [] - for key_cache in key_caches: - cloned_key_caches.append(key_cache.clone()) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] - for _ in range(num_layers): - value_cache = torch.randn(size=value_cache_shape, - dtype=dtype, - device='cuda') - value_caches.append(value_cache) - cloned_value_caches = [] - for value_cache in value_caches: - cloned_value_caches.append(value_cache.clone()) + dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) + block_mapping = {} + for i in range(num_mappings): + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping[src] = [dst1, dst2] + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, block_size, + num_layers, num_heads, + head_size, dtype, seed) + + # Clone the KV caches. + cloned_key_caches = [key_cache.clone() for key_cache in key_caches] + cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. cache_ops.copy_blocks(key_caches, value_caches, block_mapping) - # Reference implementation. + # Run the reference implementation. for src, dsts in block_mapping.items(): for dst in dsts: - for key_cache, cloned_key_cache in zip(key_caches, - cloned_key_caches): + for cloned_key_cache in cloned_key_caches: cloned_key_cache[dst] = cloned_key_cache[src] - for value_cache, cloned_value_cache in zip(value_caches, - cloned_value_caches): + for cloned_value_cache in cloned_value_caches: cloned_value_cache[dst] = cloned_value_cache[src] # Compare the results. @@ -66,15 +81,29 @@ def run_copy_blocks( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_reshape_and_cache( +def test_reshape_and_cache( + kv_cache_factory, num_tokens: int, num_heads: int, head_size: int, block_size: int, num_blocks: int, dtype: torch.dtype, + seed: int, ) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') @@ -87,110 +116,31 @@ def run_reshape_and_cache( device='cuda') _, key, value = qkv.unbind(dim=1) - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') - cloned_key_cache = key_cache.clone() + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, + num_heads, head_size, dtype, + seed) + key_cache, value_cache = key_caches[0], value_caches[0] - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_cache = torch.randn(size=value_cache_shape, - dtype=dtype, - device='cuda') + # Clone the KV caches. + cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() + # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + # Run the reference implementation. + reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() for i in range(num_tokens): - reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) - block_idx = torch.div(slot_mapping[i], - block_size, - rounding_mode='floor') - block_offset = slot_mapping[i] % block_size + block_idx = block_indicies[i] + block_offset = block_offsets[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i] assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) - - -@torch.inference_mode() -def run_gather_cached_kv( - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, -) -> None: - num_slots = block_size * num_blocks - slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') - - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device='cuda') - _, key, value = qkv.unbind(dim=1) - - qkv_clone = qkv.clone() - _, cloned_key, cloned_value = qkv_clone.unbind(dim=1) - - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_cache = torch.randn(size=value_cache_shape, - dtype=dtype, - device='cuda') - - cache_ops.gather_cached_kv(key, value, key_cache, value_cache, - slot_mapping) - - # Reference implementation. - for i in range(num_tokens): - reshaped_key = cloned_key.reshape(num_tokens, num_heads, - head_size // x, x) - block_idx = torch.div(slot_mapping[i], - block_size, - rounding_mode='floor') - block_offset = slot_mapping[i] % block_size - reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] - cloned_value[i] = value_cache[block_idx, :, :, block_offset] - - assert torch.allclose(key, cloned_key) - assert torch.allclose(value, cloned_value) - - -def test_copy_blocks() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_copy_blocks(num_mappings=23, - num_layers=7, - num_heads=17, - head_size=16, - block_size=8, - num_blocks=1024, - dtype=dtype) - - -def test_reshape_and_cache() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_reshape_and_cache(num_tokens=3, - num_heads=2, - head_size=16, - block_size=8, - num_blocks=2, - dtype=dtype) - - -def test_gather_cached_kv() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - run_gather_cached_kv(num_tokens=3, - num_heads=2, - head_size=16, - block_size=8, - num_blocks=2, - dtype=dtype) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index b1309669e82e5..a63ef5cc76ffd 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -1,35 +1,50 @@ +import pytest import torch import torch.nn as nn from vllm import layernorm_ops +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] + class RefRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() weight = torch.empty(hidden_size) - weight.uniform_(-1e-3, 1e-3) + weight.normal_(mean=1.0, std=0.1) self.weight = nn.Parameter(weight) self.variance_epsilon = eps def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, - keepdim=True) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - return self.weight * hidden_states + return self.weight * hidden_states.to(input_dtype) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_rms_norm( +def test_rms_norm( num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, ) -> None: - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda') + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(hidden_size**-0.5) + x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x.uniform_(-scale, scale) ref = RefRMSNorm(hidden_size).to(dtype).cuda() out = torch.empty_like(x) @@ -40,17 +55,4 @@ def run_rms_norm( ref.variance_epsilon, ) ref_out = ref(x) - assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) - - -def test_rms_norm() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - for num_tokens in [7, 128, 2048]: - for hidden_size in [13, 64, 1024, 5120]: - print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' - f'{num_tokens}, hidden_size={hidden_size}') - run_rms_norm( - num_tokens=num_tokens, - hidden_size=hidden_size, - dtype=dtype, - ) + assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 99385baa2d623..d830b268d8bbc 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,11 +1,19 @@ -from typing import Tuple +from typing import Optional, Tuple +import pytest import torch import torch.nn as nn import torch.nn.functional as F from vllm import pos_encoding_ops +DTYPES = [torch.half, torch.bfloat16, torch.float] +HEAD_SIZES = [64, 80, 96, 112, 128, 256] +ROTARY_DIMS = [None, 32] # None means rotary dim == head size +NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing +NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing +SEEDS = [0] + def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., :x.shape[-1] // 2] @@ -74,16 +82,28 @@ def forward( return query, key +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def run_rotary_embedding_neox( +def test_rotary_embedding_neox( num_tokens: int, num_heads: int, head_size: int, - max_position: int, - rotary_dim: int, + rotary_dim: Optional[int], dtype: torch.dtype, + seed: int, + max_position: int = 8192, base: int = 10000, ) -> None: + if rotary_dim is None: + rotary_dim = head_size + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') query = torch.randn(num_tokens, num_heads * head_size, @@ -97,7 +117,7 @@ def run_rotary_embedding_neox( # Create the rotary embedding. inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) t = torch.arange(max_position).float() - freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) @@ -129,19 +149,5 @@ def run_rotary_embedding_neox( ref_key = ref_key.view(num_tokens, num_heads * head_size) # Compare the results. - assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) - assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) - - -def test_rotary_embedding_neox() -> None: - for dtype in [torch.half, torch.bfloat16, torch.float]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Running tests for head_size={head_size} and dtype={dtype}') - run_rotary_embedding_neox( - num_tokens=2145, - num_heads=5, - head_size=head_size, - max_position=8192, - rotary_dim=head_size, - dtype=dtype, - ) + assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)