Skip to content

Commit

Permalink
Use Tensor::set_shape instead of view
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Feb 5, 2024
1 parent e9d5e44 commit 48fabb9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,6 @@ ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32
}
}

// similar to torch.Tensor.view
ov::Tensor view_as_3d(ov::Tensor tensor) {
ov::Shape shape = tensor.get_shape();
OPENVINO_ASSERT(shape.size() == 4);
const std::uint32_t batch_size = shape[0], seq_len = shape[1], num_heads = shape[2], head_size = shape[3];
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads * head_size}), tensor.data());
}

ov::Tensor view_as_4d(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) {
ov::Shape shape = tensor.get_shape();
const std::uint32_t batch_size = shape[0], seq_len = shape[1];
OPENVINO_ASSERT(shape.size() == 3 && num_heads * head_size == shape[3]);
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads, head_size}), tensor.data());
}

bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2];
ov::Shape query_shape = query.get_shape();
Expand All @@ -212,10 +197,10 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
ov::Tensor context_lens = inputs[8];
ov::Tensor block_tables = inputs[9];

// reshape to [batch_size, seq_len, num_heads/m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
query = view_as_4d(query, m_num_heads, m_head_size);
key = view_as_4d(key, m_num_kv_heads, m_head_size);
value = view_as_4d(value, m_num_kv_heads, m_head_size);
// reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
query.set_shape({batch_size * seq_len, m_num_heads, m_head_size});
key.set_shape({batch_size * seq_len, m_num_kv_heads, m_head_size});
value.set_shape(key.get_shape());

// put current K, V values into key_cache and value_cache
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping);
Expand All @@ -225,6 +210,12 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
outputs[0].set_shape(query.get_shape());

if (is_prompt) {
// reshape to [batch_size, seq_len, m_num_kv_heads, head_size]
query.set_shape({batch_size, seq_len, m_num_heads, m_head_size});
outputs[0].set_shape(query.get_shape());
key.set_shape({batch_size, seq_len, m_num_kv_heads, m_head_size});
value.set_shape(key.get_shape());

auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens);
ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale);

Expand All @@ -237,15 +228,16 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons

m_prefill_request.infer();
} else {
// 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size]
paged_attention_v1_cpu(outputs[0],
query, key_cache, value_cache,
m_num_kv_heads, m_scale,
block_tables, context_lens,
m_block_size, max_context_len);
}

// reshape
outputs[0] = view_as_3d(outputs[0]);
// reshape back to [batch_size, seq_len, num_heads * head_size]
outputs[0].set_shape(query_shape); // works like reshape

return true;
}

0 comments on commit 48fabb9

Please sign in to comment.