-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TRANSFORMATIONS][GPU] SDPA Fusion passes
Signed-off-by: Vladimir Paramuzov <[email protected]>
- Loading branch information
1 parent
357eb54
commit 83ccfe3
Showing
8 changed files
with
822 additions
and
0 deletions.
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation. | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ | ||
/// │ │ │ | ||
/// │ │ │ | ||
/// ┌───┴───┐ ┌─────┴──────┐ │ | ||
/// │ MatMul│<──│ Transpose │ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └────────────┘ │ | ||
/// ┌───┴───┐ ┌─────────────┐ │ | ||
/// │ Add │<───│AttentionMask│ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └─────────────┘ │ | ||
/// ┌───┴───┐ │ | ||
/// │Softmax│ │ | ||
/// └───┬───┘ │ | ||
/// │ │ | ||
/// ┌───┴───┐ │ | ||
/// │ MatMul│<─────────────────────┘ | ||
/// └───┬───┘ | ||
/// ┌───┴───┐ | ||
/// │ Output│ | ||
/// └───────┘ | ||
/// | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ | ||
/// │ │ │ │ | ||
/// │ │ │ │ | ||
/// ┌───┴────────────┴────────────┴───────────────┴─┐ | ||
/// │ ScaledDotProductAttention │ | ||
/// └────────────────────┬──────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAFusion", "0"); | ||
SDPAFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
58 changes: 58 additions & 0 deletions
58
...common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │ | ||
/// │ │ │ └──────┬──────┘ └───────┬─────┘ | ||
/// │ │ │ │ | | ||
/// ┌───┴───┐ ┌───┴───┐ │ │ | | ||
/// │ Mul | │ Mul │ | │ | | ||
/// └───┬───┘ └───┬───┘ │ │ │ | ||
/// │ │ │ │ │ | ||
/// | │ │ │ │ | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | | ||
/// │ ScaledDotProductAttention │──────────────────┘ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘ | ||
/// │ │ │ │ | | ||
/// │ │ │ │ | | ||
/// | │ │ │ | | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | | ||
/// │ ScaledDotProductAttention │───────────┘ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
/// Multiply ops for Q and K are eliminated in the following cases: | ||
/// 1. Q_scale and K_scale are constant | ||
/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1 | ||
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAScaleFusion", "0"); | ||
SDPAScaleFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_fusion.hpp" | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/op/softmax.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/op/unsqueeze.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAFusion::SDPAFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
|
||
auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}}); | ||
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); | ||
auto qk = qk_nt | qk_nn; | ||
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask}); | ||
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}}); | ||
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
|
||
auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) { | ||
auto q_pshape = qk_matmul->get_input_partial_shape(0); | ||
auto k_pshape = qk_matmul->get_input_partial_shape(1); | ||
|
||
const size_t q_head_size_idx = 3; | ||
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2; | ||
|
||
return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() && | ||
k_pshape[k_head_size_idx].is_static() && | ||
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length(); | ||
}; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto q_node = pattern_map.at(q); | ||
auto k_node = pattern_map.at(k); | ||
auto v_node = pattern_map.at(v); | ||
|
||
if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) { | ||
return false; | ||
} | ||
|
||
if (pattern_map.at(qk).get_target_inputs().size() > 1 || | ||
pattern_map.at(softmax).get_target_inputs().size() > 1) { | ||
return false; | ||
} | ||
if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 || | ||
pattern_map.at(mask).get_partial_shape().size() > 4)) { | ||
return false; | ||
} | ||
|
||
Output<ov::Node> mask_value; | ||
Output<ov::Node> mask_input; | ||
if (pattern_map.find(optional_add_mask) != pattern_map.end()) { | ||
mask_value = pattern_map.at(mask); | ||
} else { | ||
mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0}); | ||
} | ||
|
||
if (mask_value.get_partial_shape().size() > 4) { | ||
return false; | ||
} | ||
|
||
if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) { | ||
mask_input = mask_value; | ||
} else { | ||
size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size(); | ||
std::vector<int64_t> axes(rank_diff); | ||
std::iota(axes.begin(), axes.end(), 0); | ||
mask_input = std::make_shared<ov::op::v0::Unsqueeze>( | ||
mask_value, | ||
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes)); | ||
} | ||
|
||
std::shared_ptr<ov::Node> scale_node = | ||
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f}); | ||
|
||
std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node, | ||
k_node, | ||
v_node, | ||
mask_input, | ||
scale_node, | ||
false); | ||
|
||
sdpa->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
ov::copy_runtime_info(m.get_matched_nodes(), sdpa); | ||
ov::replace_node(m.get_match_root(), sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
140 changes: 140 additions & 0 deletions
140
src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" | ||
|
||
#include <memory> | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAScaleFusion::SDPAScaleFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
auto sdpa_scale = makeConst({}); | ||
auto scale_q = makePattern("[]") | makePattern("[1]"); | ||
auto scale_k = makePattern("[]") | makePattern("[1]"); | ||
|
||
auto scaled_q = optional<ov::op::v1::Multiply>({q, scale_q}); | ||
auto scaled_k = optional<ov::op::v1::Multiply>({k, scale_k}); | ||
auto sdpa_mask_scale = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale}, | ||
{{"causal", false}}); | ||
auto sdpa_mask = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask}, {{"causal", false}}); | ||
auto sdpa_simple = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v}, {{"causal", false}}); | ||
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto sdpa = m.get_match_root(); | ||
|
||
const bool has_q_scale = pattern_map.count(scaled_q); | ||
const bool has_k_scale = pattern_map.count(scaled_k); | ||
|
||
// Nothing to do | ||
if (!has_q_scale && !has_k_scale) | ||
return false; | ||
|
||
auto prev_scale_value = 1.0f; | ||
auto scale_q_value = 1.0f; | ||
auto scale_k_value = 1.0f; | ||
auto scale_et = sdpa->get_output_element_type(0); | ||
|
||
Output<ov::Node> q_input = sdpa->get_input_source_output(0); | ||
Output<ov::Node> k_input = sdpa->get_input_source_output(1); | ||
|
||
std::shared_ptr<ov::Node> scale_q_node = nullptr; | ||
std::shared_ptr<ov::Node> scale_k_node = nullptr; | ||
|
||
if (pattern_map.find(sdpa_scale) != pattern_map.end()) { | ||
auto prev_scale_node = | ||
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr()); | ||
prev_scale_value = prev_scale_node->cast_vector<float>()[0]; | ||
scale_et = prev_scale_node->get_output_element_type(0); | ||
} else { | ||
auto head_size = q_input.get_partial_shape()[3]; | ||
if (head_size.is_dynamic()) | ||
return false; | ||
|
||
prev_scale_value = 1.0f / std::sqrt(static_cast<float>(head_size.get_length())); | ||
} | ||
|
||
// Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA | ||
if (has_q_scale) { | ||
scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr(); | ||
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) { | ||
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0]; | ||
q_input = pattern_map.at(q); | ||
} | ||
} | ||
if (has_k_scale) { | ||
scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr(); | ||
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) { | ||
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0]; | ||
k_input = pattern_map.at(k); | ||
} | ||
} | ||
|
||
Output<ov::Node> new_scale_node; | ||
auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value; | ||
|
||
// If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA | ||
if (new_scale_val == 1.0f) { | ||
if (has_q_scale && !ov::is_type<ov::op::v0::Constant>(scale_q_node)) { | ||
new_scale_node = pattern_map.at(scale_q); | ||
q_input = pattern_map.at(q); | ||
} else if (has_k_scale && !ov::is_type<ov::op::v0::Constant>(scale_k_node)) { | ||
new_scale_node = pattern_map.at(scale_k); | ||
k_input = pattern_map.at(k); | ||
} else { | ||
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val}); | ||
} | ||
} else { | ||
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val}); | ||
} | ||
|
||
OutputVector new_inputs = {q_input, k_input, pattern_map.at(v)}; | ||
if (pattern_map.find(mask) != pattern_map.end()) { | ||
new_inputs.push_back(pattern_map.at(mask)); | ||
} else { | ||
new_inputs.push_back( | ||
ov::op::v0::Constant::create(new_scale_node.get_element_type(), ov::Shape{}, std::vector<float>{0.0f})); | ||
} | ||
|
||
new_inputs.push_back(new_scale_node); | ||
|
||
auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs); | ||
new_sdpa->set_friendly_name(sdpa->get_friendly_name()); | ||
ov::copy_runtime_info(sdpa, new_sdpa); | ||
ov::replace_node(sdpa, new_sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
Oops, something went wrong.