diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..d754fc46a0167 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,12 +19,13 @@ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)\ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - + #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ diff --git a/requirements.txt b/requirements.txt index 05ec2e804e13b..6b46f778f57fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,6 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 pynvml == 11.5.0 -triton >= 2.1.0 +triton >= 2.2.0 outlines >= 0.0.27 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d8dc74bc7b003..0c56052c39213 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -92,9 +92,17 @@ def test_copy_blocks( # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + # NOTE: torch.allclose has not supported + # torch.fp8_e5m2/torch.fp8_e4m3fn dtypes. + if kv_cache_dtype == "fp8_e5m2": + key_cache = key_cache.view(torch.half) + cloned_key_cache = cloned_key_cache.view(torch.half) assert torch.allclose(key_cache, cloned_key_cache) for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): + if kv_cache_dtype == "fp8_e5m2": + value_cache = value_cache.view(torch.half) + cloned_value_cache = cloned_value_cache.view(torch.half) assert torch.allclose(value_cache, cloned_value_cache) diff --git a/vllm/model_executor/layers/attention/ops/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py index 70f09224f1cf6..c50c34e076a7c 100644 --- a/vllm/model_executor/layers/attention/ops/prefix_prefill.py +++ b/vllm/model_executor/layers/attention/ops/prefix_prefill.py @@ -4,699 +4,638 @@ import torch import triton import triton.language as tl - -if triton.__version__ >= "2.1.0": - - @triton.jit - def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=None): - - cap = torch.cuda.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 8 if Lk <= 64 else 8 - if alibi_slopes is not None: - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - b_start_loc, - b_seq_len, - b_ctx_len, - alibi_slopes, - v_cache.shape[3], - 8, - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - _fwd_kernel[grid]( +import packaging + +assert packaging.version.parse(triton.__version__) >= packaging.version.parse( + "2.2.0"), "Triton version >= 2.2.0 is required." + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0).to(q.dtype) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None): + + cap = torch.cuda.get_device_capability() + BLOCK = 128 if cap[0] >= 8 else 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( q, k, v, @@ -707,6 +646,7 @@ def context_attention_fwd(q, b_start_loc, b_seq_len, b_ctx_len, + alibi_slopes, v_cache.shape[3], 8, o, @@ -743,3 +683,50 @@ def context_attention_fwd(q, num_stages=1, ) return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/vllm/utils.py b/vllm/utils.py index fe6fd27962cd3..29a8151f64889 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -29,7 +29,7 @@ "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, - "fp8_e5m2": torch.uint8, + "fp8_e5m2": torch.float8_e5m2, } @@ -267,10 +267,8 @@ def create_kv_caches_with_random( torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in ["half", "bfloat16", "float"]: + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - elif cache_dtype == "fp8_e5m2": - torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype):