diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/attention_impl.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/attention_impl.cpp index 826a19066..cafd10bd5 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/attention_impl.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/attention_impl.cpp @@ -3,6 +3,131 @@ namespace { +// template +// struct paged_attention_v1_impl { +// static void +// call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] +// const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] +// const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, +// // head_size/x, block_size, x] +// const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, +// // head_size, block_size] +// const int num_kv_heads, const float scale, +// const int +// *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] +// const int *__restrict__ context_lens, // [num_seqs] +// const int max_num_blocks_per_seq, +// const int q_stride, const int kv_block_stride, const int kv_head_stride, +// const int num_seqs, const int num_heads) { +// OPENVINO_ASSERT(HEAD_SIZE % 16 == 0); +// constexpr int x = 16 / sizeof(scalar_t); +// const int num_queries_per_kv = num_heads / num_kv_heads; + +// int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; +// int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; +// OPENVINO_ASSERT((max_context_len_padded * sizeof(float)) % 64 == 0); + +// size_t logits_bytes = num_heads * max_context_len_padded * sizeof(float); +// float *logits = (float *)std::aligned_alloc( +// 64, logits_bytes); // Cacheline alignment for each context token. +// // [head_num, max_context_len_padded] + +// std::memset(out, 0, num_seqs * num_heads * HEAD_SIZE * sizeof(scalar_t)); + +// for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { +// int context_len = context_lens[seq_idx]; +// const int *seq_block_table = +// block_tables + max_num_blocks_per_seq * seq_idx; +// const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; +// std::memset(logits, 0, logits_bytes); + +// // Compute attention logits +// #pragma omp parallel for collapse(2) +// for (int block_idx = 0; block_idx < block_num; ++block_idx) { +// for (int head_idx = 0; head_idx < num_heads; ++head_idx) { +// const int64_t kv_head_idx = head_idx / num_queries_per_kv; +// const int64_t physical_block_idx = seq_block_table[block_idx]; +// const scalar_t *__restrict__ q_vec_ptr = +// q + seq_idx * q_stride + head_idx * HEAD_SIZE; +// const scalar_t *__restrict__ k_block_cache_ptr = +// k_cache + physical_block_idx * kv_block_stride + +// kv_head_idx * kv_head_stride; +// float *__restrict__ head_block_logits = +// logits + head_idx * max_context_len_padded + +// block_idx * BLOCK_SIZE; + +// for (int q_offset = 0; q_offset < HEAD_SIZE; +// q_offset += x, q_vec_ptr += x) { +// for (int token_idx = 0; token_idx < BLOCK_SIZE; +// ++token_idx, k_block_cache_ptr += x) { +// for (int i = 0; i < x; ++i) { +// head_block_logits[token_idx] += +// q_vec_ptr[i] * k_block_cache_ptr[i] * scale; +// } +// } +// } +// } +// } + +// // Compute softmax +// #pragma omp parallel for +// for (int head_idx = 0; head_idx < num_heads; ++head_idx) { +// float *head_logit_ptr = logits + head_idx * max_context_len_padded; +// float max_logit = head_logit_ptr[0]; +// for (int i = 1; i < context_len; ++i) { +// max_logit = +// max_logit >= head_logit_ptr[i] ? max_logit : head_logit_ptr[i]; +// } + +// float sum = 0; +// for (int i = 0; i < context_len; ++i) { +// head_logit_ptr[i] = std::exp(head_logit_ptr[i] - max_logit); +// sum += head_logit_ptr[i]; +// } + +// for (int i = 0; i < context_len; ++i) { +// head_logit_ptr[i] /= sum; +// } + +// int remaining_seq_upper = block_num * BLOCK_SIZE; +// for (int i = context_len; i < remaining_seq_upper; ++i) { +// head_logit_ptr[i] = 0; +// } +// } + +// // Compute value +// constexpr int head_partition_num = HEAD_SIZE / 16; +// #pragma omp parallel for collapse(2) +// for (int head_idx = 0; head_idx < num_heads; ++head_idx) { +// for (int head_part_idx = 0; head_part_idx < head_partition_num; +// ++head_part_idx) { +// for (int block_idx = 0; block_idx < block_num; ++block_idx) { +// const int64_t kv_head_idx = head_idx / num_queries_per_kv; +// const int64_t physical_block_idx = seq_block_table[block_idx]; +// const float *__restrict__ prob_vec_ptr = +// logits + head_idx * max_context_len_padded + +// block_idx * BLOCK_SIZE; +// const scalar_t *__restrict__ v_block_cache_ptr = +// v_cache + physical_block_idx * kv_block_stride + +// kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * 16; +// scalar_t *__restrict__ out_ptr = +// out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + +// head_part_idx * 16; + +// for (int i = 0; i < 16; ++i, v_block_cache_ptr += BLOCK_SIZE) { +// for (int j = 0; j < BLOCK_SIZE; ++j) { +// out_ptr[i] += prob_vec_ptr[j] * v_block_cache_ptr[j]; +// } +// } +// } +// } +// } +// } +// std::free(logits); +// } +// }; + + template struct paged_attention_v1_impl { static void @@ -19,13 +144,14 @@ struct paged_attention_v1_impl { const int max_num_blocks_per_seq, const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads) { - OPENVINO_ASSERT(HEAD_SIZE % 16 == 0); + // TORCH_CHECK(HEAD_SIZE % 16 == 0); + // TORCH_CHECK(alibi_slopes == nullptr, "Unsupport alibi_slopes for CPU"); constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - OPENVINO_ASSERT((max_context_len_padded * sizeof(float)) % 64 == 0); + // TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); size_t logits_bytes = num_heads * max_context_len_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( @@ -69,6 +195,11 @@ struct paged_attention_v1_impl { } } + // std::cout << std::endl; + // for (int i = 0; i < 40; ++i) + // std::cout << logits[i] << " "; + // exit(1); + // Compute softmax #pragma omp parallel for for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -127,179 +258,6 @@ struct paged_attention_v1_impl { } }; -// template -// struct paged_attention_v1_impl { -// using scalar_t = c10::BFloat16; - -// static void -// call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] -// const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, -// // head_size/x, block_size, x] -// const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, -// // head_size, block_size] -// const int num_kv_heads, const float scale, -// const int -// *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] -// const int *__restrict__ context_lens, // [num_seqs] -// const int max_num_blocks_per_seq, -// const float *__restrict__ alibi_slopes, // [num_heads] -// const int q_stride, const int kv_block_stride, const int kv_head_stride, -// const int num_seqs, const int num_heads) { -// OPENVINO_ASSERT(alibi_slopes == nullptr, "Unsupport alibi_slopes for CPU"); -// constexpr int x = 16 / sizeof(scalar_t); -// const int num_queries_per_kv = num_heads / num_kv_heads; - -// using scalar_vec_t = vec_op::vec_t; -// constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - -// static_assert(x == VEC_ELEM_NUM); -// static_assert(BLOCK_SIZE == 16); -// static_assert(BLOCK_SIZE % VEC_ELEM_NUM == 0); - -// int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; -// int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; -// TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); - -// const int parallel_work_item_num = omp_get_max_threads(); - -// size_t logits_bytes = -// parallel_work_item_num * max_context_len_padded * sizeof(float); -// float *logits = (float *)std::aligned_alloc( -// 64, logits_bytes); // Cacheline alignment for each context token. -// // [parallel_work_item_num, max_context_len_padded] - -// #pragma omp parallel for schedule(dynamic) collapse(2) -// for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { -// for (int head_idx = 0; head_idx < num_heads; ++head_idx) { -// int context_len = context_lens[seq_idx]; -// const int *seq_block_table = -// block_tables + max_num_blocks_per_seq * seq_idx; -// const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; -// const int64_t kv_head_idx = head_idx / num_queries_per_kv; -// const scalar_t *__restrict__ q_vec_ptr = -// q + seq_idx * q_stride + head_idx * HEAD_SIZE; -// float *__restrict__ thread_block_logits = -// logits + omp_get_thread_num() * max_context_len_padded; - -// // Compute logits -// for (int block_idx = 0; block_idx < block_num; ++block_idx) { -// const int64_t physical_block_idx = seq_block_table[block_idx]; -// const scalar_t *__restrict__ k_block_cache_ptr = -// k_cache + physical_block_idx * kv_block_stride + -// kv_head_idx * kv_head_stride; -// float *__restrict__ head_block_logits = -// thread_block_logits + block_idx * BLOCK_SIZE; - -// static_assert(vec_op::BF16Vec32::get_elem_num() % x == 0); -// constexpr int TOKEN_PER_GROUP = vec_op::BF16Vec32::get_elem_num() / x; -// static_assert(BLOCK_SIZE % TOKEN_PER_GROUP == 0); -// constexpr int TOKEN_GROUPS = BLOCK_SIZE / TOKEN_PER_GROUP; - -// // vec_op::FP32Vec8 accums[BLOCK_SIZE]; -// vec_op::FP32Vec16 group_accums[TOKEN_GROUPS]; - -// for (int q_offset = 0; q_offset < HEAD_SIZE; -// q_offset += x, k_block_cache_ptr += x * BLOCK_SIZE) { -// scalar_vec_t q_vec(q_vec_ptr + q_offset); -// vec_op::BF16Vec32 q_group_vec(q_vec); - -// vec_op::unroll_loop( -// [k_block_cache_ptr, &q_group_vec, -// &group_accums](int token_group_idx) { -// vec_op::BF16Vec32 k_group_vec(k_block_cache_ptr + -// token_group_idx * x * -// TOKEN_PER_GROUP); - -// group_accums[token_group_idx] = vec_op::fma( -// q_group_vec, k_group_vec, group_accums[token_group_idx]); -// }); -// } - -// vec_op::unroll_loop([&group_accums, -// head_block_logits, -// scale](int token_group_idx) { -// vec_op::unroll_loop([&group_accums, -// head_block_logits, scale, -// token_group_idx]( -// int token_idx) { -// float dot_v = -// group_accums[token_group_idx] -// .template reduce_sub_sum< -// vec_op::FP32Vec16::get_elem_num() / TOKEN_PER_GROUP>( -// token_idx); -// head_block_logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = -// dot_v * scale; -// }); -// }); -// } - -// // Compute softmax -// float max_logit = thread_block_logits[0]; -// for (int i = 1; i < context_len; ++i) { -// max_logit = max_logit >= thread_block_logits[i] -// ? max_logit -// : thread_block_logits[i]; -// } - -// float sum = 0; -// for (int i = 0; i < context_len; ++i) { -// thread_block_logits[i] = std::exp(thread_block_logits[i] - max_logit); -// sum += thread_block_logits[i]; -// } - -// for (int i = 0; i < context_len; ++i) { -// thread_block_logits[i] /= sum; -// } - -// int remaining_seq_upper = block_num * BLOCK_SIZE; -// for (int i = context_len; i < remaining_seq_upper; ++i) { -// thread_block_logits[i] = 0; -// } - -// // Compute value -// constexpr int head_elem_num_per_partition = 16; -// constexpr int head_partition_num = -// HEAD_SIZE / head_elem_num_per_partition; -// for (int head_part_idx = 0; head_part_idx < head_partition_num; -// ++head_part_idx) { -// vec_op::FP32Vec16 accums[head_elem_num_per_partition]; -// scalar_t *__restrict__ out_ptr = -// out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + -// head_part_idx * head_elem_num_per_partition; -// for (int block_idx = 0; block_idx < block_num; ++block_idx) { -// const int64_t physical_block_idx = seq_block_table[block_idx]; -// const float *__restrict__ prob_vec_ptr = -// thread_block_logits + block_idx * BLOCK_SIZE; -// const scalar_t *__restrict__ v_block_cache_ptr = -// v_cache + physical_block_idx * kv_block_stride + -// kv_head_idx * kv_head_stride + -// BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; - -// vec_op::FP32Vec16 prob_vec(prob_vec_ptr); - -// vec_op::unroll_loop( -// [&](int head_elem_idx) { -// vec_op::BF16Vec16 v_vec(v_block_cache_ptr + -// BLOCK_SIZE * head_elem_idx); -// vec_op::FP32Vec16 fp32_v_vec(v_vec.reg); -// accums[head_elem_idx] = -// accums[head_elem_idx] + prob_vec * fp32_v_vec; -// }); -// } - -// vec_op::unroll_loop( -// [&](int head_elem_idx) { -// float value = accums[head_elem_idx].reduce_sum(); -// vec_op::storeFP32ToT(value, out_ptr + head_elem_idx); -// }); -// } -// } -// } -// std::free(logits); -// } -// }; - #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ @@ -317,14 +275,23 @@ void paged_attention_v1_impl_launcher( int num_heads = query.get_shape()[1]; int head_size = query.get_shape()[2]; int max_num_blocks_per_seq = block_tables.get_shape()[1]; - int q_stride = query.get_strides()[0]; - int kv_block_stride = key_cache.get_strides()[0]; - int kv_head_stride = key_cache.get_strides()[1]; - - T *out_ptr = reinterpret_cast(out.data()); - T *query_ptr = reinterpret_cast(query.data()); - T *key_cache_ptr = reinterpret_cast(key_cache.data()); - T *value_cache_ptr = reinterpret_cast(value_cache.data()); + int q_stride = query.get_strides()[0] / query.get_element_type().size(); + int kv_block_stride = key_cache.get_strides()[0] / key_cache.get_element_type().size(); + int kv_head_stride = key_cache.get_strides()[1] / key_cache.get_element_type().size(); + OPENVINO_ASSERT(sizeof(float) == key_cache.get_element_type().size()); + + // std::cout << "num_seqs " << num_seqs << std::endl; + // std::cout << "num_heads " << num_heads << std::endl; + // std::cout << "head_size " << head_size << std::endl; + // std::cout << "max_num_blocks_per_seq " << max_num_blocks_per_seq << std::endl; + // std::cout << "q_stride " << q_stride << std::endl; + // std::cout << "kv_block_stride " << kv_block_stride << std::endl; + // std::cout << "kv_head_stride " << kv_head_stride << std::endl; + + T *out_ptr = out.data(); + T *query_ptr = query.data(); + T *key_cache_ptr = key_cache.data(); + T *value_cache_ptr = value_cache.data(); int *block_tables_ptr = block_tables.data(); int *context_lens_ptr = context_lens.data(); diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/cache_impl.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/cache_impl.cpp index 2c62a3b9b..90d8fd73a 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/cache_impl.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/cache_impl.cpp @@ -10,15 +10,15 @@ template void reshape_and_cache_cpu_impl( const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const int32_t *__restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; -#pragma omp parallel for collapse(2) +// #pragma omp parallel for collapse(2) for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int64_t slot_idx = slot_mapping[token_idx]; + const int32_t slot_idx = slot_mapping[token_idx]; if (slot_idx >= 0) { int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = @@ -33,6 +33,8 @@ void reshape_and_cache_cpu_impl( scalar_t *target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; + + // std::cout << (block_elem_num * block_index + head_idx * block_size * head_size) << " "; for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { const int64_t target_offset = @@ -67,9 +69,12 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value, int x = key_cache_shape[4]; ov::Strides key_strides = key.get_strides(); - int key_stride = key_strides[0]; + int key_stride = key_strides[0] / key.get_element_type().size(); ov::Strides value_strides = value.get_strides(); - int value_stride = value_strides[0]; + int value_stride = value_strides[0] / value.get_element_type().size(); + + OPENVINO_ASSERT(slot_mapping.get_element_type() == ov::element::i32, + "slot_mapping must be of type i64, given ", slot_mapping.get_element_type()); switch (key.get_element_type()) { case ov::element::f32: @@ -78,7 +83,7 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value, reshape_and_cache_cpu_impl( key.data(), value.data(), key_cache.data(), value_cache.data(), - slot_mapping.data(), num_tokens, key_stride, + slot_mapping.data(), num_tokens, key_stride, value_stride, num_heads, head_size, block_size, x); break; case ov::element::f16: @@ -87,7 +92,7 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value, reshape_and_cache_cpu_impl( key.data(), value.data(), key_cache.data(), value_cache.data(), - slot_mapping.data(), num_tokens, key_stride, + slot_mapping.data(), num_tokens, key_stride, value_stride, num_heads, head_size, block_size, x); break; default: diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp index 8ee5dee94..986569ebb 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp @@ -13,12 +13,14 @@ #include "cpu_ops.hpp" -std::shared_ptr TemplateExtension::PagedAttention::make_prefill_subgraph() { - ov::element::Type_t type = ov::element::f32, attention_mask_type = ov::element::boolean; - auto query = std::make_shared(type, ov::PartialShape({-1, -1, m_num_heads, m_head_size})); - auto key = std::make_shared(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size})); - auto value = std::make_shared(type, ov::PartialShape({-1, -1, m_num_kv_heads, m_head_size})); - auto attention_mask = std::make_shared(attention_mask_type, ov::PartialShape({-1, -1, -1})); +namespace { + +std::shared_ptr make_prefill_subgraph(std::size_t num_heads, std::size_t num_kv_heads, std::size_t head_size) { + ov::element::Type_t type = ov::element::f32, attention_mask_type = ov::element::f32; + auto query = std::make_shared(type, ov::PartialShape({-1, -1, num_heads, head_size})); + auto key = std::make_shared(type, ov::PartialShape({-1, -1, num_kv_heads, head_size})); + auto value = std::make_shared(type, ov::PartialShape({-1, -1, num_kv_heads, head_size})); + auto mask = std::make_shared(attention_mask_type, ov::PartialShape({-1, -1, -1, -1})); auto scale = std::make_shared(type, ov::Shape({1})); // transpose Q, K and V to swap num_heads and seq_len dimensions @@ -27,23 +29,19 @@ std::shared_ptr TemplateExtension::PagedAttention::make_prefill_subgr auto key_transposed = std::make_shared(key, permute_const); auto value_transposed = std::make_shared(value, permute_const); - auto spda = std::make_shared(query_transposed, key_transposed, value_transposed, attention_mask, scale, false); + auto spda = std::make_shared(query_transposed, key_transposed, value_transposed, mask, scale, false); // transpose SPDA output to [batch, seq_len, num_heads, head_size] back auto spda_transposed = std::make_shared(spda, permute_const); - return std::make_shared(spda_transposed, ov::ParameterVector{query, key, value, attention_mask, scale}, "spda_prefill_model"); + return std::make_shared(spda_transposed, ov::ParameterVector{query, key, value, mask, scale}, "spda_prefill_model"); } -TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs, - const float scale) - : ov::op::Op(inputs), - m_scale(scale) { - constructor_validate_and_infer_types(); +} - // compile model for prefill stage - auto compiled_model = ov::Core().compile_model(make_prefill_subgraph(), "CPU"); - m_prefill_request = compiled_model.create_infer_request(); +TemplateExtension::PagedAttention::PagedAttention(const ov::OutputVector& inputs) + : ov::op::Op(inputs) { + constructor_validate_and_infer_types(); } TemplateExtension::PagedAttention::PagedAttention(const ov::Output& query, @@ -63,95 +61,107 @@ TemplateExtension::PagedAttention::PagedAttention(const ov::Output& qu // const ov::Output& use_cuda_graph, // const ov::Output& attn_bias // end of arguments from InputMetadata - const float scale) - : PagedAttention({query, key, value, key_cache, value_cache, - is_prompt, slot_mapping, max_context_len, context_lens, block_tables - }, scale) {} + const ov::Output& scale) + : PagedAttention(ov::OutputVector{query, key, value, key_cache, value_cache, + is_prompt, slot_mapping, max_context_len, context_lens, block_tables, scale + }) {} void TemplateExtension::PagedAttention::validate_and_infer_types() { // value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] - auto value_cache_shape = get_input_shape(4); - m_num_kv_heads = value_cache_shape[1]; - m_head_size = value_cache_shape[2]; - m_block_size = value_cache_shape[3]; + auto value_cache_shape = get_input_partial_shape(4); + // m_num_kv_heads = value_cache_shape[1]; + // m_head_size = value_cache_shape[2]; + // m_block_size = value_cache_shape[3]; + NODE_VALIDATION_CHECK(this, + value_cache_shape.size() == 4, + "Value cache shape must be 4 dims"); // key_cache: shape [num_blocks, num_kv_heads, head_size/x, block_size, x] - auto key_cache_shape = get_input_shape(3); + auto key_cache_shape = get_input_partial_shape(3); NODE_VALIDATION_CHECK(this, - value_cache_shape[0] == key_cache_shape[0] && // num_blocks - key_cache_shape[1] == m_num_kv_heads && - key_cache_shape[2] * key_cache_shape[4] == m_head_size && - m_block_size == key_cache_shape[3], // block_size, - "Key cache validation failed"); + value_cache_shape.size() == 4, + // value_cache_shape[0] == key_cache_shape[0] && // num_blocks + // key_cache_shape[1] == m_num_kv_heads && + // key_cache_shape[2] * key_cache_shape[4] == m_head_size && + // m_block_size == key_cache_shape[3], // block_size, + "Key cache shape must be 4 dims"); // query: shape [batch_size, seq_len, num_heads * head_size] auto query_type = get_input_element_type(0); auto query_shape = get_input_partial_shape(0); - m_num_heads = query_shape[2].get_length(); NODE_VALIDATION_CHECK(this, - query_type.is_real() && - query_shape.size() == 3 && - query_shape[2] == m_num_heads * m_head_size, - "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]"); + // query_type.is_real() && + query_shape.size() == 3, + // query_shape[2] == m_num_heads * m_head_size, + "Query type must be real, shape must be like [batch_size, seq_len, num_heads * head_size]. ", + "Got element type ", query_type, ", shape ", query_shape); // key: shape [batch_size, seq_len, num_kv_heads * head_size] auto key_type = get_input_element_type(1); auto key_shape = get_input_partial_shape(1); NODE_VALIDATION_CHECK(this, query_type == key_type && - key_shape.size() == query_shape.size() && - key_shape[2] == m_num_kv_heads * m_head_size, - "Key type must be the same as query, shape must be the same as query"); + key_shape.size() == 3, + "Key type must be the same as query, shape must be the same as query. " + "Got element type ", key_type, ", shape ", key_shape); // value: shape [batch_size, seq_len, num_kv_heads * head_size] auto value_type = get_input_element_type(2); auto value_shape = get_input_partial_shape(2); NODE_VALIDATION_CHECK(this, key_type == value_type && - key_shape == value_shape, "Value type must be the same as key, shape must be the same as key"); + key_shape == value_shape, "Value type must be the same as key, shape must be the same as key." + "Got element type ", value_type, ", shape ", value_shape); // is_prompt: boolean scalar NODE_VALIDATION_CHECK(this, - get_input_element_type(5) == ov::element::boolean && - get_input_shape(5) == ov::Shape({1}), - "is_prompt validation failed"); + // get_input_element_type(5) == ov::element::boolean && + get_input_shape(5) == ov::Shape({}), + "is_prompt validation failed. ", + "Got element type ", get_input_element_type(5), ", shape ", get_input_shape(5)); // slot_mapping: shape [batch_size, max_context_len] auto slot_mapping_shape = get_input_partial_shape(6); NODE_VALIDATION_CHECK(this, - get_input_element_type(6) == ov::element::i64 && + // get_input_element_type(6) == ov::element::i64 && slot_mapping_shape.size() == 2, - "slot_mapping validation failed"); + "slot_mapping validation failed. ", + "Got element type ", get_input_element_type(6), ", shape ", slot_mapping_shape); // max_context_len: integer scalar NODE_VALIDATION_CHECK(this, - get_input_element_type(7) == ov::element::i32 && - get_input_shape(7) == ov::Shape({1}), - "max_context_len validation failed"); + // get_input_element_type(7) == ov::element::i32 && + get_input_shape(7) == ov::Shape({}), + "max_context_len validation failed. ", + "Got element type ", get_input_element_type(7), ", shape ", get_input_shape(7)); // context_lens: shape [batch_size] - auto context_lens_shape = get_input_shape(8); + auto context_lens_shape = get_input_partial_shape(8); NODE_VALIDATION_CHECK(this, - get_input_element_type(8) == ov::element::i32 && + // get_input_element_type(8) == ov::element::i32 && context_lens_shape.size() == 1, - "context_lens validation failed"); + "context_lens validation failed. ", + "Got element type ", get_input_element_type(8), ", shape ", context_lens_shape); // block_tables: shape [batch_size, max_block_per_request] NODE_VALIDATION_CHECK(this, - get_input_element_type(9) == ov::element::i32 && + // get_input_element_type(9) == ov::element::i32 && get_input_partial_shape(9).size() == 2, - "block_tables validation failed"); + "block_tables validation failed. ", + "Got element type ", get_input_element_type(9), ", shape ", get_input_partial_shape(9)); + + // scale: float scalar + NODE_VALIDATION_CHECK(this, + // get_input_element_type(10) == ov::element::f32 && + get_input_shape(10) == ov::Shape({}), + "block_tables validation failed. ", + "Got element type ", get_input_element_type(10), ", shape ", get_input_shape(10)); set_output_type(0, query_type, query_shape); } std::shared_ptr TemplateExtension::PagedAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { - return std::make_shared(new_args, m_scale); -} - -bool TemplateExtension::PagedAttention::visit_attributes(ov::AttributeVisitor& visitor) { - visitor.on_attribute("scale", m_scale); - return true; + return std::make_shared(new_args); } bool TemplateExtension::PagedAttention::has_evaluate() const { @@ -164,43 +174,72 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value, ov::Tensor slot_mapping); // generate buttom diagonal boolean attention mask for a prefill stage -ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32_t max_context_len, ov::Tensor context_lens) { - OPENVINO_ASSERT(num_seqs == context_lens.get_size()); +ov::Tensor generate_attention_mask(const std::size_t batch_size, const std::size_t seq_len) { + ov::Shape attention_mask_shape({batch_size, 1, seq_len, seq_len}); + ov::Tensor attention_mask(ov::element::f32, attention_mask_shape); + int attention_mask_stride = attention_mask.get_strides()[0] / sizeof(float); - ov::Shape attention_mask_shape({num_seqs, 1, max_context_len, max_context_len}); - ov::Tensor attention_mask(ov::element::boolean, attention_mask_shape); - int attention_mask_stride = attention_mask.get_strides()[0]; + static_assert(std::numeric_limits::is_iec559, "IEEE 754 required"); + float negative_inf = -std::numeric_limits::infinity(); - std::fill_n(attention_mask.data(), attention_mask.get_size(), false); + std::fill_n(attention_mask.data(), attention_mask.get_size(), 0); - for (int current_seq = 0; current_seq < num_seqs; ++current_seq) { - std::int32_t context_len = context_lens.data()[current_seq]; - OPENVINO_ASSERT(context_len <= max_context_len); - - bool * attention_mask_data = attention_mask.data() + current_seq * attention_mask_stride; - for (int x = 0; x < context_len; ++x) { - for (int y = 0; y < context_len; ++y) { - attention_mask_data[x * max_context_len + y] = x >= y; + for (int batch_id = 0; batch_id < batch_size; ++batch_id) { + float * attention_mask_data = attention_mask.data() + batch_id * attention_mask_stride; + for (int x = 0; x < seq_len; ++x) { + for (int y = 0; y < seq_len; ++y) { + attention_mask_data[x * seq_len + y] = x < y ? negative_inf : 0.0f; } } } + + return attention_mask; +} + +void print_tensor(const std::string& title, ov::Tensor tensor) { + std::cout << title << std::endl; + size_t size = std::min(40, tensor.get_size()); + for (int x = 0; x < size; ++x) { + if (tensor.get_element_type() == ov::element::f32) + std::cout << tensor.data()[x] << " "; + else if (tensor.get_element_type() == ov::element::i32) + std::cout << tensor.data()[x] << " "; + } + std::cout << std::endl; } bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2]; - ov::Shape query_shape = query.get_shape(); - const std::int32_t batch_size = query_shape[0], seq_len = query_shape[1], hidden_size = query_shape[2]; ov::Tensor key_cache = inputs[3], value_cache = inputs[4]; const bool is_prompt = inputs[5].data()[0]; ov::Tensor slot_mapping = inputs[6]; const std::int32_t max_context_len = inputs[7].data()[0]; ov::Tensor context_lens = inputs[8]; ov::Tensor block_tables = inputs[9]; + float scale = inputs[10].data()[0]; + + // Shapes + ov::Shape query_shape = query.get_shape(); + const std::size_t batch_size = query_shape[0], seq_len = query_shape[1], hidden_size = query_shape[2]; + + ov::Shape value_cache_shape = value_cache.get_shape(); + const std::size_t num_kv_heads = value_cache_shape[1], head_size = value_cache_shape[2], + num_heads = hidden_size / head_size, block_size = value_cache_shape[3]; + + if (!m_prefill_request) { + // compile model for prefill stage + auto compiled_model = ov::Core().compile_model(make_prefill_subgraph(num_heads, num_kv_heads, head_size), "CPU"); + m_prefill_request = compiled_model.create_infer_request(); + } // reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size] - query.set_shape({batch_size * seq_len, m_num_heads, m_head_size}); - key.set_shape({batch_size * seq_len, m_num_kv_heads, m_head_size}); + void * query_data = query.data(), * key_data = key.data(), * value_data = value.data(); + query.set_shape({batch_size * seq_len, num_heads, head_size}); + OPENVINO_ASSERT(query_data == query.data()); + key.set_shape({batch_size * seq_len, num_kv_heads, head_size}); + OPENVINO_ASSERT(key_data == key.data()); value.set_shape(key.get_shape()); + OPENVINO_ASSERT(value_data == value.data()); // put current K, V values into key_cache and value_cache reshape_and_cache(key, value, key_cache, value_cache, slot_mapping); @@ -208,16 +247,34 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons // set output shape OPENVINO_ASSERT(outputs.size() == 1); outputs[0].set_shape(query.get_shape()); + void * output_data = outputs[0].data(); + + // std::cout << "key_cache shape " << key_cache.get_shape() << std::endl; + // std::cout << "value_cache shape " << value_cache.get_shape() << std::endl; + // std::cout << "num_kv_heads " << num_kv_heads << std::endl; + // std::cout << "block_tables shape " << block_tables.get_shape() << std::endl; + // std::cout << "context_lens shape " << context_lens.get_shape() << std::endl; + // std::cout << "block_size " << block_size << std::endl; + // std::cout << "max_context_len " << max_context_len << std::endl; + // std::cout << "outputs[0] shape " << outputs[0].get_shape() << std::endl; + // std::cout << "num_kv_heads " << num_kv_heads << std::endl; + // std::cout << "num_heads " << num_heads << std::endl; + // std::cout << "head_size " << head_size << std::endl; if (is_prompt) { - // reshape to [batch_size, seq_len, m_num_kv_heads, head_size] - query.set_shape({batch_size, seq_len, m_num_heads, m_head_size}); + // reshape to [batch_size, seq_len, num_kv_heads, head_size] + query.set_shape({batch_size, seq_len, num_heads, head_size}); + OPENVINO_ASSERT(query_data == query.data()); outputs[0].set_shape(query.get_shape()); - key.set_shape({batch_size, seq_len, m_num_kv_heads, m_head_size}); + OPENVINO_ASSERT(output_data == outputs[0].data()); + key.set_shape({batch_size, seq_len, num_kv_heads, head_size}); + OPENVINO_ASSERT(key_data == key.data()); value.set_shape(key.get_shape()); + OPENVINO_ASSERT(value_data == value.data()); - auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens); - ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale); + auto attention_mask = generate_attention_mask(batch_size, seq_len); + scale = 1.0f; // TODO + ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&scale); m_prefill_request.set_input_tensor(0, query); m_prefill_request.set_input_tensor(1, key); @@ -228,16 +285,25 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons m_prefill_request.infer(); } else { + std::fill_n(outputs[0].data(), outputs[0].get_size(), 0.0f); + // 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size] paged_attention_v1_cpu(outputs[0], query, key_cache, value_cache, - m_num_kv_heads, m_scale, + num_kv_heads, scale, block_tables, context_lens, - m_block_size, max_context_len); + block_size, max_context_len); + + // print_tensor("query", query); + // print_tensor("block_tables", block_tables); + // print_tensor("context_lens", context_lens); + // print_tensor("output", outputs[0]); + // exit(1); } // reshape back to [batch_size, seq_len, num_heads * head_size] outputs[0].set_shape(query_shape); // works like reshape + OPENVINO_ASSERT(output_data == outputs[0].data()); return true; } diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp index f996036b1..788cded5a 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.hpp @@ -23,13 +23,12 @@ namespace TemplateExtension { class PagedAttention : public ov::op::Op { public: - OPENVINO_OP("PagedAttention"); + OPENVINO_OP("PagedAttentionExtension"); OPENVINO_FRAMEWORK_MAP(pytorch, "vllm.model_executor.layers.attention.PagedAttention"); PagedAttention() = default; - PagedAttention(const ov::OutputVector& inputs, - const float scale); + PagedAttention(const ov::OutputVector& inputs); PagedAttention(const ov::Output& query, const ov::Output& key, @@ -48,20 +47,15 @@ class PagedAttention : public ov::op::Op { // const ov::Output& use_cuda_graph, // const ov::Output& attn_bias // end of arguments from InputMetadata - const float scale); + const ov::Output& scale); std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - bool visit_attributes(ov::AttributeVisitor& visitor) override; void validate_and_infer_types() override; bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; bool has_evaluate() const override; private: - std::shared_ptr make_prefill_subgraph(); - - std::uint32_t m_num_heads, m_num_kv_heads, m_head_size, m_block_size; - float m_scale; mutable ov::InferRequest m_prefill_request; };