Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Add PagedAttention support #23524

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
auto convert = std::make_shared<opset4::Convert>(param, to);
for (auto& input : param_consumers) {
const auto consumer = input.get_node();
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer) ||
// TODO: refactor after ngraph op defined
// The fourth and fifth inputs are kvcache and should be directly connected to parameters
(consumer->get_type_name() == std::string("PagedAttentionExtension") &&
(input.get_index() == 3 || input.get_index() == 4))) {
continue;
}
input.replace_source_output(convert);
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 ANY
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
API src/nodes/kernels/scaled_attn/attn_memcpy.hpp
NAME attn_memcpy
NAME attn_memcpy paged_attn_memcpy
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"Ngram", Type::Ngram},
{"ScaledDotProductAttention", Type::ScaledDotProductAttention},
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
{"PagedAttentionExtension", Type::ScaledDotProductAttention},
luo-cheng2021 marked this conversation as resolved.
Show resolved Hide resolved
{"RoPE", Type::RoPE},
};
return type_to_name_tbl;
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,10 @@ void Graph::EnforceInferencePrecision() {
if (node->getOriginalInputPrecisionAtPort(inPort) != ov::element::f32)
return true;

// kvcache of PagedAttention should be written directly
if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 &&
(inPort == 3 || inPort == 4))
luo-cheng2021 marked this conversation as resolved.
Show resolved Hide resolved
return true;
const auto &parent = node->getParentEdgeAt(inPort)->getParent();
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
* Element type conversion to bf16 is done automatically, if convolution follows up after Constant Inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,43 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
});
}

template <typename T, typename T2>
static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
attn_copy(past_k_output.ptr<T2>(block_idx, h, 0),
k_input.ptr<T>(b, h, m, 0),
S);
attn_copy(past_v_output.ptr<T2>(block_idx, h, 0),
v_input.ptr<T>(b, h, m, 0),
S);
});
}

static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
if (block_idx < 0) return;
std::memcpy(past_k_output.ptr_v(block_idx, h, 0),
k_input.ptr_v(b, h, m, 0),
S * k_input.m_element_size);
std::memcpy(past_v_output.ptr_v(block_idx, h, 0),
v_input.ptr_v(b, h, m, 0),
S * v_input.m_element_size);
});
}

void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
Expand All @@ -90,6 +127,23 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy");
}
}

void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping) {
if (past_k_output.get_precision() == k_input.get_precision()) {
paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::f16) {
paged_attn_memcpy_kernel<float, ov::float16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) {
paged_attn_memcpy_kernel<float, ov::bfloat16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
} else {
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in paged_attn_memcpy");
}
}

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output);

void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
const ov::intel_cpu::PlainTensor& v_input,
const ov::intel_cpu::PlainTensor& past_k_output,
const ov::intel_cpu::PlainTensor& past_v_output,
const ov::intel_cpu::PlainTensor& slot_mapping);

} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
Expand Down
Loading
Loading