forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TRANSFORMATIONS][GPU] SDPA Fusion passes
Signed-off-by: Vladimir Paramuzov <[email protected]>
- Loading branch information
1 parent
702ce05
commit fbedd1e
Showing
8 changed files
with
591 additions
and
0 deletions.
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation. | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ | ||
/// │ │ │ | ||
/// │ │ │ | ||
/// ┌───┴───┐ ┌─────┴──────┐ │ | ||
/// │ MatMul│<──│ Transpose │ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └────────────┘ │ | ||
/// ┌───┴───┐ ┌─────────────┐ │ | ||
/// │ Add │<───│AttentionMask│ │ | ||
/// └───┬───┘ | (Optional) │ │ | ||
/// │ └─────────────┘ │ | ||
/// ┌───┴───┐ │ | ||
/// │Softmax│ │ | ||
/// └───┬───┘ │ | ||
/// │ │ | ||
/// ┌───┴───┐ │ | ||
/// │ MatMul│<─────────────────────┘ | ||
/// └───┬───┘ | ||
/// ┌───┴───┐ | ||
/// │ Output│ | ||
/// └───────┘ | ||
/// | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ | ||
/// │ │ │ │ | ||
/// │ │ │ │ | ||
/// ┌───┴────────────┴────────────┴───────────────┴─┐ | ||
/// │ ScaledDotProductAttention │ | ||
/// └────────────────────┬──────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAFusion", "0"); | ||
SDPAFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
55 changes: 55 additions & 0 deletions
55
...common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/matcher_pass.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op | ||
/// Before: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ | ||
/// │ │ │ └──────┬──────┘ | ||
/// │ │ │ │ | ||
/// ┌───┴───┐ ┌───┴───┐ │ │ | ||
/// │ Mul | │ Mul │ | │ | ||
/// └───┬───┘ └───┬───┘ │ │ | ||
/// │ │ │ │ | ||
/// | │ │ │ | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | ||
/// │ ScaledDotProductAttention │ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
/// After: | ||
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐ | ||
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | | ||
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘ | ||
/// │ │ │ │ | | ||
/// │ │ │ │ | | ||
/// | │ │ │ | | ||
/// ┌───┴────────────┴────────────┴─────────────┴─┐ | | ||
/// │ ScaledDotProductAttention │───────────┘ | ||
/// └────────────────────┬────────────────────────┘ | ||
/// │ | ||
/// │ | ||
/// ┌────┴────┐ | ||
/// │ Output │ | ||
/// └─────────┘ | ||
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("SDPAScaleFusion", "0"); | ||
SDPAScaleFusion(); | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_fusion.hpp" | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/matmul.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/op/softmax.hpp" | ||
#include "openvino/op/transpose.hpp" | ||
#include "openvino/pass/pattern/op/optional.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAFusion::SDPAFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
|
||
auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}}); | ||
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t | k}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); | ||
auto qk = qk_nt | qk_nn; | ||
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask}); | ||
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}}); | ||
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
|
||
auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) { | ||
auto q_pshape = qk_matmul->get_input_partial_shape(0); | ||
auto k_pshape = qk_matmul->get_input_partial_shape(1); | ||
|
||
const size_t q_head_size_idx = 3; | ||
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2; | ||
|
||
return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() && | ||
k_pshape[k_head_size_idx].is_static() && | ||
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length(); | ||
}; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto q_node = pattern_map.at(q); | ||
auto k_node = pattern_map.at(k); | ||
auto v_node = pattern_map.at(v); | ||
|
||
if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) { | ||
return false; | ||
} | ||
|
||
std::shared_ptr<ov::Node> mask_node = nullptr; | ||
if (pattern_map.find(optional_add_mask) != pattern_map.end()) { | ||
mask_node = pattern_map.at(mask).get_node_shared_ptr(); | ||
} else { | ||
mask_node = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0}); | ||
} | ||
|
||
std::shared_ptr<ov::Node> scale_node = | ||
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f}); | ||
|
||
std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node, | ||
k_node, | ||
v_node, | ||
mask_node, | ||
scale_node, | ||
false); | ||
|
||
sdpa->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
ov::copy_runtime_info(m.get_matched_nodes(), sdpa); | ||
ov::replace_node(m.get_match_root(), sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
83 changes: 83 additions & 0 deletions
83
src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" | ||
|
||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/core/type.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/scaled_dot_product_attention.hpp" | ||
#include "openvino/pass/pattern/op/pattern.hpp" | ||
#include "transformations/utils/gen_pattern.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
SDPAScaleFusion::SDPAScaleFusion() { | ||
using namespace ov::pass::pattern; | ||
using namespace ov::gen_pattern; | ||
|
||
auto q = makePattern(ov::Rank(4)); | ||
auto k = makePattern(ov::Rank(4)); | ||
auto v = makePattern(ov::Rank(4)); | ||
auto mask = makePattern(); | ||
auto sdpa_scale = makePattern("[]"); | ||
auto scale_q = makePattern("[]"); | ||
auto scale_k = makePattern("[]"); | ||
|
||
auto scaled_q = makePattern<ov::op::v1::Multiply>({q, scale_q}); | ||
auto scaled_k = makePattern<ov::op::v1::Multiply>({k, scale_k}); | ||
auto sdpa_mask_scale = | ||
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale}); | ||
auto sdpa_mask = makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask}); | ||
auto sdpa_simple = makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v}); | ||
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale; | ||
|
||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
if (transformation_callback(m.get_match_root())) { | ||
return false; | ||
} | ||
|
||
auto sdpa = m.get_match_root(); | ||
|
||
auto prev_scale_value = 1.0f; | ||
auto scale_et = sdpa->get_output_element_type(0); | ||
if (pattern_map.find(sdpa_scale) != pattern_map.end()) { | ||
auto prev_scale_node = | ||
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr()); | ||
prev_scale_value = prev_scale_node->cast_vector<float>()[0]; | ||
scale_et = prev_scale_node->get_output_element_type(0); | ||
} | ||
auto scale_q_node = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(scale_q).get_node_shared_ptr()); | ||
auto scale_k_node = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(scale_k).get_node_shared_ptr()); | ||
auto new_scale_val = | ||
prev_scale_value * scale_q_node->cast_vector<float>()[0] * scale_k_node->cast_vector<float>()[0]; | ||
auto new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val}); | ||
|
||
OutputVector new_inputs = {pattern_map.at(q), pattern_map.at(k), pattern_map.at(v)}; | ||
if (pattern_map.find(mask) != pattern_map.end()) { | ||
new_inputs.push_back(pattern_map.at(mask)); | ||
} else { | ||
new_inputs.push_back(ov::op::v0::Constant::create(new_scale_node->get_output_element_type(0), | ||
ov::Shape{}, | ||
std::vector<float>{0.0f})); | ||
} | ||
|
||
new_inputs.push_back(new_scale_node); | ||
|
||
auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs); | ||
new_sdpa->set_friendly_name(sdpa->get_friendly_name()); | ||
ov::copy_runtime_info(sdpa, new_sdpa); | ||
ov::replace_node(sdpa, new_sdpa); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace pass | ||
} // namespace ov |
Oops, something went wrong.