From e8cc7967ff8a6f8432747a9e87ab451d36e1ff57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Thu, 18 Apr 2024 00:51:28 -0700 Subject: [PATCH] [Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (#4128) --- tests/kernels/test_prefix_prefill.py | 2 +- vllm/attention/ops/prefix_prefill.py | 44 +++++++++++++++++----------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 6494fb34af98f..ad31b0a7c2a19 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -10,7 +10,7 @@ NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128] +HEAD_SIZES = [128, 96] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70f09224f1cf6..4896cf3909c6e 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -47,7 +47,8 @@ def _fwd_kernel( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -59,26 +60,30 @@ def _fwd_kernel( cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len block_start_loc = BLOCK_M * start_m # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -99,7 +104,8 @@ def _fwd_kernel( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -126,7 +132,8 @@ def _fwd_kernel( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -142,16 +149,15 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -179,8 +185,8 @@ def _fwd_kernel( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), other=0.0) p = p.to(v.dtype) @@ -195,7 +201,8 @@ def _fwd_kernel( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) return @triton.jit @@ -636,7 +643,8 @@ def context_attention_fwd(q, # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = 2**((Lk - 1).bit_length()) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -646,6 +654,7 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: + assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -738,6 +747,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1,