diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index a5ca0425c50bc7..273f81a07ad9c1 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1111,6 +1111,7 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { presentv_input = inputs[ID_VCACHE]; } else { if (m_config.config.fuse_concat) { + CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states"); // initialization will be also completed in this func gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber)); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 1f0af5c7b22f30..fc43aeaccf5fab 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto find_assign = [&](const ov::Output& out, opset6::Assign*& assign, opset1::Convert*& cvt) { auto present_to = out.get_target_inputs(); - if (present_to.size() < 2) - return false; for (auto& to : present_to) { auto to_node = to.get_node(); if (auto convert = dynamic_cast(to_node)) { @@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() { const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); + for (auto&& item : {concat_k_node, concat_v_node}) { + auto&& children = item->get_output_target_inputs(0); + switch (children.size()) { + case 2: + // pass, as the existence of Assign will be checked later + break; + case 3: + // the first one leads to SDPA, otherwise the matcher doesn't find the pattern + // the second one leads to Assign, and this is checked later + // the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf + if (!std::any_of(children.begin(), children.end(), [](const ov::Input& child) { + return ov::is_type(child.get_node()) || + ov::is_type(child.get_node()); + })) { + return false; + } + break; + default: + return false; + } + } + opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr; opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr; if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))