From 441aebb2ba8646b363d4e173eb534f072b31ce94 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 21 Mar 2024 15:33:46 +0800 Subject: [PATCH] [CPU] Add PagedAttention support (#23524) ### Details: - *Support PagedAttention support, depends on:* - openvino_contrib: https://github.com/openvinotoolkit/openvino_contrib/pull/867 - vLLM: https://github.com/ilya-lavrenov/vllm/pull/4 - *TODO* - Models with alibi feature ### Tickets: - *[134329](https://jira.devtools.intel.com/browse/CVS-134329)* - *[134327](https://jira.devtools.intel.com/browse/CVS-134327)* --- .../src/transformations/convert_precision.cpp | 6 +- src/plugins/intel_cpu/CMakeLists.txt | 2 +- src/plugins/intel_cpu/src/cpu_types.cpp | 1 + src/plugins/intel_cpu/src/graph.cpp | 4 + .../nodes/kernels/scaled_attn/attn_memcpy.cpp | 54 ++ .../nodes/kernels/scaled_attn/attn_memcpy.hpp | 6 + .../kernels/scaled_attn/mha_single_token.cpp | 170 ++++--- .../kernels/scaled_attn/mha_single_token.hpp | 1 + .../src/nodes/kernels/x64/brgemm_kernel.cpp | 4 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 476 +++++++++++++----- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 18 + 11 files changed, 544 insertions(+), 198 deletions(-) diff --git a/src/common/transformations/src/transformations/convert_precision.cpp b/src/common/transformations/src/transformations/convert_precision.cpp index d5c8204663c242..b29ab74981a483 100644 --- a/src/common/transformations/src/transformations/convert_precision.cpp +++ b/src/common/transformations/src/transformations/convert_precision.cpp @@ -607,7 +607,11 @@ bool fuse_type_to_parameter(const std::shared_ptr& node, auto convert = std::make_shared(param, to); for (auto& input : param_consumers) { const auto consumer = input.get_node(); - if (ov::is_type(consumer) || ov::is_type(consumer)) { + if (ov::is_type(consumer) || ov::is_type(consumer) || + // TODO: refactor after ngraph op defined + // The fourth and fifth inputs are kvcache and should be directly connected to parameters + (consumer->get_type_name() == std::string("PagedAttentionExtension") && + (input.get_index() == 3 || input.get_index() == 4))) { continue; } input.replace_source_output(convert); diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index 8efd078836275c..70da87819f03e5 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -176,7 +176,7 @@ cross_compiled_file(${TARGET_NAME} ARCH AVX512F AVX2 ANY src/nodes/kernels/scaled_attn/attn_memcpy.cpp API src/nodes/kernels/scaled_attn/attn_memcpy.hpp - NAME attn_memcpy + NAME attn_memcpy paged_attn_memcpy NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index 15a4edb5392bab..629fc5b0db2466 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -217,6 +217,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"Ngram", Type::Ngram}, {"ScaledDotProductAttention", Type::ScaledDotProductAttention}, {"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention}, + {"PagedAttentionExtension", Type::ScaledDotProductAttention}, {"RoPE", Type::RoPE}, }; return type_to_name_tbl; diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index e97c51d4322b73..3ffff01c6da6e7 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -1680,6 +1680,10 @@ void Graph::EnforceInferencePrecision() { if (node->getOriginalInputPrecisionAtPort(inPort) != ov::element::f32) return true; + // kvcache of PagedAttention should be written directly + if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 && + (inPort == 3 || inPort == 4)) + return true; const auto &parent = node->getParentEdgeAt(inPort)->getParent(); /* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing. * Element type conversion to bf16 is done automatically, if convolution follows up after Constant Inputs diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp index 08d9635da9ffd9..c170464eeb47ee 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp @@ -76,6 +76,43 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, }); } +template +static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output, + const ov::intel_cpu::PlainTensor& slot_mapping) { + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; + parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + auto block_idx = slot_mapping.ptr(b)[m]; + if (block_idx < 0) return; + attn_copy(past_k_output.ptr(block_idx, h, 0), + k_input.ptr(b, h, m, 0), + S); + attn_copy(past_v_output.ptr(block_idx, h, 0), + v_input.ptr(b, h, m, 0), + S); + }); +} + +static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output, + const ov::intel_cpu::PlainTensor& slot_mapping) { + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3]; + parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { + auto block_idx = slot_mapping.ptr(b)[m]; + if (block_idx < 0) return; + std::memcpy(past_k_output.ptr_v(block_idx, h, 0), + k_input.ptr_v(b, h, m, 0), + S * k_input.m_element_size); + std::memcpy(past_v_output.ptr_v(block_idx, h, 0), + v_input.ptr_v(b, h, m, 0), + S * v_input.m_element_size); + }); +} + void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& v_input, const ov::intel_cpu::PlainTensor& past_k_output, @@ -90,6 +127,23 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy"); } } + +void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output, + const ov::intel_cpu::PlainTensor& slot_mapping) { + if (past_k_output.get_precision() == k_input.get_precision()) { + paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping); + } else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::f16) { + paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping); + } else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) { + paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping); + } else { + OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in paged_attn_memcpy"); + } +} + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp index 68bf517475888b..2c44534a8462d7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp @@ -20,6 +20,12 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_k_output, const ov::intel_cpu::PlainTensor& past_v_output); +void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, + const ov::intel_cpu::PlainTensor& v_input, + const ov::intel_cpu::PlainTensor& past_k_output, + const ov::intel_cpu::PlainTensor& past_v_output, + const ov::intel_cpu::PlainTensor& slot_mapping); + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 3121a5852d19da..d16f85f154b685 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -594,6 +594,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, + const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, @@ -609,15 +610,17 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto H = query.size(1); auto q_len = query.size(2); auto S = query.size(3); - auto kv_len = present_key.size(2); auto h_group_num = present_key.size(1); size_t h_each_group_len = 1; + bool is_pagedattn = context_lens; if (h_group_num != H) { h_each_group_len = H / h_group_num; } if (d_scale == 0.0f) d_scale = 1.0f / sqrt(S); auto nthr = parallel_get_max_threads(); + // max kv len + auto kv_len = beams.size(1); // use per-token kernel, for each k,v token // attn mask is a matrix of q_len(kv_len) @@ -642,53 +645,79 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, size_t b, h_group, pk; if (start < end) { parallel_it_init(start, b, B, h_group, h_group_num, pk, kv_len); - if (q_len == 1 && h_each_group_len == 1) { - if (B == 1) { - // the memory will be continuous when b==1 - for (size_t iwork = start; iwork < end; ++iwork) { - auto p = past_k_scale_zp.ptr(0, h_group, pk); - auto p_k = present_key.ptr(0, h_group, pk); - prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); - buf_attn_w.ptr(0, h_group, 0)[pk] = - dot_product(query.ptr(0, h_group), p_k, - S, p, p + 1, head_sum.ptr(0, h_group)); - parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); + if (is_pagedattn) { + for (size_t iwork = start; iwork < end; ++iwork) { + auto context_len = static_cast(context_lens.ptr()[b]); + // kv_len must be valid + if (pk < context_len) { + auto block_idx = beams.ptr(b)[pk]; + OPENVINO_ASSERT(block_idx >= 0, "block idx must be greater or equal than 0"); + + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + buf_attn_w.ptr(b, h, pq)[pk] = + dot_product(query.ptr(b, h, pq), present_key.ptr(block_idx, h_group), + S, nullptr, nullptr, nullptr); + } + } + } + parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); + } + } else { + if (q_len == 1 && h_each_group_len == 1) { + if (B == 1) { + // the memory will be continuous when b==1 + for (size_t iwork = start; iwork < end; ++iwork) { + auto p = past_k_scale_zp.ptr(0, h_group, pk); + auto p_k = present_key.ptr(0, h_group, pk); + prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); + buf_attn_w.ptr(0, h_group, 0)[pk] = + dot_product(query.ptr(0, h_group), p_k, + S, p, p + 1, head_sum.ptr(0, h_group)); + parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); + } + } else { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pk] : b; + auto p = past_k_scale_zp.ptr(b_kv, h_group, pk); + auto p_k = present_key.ptr(b_kv, h_group, pk); + buf_attn_w.ptr(b, h_group, 0)[pk] = + dot_product(query.ptr(b, h_group), p_k, + S, p, p + 1, head_sum.ptr(b, h_group)); + parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); + } } } else { for (size_t iwork = start; iwork < end; ++iwork) { auto b_kv = beams ? beams.ptr(b)[pk] : b; - auto p = past_k_scale_zp.ptr(b_kv, h_group, pk); - auto p_k = present_key.ptr(b_kv, h_group, pk); - buf_attn_w.ptr(b, h_group, 0)[pk] = - dot_product(query.ptr(b, h_group), p_k, - S, p, p + 1, head_sum.ptr(b, h_group)); - parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); - } - } - } else { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pk] : b; - for (size_t pq = 0; pq < q_len; pq++) { - auto p = past_k_scale_zp.ptr(b_kv, h_group, pk); - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - buf_attn_w.ptr(b, h, pq)[pk] = - dot_product(query.ptr(b, h, pq), present_key.ptr(b_kv, h_group, pk), - S, p, p + 1, head_sum.ptr(b, h, pq)); + for (size_t pq = 0; pq < q_len; pq++) { + auto p = past_k_scale_zp.ptr(b_kv, h_group, pk); + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + buf_attn_w.ptr(b, h, pq)[pk] = + dot_product(query.ptr(b, h, pq), present_key.ptr(b_kv, h_group, pk), + S, p, p + 1, head_sum.ptr(b, h, pq)); + } } + parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); } - parallel_it_step(b, B, h_group, h_group_num, pk, kv_len); } } } }); parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { + auto cur_kv_len = kv_len; + auto ncausal = auto_causal ? (cur_kv_len - q_len + pq + 1) : cur_kv_len; + if (is_pagedattn) { + cur_kv_len = static_cast(context_lens.ptr()[b]); + ncausal = cur_kv_len; + } // apply attention mask & sofmax - auto ncausal = auto_causal ? (kv_len - q_len + pq + 1) : kv_len; float* alibi_ptr = alibi_mask ? &alibi_mask.at({b, h, pq, 0}, true) : nullptr; uint8_t* attn_mask_ptr = nullptr; auto attn_mask_prec = attention_mask.get_precision(); - attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); + if (attention_mask) + attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; attn_softmax_kernel(buf_attn_w.ptr(b, h, pq), buf_attn_w.ptr(b, h, pq), @@ -698,7 +727,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, cmask_ptr, select_nfltmax_at_0, ncausal, - kv_len, + cur_kv_len, attn_mask_prec, ov::element::f32); }); @@ -715,35 +744,58 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, size_t b, h_group, pv; if (start < end) { parallel_it_init(start, b, B, h_group, h_group_num, pv, kv_len); - if (q_len == 1 && h_each_group_len == 1) { + if (is_pagedattn) { for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pv] : b; - auto* v = present_value.ptr(b_kv, h_group, pv); - auto p = past_v_scale_zp.ptr(b_kv, h_group, pv); - attn_acc_value(buf_attn_score.ptr(ithr, b, 0, h_group), - buf_attn_w.ptr(b, h_group, 0, pv)[0], - v, - S, - p + 0, - p + 1); + auto context_len = static_cast(context_lens.ptr()[b]); + // kv_len must be valid + if (pv < context_len) { + auto block_idx = beams.ptr(b)[pv]; + OPENVINO_ASSERT(block_idx >= 0, "block idx in vcache must be greater or equal than 0"); + auto* v = present_value.ptr(block_idx, h_group); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + attn_acc_value(buf_attn_score.ptr(ithr, b, pq, h), + buf_attn_w.ptr(b, h, pq)[pv], + v, + S, + nullptr, + nullptr); + } + } + } parallel_it_step(b, B, h_group, h_group_num, pv, kv_len); } } else { - for (size_t iwork = start; iwork < end; ++iwork) { - auto b_kv = beams ? beams.ptr(b)[pv] : b; - auto* v = present_value.ptr(b_kv, h_group, pv); - auto p = past_v_scale_zp.ptr(b_kv, h_group, pv); - for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { - attn_acc_value(buf_attn_score.ptr(ithr, b, pq, h), - buf_attn_w.ptr(b, h, pq)[pv], - v, - S, - p + 0, - p + 1); + if (q_len == 1 && h_each_group_len == 1) { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pv] : b; + auto* v = present_value.ptr(b_kv, h_group, pv); + auto p = past_v_scale_zp.ptr(b_kv, h_group, pv); + attn_acc_value(buf_attn_score.ptr(ithr, b, 0, h_group), + buf_attn_w.ptr(b, h_group, 0, pv)[0], + v, + S, + p + 0, + p + 1); + parallel_it_step(b, B, h_group, h_group_num, pv, kv_len); + } + } else { + for (size_t iwork = start; iwork < end; ++iwork) { + auto b_kv = beams ? beams.ptr(b)[pv] : b; + auto* v = present_value.ptr(b_kv, h_group, pv); + auto p = past_v_scale_zp.ptr(b_kv, h_group, pv); + for (size_t pq = 0; pq < q_len; pq++) { + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + attn_acc_value(buf_attn_score.ptr(ithr, b, pq, h), + buf_attn_w.ptr(b, h, pq)[pv], + v, + S, + p + 0, + p + 1); + } } + parallel_it_step(b, B, h_group, h_group_num, pv, kv_len); } - parallel_it_step(b, B, h_group, h_group_num, pv, kv_len); } } } @@ -763,6 +815,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, + const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, @@ -780,6 +833,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, + context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -796,6 +850,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, + context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -814,6 +869,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, + context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -830,6 +886,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, + context_lens, output_emb, buf_attn_w, buf_attn_score, @@ -846,6 +903,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, alibi_mask, attention_mask, beams, + context_lens, output_emb, buf_attn_w, buf_attn_score, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp index e29e2bae0aa07a..07edc33d914a69 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp @@ -21,6 +21,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& alibi_mask, const ov::intel_cpu::PlainTensor& attention_mask, const ov::intel_cpu::PlainTensor& beams, + const ov::intel_cpu::PlainTensor& context_lens, ov::intel_cpu::PlainTensor& output_emb, ov::intel_cpu::PlainTensor& buf_attn_w, ov::intel_cpu::PlainTensor& buf_attn_score, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp index 6aea1119788e92..a1aeaaecf9ae17 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp @@ -45,8 +45,9 @@ BrgemmKernel::BrgemmKernel(size_t M, THROW_ERROR("brgemm bf16 kernel could only be used above avx512_bf16"); bool isAMXSupported = is_bf16 && mayiuse(avx512_core_amx); + size_t vlen = cpu_isa_traits::vlen; // blocking N - N_blk = is_bf16 ? 32 : N; + N_blk = is_bf16 ? 32 : std::max(N, vlen / inType.size()); N_tail = N % N_blk; // blocking K @@ -55,7 +56,6 @@ BrgemmKernel::BrgemmKernel(size_t M, if (isAMXSupported && K_tail) { K_tail = rnd_up(K_tail, 2); } - size_t vlen = cpu_isa_traits::vlen; // copied K must be round up by vlen / inType.size(), otherwise copy B kernel may access wrong memory packedBSize = rnd_up(K, vlen / inType.size()) * rnd_up(N, N_blk) * inType.size(); size_t brg0BaseIdx = std::numeric_limits::max(); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index ffe1f6b39cf412..a5ca0425c50bc7 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -125,7 +125,8 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f) { + float d_scale = 0.0f, + size_t sliding_window = 0) { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); @@ -177,7 +178,18 @@ struct MHAKernel { } // softmax - softmax(&attn_score[0], ncausal); + if (sliding_window) { + size_t start_idx = 0; + auto new_causal = ncausal; + if (ncausal > sliding_window) { + start_idx = ncausal - static_cast(sliding_window); + new_causal = sliding_window; + } + softmax(&attn_score[start_idx], new_causal); + memset(&attn_score[0], 0, sizeof(float) * start_idx); + } else { + softmax(&attn_score[0], ncausal); + } // linearly combine value word_vec.assign(head_size, 0.0f); @@ -292,7 +304,10 @@ struct MHAKernel { qk_gemm_ptr = qk_result.first; dnnl::memory::desc attn_md(make_dnnl_dims({B, H, q_len, kv_len}), dt::f32, tag::abcd); weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd); - out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, head_size}), qkv_dt, tag::abcd); + if (has_out_transpose) + out_md = dnnl::memory::desc(make_dnnl_dims({B, q_len, H, head_size}), qkv_dt, tag::abcd); + else + out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, head_size}), qkv_dt, tag::abcd); size_t ldc_index = 2; if (has_out_transpose) { @@ -347,7 +362,8 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f) { + float d_scale = 0.0f, + size_t sliding_window = 0) { const auto B = query.size(0); const auto H = query.size(1); const auto q_len = query.size(2); @@ -410,22 +426,47 @@ struct MHAKernel { for (size_t m = m_start; m < m_end; m++) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - attn_softmax(&score.at({b, h, m, 0}), - &weight.at({b, h, m, 0}), - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - ncausal, - kv_len, - precision_of::value, - precision_of::value); + if (sliding_window) { + size_t start_idx = 0; + auto new_causal = ncausal; + if (ncausal > sliding_window) { + start_idx = ncausal - static_cast(sliding_window); + new_causal = sliding_window; + } + attn_softmax(&score.at({b, h, m, start_idx}), + &weight.at({b, h, m, start_idx}), + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + new_causal, + kv_len - start_idx, + precision_of::value, + precision_of::value); + + memset(&weight.at({b, h, m, 0}), 0, sizeof(T) * start_idx); + } else { + attn_softmax(&score.at({b, h, m, 0}), + &weight.at({b, h, m, 0}), + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + ncausal, + kv_len, + precision_of::value, + precision_of::value); + } } T* w_ptr = &weight.at({b, h, m_start, 0}); - PlainTensor& sdpa_out = is_bf16 ? fp32_out : output_emb; - float* fp32_out_ptr = - has_out_transpose ? &sdpa_out.at({b, m_start, h, 0}) : &sdpa_out.at({b, h, m_start, 0}); + float* fp32_out_ptr; + if (is_bf16) { + fp32_out_ptr = has_out_transpose ? &fp32_out.at({b, m_start, h, 0}) : &fp32_out.at({b, h, m_start, 0}); + } else { + fp32_out_ptr = has_out_transpose ? &output_emb.at({b, m_start, h * head_size}) : &output_emb.at({b, h, m_start, 0}); + } T* v_ptr = is_bf16 ? &wv_scratch_b.at({b, h / h_each_group_len, 0}) : &present_value.at({b, h / h_each_group_len, 0, 0}); wv_gemm_ptr->executeGemm(m_cnt < m_block_size, @@ -435,11 +476,21 @@ struct MHAKernel { wsp.data() + tid * wsp_size_per_thread, wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); if (is_bf16) { - cpu_convert(&fp32_out.at({b, h, m_start, 0}), - &output_emb.at({b, h, m_start, 0}), - ov::element::f32, - ov::element::bf16, - m_cnt * head_size); + if (has_out_transpose) { + for (size_t m = m_start; m < m_end; m++) { + cpu_convert(&fp32_out.at({b, m, h, 0}), + &output_emb.at({b, m, h * head_size}), + ov::element::f32, + ov::element::bf16, + head_size); + } + } else { + cpu_convert(&fp32_out.at({b, h, m_start, 0}), + &output_emb.at({b, h, m_start, 0}), + ov::element::f32, + ov::element::bf16, + m_cnt * head_size); + } } }); } @@ -467,7 +518,8 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f) { + float d_scale = 0.0f, + size_t sliding_window = 0) { auto head_size = query.size(3); if (d_scale == 0.0f) d_scale = 1.0f / sqrt(head_size); @@ -481,7 +533,8 @@ struct MHAKernel { output_emb, has_out_transpose, auto_causal, - d_scale); + d_scale, + sliding_window); } }; @@ -523,7 +576,8 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, - float d_scale = 0.0f) { + float d_scale = 0.0f, + size_t sliding_window = 0) { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); @@ -613,17 +667,39 @@ struct MHAKernel { for (size_t m = m_start; m < m_end; m++) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; - attn_softmax(qk + (m - m_start) * qk_m_stride, - qk + (m - m_start) * qk_m_stride, - d_scale, - alibi_ptr + m * alibi_stride, - attn_mask_ptr + m * attn_mask_stride, - cmask_ptr + m * cmask_stride, - select_nfltmax_at_0, - ncausal, - kv_len, - ov::element::f32, - ov::element::f32); + if (sliding_window) { + size_t start_idx = 0; + auto new_causal = ncausal; + if (ncausal > sliding_window) { + start_idx = ncausal - static_cast(sliding_window); + new_causal = sliding_window; + } + attn_softmax(qk + (m - m_start) * qk_m_stride + start_idx, + qk + (m - m_start) * qk_m_stride + start_idx, + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + new_causal, + kv_len - start_idx, + ov::element::f32, + ov::element::f32); + + memset(qk + (m - m_start) * qk_m_stride, 0, sizeof(float) * start_idx); + } else { + attn_softmax(qk + (m - m_start) * qk_m_stride, + qk + (m - m_start) * qk_m_stride, + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + ncausal, + kv_len, + ov::element::f32, + ov::element::f32); + } } mlas_sgemm("N", "N", @@ -666,12 +742,13 @@ struct MHASingleToken { const PlainTensor& attention_mask, PlainTensor& output_emb, const PlainTensor& beams, + const PlainTensor& context_lens, bool has_out_transpose, bool auto_causal, float d_scale, const PlainTensor& k_scale_zp, const PlainTensor& v_scale_zp) { - mha_single_token(query, present_key, present_value, alibi_mask, attention_mask, beams, output_emb, + mha_single_token(query, present_key, present_value, alibi_mask, attention_mask, beams, context_lens, output_emb, m_attn_w, m_temp, has_out_transpose, auto_causal, d_scale, k_scale_zp, v_scale_zp, m_head_sum); } }; @@ -700,66 +777,108 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt bool fuse_causal_attn = config.config.fuse_causal_attn; bool is_causal = config.config.is_causal; bool fuse_concat = config.config.fuse_concat; + bool is_pagedattn = config.is_pageattn; auto input_num = inputs.size(); + bool is_prompt = false; PlainTensor present_key, present_value; PlainTensor q_input; // f32[B, H, L1, S] PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor beam_table; // i32[B, max_kvLen] + PlainTensor context_lens; + PlainTensor attn_mask; + PlainTensor output_emb(output); float scale_input = 0.0f; size_t B, L1, L0, S; + size_t sliding_window = 0; q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); present_key.reset(presentk_input); present_value.reset(presentv_input); - if (beam_input) - beam_table.reset(beam_input); - PlainTensor attn_mask; - if (input_num > 3) { - // attn_mask - if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { - // bool->f32 - prepare_attn_mask(inputs[3]); - attn_mask = attn_buf; + if (is_pagedattn) { + is_prompt = *inputs[ID_IS_PROMPT]->getDataAs() == 1; + //auto max_context_len = static_cast(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs()); + context_lens.reset(inputs[ID_CONTEXT_LENS]); + beam_table.reset(inputs[ID_BLOCK_TABLES]); + scale_input = *inputs[ID_SCALE]->getDataAs(); + // TODO: alibi and sliding window + // no attn mask, auto-generated casual mask + is_causal = true; + has_out_transpose = true; + + // q: [B, L1, H*S], kv: [B, L1, Hk*S] + // k_cache: [NUM_BLOCKS, Hk, S / 4, BLOCK_SIZE, 4] + // v_cache: [NUM_BLOCKS, Hk, S, BLOCK_SIZE] + // context_lens: [B] + // block_tables: [B, max_block_per_request] + B = k_input.size(0); + L1 = k_input.size(1); + auto Hk = present_key.size(1); + S = present_value.size(2); + auto H = q_input.size(2) / S; + // L0 in each batch may be different + L0 = 0; + + q_input.assert_dims({B, L1, H * S}); + if (!is_prompt) { + context_lens.assert_dims({B}); + beam_table.assert_dims({B, 0}, true); } else { - attn_mask.reset(inputs[3]); + sliding_window = static_cast(*inputs[ID_SLIDING_WINDOW]->getDataAs()); } - // if has scale, attn_mask must be present - if (input_num > 4) { - scale_input = *inputs[4]->getDataAs(); + output_emb.assert_dims({B, L1, H * S}); + q_input = q_input.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + k_input = k_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); + v_input = v_input.reshape({B, L1, Hk, S}).permute({0, 2, 1, 3}); + present_key = present_key.reshape({present_key.size(0), Hk, S}); + present_value = present_value.reshape({present_value.size(0), Hk, S}); + } else { + if (beam_input) + beam_table.reset(beam_input); + if (input_num > 3) { + // attn_mask + if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { + // bool->f32 + prepare_attn_mask(inputs[3]); + attn_mask = attn_buf; + } else { + attn_mask.reset(inputs[3]); + } + // if has scale, attn_mask must be present + if (input_num > 4) { + scale_input = *inputs[4]->getDataAs(); + } } - } - // q: [B, H, L1, S] - const auto & permute_axes = config.config.permute_axes; - if (!permute_axes.empty()) { - q_input = q_input.permute(permute_axes); - k_input = k_input.permute(permute_axes); - v_input = v_input.permute(permute_axes); - present_key = present_key.permute(permute_axes); - present_value = present_value.permute(permute_axes); - } - B = q_input.size(0); - L1 = q_input.size(2); - S = q_input.size(3); - L0 = present_key.size(2) - L1; - auto Hk = k_input.size(1); - - if (fuse_concat) { - k_input.assert_dims({B, Hk, L1, S}); - v_input.assert_dims({B, Hk, L1, S}); - } else { - k_input.assert_dims({B, Hk, L0 + L1, S}); - v_input.assert_dims({B, Hk, L0 + L1, S}); + // q: [B, H, L1, S] + const auto & permute_axes = config.config.permute_axes; + if (!permute_axes.empty()) { + q_input = q_input.permute(permute_axes); + k_input = k_input.permute(permute_axes); + v_input = v_input.permute(permute_axes); + present_key = present_key.permute(permute_axes); + present_value = present_value.permute(permute_axes); + } + B = q_input.size(0); + L1 = q_input.size(2); + S = q_input.size(3); + L0 = present_key.size(2) - L1; + auto Hk = k_input.size(1); + + if (fuse_concat) { + k_input.assert_dims({B, Hk, L1, S}); + v_input.assert_dims({B, Hk, L1, S}); + } else { + k_input.assert_dims({B, Hk, L0 + L1, S}); + v_input.assert_dims({B, Hk, L0 + L1, S}); + } + present_key.assert_dims({B, Hk, L0 + L1, S}); + present_value.assert_dims({B, Hk, L0 + L1, S}); + if (beam_table) + beam_table.assert_dims({B, L0 + L1}); } - present_key.assert_dims({B, Hk, L0 + L1, S}); - present_value.assert_dims({B, Hk, L0 + L1, S}); - if (beam_table) - beam_table.assert_dims({B, L0 + L1}); - - ov::intel_cpu::PlainTensor output_emb(output); bool auto_causal; bool use_attn_mask; @@ -791,11 +910,15 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } // second token, or first token with pastkv fusing - bool use_one_token = L1 == 1 || (fuse_concat && L0 > 0); + bool use_one_token; + if (is_pagedattn) + use_one_token = !is_prompt; + else + use_one_token = L1 == 1 || (fuse_concat && L0 > 0); if (!use_one_token) { // multi-token version kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), - output_emb, has_out_transpose, auto_causal, scale_input); + output_emb, has_out_transpose, auto_causal, scale_input, sliding_window); } else { // 1-token version // for second token, using a special AVX2/AVX512 float path: @@ -803,7 +926,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // 2, using float will save the repack cost which typically is required for bf16/int8 opt // 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), - output_emb, beam_table, has_out_transpose, auto_causal, scale_input, k_scale_zp, v_scale_zp); + output_emb, beam_table, context_lens, has_out_transpose, auto_causal, scale_input, k_scale_zp, v_scale_zp); } } }; @@ -815,12 +938,18 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr(op); - if (node) { - m_config.config.is_causal = node->get_causal(); + if (op->get_type_name() == std::string("PagedAttentionExtension")) { + m_is_pageattn = true; + m_config.is_pageattn = true; } else { - const auto node = std::dynamic_pointer_cast(op); - m_config.config = node->get_config(); + m_is_pageattn = false; + const auto node = std::dynamic_pointer_cast(op); + if (node) { + m_config.config.is_causal = node->get_causal(); + } else { + const auto node = std::dynamic_pointer_cast(op); + m_config.config = node->get_config(); + } } } @@ -840,49 +969,83 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { rtPrecision, getInputShapeAtPort(1))); config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( rtPrecision, getInputShapeAtPort(2))); - auto nextPortIdx = 3; - if (orginSDPInputNumber > 3) { - // attn_mask - if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) { - config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::u8, getInputShapeAtPort(nextPortIdx))); - } else { + if (m_is_pageattn) { + OPENVINO_ASSERT(getOriginalInputsNumber() == 13, "The input number of PagedAttention should be 13."); + // kvcache, float, [] + auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(ID_KCACHE); + config.inConfs[ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_kv_input_mem_precision, getInputShapeAtPort(ID_KCACHE))); + config.inConfs[ID_VCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_kv_input_mem_precision, getInputShapeAtPort(ID_VCACHE))); + // is_prompt, bool, [] + config.inConfs[ID_IS_PROMPT].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::u8, getInputShapeAtPort(ID_IS_PROMPT))); + // slot_mapping, int, [batch_size, max_context_len] + config.inConfs[ID_SLOT_MAPPING].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(ID_SLOT_MAPPING))); + // max_context_len, int, [] + config.inConfs[ID_MAX_CONTEXT_LEN].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(ID_MAX_CONTEXT_LEN))); + // context_lens, int, [batch_size] + config.inConfs[ID_CONTEXT_LENS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(ID_CONTEXT_LENS))); + // block_tables, int, [batch_size, max_block_per_request] + config.inConfs[ID_BLOCK_TABLES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(ID_BLOCK_TABLES))); + // scale, float, [] + config.inConfs[ID_SCALE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(ID_SCALE))); + // alibi_slopes, float, [?] or nullptr + config.inConfs[ID_ALIBI_SLOPES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(ID_ALIBI_SLOPES))); + // sliding_window, int, [] + config.inConfs[ID_SLIDING_WINDOW].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(ID_SLIDING_WINDOW))); + } else { + auto nextPortIdx = 3; + if (orginSDPInputNumber > 3) { + // attn_mask + if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::u8, getInputShapeAtPort(nextPortIdx))); + } else { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(nextPortIdx))); + } + nextPortIdx++; + } + if (orginSDPInputNumber > 4) { config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - rtPrecision, getInputShapeAtPort(nextPortIdx))); + ov::element::f32, getInputShapeAtPort(nextPortIdx))); } - nextPortIdx++; - } - if (orginSDPInputNumber > 4) { - config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::f32, getInputShapeAtPort(nextPortIdx))); - } - if (m_config.config.fuse_concat) { - // beam_idx - config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - ov::element::i32, getInputShapeAtPort(orginSDPInputNumber + 0))); - - // Since the InputMemory nodes are simple proxy for the state memory as well as the init subgraph memory, - // it doesn't make sense to set the real KV cache precision, since we don't need any precision conversions - // provided by the common graph logic. We set precisions equal to the precisions of the state nodes to avoid - // reorder insertion in between MemoryInputSDPA and SDPA nodes. - - auto past_k_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 1)->getParent()->getOriginalOutputPrecisionAtPort(0); - // pastk - config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_k_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 1))); - - auto past_v_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 2)->getParent()->getOriginalOutputPrecisionAtPort(0); - // pastv - config.inConfs[orginSDPInputNumber + 2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_v_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 2))); - - config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_k_input_mem_precision, getOutputShapeAtPort(1))); - config.outConfs[1].inPlace(-1); - config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_v_input_mem_precision, getOutputShapeAtPort(2))); - config.outConfs[2].inPlace(-1); + if (m_config.config.fuse_concat) { + // beam_idx + config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::i32, getInputShapeAtPort(orginSDPInputNumber + 0))); + + // Since the InputMemory nodes are simple proxy for the state memory as well as the init subgraph memory, + // it doesn't make sense to set the real KV cache precision, since we don't need any precision conversions + // provided by the common graph logic. We set precisions equal to the precisions of the state nodes to avoid + // reorder insertion in between MemoryInputSDPA and SDPA nodes. + + auto past_k_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 1)->getParent()->getOriginalOutputPrecisionAtPort(0); + // pastk + config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_k_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 1))); + + auto past_v_input_mem_precision = getParentEdgeAt(orginSDPInputNumber + 2)->getParent()->getOriginalOutputPrecisionAtPort(0); + // pastv + config.inConfs[orginSDPInputNumber + 2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_v_input_mem_precision, getInputShapeAtPort(orginSDPInputNumber + 2))); + + config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_k_input_mem_precision, getOutputShapeAtPort(1))); + config.outConfs[1].inPlace(-1); + config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + past_v_input_mem_precision, getOutputShapeAtPort(2))); + config.outConfs[2].inPlace(-1); + } } config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( @@ -941,24 +1104,34 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { } PlainTensor k_scale_zp, v_scale_zp; - if (m_config.config.fuse_concat) { - // initialization will be also completed in this func - gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); - - presentk_input = m_k_state->internal_state_mem(); - presentv_input = m_v_state->internal_state_mem(); - beam_input = m_k_state->hidden_state_mem(); - k_scale_zp = m_k_state->get_scale_zp(); - v_scale_zp = m_v_state->get_scale_zp(); + if (m_is_pageattn) { + gatherConcatPastkvForPagedAttn(inputs); + + presentk_input = inputs[ID_KCACHE]; + presentv_input = inputs[ID_VCACHE]; } else { - presentk_input = inputs[1]; - presentv_input = inputs[2]; + if (m_config.config.fuse_concat) { + // initialization will be also completed in this func + gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); + + presentk_input = m_k_state->internal_state_mem(); + presentv_input = m_v_state->internal_state_mem(); + beam_input = m_k_state->hidden_state_mem(); + k_scale_zp = m_k_state->get_scale_zp(); + v_scale_zp = m_v_state->get_scale_zp(); + } else { + presentk_input = inputs[1]; + presentv_input = inputs[2]; + } } m_executor->execute(strm, m_config, inputs, output, presentk_input, presentv_input, beam_input, k_scale_zp, v_scale_zp); } bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { + if (op->get_type_name() == std::string("PagedAttentionExtension")) { + return true; + } if (!std::dynamic_pointer_cast(op) && !std::dynamic_pointer_cast(op)) { errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionWithKVCache operation are supported"; @@ -1161,6 +1334,33 @@ void ScaledDotProductAttention::resetBeamTablePastkv(const MemoryPtr& mem_cur_k, } } +void ScaledDotProductAttention::gatherConcatPastkvForPagedAttn(const std::vector& inputs) { + PlainTensor k, v, k_cache, v_cache, slot_mapping; + + k.reset(inputs[ID_K]); // [B, L1, H * S] + v.reset(inputs[ID_V]); + k_cache.reset(inputs[ID_KCACHE]); // [NUM_BLOCKS, H, S / 4, BLOCK_SIZE, 4] + v_cache.reset(inputs[ID_VCACHE]); // [NUM_BLOCKS, H, S, BLOCK_SIZE] + slot_mapping.reset(inputs[ID_SLOT_MAPPING]); // [B, max_context_len] + + auto B = k.size(0); + auto L1 = k.size(1); + auto H = k_cache.size(1); + auto S = v_cache.size(2); + + k.assert_dims({B, L1, H * S}); + v.assert_dims({B, L1, H * S}); + k_cache.assert_dims({0, H, 0, 1, 0}, true); + v_cache.assert_dims({0, H, S, 1}, true); + slot_mapping.assert_dims({B, 0}, true); + k = k.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + v = v.reshape({B, L1, H, S}).permute({0, 2, 1, 3}); + k_cache = k_cache.reshape({k_cache.size(0), H, S}); + v_cache = v_cache.reshape({v_cache.size(0), H, S}); + paged_attn_memcpy(k, v, k_cache, v_cache, slot_mapping); + // TODO: add u8 kvcache support +} + void ScaledDotProductAttention::gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx) { PlainTensor cur_k; cur_k.reset(mem_cur_k); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index d4ae2df8c7688a..38980d07e131e0 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -48,6 +48,7 @@ class ScaledDotProductAttention : public Node { private: void gatherConcatPastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v, const MemoryPtr& mem_beam_idx); + void gatherConcatPastkvForPagedAttn(const std::vector& inputs); void updateBeamTable(const MemoryPtr& mem_beam_idx, size_t new_q_len); void updatePastkv(const MemoryPtr& mem_cur_k, const MemoryPtr& mem_cur_v); ov::element::Type getRuntimePrecision() const override; @@ -55,6 +56,7 @@ class ScaledDotProductAttention : public Node { struct Config { ScaledDotProductAttentionWithKVCache::Config config; + bool is_pageattn = false; }; struct Executor { @@ -63,6 +65,7 @@ class ScaledDotProductAttention : public Node { const PlainTensor& k_scale_zp, const PlainTensor& v_scale_zp) = 0; }; + bool m_is_pageattn; Config m_config; std::shared_ptr m_executor; template struct AttentionExecutor; @@ -70,6 +73,21 @@ class ScaledDotProductAttention : public Node { std::shared_ptr m_k_state; std::shared_ptr m_v_state; + + // PagedAttention input index + static const size_t ID_Q = 0; + static const size_t ID_K = 1; + static const size_t ID_V = 2; + static const size_t ID_KCACHE = 3; + static const size_t ID_VCACHE = 4; + static const size_t ID_IS_PROMPT = 5; + static const size_t ID_SLOT_MAPPING = 6; + static const size_t ID_MAX_CONTEXT_LEN = 7; + static const size_t ID_CONTEXT_LENS = 8; + static const size_t ID_BLOCK_TABLES = 9; + static const size_t ID_SCALE = 10; + static const size_t ID_ALIBI_SLOPES = 11; + static const size_t ID_SLIDING_WINDOW = 12; }; } // namespace node