From 6a1d67aaf43e08a6c927cdd6d4504d38bf476ced Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Wed, 3 Apr 2024 10:04:55 +0000 Subject: [PATCH 1/3] Op shell for PagedAttentionExtension --- .../src/openvino/runtime/op/__init__.py | 1 + .../graph/ops/paged_attention_extension.cpp | 138 ++++++++++++++++++ .../graph/ops/paged_attention_extension.hpp | 11 ++ .../python/src/pyopenvino/pyopenvino.cpp | 2 + 4 files changed, 152 insertions(+) create mode 100644 src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp create mode 100644 src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.hpp diff --git a/src/bindings/python/src/openvino/runtime/op/__init__.py b/src/bindings/python/src/openvino/runtime/op/__init__.py index a5ae58ad365a20..0f3a15b176b2ca 100644 --- a/src/bindings/python/src/openvino/runtime/op/__init__.py +++ b/src/bindings/python/src/openvino/runtime/op/__init__.py @@ -10,6 +10,7 @@ from openvino._pyopenvino.op import Constant from openvino._pyopenvino.op import assign +from openvino._pyopenvino.op import _PagedAttentionExtension from openvino._pyopenvino.op import Parameter from openvino._pyopenvino.op import if_op from openvino._pyopenvino.op import loop diff --git a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp new file mode 100644 index 00000000000000..c5f18e9a7fc4ae --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp @@ -0,0 +1,138 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/op.hpp" +#include "pyopenvino/graph/ops/paged_attention_extension.hpp" + +namespace py = pybind11; + +namespace { + +// This is an experimental operation that is implemented in the plugins. +// Do not use in user applications, backward compatibility is not guaranteed in future releases. +class PagedAttentionExtension : public ov::op::Op { +public: + OPENVINO_OP("PagedAttentionExtension"); + + PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {} + + void validate_and_infer_types() override { + auto value_cache_shape = get_input_partial_shape(4); + // m_num_kv_heads = value_cache_shape[1]; + // m_head_size = value_cache_shape[2]; + // m_block_size = value_cache_shape[3]; + NODE_VALIDATION_CHECK(this, + value_cache_shape.size() == 4, + "Value cache shape must be 4 dims"); + + // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] + auto key_cache_shape = get_input_partial_shape(3); + NODE_VALIDATION_CHECK(this, + value_cache_shape.size() == 4, + // value_cache_shape[0] == key_cache_shape[0] && // num_blocks + // key_cache_shape[1] == m_num_kv_heads && + // key_cache_shape[2] * key_cache_shape[4] == m_head_size && + // m_block_size == key_cache_shape[3], // block_size, + "Key cache shape must be 4 dims"); + + // query: shape [batch_size, seq_len, num_heads * head_size] + auto query_type = get_input_element_type(0); + auto query_shape = get_input_partial_shape(0); + NODE_VALIDATION_CHECK(this, + // query_type.is_real() && + query_shape.size() == 3, + // query_shape[2] == m_num_heads * m_head_size, + "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", + "Got element type ", query_type, ", shape ", query_shape); + + // key: shape [batch_size, seq_len, num_kv_heads * head_size] + auto key_type = get_input_element_type(1); + auto key_shape = get_input_partial_shape(1); + NODE_VALIDATION_CHECK(this, + // query_type == key_type && + key_shape.size() == 3, + "Key type must be the same as query, shape must be the same as query. " + "Got element type ", key_type, ", shape ", key_shape); + + // value: shape [batch_size, seq_len, num_kv_heads * head_size] + // auto value_type = get_input_element_type(2); + auto value_shape = get_input_partial_shape(2); + + // is_prompt: boolean scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(5) == ov::element::boolean && + get_input_shape(5) == ov::Shape({}), + "is_prompt validation failed. ", + "Got element type ", get_input_element_type(5), ", shape ", get_input_shape(5)); + + // slot_mapping: shape [batch_size, max_context_len] + auto slot_mapping_shape = get_input_partial_shape(6); + NODE_VALIDATION_CHECK(this, + // get_input_element_type(6) == ov::element::i64 && + slot_mapping_shape.size() == 2, + "slot_mapping validation failed. ", + "Got element type ", get_input_element_type(6), ", shape ", slot_mapping_shape); + + // max_context_len: integer scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(7) == ov::element::i32 && + get_input_shape(7) == ov::Shape({}), + "max_context_len validation failed. ", + "Got element type ", get_input_element_type(7), ", shape ", get_input_shape(7)); + + // context_lens: shape [batch_size] + auto context_lens_shape = get_input_partial_shape(8); + NODE_VALIDATION_CHECK(this, + // get_input_element_type(8) == ov::element::i32 && + context_lens_shape.size() == 1, + "context_lens validation failed. ", + "Got element type ", get_input_element_type(8), ", shape ", context_lens_shape); + + // block_tables: shape [batch_size, max_block_per_request] + NODE_VALIDATION_CHECK(this, + // get_input_element_type(9) == ov::element::i32 && + get_input_partial_shape(9).size() == 2, + "block_tables validation failed. ", + "Got element type ", get_input_element_type(9), ", shape ", get_input_partial_shape(9)); + + // scale: float scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(10) == ov::element::f32 && + get_input_shape(10) == ov::Shape({}), + "block_tables validation failed. ", + "Got element type ", get_input_element_type(10), ", shape ", get_input_shape(10)); + + // alibi_slopes: 1D float tensor + NODE_VALIDATION_CHECK(this, + // get_input_element_type(11) == ov::element::f32 && + get_input_partial_shape(11).rank().get_length() == 1, + "alibi_slopes should be a 1D float tensor. ", + "Got element type ", get_input_element_type(11), ", shape ", get_input_partial_shape(11)); + + // sliding_window: int scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(12) == ov::element::i32 && + get_input_partial_shape(12).rank().get_length() == 0, + "sliding_window argument should be an i32 scalar. ", + "Got element type ", get_input_element_type(12), ", shape ", get_input_partial_shape(12)); + + set_output_type(0, query_type, query_shape); + } + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + return std::make_shared(new_args); + } + + bool has_evaluate() const override { + return true; + } +}; + +} + +void regclass_graph_op_PagedAttentionExtension(py::module m) { + py::class_, ov::Node> cls(m, "_PagedAttentionExtension"); + cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases."; + cls.def(py::init()); +} diff --git a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.hpp b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.hpp new file mode 100644 index 00000000000000..1c35f8b7ce1eb2 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_graph_op_PagedAttentionExtension(py::module m); diff --git a/src/bindings/python/src/pyopenvino/pyopenvino.cpp b/src/bindings/python/src/pyopenvino/pyopenvino.cpp index 70a8dca1ccba70..cc87433f122246 100644 --- a/src/bindings/python/src/pyopenvino/pyopenvino.cpp +++ b/src/bindings/python/src/pyopenvino/pyopenvino.cpp @@ -52,6 +52,7 @@ #include "pyopenvino/graph/ops/constant.hpp" #include "pyopenvino/graph/ops/if.hpp" #include "pyopenvino/graph/ops/loop.hpp" +#include "pyopenvino/graph/ops/paged_attention_extension.hpp" #include "pyopenvino/graph/ops/parameter.hpp" #include "pyopenvino/graph/ops/result.hpp" #include "pyopenvino/graph/ops/tensor_iterator.hpp" @@ -234,6 +235,7 @@ PYBIND11_MODULE(_pyopenvino, m) { py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ov::op"); // TODO(!) regclass_graph_op_Assign(m_op); regclass_graph_op_Constant(m_op); + regclass_graph_op_PagedAttentionExtension(m_op); regclass_graph_op_Parameter(m_op); regclass_graph_op_Result(m_op); regclass_graph_op_If(m_op); From 2332579e96838bef4a6c889d7d8a8330fae0e411 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Wed, 3 Apr 2024 10:40:26 +0000 Subject: [PATCH 2/3] Added evaluate in PA ctor, missing include for pybind type identification --- .../pyopenvino/graph/ops/paged_attention_extension.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp index c5f18e9a7fc4ae..345235d1a79b1e 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp @@ -3,6 +3,8 @@ // #include "openvino/op/op.hpp" + +#include "pyopenvino/core/common.hpp" #include "pyopenvino/graph/ops/paged_attention_extension.hpp" namespace py = pybind11; @@ -15,7 +17,9 @@ class PagedAttentionExtension : public ov::op::Op { public: OPENVINO_OP("PagedAttentionExtension"); - PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {} + PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) { + constructor_validate_and_infer_types(); + } void validate_and_infer_types() override { auto value_cache_shape = get_input_partial_shape(4); @@ -123,10 +127,6 @@ class PagedAttentionExtension : public ov::op::Op { std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { return std::make_shared(new_args); } - - bool has_evaluate() const override { - return true; - } }; } From 704b7224466b5b8c888a942ca31229af4aa1a0e5 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Wed, 3 Apr 2024 11:02:28 +0000 Subject: [PATCH 3/3] Code style fixes --- .../graph/ops/paged_attention_extension.cpp | 136 +++++++++++------- 1 file changed, 84 insertions(+), 52 deletions(-) diff --git a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp index 345235d1a79b1e..608f4fe2b61a09 100644 --- a/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp +++ b/src/bindings/python/src/pyopenvino/graph/ops/paged_attention_extension.cpp @@ -2,10 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "openvino/op/op.hpp" +#include "pyopenvino/graph/ops/paged_attention_extension.hpp" +#include "openvino/op/op.hpp" #include "pyopenvino/core/common.hpp" -#include "pyopenvino/graph/ops/paged_attention_extension.hpp" namespace py = pybind11; @@ -26,38 +26,43 @@ class PagedAttentionExtension : public ov::op::Op { // m_num_kv_heads = value_cache_shape[1]; // m_head_size = value_cache_shape[2]; // m_block_size = value_cache_shape[3]; - NODE_VALIDATION_CHECK(this, - value_cache_shape.size() == 4, - "Value cache shape must be 4 dims"); + NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims"); // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] auto key_cache_shape = get_input_partial_shape(3); NODE_VALIDATION_CHECK(this, - value_cache_shape.size() == 4, - // value_cache_shape[0] == key_cache_shape[0] && // num_blocks - // key_cache_shape[1] == m_num_kv_heads && - // key_cache_shape[2] * key_cache_shape[4] == m_head_size && - // m_block_size == key_cache_shape[3], // block_size, - "Key cache shape must be 4 dims"); + value_cache_shape.size() == 4, + // value_cache_shape[0] == key_cache_shape[0] && // num_blocks + // key_cache_shape[1] == m_num_kv_heads && + // key_cache_shape[2] * key_cache_shape[4] == m_head_size && + // m_block_size == key_cache_shape[3], // block_size, + "Key cache shape must be 4 dims"); // query: shape [batch_size, seq_len, num_heads * head_size] auto query_type = get_input_element_type(0); auto query_shape = get_input_partial_shape(0); - NODE_VALIDATION_CHECK(this, + NODE_VALIDATION_CHECK( + this, // query_type.is_real() && query_shape.size() == 3, // query_shape[2] == m_num_heads * m_head_size, "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", - "Got element type ", query_type, ", shape ", query_shape); + "Got element type ", + query_type, + ", shape ", + query_shape); // key: shape [batch_size, seq_len, num_kv_heads * head_size] auto key_type = get_input_element_type(1); auto key_shape = get_input_partial_shape(1); NODE_VALIDATION_CHECK(this, - // query_type == key_type && - key_shape.size() == 3, - "Key type must be the same as query, shape must be the same as query. " - "Got element type ", key_type, ", shape ", key_shape); + // query_type == key_type && + key_shape.size() == 3, + "Key type must be the same as query, shape must be the same as query. " + "Got element type ", + key_type, + ", shape ", + key_shape); // value: shape [batch_size, seq_len, num_kv_heads * head_size] // auto value_type = get_input_element_type(2); @@ -65,61 +70,85 @@ class PagedAttentionExtension : public ov::op::Op { // is_prompt: boolean scalar NODE_VALIDATION_CHECK(this, - // get_input_element_type(5) == ov::element::boolean && - get_input_shape(5) == ov::Shape({}), - "is_prompt validation failed. ", - "Got element type ", get_input_element_type(5), ", shape ", get_input_shape(5)); + // get_input_element_type(5) == ov::element::boolean && + get_input_shape(5) == ov::Shape({}), + "is_prompt validation failed. ", + "Got element type ", + get_input_element_type(5), + ", shape ", + get_input_shape(5)); // slot_mapping: shape [batch_size, max_context_len] auto slot_mapping_shape = get_input_partial_shape(6); NODE_VALIDATION_CHECK(this, - // get_input_element_type(6) == ov::element::i64 && - slot_mapping_shape.size() == 2, - "slot_mapping validation failed. ", - "Got element type ", get_input_element_type(6), ", shape ", slot_mapping_shape); + // get_input_element_type(6) == ov::element::i64 && + slot_mapping_shape.size() == 2, + "slot_mapping validation failed. ", + "Got element type ", + get_input_element_type(6), + ", shape ", + slot_mapping_shape); // max_context_len: integer scalar NODE_VALIDATION_CHECK(this, - // get_input_element_type(7) == ov::element::i32 && - get_input_shape(7) == ov::Shape({}), - "max_context_len validation failed. ", - "Got element type ", get_input_element_type(7), ", shape ", get_input_shape(7)); + // get_input_element_type(7) == ov::element::i32 && + get_input_shape(7) == ov::Shape({}), + "max_context_len validation failed. ", + "Got element type ", + get_input_element_type(7), + ", shape ", + get_input_shape(7)); // context_lens: shape [batch_size] auto context_lens_shape = get_input_partial_shape(8); NODE_VALIDATION_CHECK(this, - // get_input_element_type(8) == ov::element::i32 && - context_lens_shape.size() == 1, - "context_lens validation failed. ", - "Got element type ", get_input_element_type(8), ", shape ", context_lens_shape); + // get_input_element_type(8) == ov::element::i32 && + context_lens_shape.size() == 1, + "context_lens validation failed. ", + "Got element type ", + get_input_element_type(8), + ", shape ", + context_lens_shape); // block_tables: shape [batch_size, max_block_per_request] NODE_VALIDATION_CHECK(this, - // get_input_element_type(9) == ov::element::i32 && - get_input_partial_shape(9).size() == 2, - "block_tables validation failed. ", - "Got element type ", get_input_element_type(9), ", shape ", get_input_partial_shape(9)); + // get_input_element_type(9) == ov::element::i32 && + get_input_partial_shape(9).size() == 2, + "block_tables validation failed. ", + "Got element type ", + get_input_element_type(9), + ", shape ", + get_input_partial_shape(9)); // scale: float scalar NODE_VALIDATION_CHECK(this, - // get_input_element_type(10) == ov::element::f32 && - get_input_shape(10) == ov::Shape({}), - "block_tables validation failed. ", - "Got element type ", get_input_element_type(10), ", shape ", get_input_shape(10)); + // get_input_element_type(10) == ov::element::f32 && + get_input_shape(10) == ov::Shape({}), + "block_tables validation failed. ", + "Got element type ", + get_input_element_type(10), + ", shape ", + get_input_shape(10)); // alibi_slopes: 1D float tensor NODE_VALIDATION_CHECK(this, - // get_input_element_type(11) == ov::element::f32 && - get_input_partial_shape(11).rank().get_length() == 1, - "alibi_slopes should be a 1D float tensor. ", - "Got element type ", get_input_element_type(11), ", shape ", get_input_partial_shape(11)); + // get_input_element_type(11) == ov::element::f32 && + get_input_partial_shape(11).rank().get_length() == 1, + "alibi_slopes should be a 1D float tensor. ", + "Got element type ", + get_input_element_type(11), + ", shape ", + get_input_partial_shape(11)); // sliding_window: int scalar NODE_VALIDATION_CHECK(this, - // get_input_element_type(12) == ov::element::i32 && - get_input_partial_shape(12).rank().get_length() == 0, - "sliding_window argument should be an i32 scalar. ", - "Got element type ", get_input_element_type(12), ", shape ", get_input_partial_shape(12)); + // get_input_element_type(12) == ov::element::i32 && + get_input_partial_shape(12).rank().get_length() == 0, + "sliding_window argument should be an i32 scalar. ", + "Got element type ", + get_input_element_type(12), + ", shape ", + get_input_partial_shape(12)); set_output_type(0, query_type, query_shape); } @@ -129,10 +158,13 @@ class PagedAttentionExtension : public ov::op::Op { } }; -} +} // namespace void regclass_graph_op_PagedAttentionExtension(py::module m) { - py::class_, ov::Node> cls(m, "_PagedAttentionExtension"); - cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases."; + py::class_, ov::Node> cls( + m, + "_PagedAttentionExtension"); + cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is " + "guaranteed in future releases."; cls.def(py::init()); }