Skip to content

Commit

Permalink
[TRANSFORMATIONS][GPU] SDPA Fusion passes
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Paramuzov <[email protected]>
  • Loading branch information
vladimir-paramuzov committed Dec 16, 2024
1 parent 357eb54 commit 83ccfe3
Show file tree
Hide file tree
Showing 8 changed files with 822 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// This pass transforms the following sub-graph to a single Scaled Dot Product Attention operation.
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐
/// │ Q │ │ K │ │ V │
/// └───┬───┘ └───┬───┘ └───┬───┘
/// │ │ │
/// │ │ │
/// ┌───┴───┐ ┌─────┴──────┐ │
/// │ MatMul│<──│ Transpose │ │
/// └───┬───┘ | (Optional) │ │
/// │ └────────────┘ │
/// ┌───┴───┐ ┌─────────────┐ │
/// │ Add │<───│AttentionMask│ │
/// └───┬───┘ | (Optional) │ │
/// │ └─────────────┘ │
/// ┌───┴───┐ │
/// │Softmax│ │
/// └───┬───┘ │
/// │ │
/// ┌───┴───┐ │
/// │ MatMul│<─────────────────────┘
/// └───┬───┘
/// ┌───┴───┐
/// │ Output│
/// └───────┘
///
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘
/// │ │ │ │
/// │ │ │ │
/// ┌───┴────────────┴────────────┴───────────────┴─┐
/// │ ScaledDotProductAttention │
/// └────────────────────┬──────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
class TRANSFORMATIONS_API SDPAFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAFusion", "0");
SDPAFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

/// Merges explicit multiplication by scalar value for Q and K into scale attribute of SDPA op
/// Before:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌─────────────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ │ (Optional) │ │ (Optional) │
/// │ │ │ └──────┬──────┘ └───────┬─────┘
/// │ │ │ │ |
/// ┌───┴───┐ ┌───┴───┐ │ │ |
/// │ Mul | │ Mul │ | │ |
/// └───┬───┘ └───┬───┘ │ │ │
/// │ │ │ │ │
/// | │ │ │ │
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │──────────────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// After:
/// ┌───────┐ ┌───────┐ ┌───────┐ ┌─────────────┐ ┌───────┐
/// │ Q │ │ K │ │ V │ │AttentionMask│ │ Scale |
/// └───┬───┘ └───┬───┘ └───┬───┘ └──────┬──────┘ └───┬───┘
/// │ │ │ │ |
/// │ │ │ │ |
/// | │ │ │ |
/// ┌───┴────────────┴────────────┴─────────────┴─┐ |
/// │ ScaledDotProductAttention │───────────┘
/// └────────────────────┬────────────────────────┘
/// │
/// │
/// ┌────┴────┐
/// │ Output │
/// └─────────┘
/// Multiply ops for Q and K are eliminated in the following cases:
/// 1. Q_scale and K_scale are constant
/// 2. Q_scale * SDPA_Scale == 1 or K_scale * SDPA_Scale == 1
class TRANSFORMATIONS_API SDPAScaleFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SDPAScaleFusion", "0");
SDPAScaleFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/common_optimizations/sdpa_fusion.hpp"
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
Expand Down Expand Up @@ -229,6 +230,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ConvertTensorIteratorToSequence)
ADD_MATCHER(common_fusions, SplitConcatPairToInterpolateFusion, m_use_shapes)
ADD_MATCHER(common_fusions, ConvolutionToGroupConvolutionFusion)
ADD_MATCHER(common_fusions, SDPAFusion)
if (m_use_shapes) {
ADD_MATCHER(common_fusions, NearestNeighborUpsamplingFusion)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/softmax.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAFusion::SDPAFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();

auto k_t = makePattern<ov::op::v1::Transpose>({k, {0, 1, 3, 2}});
auto qk_nn = makePattern<ov::op::v0::MatMul>({q, k_t}, {{"transpose_a", false}, {"transpose_b", false}});
auto qk_nt = makePattern<ov::op::v0::MatMul>({q, k}, {{"transpose_a", false}, {"transpose_b", true}});
auto qk = qk_nt | qk_nn;
auto optional_add_mask = optional<ov::op::v1::Add>({qk, mask});
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}});
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}});

auto valid_qk_shapes = [](const std::shared_ptr<ov::op::v0::MatMul>& qk_matmul) {
auto q_pshape = qk_matmul->get_input_partial_shape(0);
auto k_pshape = qk_matmul->get_input_partial_shape(1);

const size_t q_head_size_idx = 3;
const size_t k_head_size_idx = qk_matmul->get_transpose_b() ? 3 : 2;

return q_pshape.size() == 4 && k_pshape.size() == 4 && q_pshape[q_head_size_idx].is_static() &&
k_pshape[k_head_size_idx].is_static() &&
q_pshape[q_head_size_idx].get_length() == k_pshape[k_head_size_idx].get_length();
};

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto q_node = pattern_map.at(q);
auto k_node = pattern_map.at(k);
auto v_node = pattern_map.at(v);

if (!valid_qk_shapes(ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(qk).get_node_shared_ptr()))) {
return false;
}

if (pattern_map.at(qk).get_target_inputs().size() > 1 ||
pattern_map.at(softmax).get_target_inputs().size() > 1) {
return false;
}
if (pattern_map.count(optional_add_mask) && (pattern_map.at(optional_add_mask).get_target_inputs().size() > 1 ||
pattern_map.at(mask).get_partial_shape().size() > 4)) {
return false;
}

Output<ov::Node> mask_value;
Output<ov::Node> mask_input;
if (pattern_map.find(optional_add_mask) != pattern_map.end()) {
mask_value = pattern_map.at(mask);
} else {
mask_value = ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{0});
}

if (mask_value.get_partial_shape().size() > 4) {
return false;
}

if (mask_value.get_partial_shape().rank() == 0 || mask_value.get_partial_shape().rank() == 4) {
mask_input = mask_value;
} else {
size_t rank_diff = q_node.get_partial_shape().size() - mask_value.get_partial_shape().size();
std::vector<int64_t> axes(rank_diff);
std::iota(axes.begin(), axes.end(), 0);
mask_input = std::make_shared<ov::op::v0::Unsqueeze>(
mask_value,
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{rank_diff}, axes));
}

std::shared_ptr<ov::Node> scale_node =
ov::op::v0::Constant::create(q_node.get_element_type(), ov::Shape{}, std::vector<float>{1.0f});

std::shared_ptr<ov::Node> sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q_node,
k_node,
v_node,
mask_input,
scale_node,
false);

sdpa->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
ov::replace_node(m.get_match_root(), sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(qkv, "SDPAFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/sdpa_scale_fusion.hpp"

#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"

namespace ov {
namespace pass {

SDPAScaleFusion::SDPAScaleFusion() {
using namespace ov::pass::pattern;
using namespace ov::gen_pattern;

auto q = makePattern(ov::Rank(4));
auto k = makePattern(ov::Rank(4));
auto v = makePattern(ov::Rank(4));
auto mask = makePattern();
auto sdpa_scale = makeConst({});
auto scale_q = makePattern("[]") | makePattern("[1]");
auto scale_k = makePattern("[]") | makePattern("[1]");

auto scaled_q = optional<ov::op::v1::Multiply>({q, scale_q});
auto scaled_k = optional<ov::op::v1::Multiply>({k, scale_k});
auto sdpa_mask_scale =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask, sdpa_scale},
{{"causal", false}});
auto sdpa_mask =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v, mask}, {{"causal", false}});
auto sdpa_simple =
makePattern<ov::op::v13::ScaledDotProductAttention>({scaled_q, scaled_k, v}, {{"causal", false}});
auto sdpa = sdpa_simple | sdpa_mask | sdpa_mask_scale;

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
if (transformation_callback(m.get_match_root())) {
return false;
}

auto sdpa = m.get_match_root();

const bool has_q_scale = pattern_map.count(scaled_q);
const bool has_k_scale = pattern_map.count(scaled_k);

// Nothing to do
if (!has_q_scale && !has_k_scale)
return false;

auto prev_scale_value = 1.0f;
auto scale_q_value = 1.0f;
auto scale_k_value = 1.0f;
auto scale_et = sdpa->get_output_element_type(0);

Output<ov::Node> q_input = sdpa->get_input_source_output(0);
Output<ov::Node> k_input = sdpa->get_input_source_output(1);

std::shared_ptr<ov::Node> scale_q_node = nullptr;
std::shared_ptr<ov::Node> scale_k_node = nullptr;

if (pattern_map.find(sdpa_scale) != pattern_map.end()) {
auto prev_scale_node =
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(sdpa_scale).get_node_shared_ptr());
prev_scale_value = prev_scale_node->cast_vector<float>()[0];
scale_et = prev_scale_node->get_output_element_type(0);
} else {
auto head_size = q_input.get_partial_shape()[3];
if (head_size.is_dynamic())
return false;

prev_scale_value = 1.0f / std::sqrt(static_cast<float>(head_size.get_length()));
}

// Extract scalar scale values for Q and K if those are constant and set new inputs for SDPA
if (has_q_scale) {
scale_q_node = pattern_map.at(scale_q).get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
scale_q_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_q_node)->cast_vector<float>()[0];
q_input = pattern_map.at(q);
}
}
if (has_k_scale) {
scale_k_node = pattern_map.at(scale_k).get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
scale_k_value = ov::as_type_ptr<ov::op::v0::Constant>(scale_k_node)->cast_vector<float>()[0];
k_input = pattern_map.at(k);
}
}

Output<ov::Node> new_scale_node;
auto new_scale_val = prev_scale_value * scale_q_value * scale_k_value;

// If new scale is 1 and we have non-constant scale node for either Q or K, then we can make it a scale of SDPA
if (new_scale_val == 1.0f) {
if (has_q_scale && !ov::is_type<ov::op::v0::Constant>(scale_q_node)) {
new_scale_node = pattern_map.at(scale_q);
q_input = pattern_map.at(q);
} else if (has_k_scale && !ov::is_type<ov::op::v0::Constant>(scale_k_node)) {
new_scale_node = pattern_map.at(scale_k);
k_input = pattern_map.at(k);
} else {
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val});
}
} else {
new_scale_node = ov::op::v0::Constant::create(scale_et, ov::Shape{}, std::vector<float>{new_scale_val});
}

OutputVector new_inputs = {q_input, k_input, pattern_map.at(v)};
if (pattern_map.find(mask) != pattern_map.end()) {
new_inputs.push_back(pattern_map.at(mask));
} else {
new_inputs.push_back(
ov::op::v0::Constant::create(new_scale_node.get_element_type(), ov::Shape{}, std::vector<float>{0.0f}));
}

new_inputs.push_back(new_scale_node);

auto new_sdpa = sdpa->clone_with_new_inputs(new_inputs);
new_sdpa->set_friendly_name(sdpa->get_friendly_name());
ov::copy_runtime_info(sdpa, new_sdpa);
ov::replace_node(sdpa, new_sdpa);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa, "SDPAScaleFusion");
this->register_matcher(m, callback);
}

} // namespace pass
} // namespace ov
Loading

0 comments on commit 83ccfe3

Please sign in to comment.