Skip to content

Commit

Permalink
PagedAttention experimental operation (openvinotoolkit#23837)
Browse files Browse the repository at this point in the history
PagedAttention operation exposed in Python API only for easier vLLM
openvino integration. It is not intended to be used outside our
integration work in vLLM and similar applications where we can use
PagedAttention. Exposed as a hidden part of API, will not be documented.
Connected to already existing implementation in CPU plugin.

Operation is not a part of any public opset.
  • Loading branch information
slyalin authored and alvoron committed Apr 29, 2024
1 parent edc152e commit 5c80514
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/bindings/python/src/openvino/runtime/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from openvino._pyopenvino.op import Constant
from openvino._pyopenvino.op import assign
from openvino._pyopenvino.op import _PagedAttentionExtension
from openvino._pyopenvino.op import Parameter
from openvino._pyopenvino.op import if_op
from openvino._pyopenvino.op import loop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "pyopenvino/graph/ops/paged_attention_extension.hpp"

#include "openvino/op/op.hpp"
#include "pyopenvino/core/common.hpp"

namespace py = pybind11;

namespace {

// This is an experimental operation that is implemented in the plugins.
// Do not use in user applications, backward compatibility is not guaranteed in future releases.
class PagedAttentionExtension : public ov::op::Op {
public:
OPENVINO_OP("PagedAttentionExtension");

PagedAttentionExtension(const ov::OutputVector& args) : ov::op::Op(args) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override {
auto value_cache_shape = get_input_partial_shape(4);
// m_num_kv_heads = value_cache_shape[1];
// m_head_size = value_cache_shape[2];
// m_block_size = value_cache_shape[3];
NODE_VALIDATION_CHECK(this, value_cache_shape.size() == 4, "Value cache shape must be 4 dims");

// key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x]
auto key_cache_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
value_cache_shape.size() == 4,
// value_cache_shape[0] == key_cache_shape[0] && // num_blocks
// key_cache_shape[1] == m_num_kv_heads &&
// key_cache_shape[2] * key_cache_shape[4] == m_head_size &&
// m_block_size == key_cache_shape[3], // block_size,
"Key cache shape must be 4 dims");

// query: shape [batch_size, seq_len, num_heads * head_size]
auto query_type = get_input_element_type(0);
auto query_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(
this,
// query_type.is_real() &&
query_shape.size() == 3,
// query_shape[2] == m_num_heads * m_head_size,
"Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ",
"Got element type ",
query_type,
", shape ",
query_shape);

// key: shape [batch_size, seq_len, num_kv_heads * head_size]
auto key_type = get_input_element_type(1);
auto key_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
// query_type == key_type &&
key_shape.size() == 3,
"Key type must be the same as query, shape must be the same as query. "
"Got element type ",
key_type,
", shape ",
key_shape);

// value: shape [batch_size, seq_len, num_kv_heads * head_size]
// auto value_type = get_input_element_type(2);
auto value_shape = get_input_partial_shape(2);

// is_prompt: boolean scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(5) == ov::element::boolean &&
get_input_shape(5) == ov::Shape({}),
"is_prompt validation failed. ",
"Got element type ",
get_input_element_type(5),
", shape ",
get_input_shape(5));

// slot_mapping: shape [batch_size, max_context_len]
auto slot_mapping_shape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(6) == ov::element::i64 &&
slot_mapping_shape.size() == 2,
"slot_mapping validation failed. ",
"Got element type ",
get_input_element_type(6),
", shape ",
slot_mapping_shape);

// max_context_len: integer scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(7) == ov::element::i32 &&
get_input_shape(7) == ov::Shape({}),
"max_context_len validation failed. ",
"Got element type ",
get_input_element_type(7),
", shape ",
get_input_shape(7));

// context_lens: shape [batch_size]
auto context_lens_shape = get_input_partial_shape(8);
NODE_VALIDATION_CHECK(this,
// get_input_element_type(8) == ov::element::i32 &&
context_lens_shape.size() == 1,
"context_lens validation failed. ",
"Got element type ",
get_input_element_type(8),
", shape ",
context_lens_shape);

// block_tables: shape [batch_size, max_block_per_request]
NODE_VALIDATION_CHECK(this,
// get_input_element_type(9) == ov::element::i32 &&
get_input_partial_shape(9).size() == 2,
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(9),
", shape ",
get_input_partial_shape(9));

// scale: float scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(10) == ov::element::f32 &&
get_input_shape(10) == ov::Shape({}),
"block_tables validation failed. ",
"Got element type ",
get_input_element_type(10),
", shape ",
get_input_shape(10));

// alibi_slopes: 1D float tensor
NODE_VALIDATION_CHECK(this,
// get_input_element_type(11) == ov::element::f32 &&
get_input_partial_shape(11).rank().get_length() == 1,
"alibi_slopes should be a 1D float tensor. ",
"Got element type ",
get_input_element_type(11),
", shape ",
get_input_partial_shape(11));

// sliding_window: int scalar
NODE_VALIDATION_CHECK(this,
// get_input_element_type(12) == ov::element::i32 &&
get_input_partial_shape(12).rank().get_length() == 0,
"sliding_window argument should be an i32 scalar. ",
"Got element type ",
get_input_element_type(12),
", shape ",
get_input_partial_shape(12));

set_output_type(0, query_type, query_shape);
}

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
return std::make_shared<PagedAttentionExtension>(new_args);
}
};

} // namespace

void regclass_graph_op_PagedAttentionExtension(py::module m) {
py::class_<PagedAttentionExtension, std::shared_ptr<PagedAttentionExtension>, ov::Node> cls(
m,
"_PagedAttentionExtension");
cls.doc() = "Experimental extention for PagedAttention operation. Use with care: no backward compatibility is "
"guaranteed in future releases.";
cls.def(py::init<const ov::OutputVector&>());
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <pybind11/pybind11.h>

namespace py = pybind11;

void regclass_graph_op_PagedAttentionExtension(py::module m);
2 changes: 2 additions & 0 deletions src/bindings/python/src/pyopenvino/pyopenvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "pyopenvino/graph/ops/constant.hpp"
#include "pyopenvino/graph/ops/if.hpp"
#include "pyopenvino/graph/ops/loop.hpp"
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
#include "pyopenvino/graph/ops/parameter.hpp"
#include "pyopenvino/graph/ops/result.hpp"
#include "pyopenvino/graph/ops/tensor_iterator.hpp"
Expand Down Expand Up @@ -234,6 +235,7 @@ PYBIND11_MODULE(_pyopenvino, m) {
py::module m_op = m.def_submodule("op", "Package ngraph.impl.op that wraps ov::op"); // TODO(!)
regclass_graph_op_Assign(m_op);
regclass_graph_op_Constant(m_op);
regclass_graph_op_PagedAttentionExtension(m_op);
regclass_graph_op_Parameter(m_op);
regclass_graph_op_Result(m_op);
regclass_graph_op_If(m_op);
Expand Down

0 comments on commit 5c80514

Please sign in to comment.