Skip to content

Commit

Permalink
[CPU] Fix SDPA pattern matching (#23581)
Browse files Browse the repository at this point in the history
### Details:
Limit the Concat layer to have maximum 3 children. The third one is
allowed to be a ShapeOf op only (to support Mixtral).

### Tickets:
 - 135375
  • Loading branch information
maxnick authored Mar 22, 2024
1 parent 0902fe4 commit 908dac9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,7 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {
presentv_input = inputs[ID_VCACHE];
} else {
if (m_config.config.fuse_concat) {
CPU_NODE_ASSERT(m_k_state && m_v_state, "has null input states");
// initialization will be also completed in this func
gatherConcatPastkv(inputs[1], inputs[2], getSrcMemoryAtPort(orginSDPInputNumber));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
auto present_to = out.get_target_inputs();
if (present_to.size() < 2)
return false;
for (auto& to : present_to) {
auto to_node = to.get_node();
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
Expand Down Expand Up @@ -149,6 +147,28 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());

for (auto&& item : {concat_k_node, concat_v_node}) {
auto&& children = item->get_output_target_inputs(0);
switch (children.size()) {
case 2:
// pass, as the existence of Assign will be checked later
break;
case 3:
// the first one leads to SDPA, otherwise the matcher doesn't find the pattern
// the second one leads to Assign, and this is checked later
// the third child is allowed to be a ShapeOf op only, thus one of them must be ShapeOf
if (!std::any_of(children.begin(), children.end(), [](const ov::Input<ov::Node>& child) {
return ov::is_type<ov::op::v3::ShapeOf>(child.get_node()) ||
ov::is_type<ov::op::v0::ShapeOf>(child.get_node());
})) {
return false;
}
break;
default:
return false;
}
}

opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
Expand Down

0 comments on commit 908dac9

Please sign in to comment.