Skip to content

Commit

Permalink
Finalize PagedAttention implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Feb 5, 2024
1 parent b340cb3 commit b6c35d6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,43 @@

#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/constants.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/runtime/core.hpp"

#include "cpu_ops.hpp"

std::shared_ptr<ov::Model> TemplateExtension::PagedAttention::make_prefill_subgraph() {
ov::element::Type_t type = ov::element::f32, attention_mask_type = ov::element::boolean;
auto query = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, m_num_heads, m_head_size}));
auto key = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size}));
auto value = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size}));
auto attention_mask = std::make_shared<ov::op::v0::Parameter>(attention_mask_type, ov::PartialShape({-1, -1, -1}));
auto scale = std::make_shared<ov::op::v0::Parameter>(type, ov::Shape({1}));

// transpose Q, K and V to swap num_heads and seq_len dimensions
auto permute_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({4}), {0, 2, 1, 3});
auto query_transposed = std::make_shared<ov::op::v1::Transpose>(query, permute_const);
auto key_transposed = std::make_shared<ov::op::v1::Transpose>(key, permute_const);
auto value_transposed = std::make_shared<ov::op::v1::Transpose>(value, permute_const);

auto spda = std::make_shared<ov::op::v13::ScaledDotProductAttention>(query_transposed, key_transposed, value_transposed, attention_mask, scale, false);

// transpose SPDA output to [batch, seq_len, num_heads, head_size] back
auto spda_transposed = std::make_shared<ov::op::v1::Transpose>(spda, permute_const);

return std::make_shared<ov::Model>(spda_transposed, ov::ParameterVector{query, key, value, attention_mask, scale}, "spda_prefill_model");
}

TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs,
const float scale)
: ov::op::Op(inputs),
m_scale(scale) {
constructor_validate_and_infer_types();

// compile model for prefill stage
auto model = make_spda(m_num_heads, m_num_kv_heads, m_head_size, m_scale);
auto compiled_model = ov::Core().compile_model(model, "CPU");
auto compiled_model = ov::Core().compile_model(make_prefill_subgraph(), "CPU");
m_prefill_request = compiled_model.create_infer_request();
}

Expand Down Expand Up @@ -140,54 +163,79 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value,
ov::Tensor key_cache, ov::Tensor value_cache,
ov::Tensor slot_mapping);

// generate block diagonal attention mask for a prefill stage
ov::Tensor generate_attention_mask(ov::Tensor context_lens);
// generate buttom diagonal boolean attention mask for a prefill stage
ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32_t max_context_len, ov::Tensor context_lens) {
OPENVINO_ASSERT(num_seqs == context_lens.get_size());

ov::Shape attention_mask_shape({num_seqs, max_context_len, max_context_len});
ov::Tensor attention_mask(ov::element::boolean, attention_mask_shape);
int attention_mask_stride = attention_mask.get_strides()[0];

ov::Tensor view(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) {
const std::uint32_t num_seqs = tensor.get_size() / (num_heads * head_size);
return ov::Tensor(tensor.get_element_type(), ov::Shape({num_seqs, num_heads, head_size}), tensor.data());
std::fill_n(attention_mask.data<bool>(), attention_mask.get_size(), false);

for (int current_seq = 0; current_seq < num_seqs; ++current_seq) {
std::int32_t context_len = context_lens.data<std::int32_t>()[current_seq];
OPENVINO_ASSERT(context_len <= max_context_len);

bool * attention_mask_data = attention_mask.data<bool>() + current_seq * attention_mask_stride;
for (int x = 0; x < context_len; ++x) {
for (int y = 0; y < context_len; ++y) {
attention_mask_data[x * max_context_len + y] = x >= y;
}
}
}
}

std::shared_ptr<ov::Model> make_spda(std::int32_t num_heads, std::uint32_t num_kv_heads, std::uint32_t head_size, float scale) {
ov::element::Type_t type = ov::element::f32;
auto query = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, num_heads, head_size}));
auto key = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, num_kv_heads, head_size}));
auto value = std::make_shared<ov::op::v0::Parameter>(type, ov::PartialShape({-1, -1, num_kv_heads, head_size}));
auto attention_mask = generate_attention_mask({}); // TODO: fill in the shape
auto scale_const = std::make_shared<ov::op::v0::Constant>(type, ov::Shape({1}), scale);
// similar to torch.Tensor.view
ov::Tensor view_as_3d(ov::Tensor tensor) {
ov::Shape shape = tensor.get_shape();
OPENVINO_ASSERT(shape.size() == 4);
const std::uint32_t batch_size = shape[0], seq_len = shape[1], num_heads = shape[2], head_size = shape[3];
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads * head_size}), tensor.data());
}

auto spda = std::make_shared<ov::op::v13::ScaledDotProductAttention>(query, key, value, attention_mask, scale_const, false);
return std::make_shared<ov::Model>(spda, {query, key, value, attention_mask, scale_const});
ov::Tensor view_as_4d(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) {
ov::Shape shape = tensor.get_shape();
const std::uint32_t batch_size = shape[0], seq_len = shape[1];
OPENVINO_ASSERT(shape.size() == 3 && num_heads * head_size == shape[3]);
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads, head_size}), tensor.data());
}

bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2];
const std::int32_t batch_size = query.get_shape()[0], seq_len = query.get_shape()[1], hidden_size = query.get_shape()[2];
ov::Shape query_shape = query.get_shape();
const std::int32_t batch_size = query_shape[0], seq_len = query_shape[1], hidden_size = query_shape[2];
ov::Tensor key_cache = inputs[3], value_cache = inputs[4];
const bool is_prompt = inputs[5].data<bool>()[0];
ov::Tensor slot_mapping = inputs[6];
const std::int32_t max_context_len = inputs[7].data<std::int32_t>()[0];
ov::Tensor context_lens = inputs[8];
ov::Tensor block_tables = inputs[9];

// reshape to [num_seq, num_heads, head_size]
query = view(query, m_num_heads, m_head_size);
key = view(key, m_num_kv_heads, m_head_size);
value = view(value, m_num_kv_heads, m_head_size);
// reshape to [batch_size, seq_len, num_heads/m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
query = view_as_4d(query, m_num_heads, m_head_size);
key = view_as_4d(key, m_num_kv_heads, m_head_size);
value = view_as_4d(value, m_num_kv_heads, m_head_size);

// put current K, V values into key_cache and value_cache
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping);

// set output shape
OPENVINO_ASSERT(outputs.size() == 1);
outputs[0].set_shape(query.get_shape());

if (is_prompt) {
auto attention_mask = generate_attention_mask(context_lens);
auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens);
ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale);

// create a model with OpenVINO SDPA to compute first token
m_prefill_request.set_input_tensor(0, query);
m_prefill_request.set_input_tensor(1, key);
m_prefill_request.set_input_tensor(2, value);
m_prefill_request.set_input_tensor(3, attention_mask);
m_prefill_request.set_input_tensor(4, scale);
m_prefill_request.set_output_tensor(outputs[0]);

m_prefill_request.infer();
outputs[0] = m_prefill_request.get_output_tensor();
} else {
paged_attention_v1_cpu(outputs[0],
query, key_cache, value_cache,
Expand All @@ -197,7 +245,7 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
}

// reshape
outputs[0] = view(outputs[0], batch_size, seq_len, hidden_size);
outputs[0] = view_as_3d(outputs[0]);

return true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ class PagedAttention : public ov::op::Op {
bool has_evaluate() const override;

private:
std::uuint32_t m_num_heads, m_num_kv_heads, m_head_size, m_block_size;
std::shared_ptr<ov::Model> make_prefill_subgraph();

std::uint32_t m_num_heads, m_num_kv_heads, m_head_size, m_block_size;
float m_scale;
ov::InferRequest m_prefill_request;
mutable ov::InferRequest m_prefill_request;
};

} // namespace TemplateExtension

0 comments on commit b6c35d6

Please sign in to comment.