From b87a515726773977e601e8a940002c0e0b4e41cc Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 21 Mar 2024 01:28:54 +0100 Subject: [PATCH] fix f16 path --- .../src/transformations/convert_precision.cpp | 6 +++++- .../src/nodes/kernels/scaled_attn/attn_memcpy.cpp | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/common/transformations/src/transformations/convert_precision.cpp b/src/common/transformations/src/transformations/convert_precision.cpp index d5c8204663c242..b29ab74981a483 100644 --- a/src/common/transformations/src/transformations/convert_precision.cpp +++ b/src/common/transformations/src/transformations/convert_precision.cpp @@ -607,7 +607,11 @@ bool fuse_type_to_parameter(const std::shared_ptr& node, auto convert = std::make_shared(param, to); for (auto& input : param_consumers) { const auto consumer = input.get_node(); - if (ov::is_type(consumer) || ov::is_type(consumer)) { + if (ov::is_type(consumer) || ov::is_type(consumer) || + // TODO: refactor after ngraph op defined + // The fourth and fifth inputs are kvcache and should be directly connected to parameters + (consumer->get_type_name() == std::string("PagedAttentionExtension") && + (input.get_index() == 3 || input.get_index() == 4))) { continue; } input.replace_source_output(convert); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp index 0361a8cbe1e0f2..c170464eeb47ee 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp @@ -86,11 +86,11 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) { auto block_idx = slot_mapping.ptr(b)[m]; if (block_idx < 0) return; - attn_copy(past_k_output.ptr(block_idx, h, m, 0), - k_input.ptr(b, h, 0), + attn_copy(past_k_output.ptr(block_idx, h, 0), + k_input.ptr(b, h, m, 0), S); - attn_copy(past_v_output.ptr(block_idx, h, m, 0), - v_input.ptr(b, h, 0), + attn_copy(past_v_output.ptr(block_idx, h, 0), + v_input.ptr(b, h, m, 0), S); }); }