Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFORMATIONS][GPU] SDPA Fusion passes #28042

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation.
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐
/// │ Q │ │ K │ │ V │
/// └───┬───┘ └───┬───┘ └───┬───┘
/// │ │ │
/// │ │ │
/// ┌───┴───┐ ┌─────┴──────┐ │
/// │ MatMul│<──│ Transpose │ │
/// └───┬───┘ | (Optional) │ │
/// │ └────────────┘ │
/// ┌───┴───┐ ┌─────────────┐ │
/// │ Add │<───│AttentionMask│ │
/// └───┬───┘ | (Optional) │ │
/// │ └─────────────┘ │
/// ┌───┴───┐ │
/// │Softmax│ │
/// └───┬───┘ │
/// │ │
/// ┌───┴───┐ │
/// │ MatMul│<─────────────────────┘
/// └───┬───┘
/// ┌───┴───┐
/// │ Output│
/// └───────┘
///
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘
/// │ │ │ │
/// │ │ │ │
/// ┌───┴────────────┴────────────┴───────────────┴─┐
/// │ ScaledDotProductAttention │
/// └────────────────────┬──────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAFusion", "0");
SDPAFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │
/// │ │ │ └──────┬──────┘ └───────┬─────┘
/// │ │ │ │ |
/// ┌───┴───┐ ┌───┴───┐ │ │ |
/// │ Mul | │ Mul │ | │ |
/// └───┬───┘ └───┬───┘ │ │ │
/// │ │ │ │ │
/// | │ │ │ │
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │──────────────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘
/// │ │ │ │ |
/// │ │ │ │ |
/// | │ │ │ |
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │───────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// Multiply ops for Q and K are eliminated in the following cases:
/// 1. Q_scale and K_scale are constant
/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAScaleFusion", "0");
SDPAScaleFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/common_optimizations/sdpa_fusion.hpp"
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
Expand Down Expand Up @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence)
ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes)
ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion)
ADD_MATCHER(common_fusions, SDPAFusion)
if (m_use_shapes) {
ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAFusion::SDPAFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();

auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}});
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}});
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}});
auto qk = qk_nt | qk_nn;
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask});
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}});
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}});

auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) {
auto q_pshape = qk_matmul->get_input_partial_shape(0);
auto k_pshape = qk_matmul->get_input_partial_shape(1);

const size_t q_head_size_idx = 3;
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2;

return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() &&
k_pshape[k_head_size_idx].is_static() &&
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length();
};

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto q_node = pattern_map.at(q);
auto k_node = pattern_map.at(k);
auto v_node = pattern_map.at(v);

if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) {
return false;
}

if (pattern_map.at(qk).get_target_inputs().size() > 1 ||
pattern_map.at(softmax).get_target_inputs().size() > 1) {
return false;
}
if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this restriction on the number of mask consumers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not mask consumers, but Add op consumers.

pattern_map.at(mask).get_partial_shape().size() > 4)) {
return false;
}

Output<ov::Node> mask_value;
Output<ov::Node> mask_input;
if (pattern_map.find(optional_add_mask) != pattern_map.end()) {
mask_value = pattern_map.at(mask);
} else {
mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0});
}

if (mask_value.get_partial_shape().size() > 4) {
return false;
}

if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) {
mask_input = mask_value;
} else {
size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size();
std::vector<int64_t> axes(rank_diff);
std::iota(axes.begin(), axes.end(), 0);
mask_input = std::make_shared<ov::op::v0::Unsqueeze>(
mask_value,
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes));
}

std::shared_ptr<ov::Node> scale_node =
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f});

std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node,
k_node,
v_node,
mask_input,
scale_node,
false);

sdpa->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
ov::replace_node(m.get_match_root(), sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_scale_fusion.hpp"

#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAScaleFusion::SDPAScaleFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();
auto sdpa_scale = makeConst({});
auto scale_q = makePattern("[]") | makePattern("[1]");
auto scale_k = makePattern("[]") | makePattern("[1]");

auto scaled_q = optional<ov::op::v1::Multiply>({q, scale_q});
auto scaled_k = optional<ov::op::v1::Multiply>({k, scale_k});
auto sdpa_mask_scale =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale},
{{"causal", false}});
auto sdpa_mask =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask}, {{"causal", false}});
auto sdpa_simple =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v}, {{"causal", false}});
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale;

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto sdpa = m.get_match_root();

const bool has_q_scale = pattern_map.count(scaled_q);
const bool has_k_scale = pattern_map.count(scaled_k);

// Nothing to do
if (!has_q_scale && !has_k_scale)
return false;

auto prev_scale_value = 1.0f;
auto scale_q_value = 1.0f;
auto scale_k_value = 1.0f;
auto scale_et = sdpa->get_output_element_type(0);

Output<ov::Node> q_input = sdpa->get_input_source_output(0);
Output<ov::Node> k_input = sdpa->get_input_source_output(1);

std::shared_ptr<ov::Node> scale_q_node = nullptr;
std::shared_ptr<ov::Node> scale_k_node = nullptr;

if (pattern_map.find(sdpa_scale) != pattern_map.end()) {
auto prev_scale_node =
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr());
prev_scale_value = prev_scale_node->cast_vector<float>()[0];
scale_et = prev_scale_node->get_output_element_type(0);
} else {
auto head_size = q_input.get_partial_shape()[3];
if (head_size.is_dynamic())
return false;

prev_scale_value = 1.0f / std::sqrt(static_cast<float>(head_size.get_length()));
}

// Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA
if (has_q_scale) {
scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0];
q_input = pattern_map.at(q);
}
}
if (has_k_scale) {
scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0];
k_input = pattern_map.at(k);
}
}

Output<ov::Node> new_scale_node;
auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value;

// If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA
if (new_scale_val == 1.0f) {
if (has_q_scale && !ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
new_scale_node = pattern_map.at(scale_q);
q_input = pattern_map.at(q);
} else if (has_k_scale && !ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
new_scale_node = pattern_map.at(scale_k);
k_input = pattern_map.at(k);
} else {
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val});
}
} else {
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val});
}

OutputVector new_inputs = {q_input, k_input, pattern_map.at(v)};
if (pattern_map.find(mask) != pattern_map.end()) {
new_inputs.push_back(pattern_map.at(mask));
} else {
new_inputs.push_back(
ov::op::v0::Constant::create(new_scale_node.get_element_type(), ov::Shape{}, std::vector<float>{0.0f}));
}

new_inputs.push_back(new_scale_node);

auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs);
new_sdpa->set_friendly_name(sdpa->get_friendly_name());
ov::copy_runtime_info(sdpa, new_sdpa);
ov::replace_node(sdpa, new_sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Loading
Loading