Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PagedAttention experimental operation #23837

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bindings/python/src/openvino/runtime/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from openvino._pyopenvino.op import Constant
from openvino._pyopenvino.op import assign
from openvino._pyopenvino.op import _PagedAttentionExtension
from openvino._pyopenvino.op import Parameter
from openvino._pyopenvino.op import if_op
from openvino._pyopenvino.op import loop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "pyopenvino/graph/ops/paged_attention_extension.hpp"

#include "openvino/op/op.hpp"
#include "pyopenvino/core/common.hpp"

namespace py = pybind11;

namespace {

// This is an experimental operation that is implemented in the plugins.
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
class PagedAttentionExtension : public ov::op::Op {
public:
OPENVINO_OP("PagedAttentionExtension");

PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override {
auto value_cache_shape = get_input_partial_shape(4);
// m_num_kv_heads = value_cache_shape[1];
// m_head_size = value_cache_shape[2];
// m_block_size = value_cache_shape[3];
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");

// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
auto key_cache_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
value_cache_shape.size() == 4,
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
// key_cache_shape[1] == m_num_kv_heads &&
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
// m_block_size == key_cache_shape[3], // block_size,
"Key cache shape must be 4 dims");

// query: shape [batch_size, seq_len, num_heads * head_size]
auto query_type = get_input_element_type(0);
auto query_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(
this,
// query_type.is_real() &&
query_shape.size() == 3,
// query_shape[2] == m_num_heads * m_head_size,
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
"Got element type ",
query_type,
", shape ",
query_shape);

// key: shape [batch_size, seq_len, num_kv_heads * head_size]
auto key_type = get_input_element_type(1);
auto key_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
// query_type == key_type &&
key_shape.size() == 3,
"Key type must be the same as query, shape must be the same as query. "
"Got element type ",
key_type,
", shape ",
key_shape);

// value: shape [batch_size, seq_len, num_kv_heads * head_size]
// auto value_type = get_input_element_type(2);
auto value_shape = get_input_partial_shape(2);

// is_prompt: boolean scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(5) == ov::element::boolean &&
get_input_shape(5) == ov::Shape({}),
"is_prompt validation failed. ",
"Got element type ",
get_input_element_type(5),
", shape ",
get_input_shape(5));

// slot_mapping: shape [batch_size, max_context_len]
auto slot_mapping_shape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(6) == ov::element::i64 &&
slot_mapping_shape.size() == 2,
"slot_mapping validation failed. ",
"Got element type ",
get_input_element_type(6),
", shape ",
slot_mapping_shape);

// max_context_len: integer scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(7) == ov::element::i32 &&
get_input_shape(7) == ov::Shape({}),
"max_context_len validation failed. ",
"Got element type ",
get_input_element_type(7),
", shape ",
get_input_shape(7));

// context_lens: shape [batch_size]
auto context_lens_shape = get_input_partial_shape(8);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(8) == ov::element::i32 &&
context_lens_shape.size() == 1,
"context_lens validation failed. ",
"Got element type ",
get_input_element_type(8),
", shape ",
context_lens_shape);

// block_tables: shape [batch_size, max_block_per_request]
NODE_VALIDATION_CHECK(this,
// get_input_element_type(9) == ov::element::i32 &&
get_input_partial_shape(9).size() == 2,
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(9),
", shape ",
get_input_partial_shape(9));

// scale: float scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(10) == ov::element::f32 &&
get_input_shape(10) == ov::Shape({}),
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(10),
", shape ",
get_input_shape(10));

// alibi_slopes: 1D float tensor
NODE_VALIDATION_CHECK(this,
// get_input_element_type(11) == ov::element::f32 &&
get_input_partial_shape(11).rank().get_length() == 1,
"alibi_slopes should be a 1D float tensor. ",
"Got element type ",
get_input_element_type(11),
", shape ",
get_input_partial_shape(11));

// sliding_window: int scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(12) == ov::element::i32 &&
get_input_partial_shape(12).rank().get_length() == 0,
"sliding_window argument should be an i32 scalar. ",
"Got element type ",
get_input_element_type(12),
", shape ",
get_input_partial_shape(12));

set_output_type(0, query_type, query_shape);
}

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
return std::make_shared<PagedAttentionExtension>(new_args);
}
};

} // namespace

void regclass_graph_op_PagedAttentionExtension(py::module m) {
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
m,
"_PagedAttentionExtension");
cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is "
"guaranteed in future releases.";
cls.def(py::init<const ov::OutputVector&>());
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

void regclass_graph_op_PagedAttentionExtension(py::module m);
2 changes: 2 additions & 0 deletions src/bindings/python/src/pyopenvino/pyopenvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "pyopenvino/graph/ops/constant.hpp"
#include "pyopenvino/graph/ops/if.hpp"
#include "pyopenvino/graph/ops/loop.hpp"
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
#include "pyopenvino/graph/ops/parameter.hpp"
#include "pyopenvino/graph/ops/result.hpp"
#include "pyopenvino/graph/ops/tensor_iterator.hpp"
Expand Down Expand Up @@ -234,6 +235,7 @@ PYBIND11_MODULE(_pyopenvino, m) {
py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ov::op"); // TODO(!)
regclass_graph_op_Assign(m_op);
regclass_graph_op_Constant(m_op);
regclass_graph_op_PagedAttentionExtension(m_op);
regclass_graph_op_Parameter(m_op);
regclass_graph_op_Result(m_op);
regclass_graph_op_If(m_op);
Expand Down
Loading