From a34ce8bc61a90b852583879f98bfa50cb5ef2369 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 11 Dec 2024 15:17:48 +0800 Subject: [PATCH] [CPU]add cache precision check Signed-off-by: Zhang Yi3 --- src/plugins/intel_cpu/src/nodes/paged_attn.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 468fc72b28296a..b78510d6b8934f 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -196,6 +196,14 @@ void PagedAttention::execute(dnnl::stream strm) { bool PagedAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { + auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); + auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); + if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) { + if (kCachePrecision != ov::element::u8) { + errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); + return false; + } + } int orgInput = static_cast(op->get_input_size()); if (op->get_type_name() == std::string("PagedAttentionExtension") && orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) { return true;