Skip to content

Commit

Permalink
[TRANSFORMATIONS][GPU] SDPA Fusion passes
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Paramuzov <[email protected]>
  • Loading branch information
vladimir-paramuzov committed Dec 12, 2024
1 parent 702ce05 commit fbedd1e
Show file tree
Hide file tree
Showing 8 changed files with 591 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation.
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐
/// │ Q │ │ K │ │ V │
/// └───┬───┘ └───┬───┘ └───┬───┘
/// │ │ │
/// │ │ │
/// ┌───┴───┐ ┌─────┴──────┐ │
/// │ MatMul│<──│ Transpose │ │
/// └───┬───┘ | (Optional) │ │
/// │ └────────────┘ │
/// ┌───┴───┐ ┌─────────────┐ │
/// │ Add │<───│AttentionMask│ │
/// └───┬───┘ | (Optional) │ │
/// │ └─────────────┘ │
/// ┌───┴───┐ │
/// │Softmax│ │
/// └───┬───┘ │
/// │ │
/// ┌───┴───┐ │
/// │ MatMul│<─────────────────────┘
/// └───┬───┘
/// ┌───┴───┐
/// │ Output│
/// └───────┘
///
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘
/// │ │ │ │
/// │ │ │ │
/// ┌───┴────────────┴────────────┴───────────────┴─┐
/// │ ScaledDotProductAttention │
/// └────────────────────┬──────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAFusion", "0");
SDPAFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │
/// │ │ │ └──────┬──────┘
/// │ │ │ │
/// ┌───┴───┐ ┌───┴───┐ │ │
/// │ Mul | │ Mul │ | │
/// └───┬───┘ └───┬───┘ │ │
/// │ │ │ │
/// | │ │ │
/// ┌───┴────────────┴────────────┴─────────────┴─┐
/// │ ScaledDotProductAttention │
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘
/// │ │ │ │ |
/// │ │ │ │ |
/// | │ │ │ |
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │───────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAScaleFusion", "0");
SDPAScaleFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/common_optimizations/sdpa_fusion.hpp"
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
Expand Down Expand Up @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence)
ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes)
ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion)
ADD_MATCHER(common_fusions, SDPAFusion)
if (m_use_shapes) {
ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAFusion::SDPAFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();

auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}});
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t | k}, {{"transpose_a", false}, {"transpose_b", false}});
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}});
auto qk = qk_nt | qk_nn;
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask});
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}});
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}});

auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) {
auto q_pshape = qk_matmul->get_input_partial_shape(0);
auto k_pshape = qk_matmul->get_input_partial_shape(1);

const size_t q_head_size_idx = 3;
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2;

return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() &&
k_pshape[k_head_size_idx].is_static() &&
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length();
};

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto q_node = pattern_map.at(q);
auto k_node = pattern_map.at(k);
auto v_node = pattern_map.at(v);

if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) {
return false;
}

std::shared_ptr<ov::Node> mask_node = nullptr;
if (pattern_map.find(optional_add_mask) != pattern_map.end()) {
mask_node = pattern_map.at(mask).get_node_shared_ptr();
} else {
mask_node = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0});
}

std::shared_ptr<ov::Node> scale_node =
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f});

std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node,
k_node,
v_node,
mask_node,
scale_node,
false);

sdpa->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
ov::replace_node(m.get_match_root(), sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_scale_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAScaleFusion::SDPAScaleFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();
auto sdpa_scale = makePattern("[]");
auto scale_q = makePattern("[]");
auto scale_k = makePattern("[]");

auto scaled_q = makePattern<ov::op::v1::Multiply>({q, scale_q});
auto scaled_k = makePattern<ov::op::v1::Multiply>({k, scale_k});
auto sdpa_mask_scale =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale});
auto sdpa_mask = makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask});
auto sdpa_simple = makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v});
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale;

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto sdpa = m.get_match_root();

auto prev_scale_value = 1.0f;
auto scale_et = sdpa->get_output_element_type(0);
if (pattern_map.find(sdpa_scale) != pattern_map.end()) {
auto prev_scale_node =
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr());
prev_scale_value = prev_scale_node->cast_vector<float>()[0];
scale_et = prev_scale_node->get_output_element_type(0);
}
auto scale_q_node = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(scale_q).get_node_shared_ptr());
auto scale_k_node = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(scale_k).get_node_shared_ptr());
auto new_scale_val =
prev_scale_value * scale_q_node->cast_vector<float>()[0] * scale_k_node->cast_vector<float>()[0];
auto new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val});

OutputVector new_inputs = {pattern_map.at(q), pattern_map.at(k), pattern_map.at(v)};
if (pattern_map.find(mask) != pattern_map.end()) {
new_inputs.push_back(pattern_map.at(mask));
} else {
new_inputs.push_back(ov::op::v0::Constant::create(new_scale_node->get_output_element_type(0),
ov::Shape{},
std::vector<float>{0.0f}));
}

new_inputs.push_back(new_scale_node);

auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs);
new_sdpa->set_friendly_name(sdpa->get_friendly_name());
ov::copy_runtime_info(sdpa, new_sdpa);
ov::replace_node(sdpa, new_sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Loading

0 comments on commit fbedd1e

Please sign in to comment.