From fbedd1eb39b03e62e26849ebbd7c83dd84607550 Mon Sep 17 00:00:00 2001 From: Vladimir Paramuzov Date: Fri, 6 Dec 2024 09:24:06 +0400 Subject: [PATCH] [TRANSFORMATIONS][GPU] SDPA Fusion passes Signed-off-by: Vladimir Paramuzov --- .../common_optimizations/sdpa_fusion.hpp | 60 ++++++ .../sdpa_scale_fusion.hpp | 55 +++++ .../moc_transformations.cpp | 2 + .../common_optimizations/sdpa_fusion.cpp | 94 +++++++++ .../sdpa_scale_fusion.cpp | 83 ++++++++ .../common_optimizations/sdpa_fusion_test.cpp | 193 ++++++++++++++++++ .../sdpa_scale_fusion_test.cpp | 102 +++++++++ .../src/plugin/transformations_pipeline.cpp | 2 + 8 files changed, 591 insertions(+) create mode 100644 src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp create mode 100644 src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp create mode 100644 src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp create mode 100644 src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp new file mode 100644 index 00000000000000..87083900336c75 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/sdpa_fusion.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation. +/// Before: +/// ┌───────┐ ┌───────┐ ┌───────┐ +/// │ Q │ │ K │ │ V │ +/// └───┬───┘ └───┬───┘ └───┬───┘ +/// │ │ │ +/// │ │ │ +/// ┌───┴───┐ ┌─────┴──────┐ │ +/// │ MatMul│<──│ Transpose │ │ +/// └───┬───┘ | (Optional) │ │ +/// │ └────────────┘ │ +/// ┌───┴───┐ ┌─────────────┐ │ +/// │ Add │<───│AttentionMask│ │ +/// └───┬───┘ | (Optional) │ │ +/// │ └─────────────┘ │ +/// ┌───┴───┐ │ +/// │Softmax│ │ +/// └───┬───┘ │ +/// │ │ +/// ┌───┴───┐ │ +/// │ MatMul│<─────────────────────┘ +/// └───┬───┘ +/// ┌───┴───┐ +/// │ Output│ +/// └───────┘ +/// +/// After: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ +/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ +/// │ │ │ │ +/// │ │ │ │ +/// ┌───┴────────────┴────────────┴───────────────┴─┐ +/// │ ScaledDotProductAttention │ +/// └────────────────────┬──────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("SDPAFusion", "0"); + SDPAFusion(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp new file mode 100644 index 00000000000000..dff66c32494f78 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/sdpa_scale_fusion.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op +/// Before: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ +/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ +/// │ │ │ └──────┬──────┘ +/// │ │ │ │ +/// ┌───┴───┐ ┌───┴───┐ │ │ +/// │ Mul | │ Mul │ | │ +/// └───┬───┘ └───┬───┘ │ │ +/// │ │ │ │ +/// | │ │ │ +/// ┌───┴────────────┴────────────┴─────────────┴─┐ +/// │ ScaledDotProductAttention │ +/// └────────────────────┬────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +/// After: +/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐ +/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale | +/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘ +/// │ │ │ │ | +/// │ │ │ │ | +/// | │ │ │ | +/// ┌───┴────────────┴────────────┴─────────────┴─┐ | +/// │ ScaledDotProductAttention │───────────┘ +/// └────────────────────┬────────────────────────┘ +/// │ +/// │ +/// ┌────┴────┐ +/// │ Output │ +/// └─────────┘ +class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("SDPAScaleFusion", "0"); + SDPAScaleFusion(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 282fc69486b923..10f7b9a989f48a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -65,6 +65,7 @@ #include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp" #include "transformations/common_optimizations/reshape_sequence_fusion.hpp" #include "transformations/common_optimizations/ric_fusion.hpp" +#include "transformations/common_optimizations/sdpa_fusion.hpp" #include "transformations/common_optimizations/select_with_one_value_condition.hpp" #include "transformations/common_optimizations/sequence_fusion.hpp" #include "transformations/common_optimizations/shared_ops_optimization.hpp" @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence) ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes) ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion) + ADD_MATCHER(common_fusions, SDPAFusion) if (m_use_shapes) { ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion) } diff --git a/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp new file mode 100644 index 00000000000000..8616b40e831d0a --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp @@ -0,0 +1,94 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/sdpa_fusion.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/core/type.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/op/optional.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "transformations/utils/gen_pattern.hpp" + +namespace ov { +namespace pass { + +SDPAFusion::SDPAFusion() { + using namespace ov::pass::pattern; + using namespace ov::gen_pattern; + + auto q = makePattern(ov::Rank(4)); + auto k = makePattern(ov::Rank(4)); + auto v = makePattern(ov::Rank(4)); + auto mask = makePattern(); + + auto k_t = makePattern({k, {0, 1, 3, 2}}); + auto qk_nn = makePattern({q, k_t | k}, {{"transpose_a", false}, {"transpose_b", false}}); + auto qk_nt = makePattern({q, k}, {{"transpose_a", false}, {"transpose_b", true}}); + auto qk = qk_nt | qk_nn; + auto optional_add_mask = optional({qk, mask}); + auto softmax = makePattern({optional_add_mask}, {{"axis", "-1"}}); + auto qkv = makePattern({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); + + auto valid_qk_shapes = [](const std::shared_ptr& qk_matmul) { + auto q_pshape = qk_matmul->get_input_partial_shape(0); + auto k_pshape = qk_matmul->get_input_partial_shape(1); + + const size_t q_head_size_idx = 3; + const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2; + + return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() && + k_pshape[k_head_size_idx].is_static() && + q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length(); + }; + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + if (transformation_callback(m.get_match_root())) { + return false; + } + + auto q_node = pattern_map.at(q); + auto k_node = pattern_map.at(k); + auto v_node = pattern_map.at(v); + + if (!valid_qk_shapes(ov::as_type_ptr(pattern_map.at(qk).get_node_shared_ptr()))) { + return false; + } + + std::shared_ptr mask_node = nullptr; + if (pattern_map.find(optional_add_mask) != pattern_map.end()) { + mask_node = pattern_map.at(mask).get_node_shared_ptr(); + } else { + mask_node = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector{0}); + } + + std::shared_ptr scale_node = + ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector{1.0f}); + + std::shared_ptr sdpa = std::make_shared(q_node, + k_node, + v_node, + mask_node, + scale_node, + false); + + sdpa->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa); + ov::replace_node(m.get_match_root(), sdpa); + + return true; + }; + + auto m = std::make_shared(qkv, "SDPAFusion"); + this->register_matcher(m, callback); +} + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp new file mode 100644 index 00000000000000..7b2b83cf889dd8 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/sdpa_scale_fusion.cpp @@ -0,0 +1,83 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/core/type.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "transformations/utils/gen_pattern.hpp" + +namespace ov { +namespace pass { + +SDPAScaleFusion::SDPAScaleFusion() { + using namespace ov::pass::pattern; + using namespace ov::gen_pattern; + + auto q = makePattern(ov::Rank(4)); + auto k = makePattern(ov::Rank(4)); + auto v = makePattern(ov::Rank(4)); + auto mask = makePattern(); + auto sdpa_scale = makePattern("[]"); + auto scale_q = makePattern("[]"); + auto scale_k = makePattern("[]"); + + auto scaled_q = makePattern({q, scale_q}); + auto scaled_k = makePattern({k, scale_k}); + auto sdpa_mask_scale = + makePattern({scaled_q, scaled_k, v, mask, sdpa_scale}); + auto sdpa_mask = makePattern({scaled_q, scaled_k, v, mask}); + auto sdpa_simple = makePattern({scaled_q, scaled_k, v}); + auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale; + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + if (transformation_callback(m.get_match_root())) { + return false; + } + + auto sdpa = m.get_match_root(); + + auto prev_scale_value = 1.0f; + auto scale_et = sdpa->get_output_element_type(0); + if (pattern_map.find(sdpa_scale) != pattern_map.end()) { + auto prev_scale_node = + ov::as_type_ptr(pattern_map.at(sdpa_scale).get_node_shared_ptr()); + prev_scale_value = prev_scale_node->cast_vector()[0]; + scale_et = prev_scale_node->get_output_element_type(0); + } + auto scale_q_node = ov::as_type_ptr(pattern_map.at(scale_q).get_node_shared_ptr()); + auto scale_k_node = ov::as_type_ptr(pattern_map.at(scale_k).get_node_shared_ptr()); + auto new_scale_val = + prev_scale_value * scale_q_node->cast_vector()[0] * scale_k_node->cast_vector()[0]; + auto new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector{new_scale_val}); + + OutputVector new_inputs = {pattern_map.at(q), pattern_map.at(k), pattern_map.at(v)}; + if (pattern_map.find(mask) != pattern_map.end()) { + new_inputs.push_back(pattern_map.at(mask)); + } else { + new_inputs.push_back(ov::op::v0::Constant::create(new_scale_node->get_output_element_type(0), + ov::Shape{}, + std::vector{0.0f})); + } + + new_inputs.push_back(new_scale_node); + + auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs); + new_sdpa->set_friendly_name(sdpa->get_friendly_name()); + ov::copy_runtime_info(sdpa, new_sdpa); + ov::replace_node(sdpa, new_sdpa); + + return true; + }; + + auto m = std::make_shared(sdpa, "SDPAScaleFusion"); + this->register_matcher(m, callback); +} + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp new file mode 100644 index 00000000000000..6762d8eb4e179f --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp @@ -0,0 +1,193 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/transpose.hpp" + +using namespace testing; +using namespace ov::pass; +using namespace ov; + +TEST_F(TransformationTestsF, SDPAFusionTest1) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest2) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, true); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest3) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto casual = false; + { + const auto key_t = + std::make_shared(key, + op::v0::Constant::create(element::i64, Shape{4}, {0, 1, 3, 2})); + const auto qk = std::make_shared(query, key_t, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest4) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, 32, -1}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, false); + const auto softmax = std::make_shared(qk, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto mask_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{0.0f}); + const auto sdpa = std::make_shared(query, + key, + value, + mask_const, + scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAFusionTest5) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, 32, -1}; + const PartialShape value_shape{1, 32, -1, 32}; + const PartialShape attention_mask_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f16, query_shape); + const auto key = std::make_shared(element::f16, key_shape); + const auto value = std::make_shared(element::f16, value_shape); + const auto mask = std::make_shared(element::f16, attention_mask_shape); + const auto casual = false; + { + const auto qk = std::make_shared(query, key, false, false); + const auto mask_add = std::make_shared(qk, mask); + const auto softmax = std::make_shared(mask_add, -1); + const auto qkv = std::make_shared(softmax, value, false, false); + + model = std::make_shared(NodeVector{qkv}, ParameterVector{query, key, value, mask}); + manager.register_pass(); + } + + { + const auto scale_const = ov::op::v0::Constant::create(element::f16, ov::Shape{}, std::vector{1.0f}); + const auto sdpa = + std::make_shared(query, key, value, mask, scale_const, casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value, mask}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} diff --git a/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp new file mode 100644 index 00000000000000..e15f913758a0ec --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/sdpa_scale_fusion_test.cpp @@ -0,0 +1,102 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" + +using namespace testing; +using namespace ov::pass; +using namespace ov; + +TEST_F(TransformationTestsF, SDPAScaleFusionTest1) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_const); + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = + std::make_shared(q_scaled, k_scaled, v_scaled, casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto new_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto new_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{64.0f}); + const auto sdpa = std::make_shared(query, + key, + v_scaled, + new_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, SDPAScaleFusionTest2) { + const PartialShape query_shape{1, 32, -1, 32}; + const PartialShape key_shape{1, 32, -1, 32}; + const PartialShape value_shape{1, 32, -1, 32}; + + const auto query = std::make_shared(element::f32, query_shape); + const auto key = std::make_shared(element::f32, key_shape); + const auto value = std::make_shared(element::f32, value_shape); + const auto sdpa_mask_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{0.0f}); + const auto sdpa_scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{2.0f}); + const auto scale_const = ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{8.0f}); + const auto v_scaled = std::make_shared(value, scale_const); + const auto casual = false; + { + const auto q_scaled = std::make_shared(query, scale_const); + const auto k_scaled = std::make_shared(key, scale_const); + const auto sdpa = std::make_shared(q_scaled, + k_scaled, + v_scaled, + sdpa_mask_const, + sdpa_scale_const, + casual); + + model = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + manager.register_pass(); + } + + { + const auto new_scale_const = + ov::op::v0::Constant::create(element::f32, ov::Shape{}, std::vector{128.0f}); + const auto sdpa = std::make_shared(query, + key, + v_scaled, + sdpa_mask_const, + new_scale_const, + casual); + model_ref = std::make_shared(NodeVector{sdpa}, ParameterVector{query, key, value}); + } + + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 50eecf51b945b7..e73acab9832a32 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -92,6 +92,7 @@ #include "transformations/common_optimizations/lstm_cell_fusion.hpp" #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" #include "transformations/common_optimizations/mvn_fusion.hpp" +#include "transformations/common_optimizations/sdpa_scale_fusion.hpp" #include "transformations/common_optimizations/softmax_fusion.hpp" #include "transformations/common_optimizations/glu_fusion.hpp" #include "transformations/common_optimizations/transpose_sinking.hpp" @@ -936,6 +937,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { if (!disable_horizontal_fc_fusion) manager.register_pass(); + manager.register_pass(); manager.register_pass(); auto pass_config = manager.get_pass_config(); manager.register_pass();