From 71bcaf99e2cb2c677bf3a9addb9e8039cbcab22a Mon Sep 17 00:00:00 2001 From: Tao He Date: Tue, 27 Feb 2024 17:14:31 +0800 Subject: [PATCH] Enable GQA support in the prefix prefill kernels (#3007) Signed-off-by: Tao He --- tests/kernels/test_prefix_prefill.py | 61 +++++++++++++------ vllm/model_executor/layers/attention.py | 34 ++++++----- .../layers/triton_kernel/prefix_prefill.py | 39 ++++++++---- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ac93b32588cca..c068b38a66910 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -8,7 +8,8 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -NUM_HEADS = [12] +NUM_HEADS = [64] +NUM_QUERIES_PER_KV = [1, 8, 64] HEAD_SIZES = [128] DTYPES = [torch.float16] CUDA_DEVICES = [ @@ -17,12 +18,14 @@ @pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, + num_queries_per_kv: int, head_size: int, dtype: torch.dtype, device: str, @@ -41,28 +44,29 @@ def test_contexted_kv_attention( subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(subquery_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype) + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) k_cache = torch.zeros(cache_size, block_size, - num_heads, + num_kv_heads, head_size, dtype=dtype) v_cache = torch.zeros(cache_size, block_size, - num_heads, + num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) + k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( @@ -93,19 +97,21 @@ def test_contexted_kv_attention( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8).permute(0, 2, 3, 1, 4).contiguous() # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_heads, + v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() # Warm up the Triton kernel by calling it once before actually measuring generation time @@ -123,12 +129,29 @@ def test_contexted_kv_attention( attn_op = xops.fmha.cutlass.FwOp() + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( subquery_lens, seq_lens) output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), + query, + key, + value, attn_bias=attn_bias, p=0.0, scale=scale, @@ -137,9 +160,9 @@ def test_contexted_kv_attention( torch.cuda.synchronize() start_time = time.time() output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), + query, + key, + value, attn_bias=attn_bias, p=0.0, scale=scale, @@ -148,5 +171,5 @@ def test_contexted_kv_attention( torch.cuda.synchronize() end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0) + output_ref = output_ref.squeeze(0, 2) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 0622a54db1bc0..2a82325b80213 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -137,25 +137,27 @@ def forward( ) if input_metadata.is_prompt: - # Prompt run. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index a1a2ab0c4805c..70f09224f1cf6 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -45,6 +45,7 @@ def _fwd_kernel( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -53,6 +54,8 @@ def _fwd_kernel( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) @@ -85,13 +88,14 @@ def _fwd_kernel( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -131,9 +135,9 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -232,6 +236,7 @@ def _fwd_kernel_flash_attn_v2( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -240,6 +245,8 @@ def _fwd_kernel_flash_attn_v2( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) @@ -272,13 +279,14 @@ def _fwd_kernel_flash_attn_v2( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -317,9 +325,9 @@ def _fwd_kernel_flash_attn_v2( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -420,6 +428,7 @@ def _fwd_kernel_alibi( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -429,6 +438,8 @@ def _fwd_kernel_alibi( cur_head = tl.program_id(1) start_m = tl.program_id(2) + cur_kv_head = cur_head // num_queries_per_kv + # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 @@ -468,13 +479,14 @@ def _fwd_kernel_alibi( mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) off_k = (bn[None, :] * stride_k_cache_bs + - cur_head * stride_k_cache_h + + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) off_v = ( - bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, @@ -522,9 +534,9 @@ def _fwd_kernel_alibi( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -628,6 +640,7 @@ def context_attention_fwd(q, sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, @@ -674,6 +687,7 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -721,6 +735,7 @@ def context_attention_fwd(q, v_cache.stride(2), v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK,