From b6c35d69f4d478ca83b9ea447a0f89a79a7d507e Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Mon, 5 Feb 2024 17:41:00 +0800 Subject: [PATCH] Finalize PagedAttention implementation --- .../paged_attention/paged_attention.cpp | 100 +++++++++++++----- .../paged_attention/paged_attention.hpp | 6 +- 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp index f6f31fdad..388b869b3 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp @@ -6,11 +6,35 @@ #include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/parameter.hpp" -#include "openvino/op/constants.hpp" +#include "openvino/op/constant.hpp" #include "openvino/op/result.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/runtime/core.hpp" #include "cpu_ops.hpp" +std::shared_ptr TemplateExtension::PagedAttention::make_prefill_subgraph() { + ov::element::Type_t type = ov::element::f32, attention_mask_type = ov::element::boolean; + auto query = std::make_shared(type, ov::PartialShape({-1, -1, m_num_heads, m_head_size})); + auto key = std::make_shared(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size})); + auto value = std::make_shared(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size})); + auto attention_mask = std::make_shared(attention_mask_type, ov::PartialShape({-1, -1, -1})); + auto scale = std::make_shared(type, ov::Shape({1})); + + // transpose Q, K and V to swap num_heads and seq_len dimensions + auto permute_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({4}), {0, 2, 1, 3}); + auto query_transposed = std::make_shared(query, permute_const); + auto key_transposed = std::make_shared(key, permute_const); + auto value_transposed = std::make_shared(value, permute_const); + + auto spda = std::make_shared(query_transposed, key_transposed, value_transposed, attention_mask, scale, false); + + // transpose SPDA output to [batch, seq_len, num_heads, head_size] back + auto spda_transposed = std::make_shared(spda, permute_const); + + return std::make_shared(spda_transposed, ov::ParameterVector{query, key, value, attention_mask, scale}, "spda_prefill_model"); +} + TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs, const float scale) : ov::op::Op(inputs), @@ -18,8 +42,7 @@ TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs constructor_validate_and_infer_types(); // compile model for prefill stage - auto model = make_spda(m_num_heads, m_num_kv_heads, m_head_size, m_scale); - auto compiled_model = ov::Core().compile_model(model, "CPU"); + auto compiled_model = ov::Core().compile_model(make_prefill_subgraph(), "CPU"); m_prefill_request = compiled_model.create_infer_request(); } @@ -140,29 +163,48 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value, ov::Tensor key_cache, ov::Tensor value_cache, ov::Tensor slot_mapping); -// generate block diagonal attention mask for a prefill stage -ov::Tensor generate_attention_mask(ov::Tensor context_lens); +// generate buttom diagonal boolean attention mask for a prefill stage +ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32_t max_context_len, ov::Tensor context_lens) { + OPENVINO_ASSERT(num_seqs == context_lens.get_size()); + + ov::Shape attention_mask_shape({num_seqs, max_context_len, max_context_len}); + ov::Tensor attention_mask(ov::element::boolean, attention_mask_shape); + int attention_mask_stride = attention_mask.get_strides()[0]; -ov::Tensor view(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) { - const std::uint32_t num_seqs = tensor.get_size() / (num_heads * head_size); - return ov::Tensor(tensor.get_element_type(), ov::Shape({num_seqs, num_heads, head_size}), tensor.data()); + std::fill_n(attention_mask.data(), attention_mask.get_size(), false); + + for (int current_seq = 0; current_seq < num_seqs; ++current_seq) { + std::int32_t context_len = context_lens.data()[current_seq]; + OPENVINO_ASSERT(context_len <= max_context_len); + + bool * attention_mask_data = attention_mask.data() + current_seq * attention_mask_stride; + for (int x = 0; x < context_len; ++x) { + for (int y = 0; y < context_len; ++y) { + attention_mask_data[x * max_context_len + y] = x >= y; + } + } + } } -std::shared_ptr make_spda(std::int32_t num_heads, std::uint32_t num_kv_heads, std::uint32_t head_size, float scale) { - ov::element::Type_t type = ov::element::f32; - auto query = std::make_shared(type, ov::PartialShape({-1, -1, num_heads, head_size})); - auto key = std::make_shared(type, ov::PartialShape({-1, -1, num_kv_heads, head_size})); - auto value = std::make_shared(type, ov::PartialShape({-1, -1, num_kv_heads, head_size})); - auto attention_mask = generate_attention_mask({}); // TODO: fill in the shape - auto scale_const = std::make_shared(type, ov::Shape({1}), scale); +// similar to torch.Tensor.view +ov::Tensor view_as_3d(ov::Tensor tensor) { + ov::Shape shape = tensor.get_shape(); + OPENVINO_ASSERT(shape.size() == 4); + const std::uint32_t batch_size = shape[0], seq_len = shape[1], num_heads = shape[2], head_size = shape[3]; + return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads * head_size}), tensor.data()); +} - auto spda = std::make_shared(query, key, value, attention_mask, scale_const, false); - return std::make_shared(spda, {query, key, value, attention_mask, scale_const}); +ov::Tensor view_as_4d(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) { + ov::Shape shape = tensor.get_shape(); + const std::uint32_t batch_size = shape[0], seq_len = shape[1]; + OPENVINO_ASSERT(shape.size() == 3 && num_heads * head_size == shape[3]); + return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads, head_size}), tensor.data()); } bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2]; - const std::int32_t batch_size = query.get_shape()[0], seq_len = query.get_shape()[1], hidden_size = query.get_shape()[2]; + ov::Shape query_shape = query.get_shape(); + const std::int32_t batch_size = query_shape[0], seq_len = query_shape[1], hidden_size = query_shape[2]; ov::Tensor key_cache = inputs[3], value_cache = inputs[4]; const bool is_prompt = inputs[5].data()[0]; ov::Tensor slot_mapping = inputs[6]; @@ -170,24 +212,30 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons ov::Tensor context_lens = inputs[8]; ov::Tensor block_tables = inputs[9]; - // reshape to [num_seq, num_heads, head_size] - query = view(query, m_num_heads, m_head_size); - key = view(key, m_num_kv_heads, m_head_size); - value = view(value, m_num_kv_heads, m_head_size); + // reshape to [batch_size, seq_len, num_heads/m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size] + query = view_as_4d(query, m_num_heads, m_head_size); + key = view_as_4d(key, m_num_kv_heads, m_head_size); + value = view_as_4d(value, m_num_kv_heads, m_head_size); // put current K, V values into key_cache and value_cache reshape_and_cache(key, value, key_cache, value_cache, slot_mapping); + // set output shape + OPENVINO_ASSERT(outputs.size() == 1); + outputs[0].set_shape(query.get_shape()); + if (is_prompt) { - auto attention_mask = generate_attention_mask(context_lens); + auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens); + ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale); - // create a model with OpenVINO SDPA to compute first token m_prefill_request.set_input_tensor(0, query); m_prefill_request.set_input_tensor(1, key); m_prefill_request.set_input_tensor(2, value); m_prefill_request.set_input_tensor(3, attention_mask); + m_prefill_request.set_input_tensor(4, scale); + m_prefill_request.set_output_tensor(outputs[0]); + m_prefill_request.infer(); - outputs[0] = m_prefill_request.get_output_tensor(); } else { paged_attention_v1_cpu(outputs[0], query, key_cache, value_cache, @@ -197,7 +245,7 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons } // reshape - outputs[0] = view(outputs[0], batch_size, seq_len, hidden_size); + outputs[0] = view_as_3d(outputs[0]); return true; } diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp index 597b6cc90..0f05e9f47 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp @@ -56,9 +56,11 @@ class PagedAttention : public ov::op::Op { bool has_evaluate() const override; private: - std::uuint32_t m_num_heads, m_num_kv_heads, m_head_size, m_block_size; + std::shared_ptr make_prefill_subgraph(); + + std::uint32_t m_num_heads, m_num_kv_heads, m_head_size, m_block_size; float m_scale; - ov::InferRequest m_prefill_request; + mutable ov::InferRequest m_prefill_request; }; } // namespace TemplateExtension