Skip to content

Commit

Permalink
[CPU] Add PagedAttention support (#23524)
Browse files Browse the repository at this point in the history
### Details:
 - *Support PagedAttention support, depends on:*
- openvino_contrib:
openvinotoolkit/openvino_contrib#867
    - vLLM: ilya-lavrenov/vllm#4
 - *TODO*
    - Models with alibi feature
   
### Tickets:
 - *[134329](https://jira.devtools.intel.com/browse/CVS-134329)*
 - *[134327](https://jira.devtools.intel.com/browse/CVS-134327)*
  • Loading branch information
luo-cheng2021 authored Mar 21, 2024
1 parent 1c0ca0e commit 0b260ff
Show file tree
Hide file tree
Showing 11 changed files with 544 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
auto convert = std::make_shared<opset4::Convert>(param, to);
for (auto& input : param_consumers) {
const auto consumer = input.get_node();
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer) ||
// TODO: refactor after ngraph op defined
// The fourth and fifth inputs are kvcache and should be directly connected to parameters
(consumer->get_type_name() == std::string("PagedAttentionExtension") &&
(input.get_index() == 3 || input.get_index() == 4))) {
continue;
}
input.replace_source_output(convert);
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
API src/nodes/kernels/scaled_attn/attn_memcpy.hpp
NAME attn_memcpy
NAME attn_memcpy paged_attn_memcpy
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"Ngram", Type::Ngram},
{"ScaledDotProductAttention", Type::ScaledDotProductAttention},
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
{"PagedAttentionExtension", Type::ScaledDotProductAttention},
{"RoPE", Type::RoPE},
};
return type_to_name_tbl;
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,10 @@ void Graph::EnforceInferencePrecision() {
if (node->getOriginalInputPrecisionAtPort(inPort) != ov::element::f32)
return true;

// kvcache of PagedAttention should be written directly
if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 &&
(inPort == 3 || inPort == 4))
return true;
const auto &parent = node->getParentEdgeAt(inPort)->getParent();
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
* Element type conversion to bf16 is done automatically, if convolution follows up after Constant Inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,43 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
});
}

template <typename T, typename T2>
static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
attn_copy(past_k_output.ptr<T2>(block_idx, h, 0),
k_input.ptr<T>(b, h, m, 0),
S);
attn_copy(past_v_output.ptr<T2>(block_idx, h, 0),
v_input.ptr<T>(b, h, m, 0),
S);
});
}

static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
std::memcpy(past_k_output.ptr_v(block_idx, h, 0),
k_input.ptr_v(b, h, m, 0),
S * k_input.m_element_size);
std::memcpy(past_v_output.ptr_v(block_idx, h, 0),
v_input.ptr_v(b, h, m, 0),
S * v_input.m_element_size);
});
}

void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
Expand All @@ -90,6 +127,23 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy");
}
}

void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
if (past_k_output.get_precision() == k_input.get_precision()) {
paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::f16) {
paged_attn_memcpy_kernel<float, ov::float16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) {
paged_attn_memcpy_kernel<float, ov::bfloat16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else {
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in paged_attn_memcpy");
}
}

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output);

void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping);

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
Loading

0 comments on commit 0b260ff

Please sign in to comment.