diff --git a/modules/custom_operations/tests/CMakeLists.txt b/modules/custom_operations/tests/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp index 388b869b3..bf574457e 100644 --- a/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp +++ b/modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp @@ -186,21 +186,6 @@ ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32 } } -// similar to torch.Tensor.view -ov::Tensor view_as_3d(ov::Tensor tensor) { - ov::Shape shape = tensor.get_shape(); - OPENVINO_ASSERT(shape.size() == 4); - const std::uint32_t batch_size = shape[0], seq_len = shape[1], num_heads = shape[2], head_size = shape[3]; - return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads * head_size}), tensor.data()); -} - -ov::Tensor view_as_4d(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) { - ov::Shape shape = tensor.get_shape(); - const std::uint32_t batch_size = shape[0], seq_len = shape[1]; - OPENVINO_ASSERT(shape.size() == 3 && num_heads * head_size == shape[3]); - return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads, head_size}), tensor.data()); -} - bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2]; ov::Shape query_shape = query.get_shape(); @@ -212,10 +197,10 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons ov::Tensor context_lens = inputs[8]; ov::Tensor block_tables = inputs[9]; - // reshape to [batch_size, seq_len, num_heads/m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size] - query = view_as_4d(query, m_num_heads, m_head_size); - key = view_as_4d(key, m_num_kv_heads, m_head_size); - value = view_as_4d(value, m_num_kv_heads, m_head_size); + // reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size] + query.set_shape({batch_size * seq_len, m_num_heads, m_head_size}); + key.set_shape({batch_size * seq_len, m_num_kv_heads, m_head_size}); + value.set_shape(key.get_shape()); // put current K, V values into key_cache and value_cache reshape_and_cache(key, value, key_cache, value_cache, slot_mapping); @@ -225,6 +210,12 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons outputs[0].set_shape(query.get_shape()); if (is_prompt) { + // reshape to [batch_size, seq_len, m_num_kv_heads, head_size] + query.set_shape({batch_size, seq_len, m_num_heads, m_head_size}); + outputs[0].set_shape(query.get_shape()); + key.set_shape({batch_size, seq_len, m_num_kv_heads, m_head_size}); + value.set_shape(key.get_shape()); + auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens); ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale); @@ -237,6 +228,7 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons m_prefill_request.infer(); } else { + // 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size] paged_attention_v1_cpu(outputs[0], query, key_cache, value_cache, m_num_kv_heads, m_scale, @@ -244,8 +236,8 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons m_block_size, max_context_len); } - // reshape - outputs[0] = view_as_3d(outputs[0]); + // reshape back to [batch_size, seq_len, num_heads * head_size] + outputs[0].set_shape(query_shape); // works like reshape return true; }