From 27c7056f586eb70c3c55034b23d84cae36fffbcb Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Wed, 20 Mar 2024 15:50:05 +0100 Subject: [PATCH 1/3] [CPU] Fix SDPA pattern matching --- .../intel_cpu/src/nodes/scaled_attn.cpp | 1 + .../common/pass/stateful_sdpa_fusion.cpp | 24 +++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index a5ca0425c50bc7..273f81a07ad9c1 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1111,6 +1111,7 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { presentv_input = inputs[ID_VCACHE]; } else { if (m_config.config.fuse_concat) { + CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states"); // initialization will be also completed in this func gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 1f0af5c7b22f30..6f8ca7d3b6204f 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto find_assign = [&](const ov::Output& out, opset6::Assign*& assign, opset1::Convert*& cvt) { auto present_to = out.get_target_inputs(); - if (present_to.size() < 2) - return false; for (auto& to : present_to) { auto to_node = to.get_node(); if (auto convert = dynamic_cast(to_node)) { @@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() { const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); + for (auto&& item : {concat_k_node, concat_v_node}) { + auto&& children = item->get_output_target_inputs(0); + switch (children.size()) { + case 2: + // pass, as the existence of Assign will be checked later + break; + case 3: + // the fist one leads to SDPA, otherwise the matcher don't find the pattern + // the second one leads to Assign, and this is checked later + // the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf + if (!std::any_of(children.begin(), children.end(), [](const ov::Input& child) { + return ov::is_type(child.get_node()) || + ov::is_type(child.get_node()); + })) { + return false; + } + break; + default: + return false; + } + } + opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr; opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr; if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) From 1175121e458e24a85f2ae23ff2cdda412313bd3e Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Wed, 20 Mar 2024 19:05:04 +0100 Subject: [PATCH 2/3] Minor fix --- .../cpu_opset/common/pass/stateful_sdpa_fusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 6f8ca7d3b6204f..c66d500e39e28b 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -154,7 +154,7 @@ StatefulSDPAFusion::StatefulSDPAFusion() { // pass, as the existence of Assign will be checked later break; case 3: - // the fist one leads to SDPA, otherwise the matcher don't find the pattern + // the fist one leads to SDPA, otherwise the matcher doesn't find the pattern // the second one leads to Assign, and this is checked later // the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf if (!std::any_of(children.begin(), children.end(), [](const ov::Input& child) { From 2edd72f93355883ac0c0b8402b8ca137b7d12b8c Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Thu, 21 Mar 2024 12:11:03 +0100 Subject: [PATCH 3/3] Fix typo --- .../cpu_opset/common/pass/stateful_sdpa_fusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index c66d500e39e28b..fc43aeaccf5fab 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -154,7 +154,7 @@ StatefulSDPAFusion::StatefulSDPAFusion() { // pass, as the existence of Assign will be checked later break; case 3: - // the fist one leads to SDPA, otherwise the matcher doesn't find the pattern + // the first one leads to SDPA, otherwise the matcher doesn't find the pattern // the second one leads to Assign, and this is checked later // the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf if (!std::any_of(children.begin(), children.end(), [](const ov::Input& child) {