From 980fabc670b1fd914e8f3771f5110247d143245e Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Mon, 12 Feb 2024 05:39:43 +0800 Subject: [PATCH] Passes accuracy --- .../paged_attention/paged_attention.cpp | 93 ++++++------------- .../paged_attention/paged_attention.hpp | 7 +- 2 files changed, 32 insertions(+), 68 deletions(-) diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp index 986569ebb..4b677e39b 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp @@ -15,7 +15,7 @@ namespace { -std::shared_ptr make_prefill_subgraph(std::size_t num_heads, std::size_t num_kv_heads, std::size_t head_size) { +std::shared_ptr make_prefill_subgraph(std::int64_t num_heads = -1, std::int64_t num_kv_heads = -1, std::int64_t head_size = -1) { ov::element::Type_t type = ov::element::f32, attention_mask_type = ov::element::f32; auto query = std::make_shared(type, ov::PartialShape({-1, -1, num_heads, head_size})); auto key = std::make_shared(type, ov::PartialShape({-1, -1, num_kv_heads, head_size})); @@ -39,9 +39,20 @@ std::shared_ptr make_prefill_subgraph(std::size_t num_heads, std::siz } +ov::InferRequest TemplateExtension::PagedAttention::m_prefill_request; +std::once_flag TemplateExtension::PagedAttention::m_once; + TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs) : ov::op::Op(inputs) { constructor_validate_and_infer_types(); + + // compile model for prefill stage + std::call_once(m_once, [_this=this] () { + ov::Core core; + core.register_plugin("/mnt/data3_1878/ilya/Documents/Programming/git_repo/openvino/bin/intel64/Release/libopenvino_intel_cpu_plugin.so", "CPU2"); + auto compiled_model = core.compile_model(make_prefill_subgraph(), "CPU2"); + _this->m_prefill_request = compiled_model.create_infer_request(); + }); } TemplateExtension::PagedAttention::PagedAttention(const ov::Output& query, @@ -168,11 +179,6 @@ bool TemplateExtension::PagedAttention::has_evaluate() const { return get_input_element_type(0) == ov::element::f32; } -// puts current K, V values into key_cache and value_cache -void reshape_and_cache(ov::Tensor key, ov::Tensor value, - ov::Tensor key_cache, ov::Tensor value_cache, - ov::Tensor slot_mapping); - // generate buttom diagonal boolean attention mask for a prefill stage ov::Tensor generate_attention_mask(const std::size_t batch_size, const std::size_t seq_len) { ov::Shape attention_mask_shape({batch_size, 1, seq_len, seq_len}); @@ -182,13 +188,11 @@ ov::Tensor generate_attention_mask(const std::size_t batch_size, const std::size static_assert(std::numeric_limits::is_iec559, "IEEE 754 required"); float negative_inf = -std::numeric_limits::infinity(); - std::fill_n(attention_mask.data(), attention_mask.get_size(), 0); - for (int batch_id = 0; batch_id < batch_size; ++batch_id) { float * attention_mask_data = attention_mask.data() + batch_id * attention_mask_stride; - for (int x = 0; x < seq_len; ++x) { - for (int y = 0; y < seq_len; ++y) { - attention_mask_data[x * seq_len + y] = x < y ? negative_inf : 0.0f; + for (int y = 0; y < seq_len; ++y) { + for (int x = 0; x < seq_len; ++x) { + attention_mask_data[y * seq_len + x] = x > y ? negative_inf : 0.0f; } } } @@ -196,18 +200,6 @@ ov::Tensor generate_attention_mask(const std::size_t batch_size, const std::size return attention_mask; } -void print_tensor(const std::string& title, ov::Tensor tensor) { - std::cout << title << std::endl; - size_t size = std::min(40, tensor.get_size()); - for (int x = 0; x < size; ++x) { - if (tensor.get_element_type() == ov::element::f32) - std::cout << tensor.data()[x] << " "; - else if (tensor.get_element_type() == ov::element::i32) - std::cout << tensor.data()[x] << " "; - } - std::cout << std::endl; -} - bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2]; ov::Tensor key_cache = inputs[3], value_cache = inputs[4]; @@ -226,84 +218,53 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons const std::size_t num_kv_heads = value_cache_shape[1], head_size = value_cache_shape[2], num_heads = hidden_size / head_size, block_size = value_cache_shape[3]; - if (!m_prefill_request) { - // compile model for prefill stage - auto compiled_model = ov::Core().compile_model(make_prefill_subgraph(num_heads, num_kv_heads, head_size), "CPU"); - m_prefill_request = compiled_model.create_infer_request(); - } - // reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size] - void * query_data = query.data(), * key_data = key.data(), * value_data = value.data(); query.set_shape({batch_size * seq_len, num_heads, head_size}); - OPENVINO_ASSERT(query_data == query.data()); key.set_shape({batch_size * seq_len, num_kv_heads, head_size}); - OPENVINO_ASSERT(key_data == key.data()); value.set_shape(key.get_shape()); - OPENVINO_ASSERT(value_data == value.data()); - // put current K, V values into key_cache and value_cache - reshape_and_cache(key, value, key_cache, value_cache, slot_mapping); + reshape_and_cache_cpu(key, value, key_cache, value_cache, slot_mapping); // set output shape OPENVINO_ASSERT(outputs.size() == 1); outputs[0].set_shape(query.get_shape()); - void * output_data = outputs[0].data(); - - // std::cout << "key_cache shape " << key_cache.get_shape() << std::endl; - // std::cout << "value_cache shape " << value_cache.get_shape() << std::endl; - // std::cout << "num_kv_heads " << num_kv_heads << std::endl; - // std::cout << "block_tables shape " << block_tables.get_shape() << std::endl; - // std::cout << "context_lens shape " << context_lens.get_shape() << std::endl; - // std::cout << "block_size " << block_size << std::endl; - // std::cout << "max_context_len " << max_context_len << std::endl; - // std::cout << "outputs[0] shape " << outputs[0].get_shape() << std::endl; - // std::cout << "num_kv_heads " << num_kv_heads << std::endl; - // std::cout << "num_heads " << num_heads << std::endl; - // std::cout << "head_size " << head_size << std::endl; if (is_prompt) { // reshape to [batch_size, seq_len, num_kv_heads, head_size] query.set_shape({batch_size, seq_len, num_heads, head_size}); - OPENVINO_ASSERT(query_data == query.data()); - outputs[0].set_shape(query.get_shape()); - OPENVINO_ASSERT(output_data == outputs[0].data()); key.set_shape({batch_size, seq_len, num_kv_heads, head_size}); - OPENVINO_ASSERT(key_data == key.data()); value.set_shape(key.get_shape()); - OPENVINO_ASSERT(value_data == value.data()); + outputs[0].set_shape(query.get_shape()); auto attention_mask = generate_attention_mask(batch_size, seq_len); - scale = 1.0f; // TODO - ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&scale); + ov::Tensor scale_tensor(ov::element::f32, ov::Shape{1}, &scale); m_prefill_request.set_input_tensor(0, query); m_prefill_request.set_input_tensor(1, key); m_prefill_request.set_input_tensor(2, value); m_prefill_request.set_input_tensor(3, attention_mask); - m_prefill_request.set_input_tensor(4, scale); + m_prefill_request.set_input_tensor(4, scale_tensor); m_prefill_request.set_output_tensor(outputs[0]); m_prefill_request.infer(); - } else { - std::fill_n(outputs[0].data(), outputs[0].get_size(), 0.0f); + for (const auto& profile_info : m_prefill_request.get_profiling_info()) { + std::cout << "node_type " << profile_info.node_type + << ", exec_type = " << profile_info.exec_type + << ", cpu_time = " << profile_info.cpu_time.count() + << ", real_time = " << profile_info.real_time.count() << std::endl; + } + } else { // 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size] paged_attention_v1_cpu(outputs[0], query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len); - - // print_tensor("query", query); - // print_tensor("block_tables", block_tables); - // print_tensor("context_lens", context_lens); - // print_tensor("output", outputs[0]); - // exit(1); } // reshape back to [batch_size, seq_len, num_heads * head_size] - outputs[0].set_shape(query_shape); // works like reshape - OPENVINO_ASSERT(output_data == outputs[0].data()); + outputs[0].set_shape(query_shape); return true; } diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp index 788cded5a..835322978 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp @@ -4,6 +4,8 @@ #pragma once +#include + #include "openvino/op/op.hpp" #include "openvino/runtime/infer_request.hpp" #include "openvino/frontend/pytorch/extension/op.hpp" @@ -23,7 +25,7 @@ namespace TemplateExtension { class PagedAttention : public ov::op::Op { public: - OPENVINO_OP("PagedAttentionExtension"); + OPENVINO_OP("PagedAttentionExtension", "extension"); OPENVINO_FRAMEWORK_MAP(pytorch, "vllm.model_executor.layers.attention.PagedAttention"); PagedAttention() = default; @@ -56,7 +58,8 @@ class PagedAttention : public ov::op::Op { bool has_evaluate() const override; private: - mutable ov::InferRequest m_prefill_request; + static ov::InferRequest m_prefill_request; + static std::once_flag m_once; }; } // namespace TemplateExtension