-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRANSFORMATIONS][GPU] SDPA Fusion passes #28042
Open
vladimir-paramuzov
wants to merge
1
commit into
openvinotoolkit:master
Choose a base branch
from
vladimir-paramuzov:sdpa_fusion
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+822
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation. | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ | ||
/// │ │ │ | ||
/// │ │ │ | ||
/// ┌───┴───┐ ┌─────┴──────┐ │ | ||
/// │ MatMul│<──│ Transpose │ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └────────────┘ │ | ||
/// ┌───┴───┐ ┌─────────────┐ │ | ||
/// │ Add │<───│AttentionMask│ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └─────────────┘ │ | ||
/// ┌───┴───┐ │ | ||
/// │Softmax│ │ | ||
/// └───┬───┘ │ | ||
/// │ │ | ||
/// ┌───┴───┐ │ | ||
/// │ MatMul│<─────────────────────┘ | ||
/// └───┬───┘ | ||
/// ┌───┴───┐ | ||
/// │ Output│ | ||
/// └───────┘ | ||
/// | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ | ||
/// │ │ │ │ | ||
/// │ │ │ │ | ||
/// ┌───┴────────────┴────────────┴───────────────┴─┐ | ||
/// │ ScaledDotProductAttention │ | ||
/// └────────────────────┬──────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAFusion", "0"); | ||
SDPAFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
58 changes: 58 additions & 0 deletions
58
...common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │ | ||
/// │ │ │ └──────┬──────┘ └───────┬─────┘ | ||
/// │ │ │ │ | | ||
/// ┌───┴───┐ ┌───┴───┐ │ │ | | ||
/// │ Mul | │ Mul │ | │ | | ||
/// └───┬───┘ └───┬───┘ │ │ │ | ||
/// │ │ │ │ │ | ||
/// | │ │ │ │ | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | | ||
/// │ ScaledDotProductAttention │──────────────────┘ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘ | ||
/// │ │ │ │ | | ||
/// │ │ │ │ | | ||
/// | │ │ │ | | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | | ||
/// │ ScaledDotProductAttention │───────────┘ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
/// Multiply ops for Q and K are eliminated in the following cases: | ||
/// 1. Q_scale and K_scale are constant | ||
/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1 | ||
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAScaleFusion", "0"); | ||
SDPAScaleFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_fusion.hpp" | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/op/softmax.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAFusion::SDPAFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
|
||
auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}}); | ||
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); | ||
auto qk = qk_nt | qk_nn; | ||
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask}); | ||
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}}); | ||
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
|
||
auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) { | ||
auto q_pshape = qk_matmul->get_input_partial_shape(0); | ||
auto k_pshape = qk_matmul->get_input_partial_shape(1); | ||
|
||
const size_t q_head_size_idx = 3; | ||
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2; | ||
|
||
return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() && | ||
k_pshape[k_head_size_idx].is_static() && | ||
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length(); | ||
}; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto q_node = pattern_map.at(q); | ||
auto k_node = pattern_map.at(k); | ||
auto v_node = pattern_map.at(v); | ||
|
||
if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) { | ||
return false; | ||
} | ||
|
||
if (pattern_map.at(qk).get_target_inputs().size() > 1 || | ||
pattern_map.at(softmax).get_target_inputs().size() > 1) { | ||
return false; | ||
} | ||
if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 || | ||
pattern_map.at(mask).get_partial_shape().size() > 4)) { | ||
return false; | ||
} | ||
|
||
Output<ov::Node> mask_value; | ||
Output<ov::Node> mask_input; | ||
if (pattern_map.find(optional_add_mask) != pattern_map.end()) { | ||
mask_value = pattern_map.at(mask); | ||
} else { | ||
mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0}); | ||
} | ||
|
||
if (mask_value.get_partial_shape().size() > 4) { | ||
return false; | ||
} | ||
|
||
if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) { | ||
mask_input = mask_value; | ||
} else { | ||
size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size(); | ||
std::vector<int64_t> axes(rank_diff); | ||
std::iota(axes.begin(), axes.end(), 0); | ||
mask_input = std::make_shared<ov::op::v0::Unsqueeze>( | ||
mask_value, | ||
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes)); | ||
} | ||
|
||
std::shared_ptr<ov::Node> scale_node = | ||
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f}); | ||
|
||
std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node, | ||
k_node, | ||
v_node, | ||
mask_input, | ||
scale_node, | ||
false); | ||
|
||
sdpa->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
ov::copy_runtime_info(m.get_matched_nodes(), sdpa); | ||
ov::replace_node(m.get_match_root(), sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
140 changes: 140 additions & 0 deletions
140
src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAScaleFusion::SDPAScaleFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
auto sdpa_scale = makeConst({}); | ||
auto scale_q = makePattern("[]") | makePattern("[1]"); | ||
auto scale_k = makePattern("[]") | makePattern("[1]"); | ||
|
||
auto scaled_q = optional<ov::op::v1::Multiply>({q, scale_q}); | ||
auto scaled_k = optional<ov::op::v1::Multiply>({k, scale_k}); | ||
auto sdpa_mask_scale = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale}, | ||
{{"causal", false}}); | ||
auto sdpa_mask = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask}, {{"causal", false}}); | ||
auto sdpa_simple = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v}, {{"causal", false}}); | ||
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto sdpa = m.get_match_root(); | ||
|
||
const bool has_q_scale = pattern_map.count(scaled_q); | ||
const bool has_k_scale = pattern_map.count(scaled_k); | ||
|
||
// Nothing to do | ||
if (!has_q_scale && !has_k_scale) | ||
return false; | ||
|
||
auto prev_scale_value = 1.0f; | ||
auto scale_q_value = 1.0f; | ||
auto scale_k_value = 1.0f; | ||
auto scale_et = sdpa->get_output_element_type(0); | ||
|
||
Output<ov::Node> q_input = sdpa->get_input_source_output(0); | ||
Output<ov::Node> k_input = sdpa->get_input_source_output(1); | ||
|
||
std::shared_ptr<ov::Node> scale_q_node = nullptr; | ||
std::shared_ptr<ov::Node> scale_k_node = nullptr; | ||
|
||
if (pattern_map.find(sdpa_scale) != pattern_map.end()) { | ||
auto prev_scale_node = | ||
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr()); | ||
prev_scale_value = prev_scale_node->cast_vector<float>()[0]; | ||
scale_et = prev_scale_node->get_output_element_type(0); | ||
} else { | ||
auto head_size = q_input.get_partial_shape()[3]; | ||
if (head_size.is_dynamic()) | ||
return false; | ||
|
||
prev_scale_value = 1.0f / std::sqrt(static_cast<float>(head_size.get_length())); | ||
} | ||
|
||
// Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA | ||
if (has_q_scale) { | ||
scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr(); | ||
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) { | ||
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0]; | ||
q_input = pattern_map.at(q); | ||
} | ||
} | ||
if (has_k_scale) { | ||
scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr(); | ||
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) { | ||
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0]; | ||
k_input = pattern_map.at(k); | ||
} | ||
} | ||
|
||
Output<ov::Node> new_scale_node; | ||
auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value; | ||
|
||
// If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA | ||
if (new_scale_val == 1.0f) { | ||
if (has_q_scale && !ov::is_type<ov::op::v0::Constant>(scale_q_node)) { | ||
new_scale_node = pattern_map.at(scale_q); | ||
q_input = pattern_map.at(q); | ||
} else if (has_k_scale && !ov::is_type<ov::op::v0::Constant>(scale_k_node)) { | ||
new_scale_node = pattern_map.at(scale_k); | ||
k_input = pattern_map.at(k); | ||
} else { | ||
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val}); | ||
} | ||
} else { | ||
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val}); | ||
} | ||
|
||
OutputVector new_inputs = {q_input, k_input, pattern_map.at(v)}; | ||
if (pattern_map.find(mask) != pattern_map.end()) { | ||
new_inputs.push_back(pattern_map.at(mask)); | ||
} else { | ||
new_inputs.push_back( | ||
ov::op::v0::Constant::create(new_scale_node.get_element_type(), ov::Shape{}, std::vector<float>{0.0f})); | ||
} | ||
|
||
new_inputs.push_back(new_scale_node); | ||
|
||
auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs); | ||
new_sdpa->set_friendly_name(sdpa->get_friendly_name()); | ||
ov::copy_runtime_info(sdpa, new_sdpa); | ||
ov::replace_node(sdpa, new_sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this restriction on the number of mask consumers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not mask consumers, but Add op consumers.