diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp index 50c0ecd20e76af..825ce8acbd7998 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp @@ -15,6 +15,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API PositionIDsReplacer; +class TRANSFORMATIONS_API PositionIDsReplacerQwen; } // namespace pass } // namespace ov @@ -24,3 +25,22 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass { OPENVINO_MATCHER_PASS_RTTI("PositionIDsReplacer"); explicit PositionIDsReplacer(const Output& position_ids); }; + +/** + * @brief Qwen model expects data processing in order, the "position ids" input is detached and + * is not explicitly used in the model. The model uses implicitly defined "position ids" based + * on the past KV cache size. + * + * To use this model in Continuous batching mode, we need to apply position_ids and + * use the corresponding rotary_emb_cos/rotary_emb_sin. + * For this, we replace + * rotary_emb_cos/rotary_emb_sin -> Slice -> Slice + * With + * rotary_emb_cos/rotary_emb_sin -> Gather(by position_ids) + * Which enables applying RoPE for each token independently of their order in the input tensor. + */ +class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("PositionIDsReplacerQwen"); + explicit PositionIDsReplacerQwen(const Output& position_ids); +}; diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp index f5497207eb4e17..d1cc5d5126cd67 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp @@ -4,7 +4,6 @@ #pragma once -#include "openvino/cc/pass/itt.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/subtract.hpp" #include "openvino/pass/matcher_pass.hpp" @@ -22,6 +21,8 @@ class TRANSFORMATIONS_API PrevSequenceLengthPattern; class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass { public: - OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern"); - explicit PrevSequenceLengthPattern(std::shared_ptr prev_max_seq_len, std::shared_ptr batch_dim); + OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern", "0"); + explicit PrevSequenceLengthPattern(const std::shared_ptr& unsqueezed_input_ids, + const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids); }; diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp index b5ecb96fa95198..2456161ea80a78 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp @@ -15,6 +15,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API TotalSequenceLengthPattern; +class TRANSFORMATIONS_API TotalSequenceLengthPatternQwen; } // namespace pass } // namespace ov @@ -24,3 +25,22 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass { OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern"); explicit TotalSequenceLengthPattern(const std::shared_ptr& max_context_len); }; + +/** + * @brief Qwen model has a specific pattern for TotalSequenceLen place detection. + * + * common pattern: Add (PrevSeqLen, CurrentSeqLen) + * + * The CurrentSeqLen is presented in this form: + * CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather + * + * Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern + * and replaced it with the next subgraph: + * PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen) + * + **/ +class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0"); + explicit TotalSequenceLengthPatternQwen(const std::shared_ptr& max_context_len); +}; diff --git a/src/common/transformations/include/transformations/utils/gen_pattern.hpp b/src/common/transformations/include/transformations/utils/gen_pattern.hpp index 21309e339c959c..976561b4844a17 100644 --- a/src/common/transformations/include/transformations/utils/gen_pattern.hpp +++ b/src/common/transformations/include/transformations/utils/gen_pattern.hpp @@ -539,6 +539,11 @@ class AttrSetter : public ov::AttributeVisitor { a->set(m_attr_map[name].as_vector()); } else if (auto a = ov::as_type>(&adapter)) { a->set(m_attr_map[name].as_T_vector()); + } else if (auto a = dynamic_cast>*>(&adapter)) { + ov::op::util::VariableInfo var_info; + var_info.variable_id = m_attr_map[name].as_string(); + auto variable = std::make_shared(var_info); + a->set(variable); } else { OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name); } @@ -896,6 +901,7 @@ struct PatternNode { // scalar constant (treated as wildcard for single-element-constant with any rank) PatternNode(int v) : node(std::make_shared(element::from(), Shape({}), v)) {} PatternNode(float v) : node(std::make_shared(element::from(), Shape({}), v)) {} + PatternNode(long long v) : node(std::make_shared(element::from(), Shape({}), v)) {} PatternNode(std::initializer_list v, values_info vi = nullptr) { node = ConstVector(std::vector(v), vi); diff --git a/src/common/transformations/include/transformations/utils/print_model.hpp b/src/common/transformations/include/transformations/utils/print_model.hpp index 0829cd7e320e88..53fa7de51c5eca 100644 --- a/src/common/transformations/include/transformations/utils/print_model.hpp +++ b/src/common/transformations/include/transformations/utils/print_model.hpp @@ -19,6 +19,7 @@ #include "openvino/core/model.hpp" #include "openvino/core/node.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/util/multi_subgraph_base.hpp" #include "openvino/pass/pass.hpp" #include "transformations/utils/utils.hpp" diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index a72a49fb4832eb..1cc9be37606950 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -7,11 +7,18 @@ #include "openvino/cc/pass/itt.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" using namespace ov::op; +using namespace ov::pass::pattern; // TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when // position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using @@ -19,25 +26,87 @@ using namespace ov::op; ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output& position_ids) { MATCHER_SCOPE(PositionIDsReplacer); - auto input_ids = pattern::any_input(); - auto input_embed = pattern::wrap_type({pattern::any_input(), input_ids, pattern::any_input()}); + auto input_ids = any_input(); + auto input_embed = wrap_type({any_input(), input_ids, any_input()}); - auto position_ids_pattern = pattern::any_input(); - auto offset = pattern::wrap_type(); - auto add_offset = pattern::wrap_type({position_ids_pattern, offset}); - auto convert = pattern::wrap_type({add_offset}); - auto position_embed = pattern::wrap_type({pattern::any_input(), convert, pattern::any_input()}); + auto position_ids_pattern = any_input(); + auto offset = wrap_type(); + auto add_offset = wrap_type({position_ids_pattern, offset}); + auto convert = wrap_type({add_offset}); + auto position_embed = wrap_type({any_input(), convert, any_input()}); - auto mul = pattern::optional({input_embed, pattern::any_input()}); + auto mul = optional({input_embed, any_input()}); - auto add = pattern::wrap_type({mul, position_embed}); + auto add = wrap_type({mul, position_embed}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids.get_node_shared_ptr()); return true; }; - auto m = std::make_shared(add, matcher_name); + auto m = std::make_shared(add, matcher_name); register_matcher(m, callback); -} \ No newline at end of file +} + +ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& position_ids) { + MATCHER_SCOPE(PositionIDsReplacerQwen); + + auto _const = []() { + return wrap_type(); + }; + + // total seq len: + auto p_max_context_len = wrap_type(); + auto p_opt_convert = optional(p_max_context_len); + auto p_opt_reshape = optional({p_opt_convert, any_input()}); + + // current seg len + auto p_input_ids = wrap_type(); + auto p_unsqueeze = wrap_type({p_input_ids, _const()}); + auto p_shape_of = wrap_type({p_unsqueeze}); + auto p_current_len = wrap_type({p_shape_of, _const(), _const()}); + + auto p_rotary_emb_sincos = wrap_type(); + auto p_neg_const = wrap_type(); + auto p_neg_mul = wrap_type({p_current_len, p_neg_const}); + // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] + auto p_slice_1 = wrap_type({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()}); + auto p_slice_2 = wrap_type({p_slice_1, p_neg_mul, _const(), _const(), _const()}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto max_context_len = pattern_map.at(p_max_context_len).get_node_shared_ptr(); + if (max_context_len->get_friendly_name() != "max_context_len") { + return false; + } + auto rotary_emb_sincos = pattern_map.at(p_rotary_emb_sincos).get_node_shared_ptr(); + auto slice_1 = pattern_map.at(p_slice_1).get_node_shared_ptr(); + auto slice_2 = pattern_map.at(p_slice_2).get_node_shared_ptr(); + + auto axis = v0::Constant::create(element::i64, Shape{}, {1}); + // in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin + // are used not in the sequential order, so we need to use position_ids to get the expected values. + auto gather = std::make_shared(slice_1->input_value(0), position_ids, axis); + gather->set_friendly_name(slice_2->get_friendly_name()); + gather->validate_and_infer_types(); + + auto pshape = rotary_emb_sincos->get_output_partial_shape(0); + if (pshape.rank().is_dynamic() || pshape.rank().get_length() != 4) { + return false; + } + + // PagedAttention expects the next layout for Q,K,V: + // [batch_size_in_tokens, num_kv_heads * head_size] + // so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch + // num_kv_heads * head_size are already handled in the StateManagementPattern transformation + auto head_size = static_cast(pshape[3].get_length()); + auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector{-1, 1, 1, head_size}); + auto reshape = std::make_shared(gather, new_shape, false); + replace_node(slice_2, reshape); + return true; + }; + + auto m = std::make_shared(p_slice_2, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp index 36d9d88975b2e0..55d7af822c3857 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp @@ -14,8 +14,9 @@ using namespace ov::op; -ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr prev_max_seq_len, - std::shared_ptr batch_dim) { +ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(const std::shared_ptr& unsqueezed_input_ids, + const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids) { MATCHER_SCOPE(PrevSequenceLengthPattern); // The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in // kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim. @@ -40,8 +41,16 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptrget_output_element_type(0); std::shared_ptr replacement; if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) { + auto cur_seq_len = std::make_shared(std::make_shared(unsqueezed_input_ids), + v0::Constant::create(element::i64, Shape{}, {1}), + v0::Constant::create(element::i64, Shape{}, {0})); + auto cur_seq_len_i32 = std::make_shared(cur_seq_len, element::i32); + auto prev_max_seq_len = std::make_shared(max_context_len, cur_seq_len_i32); replacement = prev_max_seq_len; } else { + // it is not always required, so will be disposed if not needed + auto batch_dim = std::make_shared(position_ids); + // assumption that any other axis should point to batch dimension, precise reasoning is too complex // TODO: provide more reliable check replacement = batch_dim; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index b55c3d73316120..a36085c34237a4 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -437,6 +437,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par parameters_to_remove.push_back(param); } + pa_transpose->set_friendly_name(sdpa_node->get_friendly_name()); replace_node(m.get_match_root(), pa_transpose); return true; }; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp index 18387d5ca1ae04..cbf9426a0c82c5 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp @@ -6,27 +6,49 @@ #include "openvino/cc/pass/itt.hpp" #include "openvino/core/validation_util.hpp" +#include "openvino/op/add.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" using namespace ov::op; +using namespace ov::pass::pattern; + +namespace { + +void align_replacement(std::shared_ptr& replacement, + const ov::PartialShape& required_shape, + ov::element::Type target_type) { + if (replacement->get_output_element_type(0) != target_type) { + replacement = std::make_shared(replacement, target_type); + } + + if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { + replacement = ov::op::util::reshapeTo(replacement, ov::Shape(required_shape.rank().get_length(), 1)); + } +} + +} // namespace ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( const std::shared_ptr& max_context_len) { MATCHER_SCOPE(TotalSequenceLengthPattern); - auto kv_past = pattern::wrap_type({pattern::any_input()}); - auto kv_gather = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); - auto kv_current = pattern::any_input(); - auto kv_concat = pattern::wrap_type({kv_gather, kv_current}); - auto kv_shape = pattern::wrap_type({kv_concat}); - auto gather_idx_label = pattern::wrap_type(); - auto seq = pattern::wrap_type({kv_shape, gather_idx_label, pattern::any_input()}); + auto kv_past = wrap_type({any_input()}); + auto kv_gather = wrap_type({kv_past, any_input(), any_input()}); + auto kv_current = any_input(); + auto kv_concat = wrap_type({kv_gather, kv_current}); + auto kv_shape = wrap_type({kv_concat}); + auto gather_idx_label = wrap_type(); + auto seq = wrap_type({kv_shape, gather_idx_label, any_input()}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- // use symbolic infra or look at the constant input const auto& pattern_map = m.get_pattern_value_map(); @@ -71,16 +93,8 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( if (concat_axis_to_compare == gather_idx_to_compare) { auto target_type = gather->get_output_element_type(0); - - if (replacement->get_output_element_type(0) != target_type) { - replacement = std::make_shared(replacement, target_type); - } - auto required_shape = gather->get_output_partial_shape(0); - - if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { - replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1)); - } + align_replacement(replacement, required_shape, target_type); } else { // TODO: change in the future when we start supporting dynamic shapes here replacement = ov::util::get_constant_from_source(gather->output(0)); @@ -94,6 +108,41 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( return true; }; - auto m = std::make_shared(seq, matcher_name); + auto m = std::make_shared(seq, matcher_name); + register_matcher(m, callback); +} + +ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen( + const std::shared_ptr& max_context_len) { + MATCHER_SCOPE(TotalSequenceLengthPatternQwen); + + auto p_input_ids = wrap_type(); + auto p_unsqueeze = wrap_type({p_input_ids, any_input()}); + auto p_opt_reshape_1 = optional({p_unsqueeze, any_input()}); + auto p_opt_convert_1 = optional(p_opt_reshape_1); + auto p_kv_shape_current = wrap_type({p_opt_convert_1}); + auto p_seq_current = wrap_type({p_kv_shape_current, any_input(), any_input()}); + auto p_opt_convert_2 = optional(p_seq_current); + + auto p_max_context_len = wrap_type(); + auto p_prev_max_seq_len = wrap_type({p_max_context_len, any_input()}); + auto p_opt_convert_3 = optional(p_prev_max_seq_len); + auto p_opt_reshape_2 = optional({p_opt_convert_3, any_input()}); + auto p_total_seq = wrap_type({p_opt_convert_2, p_opt_reshape_2}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto total_seq = pattern_map.at(p_total_seq).get_node_shared_ptr(); + std::shared_ptr replacement = max_context_len; + + auto target_type = total_seq->get_output_element_type(0); + auto required_shape = total_seq->get_output_partial_shape(0); + align_replacement(replacement, required_shape, target_type); + + replace_node(total_seq, replacement); + return true; + }; + + auto m = std::make_shared(p_total_seq, matcher_name); register_matcher(m, callback); -} \ No newline at end of file +} diff --git a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp new file mode 100644 index 00000000000000..840309993c939a --- /dev/null +++ b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp @@ -0,0 +1,618 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/pass/sdpa_to_paged_attention.hpp" + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/model.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/ops.hpp" +#include "openvino/op/paged_attention.hpp" +#include "openvino/op/power.hpp" +#include "openvino/op/reduce_mean.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" +#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" +#include "transformations/utils/gen_pattern.hpp" +#include "transformations/utils/print_model.hpp" + +using namespace ov; +using namespace std; +using namespace testing; +using namespace ov::op; +using namespace ov::gen_pattern; + +namespace { + +// Constants and Parameters attributes: +auto el_type_i64 = std::pair({"element_type", "i64"}); +auto el_type_i32 = std::pair({"element_type", "i32"}); +auto el_type_f32 = std::pair({"element_type", "f32"}); + +// Convert ops attributes: +auto dest_type_i64 = std::pair({"destination_type", "i64"}); +auto dest_type_f32 = std::pair({"destination_type", "f32"}); +auto dest_type_f16 = std::pair({"destination_type", "f16"}); + +// Other attributes: +auto numpy_broadcast = std::pair({"auto_broadcast", "numpy"}); +auto special_zero_true = std::pair({"special_zero", true}); + +auto single_val = [](int rank, float val) { + return makeConst(element::f32, ov::Shape{std::vector(rank, 1)}, {val}); +}; + +ov::ParameterVector nodes_to_params(const ov::NodeVector& node_vec) { + ov::ParameterVector params; + params.reserve(node_vec.size()); + for (const auto& node : node_vec) { + params.push_back(ov::as_type_ptr(node)); + } + return params; +} + +enum QKV : int { Q = 0, K = 1, V = 2 }; +vector MOCK_VALUE = {1}; + +// original weights = 151936, attention_weights = 12288 +#define WEIGHTS 1024 +#define ATTENTION_WEIGHTS 512 + +class Qwen7bChatSDPA { +public: + static std::shared_ptr gen_embeddings(const std::shared_ptr& input_ids) { + auto view_reshape = makeOP({input_ids, {-1, 0}}, {special_zero_true}); + auto input_ids_i64 = makeOP({view_reshape}, {dest_type_i64}); + + auto weights = makeConst(element::u8, {WEIGHTS, 4096}, MOCK_VALUE); + auto weights_fp16 = makeOP({weights}, {dest_type_f16}); + auto zero_point = makeConst(element::u8, {WEIGHTS, 1}, MOCK_VALUE); + auto zero_point_fp16 = makeOP({zero_point}, {dest_type_f16}); + auto zero_point_subtract = makeOP({weights_fp16, zero_point_fp16}, {numpy_broadcast}); + + auto scale = makeConst(element::f16, {WEIGHTS, 1}, MOCK_VALUE); + auto mul_scale = makeOP({zero_point_subtract, scale}, {numpy_broadcast}); + auto fq_weights = makeOP({mul_scale}, {dest_type_f32}); + + return makeOP({fq_weights, input_ids_i64, 0}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_attention_weights() { + auto weights = makeConst(element::u8, {ATTENTION_WEIGHTS, 4096}, MOCK_VALUE); + auto weights_f16 = makeOP({weights}, {dest_type_f16}); + + auto zero_points = makeConst(element::u8, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto zero_points_f16 = makeOP({zero_points}, {dest_type_f16}); + auto subtract = makeOP({weights_f16, zero_points_f16}, {numpy_broadcast}); + + auto scale = makeConst(element::f16, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto mul = makeOP({subtract, scale}, {numpy_broadcast}); + return makeOP({mul}, {dest_type_f32}); + } + + static std::shared_ptr gen_qkv_proj(const std::shared_ptr& embeddings) { + auto _const_0 = single_val(/*rank*/ 3, /*val*/ 2); + auto pow = makeOP({embeddings, _const_0}, {numpy_broadcast}); + auto mean = makeOP({pow, {-1}}, {{"keep_dims", true}}); + + auto _const_1 = single_val(/*rank*/ 3, /*val*/ 1); + auto add = makeOP({mean, _const_1}, {numpy_broadcast}); + auto sqrt = makeOP({add}); + + auto _const_2 = single_val(/*rank*/ 3, /*val*/ 1); + auto div = makeOP({_const_2, sqrt}, {numpy_broadcast, {"m_pythondiv", true}}); + auto mul_0 = makeOP({embeddings, div}, {numpy_broadcast}); + + auto _const_3 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); + auto mul_1 = makeOP({mul_0, _const_3}, {numpy_broadcast}); + auto attention_weights = gen_attention_weights(); + auto linear_matmul = + makeOP({mul_1, attention_weights}, {{"transpose_a", false}, {"transpose_b", true}}); + + auto _const_4 = makeConst(element::f32, {1, 1, ATTENTION_WEIGHTS}, MOCK_VALUE); + auto linear_add = makeOP({linear_matmul, _const_4}, {numpy_broadcast}); + return makeOP({linear_add, 2, {4096, 4096, -1}}); + } + + static std::shared_ptr gen_cache(const std::shared_ptr& input_ids, + const std::shared_ptr& beam_idx, + const std::string& name) { + auto shape_of = makeOP({input_ids}, {{"output_type", "i64"}}); + auto gather = makeOP({shape_of, {0}, 0}, {{"batch_dims", 0}}); + auto concat = makeOP({gather, {0ll}, {32ll}, {128ll}}, {{"axis", 0}}); + auto init_to_read = makeOP({0.000000f, concat}, {{"mode", "numpy"}}); + auto cache = makeOP( + {init_to_read}, + {{"variable_id", name}, {"variable_type", "f32"}, {"variable_shape", PartialShape{DYN, DYN, 32, 128}}}); + return makeOP({cache, beam_idx, 0}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_current_len(const std::shared_ptr& input_ids) { + auto shape_of = makeOP({input_ids}, {{"output_type", "i64"}}); + return makeOP({shape_of, {1}, 0}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_past_len(const std::shared_ptr& k_cache) { + auto shape_of = makeOP({k_cache}, {{"output_type", "i64"}}); + return makeOP({shape_of, {1}, 0}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_total_len(const std::shared_ptr& cur_len, + const std::shared_ptr& past_len) { + return makeOP({cur_len, past_len}, {numpy_broadcast}); + } + + static std::shared_ptr gen_rope(QKV idx, + const std::shared_ptr& qkv_proj, + const std::shared_ptr& head_size, + const std::shared_ptr& sliced_sin, + const std::shared_ptr& sliced_cos) { + auto current_k = makeOP({qkv_proj->output(idx), {0, 0, 32, 128}}, {special_zero_true}); + auto sliced_k = makeOP({current_k, {0}, head_size, {1}, {3}}); + auto mul_1 = makeOP({sliced_k, sliced_cos}, {numpy_broadcast}); + + auto reshape = makeOP({sliced_k, {0, 0, 32, 2, 64}}, {special_zero_true}); + auto split_1 = makeOP({reshape, -2}, {{"num_splits", 2}}); + auto list_unpack_1 = makeOP({split_1->output(1), -2}); + + auto _const = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_2 = makeOP({list_unpack_1, _const}, {numpy_broadcast}); + auto list_unpack_2 = makeOP({split_1->output(0), -2}); + auto concat = makeOP({mul_2, list_unpack_2}, {{"axis", -1}}); + + auto mul_3 = makeOP({concat, sliced_sin}, {numpy_broadcast}); + return makeOP({mul_1, mul_3}, {numpy_broadcast}); + } + + static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& total_seq_len, + const std::shared_ptr& neg_mul, + std::shared_ptr& head_size) { + auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto sliced_sin_by_total = makeOP({sin, {0}, total_seq_len, {1}, {1}}); + auto rotary_emb_sin_shape = makeOP({sliced_sin_by_total}, {{"output_type", "i64"}}); + head_size = makeOP({rotary_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}}); + return makeOP({sliced_sin_by_total, neg_mul, {LLONG_MAX}, {1}, {1}}); + } + + static std::shared_ptr gen_rope_emb_cos(const std::shared_ptr& total_seq_len, + const std::shared_ptr& neg_mul) { + auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto sliced_cos_by_total = makeOP({cos, {0}, total_seq_len, {1}, {1}}); + return makeOP({sliced_cos_by_total, neg_mul, {LLONG_MAX}, {1}, {1}}); + } + + static std::shared_ptr neg_mul(const std::shared_ptr& current_seq_len) { + return makeOP({current_seq_len, {-1ll}}, {numpy_broadcast}); + } + + static std::shared_ptr gen_V(const std::shared_ptr& cache, const std::shared_ptr& qkv_proj) { + auto v_current = makeOP({qkv_proj->output(2), {0, 0, 32, 128}}, {special_zero_true}); + auto v_total = makeOP({cache, v_current}, {{"axis", 1}}); + return makeOP({v_total, {0, 2, 1, 3}}); + } + + static std::shared_ptr gen_K(const std::shared_ptr& cache, const std::shared_ptr& rope_K) { + auto full_k = makeOP({cache, rope_K}, {{"axis", 1}}); + return makeOP({full_k, {0, 2, 1, 3}}); + } + + static std::shared_ptr gen_Q(const std::shared_ptr& past_seq_len_2, + const std::shared_ptr& total_seq_len_2, + const std::shared_ptr& rope_Q) { + auto _const = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); + auto slice = makeOP({_const, past_seq_len_2, total_seq_len_2, {1}, {1}}); + auto mul = makeOP({rope_Q, slice}, {numpy_broadcast}); + return makeOP({mul, {0, 2, 1, 3}}); + } + + static std::shared_ptr gen_total_seq_len_2(const std::shared_ptr& past_k_len, + const std::shared_ptr& rope_k) { + auto shape_rope_k = makeOP({rope_k}, {{"output_type", "i64"}}); + auto cur_len = makeOP({shape_rope_k, {1}, 0}, {{"batch_dims", 0}}); + return makeOP({past_k_len, cur_len}, {numpy_broadcast}); + } + + static std::shared_ptr gen_past_seq_len_2(const std::shared_ptr& total_seq_len, + const std::shared_ptr& rope_q) { + auto shape_rope_q = makeOP({rope_q}, {{"output_type", "i64"}}); + auto cur_len = makeOP({shape_rope_q, {1}, 0}, {{"batch_dims", 0}}); + return makeOP({total_seq_len, cur_len}, {numpy_broadcast}); + } + + static std::shared_ptr gen_attention_mask(const std::shared_ptr& Q_in, + const std::shared_ptr& attention_mask_in, + const std::shared_ptr& total_seq_len) { + auto _const = makeConst(element::boolean, {1, 1, 8192, 8192}, MOCK_VALUE); + auto shape_of_q = makeOP({Q_in}, {{"output_type", "i64"}}); + auto gather = makeOP({shape_of_q, {2}, 0}, {{"batch_dims", 0}}); + auto sub_1 = makeOP({total_seq_len, gather}, {numpy_broadcast}); + auto concat = makeOP({sub_1, {0ll}}, {{"axis", 0}}); + auto broadcast = makeOP({total_seq_len, {2}}, {{"mode", "numpy"}}); + auto slice = makeOP({_const, concat, broadcast, {1, 1}, {2, 3}}); + auto bitwise_not = makeOP({slice}); + + auto _const_1 = single_val(/*rank*/ 4, /*val*/ 1); + auto view_reshape = makeOP({attention_mask_in, {0, 0}}, {special_zero_true}); + auto unsqueeze_0 = makeOP({view_reshape, 1}); + auto unsqueeze_1 = makeOP({unsqueeze_0, 2}); + auto convert_0 = makeOP({unsqueeze_1}, {dest_type_f32}); + + auto _const_2 = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_1 = makeOP({convert_0, _const_2}, {numpy_broadcast}); + auto sub_2 = makeOP({_const_1, mul_1}, {numpy_broadcast}); + + auto _const_3 = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_2 = makeOP({sub_2, _const_3}, {numpy_broadcast}); + auto list_construct = makeOP({{1ll}, {1ll}, gather, {1ll}}, {{"axis", 0}}); + auto expand_broadcast = makeOP({mul_2, list_construct}, {{"mode", "bidirectional"}}); + return makeOP({bitwise_not, -FLT_MAX, expand_broadcast}, {numpy_broadcast}); + } +}; + +class Qwen7bChatPA { +public: + static std::shared_ptr gen_embeddings(const std::shared_ptr& input_ids) { + auto weights = makeConst(element::u8, {WEIGHTS, 4096}, MOCK_VALUE); + auto weights_fp16 = makeOP({weights}, {dest_type_f16}); + + auto zero_point = makeConst(element::u8, {WEIGHTS, 1}, MOCK_VALUE); + auto zero_point_fp16 = makeOP({zero_point}, {dest_type_f16}); + auto sub = makeOP({weights_fp16, zero_point_fp16}, {numpy_broadcast}); + + auto scale = makeConst(element::f16, {WEIGHTS, 1}, MOCK_VALUE); + auto mul = makeOP({sub, scale}, {numpy_broadcast}); + auto mul_fp32 = makeOP({mul}, {dest_type_f32}); + + auto reshape_view = makeOP({input_ids, {-1, 0}}, {special_zero_true}); + auto reshape_view_i64 = makeOP({reshape_view}, {dest_type_i64}); + return makeOP({mul_fp32, reshape_view_i64, 0}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_qkv_proj(const std::shared_ptr& embeddings) { + auto _const_0 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto pow = makeOP({embeddings, _const_0}, {numpy_broadcast}); + auto mean = makeOP({pow, {-1}}, {{"keep_dims", true}}); + auto _const_1 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto add_0 = makeOP({mean, _const_1}, {numpy_broadcast}); + + auto sqrt = makeOP({add_0}); + auto _const_2 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto div = makeOP({_const_2, sqrt}, {numpy_broadcast, {"m_pythondiv", true}}); + auto mul_0 = makeOP({embeddings, div}, {numpy_broadcast}); + + auto _const_3 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); + auto mul_1 = makeOP({mul_0, _const_3}, {numpy_broadcast}); + + auto _const_4 = makeConst(element::u8, {ATTENTION_WEIGHTS, 4096}, MOCK_VALUE); + auto convert_0 = makeOP({_const_4}, {dest_type_f16}); + + auto _const_5 = makeConst(element::u8, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto convert_1 = makeOP({_const_5}, {dest_type_f16}); + auto sub = makeOP({convert_0, convert_1}, {numpy_broadcast}); + + auto _const_6 = makeConst(element::f16, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto mul_2 = makeOP({sub, _const_6}, {numpy_broadcast}); + auto convert_2 = makeOP({mul_2}, {dest_type_f32}); + auto matmul = makeOP({mul_1, convert_2}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Constant_270 = makeConst(element::f32, {1, 1, ATTENTION_WEIGHTS}, MOCK_VALUE); + auto add_1 = makeOP({matmul, Constant_270}, {numpy_broadcast}); + + return makeOP({add_1, 2, {4096, 4096, -1}}); + } + + static std::shared_ptr gen_rope(QKV idx, + const std::shared_ptr& qkv_proj, + const std::shared_ptr& head_size, + const std::shared_ptr& sin, + const std::shared_ptr& cos) { + auto Q_or_K = makeOP({qkv_proj->output(idx), {0, 0, 32, 128}}, {special_zero_true}); + auto sliced = makeOP({Q_or_K, {0}, head_size, {1}, {3}}); + auto mul_0 = makeOP({sliced, sin}, {numpy_broadcast}); + + auto reshape = makeOP({sliced, {0, 0, 32, 2, 64}}, {special_zero_true}); + auto split = makeOP({reshape, -2}, {{"num_splits", 2}}); + auto squeeze_0 = makeOP({split->output(1), -2}); + auto _const_0 = makeConst(element::f32, {1, 1, 1, 1}, {1.000000f}); + auto mul_1 = makeOP({squeeze_0, _const_0}, {numpy_broadcast}); + + auto squeeze_1 = makeOP({split->output(0), -2}); + auto concat = makeOP({mul_1, squeeze_1}, {{"axis", -1}}); + auto mul_2 = makeOP({concat, cos}, {numpy_broadcast}); + return makeOP({mul_0, mul_2}, {numpy_broadcast}); + } + + static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids, + std::shared_ptr& head_size) { + auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto slice_sin = makeOP({sin, position_ids, 1}, {{"batch_dims", 0}}); + + auto slice = makeOP({sin, {0}, max_context_len, {1}, {1}}); + auto shape_of = makeOP({slice}, {{"output_type", "i64"}}); + head_size = makeOP({shape_of, {3}, 0}, {{"batch_dims", 0}}); + + return makeOP({slice_sin, {-1, 1, 1, 128}}, {{"special_zero", false}}); + } + + static std::shared_ptr gen_rope_emb_cos(const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids) { + auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto slice = makeOP({cos, position_ids, 1}, {{"batch_dims", 0}}); + return makeOP({slice, {-1, 1, 1, 128}}, {{"special_zero", false}}); + } + + static std::shared_ptr align_pa_layout(const std::shared_ptr& pa, + const std::shared_ptr& head_size) { + auto shape = makeOP({{0ll}, {1ll}, {-1ll}, head_size}, {{"axis", 0}}); + auto reshaped = makeOP({pa->output(0), shape}, {special_zero_true}); + return makeOP({reshaped, {0, 2, 1, 3}}); + } + + static std::shared_ptr gen_current_len(const std::shared_ptr& rope_K) { + auto shape_of = makeOP({rope_K}, {{"output_type", "i32"}}); + return makeOP({shape_of, {1}, 0ll}, {{"batch_dims", 0}}); + } + + static std::shared_ptr gen_past_len(const std::shared_ptr& input_ids, + const std::shared_ptr& max_context_len) { + auto shape_of = makeOP({input_ids}, {{"output_type", "i64"}}); + auto cur_len = makeOP({shape_of, 1ll, 0ll}, {{"batch_dims", 0}}); + auto cur_len_i32 = makeOP({cur_len}, {{"destination_type", "i32"}}); + + auto past_len = makeOP({max_context_len, cur_len_i32}, {numpy_broadcast}); + auto past_len_i32 = makeOP({past_len}, {{"destination_type", "i32"}}); + return makeOP({past_len_i32, {1}}, {special_zero_true}); + } + + static std::shared_ptr gen_total_len(const std::shared_ptr& cur_len, + const std::shared_ptr& past_len) { + return makeOP({past_len, cur_len}, {numpy_broadcast}); + } + + static std::shared_ptr gen_V(const std::shared_ptr& qkv_proj, std::shared_ptr& head_size) { + auto current_V = makeOP({qkv_proj->output(2), {0, 0, 32, 128}}, {special_zero_true}); + auto gather = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); + auto transpose = makeOP({current_V, gather}); + + auto shape_of = makeOP({transpose}, {{"output_type", "i64"}}); + auto gather_2 = makeOP({shape_of, -1ll, 0ll}, {{"batch_dims", 0}}); + head_size = makeOP({gather_2, 0}); + + return makeOP({transpose, {0, -1}}, {special_zero_true}); + } + + static std::shared_ptr gen_K(const std::shared_ptr& rope_K) { + auto gather = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); + auto transpose = makeOP({rope_K, gather}); + return makeOP({transpose, {0, -1}}, {special_zero_true}); + } + + static std::shared_ptr gen_Q(const std::shared_ptr& total_seq_len, + const std::shared_ptr& rope_Q) { + auto _const_1 = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); + auto shape_of = makeOP({rope_Q}, {{"output_type", "i32"}}); + auto current_seq_len = makeOP({shape_of, {1}, 0ll}, {{"batch_dims", 0}}); + auto past_seq_len = makeOP({total_seq_len, current_seq_len}, {numpy_broadcast}); + + auto slice = makeOP({_const_1, past_seq_len, total_seq_len, {1}, {1}}); + auto mul = makeOP({rope_Q, slice}, {numpy_broadcast}); + auto transpose_1 = makeOP({mul, {0, 2, 1, 3}}); + + auto transpose_2 = makeOP({transpose_1, {0, 2, 1, 3}}); + return makeOP({transpose_2, {0, -1}}, {special_zero_true}); + } +}; + +} // namespace + +TEST_F(TransformationTestsF, SDPAToPA_Qwen) { + { + // Inputs to SDPA transformer: + auto beam_idx = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); + auto position_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); + auto attention_mask = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); + auto input_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); + ParameterVector params = nodes_to_params({position_ids, input_ids, attention_mask, beam_idx}); + + beam_idx->output(0).add_names({"beam_idx"}); + position_ids->output(0).add_names({"position_ids"}); + attention_mask->output(0).add_names({"attention_mask"}); + input_ids->output(0).add_names({"input_ids"}); + + // Embeddings processing: + auto embeddings = Qwen7bChatSDPA::gen_embeddings(input_ids); + auto qkv_proj = Qwen7bChatSDPA::gen_qkv_proj(embeddings); + + // KV cache: + auto k_cache = Qwen7bChatSDPA::gen_cache(input_ids, beam_idx, "K_cache"); + auto v_cache = Qwen7bChatSDPA::gen_cache(input_ids, beam_idx, "V_cache"); + + // Current/past/total Seq lengths calculation: + auto current_seq_len = Qwen7bChatSDPA::gen_current_len(input_ids); + auto past_seq_len = Qwen7bChatSDPA::gen_past_len(k_cache); + auto total_seq_len = Qwen7bChatSDPA::gen_total_len(current_seq_len, past_seq_len); + + // RoPE emb sin/cos init: + auto neg_cur_seq_len = Qwen7bChatSDPA::neg_mul(current_seq_len); + auto head_size = shared_ptr(); + auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size); + auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len); + + // RoPE for Q,K inputs: + auto rope_q = Qwen7bChatSDPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); + auto rope_k = Qwen7bChatSDPA::gen_rope(QKV::K, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); + + // Lengths: + auto total_seq_len_2 = Qwen7bChatSDPA::gen_total_seq_len_2(past_seq_len, rope_k); + auto past_seq_len_2 = Qwen7bChatSDPA::gen_past_seq_len_2(total_seq_len_2, rope_q); + + // Q, K, V: + auto Q = Qwen7bChatSDPA::gen_Q(past_seq_len_2, total_seq_len_2, rope_q); + auto K = Qwen7bChatSDPA::gen_K(k_cache, rope_k); + auto V = Qwen7bChatSDPA::gen_V(v_cache, qkv_proj); + + // Attention mask: + auto attention_mask_to_sdpa = Qwen7bChatSDPA::gen_attention_mask(Q, attention_mask, total_seq_len_2); + + // SDPA: + auto sdpa = makeOP({Q, K, V, attention_mask_to_sdpa}, {{"causal", false}}); + auto res = makeOP({sdpa}); + + model = std::make_shared(OutputVector{res}, params); + manager.register_pass(); + } + + { + // Inputs to PA transformer: + auto max_context_len = makeOP({}, {{"shape", PartialShape{}}, el_type_i32}); + auto block_indices_begins = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i32}); + auto block_indices = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i32}); + auto subsequence_begins = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i32}); + auto past_lens = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i32}); + auto value_cache_0 = makeOP({}, {{"shape", PartialShape{DYN, 32, 128}}, el_type_f32}); + auto key_cache_0 = makeOP({}, {{"shape", PartialShape{DYN, 32, 128}}, el_type_f32}); + auto input_ids = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); + auto position_ids = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); + auto params = nodes_to_params({max_context_len, + block_indices_begins, + block_indices, + subsequence_begins, + past_lens, + value_cache_0, + key_cache_0, + input_ids, + position_ids}); + + // Inputs pre-processing: + auto max_context_len_i64 = makeOP({max_context_len}, {dest_type_i64}); + auto max_context_len_aligned = makeOP({max_context_len_i64, {1}}, {special_zero_true}); + auto input_ids_aligned = makeOP({input_ids, 1}); + auto position_ids_aligned = makeOP({position_ids, 1}); + + // Embeddings processing: + auto embeddings = Qwen7bChatPA::gen_embeddings(input_ids_aligned); + auto qkv_proj = Qwen7bChatPA::gen_qkv_proj(embeddings); + + // RoPE emb sin/cos init: + auto head_size = shared_ptr(); + auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size); + auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned); + + // rope Q, K: + auto rope_Q = Qwen7bChatPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); + auto rope_K = Qwen7bChatPA::gen_rope(QKV::K, qkv_proj, head_size, rope_emb_sin, rope_emb_cos); + + // Current/past/total Seq lengths calculation: + auto current_seq_len = Qwen7bChatPA::gen_current_len(rope_K); + auto past_seq_len = Qwen7bChatPA::gen_past_len(input_ids_aligned, max_context_len); + auto total_seq_len = Qwen7bChatPA::gen_total_len(current_seq_len, past_seq_len); + + // Q, K, V: + shared_ptr head_size_2; + auto Q = Qwen7bChatPA::gen_Q(total_seq_len, rope_Q); + auto K = Qwen7bChatPA::gen_K(rope_K); + auto V = Qwen7bChatPA::gen_V(qkv_proj, head_size_2); + + // Additional PA arguments: + auto sliding_window = std::make_shared(element::i32, Shape{}, 0); + auto alibi_slopes = std::make_shared(element::f32, Shape{0}); + auto scale = std::make_shared(element::f32, Shape{}, MOCK_VALUE); + + // PagedAttention: + auto pa = std::make_shared(OutputVector{Q, + K, + V, + key_cache_0, + value_cache_0, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len}); + pa->set_out_type(0, element::i64); + auto pa_aligned = Qwen7bChatPA::align_pa_layout(pa, head_size_2); + auto res = makeOP({pa_aligned}); + + model_ref = std::make_shared(OutputVector{res}, params); + } + // TODO: align precisions, check the copying of "fuse_names" attr in SDPAToPagedAttention + // checking the graph structure and names, other checks are temporarily disabled: + comparator.disable(FunctionsComparator::PRECISIONS); + disable_rt_info_check(); +} + +TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) { + { + // Inputs to SDPA transformer: + auto beam_idx = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); + auto input_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); + ParameterVector params = nodes_to_params({input_ids, beam_idx}); + + // K cache + auto k_cache = Qwen7bChatSDPA::gen_cache(input_ids, beam_idx, "K_cache"); + + // Current/past/total Seq lengths calculation: + auto current_len = Qwen7bChatSDPA::gen_current_len(input_ids); + auto past_len = Qwen7bChatSDPA::gen_past_len(k_cache); + auto total_len = Qwen7bChatSDPA::gen_total_len(current_len, past_len); + auto result = std::make_shared(total_len); + + // Expected that these Nodes to be created inside SDPAToPagedAttention + auto new_input_ids = std::make_shared(element::i64, PartialShape{DYN}); + auto axis = v0::Constant::create(element::i32, Shape{}, {1}); + auto aligned_input_ids = std::make_shared(new_input_ids, axis); + + input_ids->output(0).replace(aligned_input_ids); + auto max_context_len = std::make_shared(element::i32, PartialShape{}); + max_context_len->output(0).set_names({"max_context_len"}); + auto position_ids = std::make_shared(element::i64, PartialShape{DYN}); + position_ids->output(0).set_names({"position_ids"}); + + params.push_back(max_context_len); + params.push_back(new_input_ids); + + // Model and Transformations: + model = std::make_shared(ResultVector{result}, params); + manager.register_pass(aligned_input_ids, max_context_len, position_ids); + manager.register_pass(max_context_len); + } + + { + // Inputs to PA transformer: + auto max_context_len = makeOP({}, {{"shape", PartialShape{}}, el_type_i32}); + auto params = nodes_to_params({max_context_len}); + + // Inputs pre-processing: + auto max_context_len_i64 = makeOP({max_context_len}, {dest_type_i64}); + auto max_context_len_aligned = makeOP({max_context_len_i64, {1}}, {special_zero_true}); + + auto result = std::make_shared(max_context_len_aligned); + model_ref = std::make_shared(ResultVector{result}, params); + } + // TODO: align precisions, check the copying of "fuse_names" attr in SDPAToPagedAttention + // checking the graph structure and names, other checks are temporarily disabled: + comparator.disable(FunctionsComparator::PRECISIONS); + disable_result_friendly_names_check(); + disable_rt_info_check(); +} diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp index 74aeacb0719cee..d52e78dbd6a489 100644 --- a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -19,7 +19,7 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass { public: OPENVINO_MODEL_PASS_RTTI("SDPAToPagedAttention"); - SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); + explicit SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); bool run_on_model(const std::shared_ptr& model) override; private: diff --git a/src/core/src/pass/manager.cpp b/src/core/src/pass/manager.cpp index a6f1fc287e221c..b084ec4dc38e09 100644 --- a/src/core/src/pass/manager.cpp +++ b/src/core/src/pass/manager.cpp @@ -5,6 +5,7 @@ #include "openvino/pass/manager.hpp" #include +#include #include #include #include diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 872e4539eda8df..e6fc744bb5ef4f 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -81,15 +81,12 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrset_partial_shape(PartialShape{-1}); + auto input_ids_target_inputs = input_ids_node->get_output_target_inputs(0); auto unsqueezed_input_ids = std::make_shared(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1})); - replace_node(input_ids_node, unsqueezed_input_ids); - - auto cur_seq_len = std::make_shared(std::make_shared(unsqueezed_input_ids), - v0::Constant::create(element::i64, Shape{}, {1}), - v0::Constant::create(element::i64, Shape{}, {0})); - auto prev_max_seq_len = - std::make_shared(max_context_len, std::make_shared(cur_seq_len, element::i32)); + for (const auto& target : input_ids_target_inputs) { + target.replace_source_output(unsqueezed_input_ids); + } ParameterVector kv_parameters; ParameterVector parameters_to_remove; @@ -106,15 +103,15 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrset_partial_shape(PartialShape{-1}); position_ids->validate_and_infer_types(); } + auto position_ids_target_inputs = position_ids->get_output_target_inputs(0); auto unsqueezed_position_ids = std::make_shared(position_ids, v0::Constant::create(element::i32, Shape{}, {1})); - replace_node(position_ids, unsqueezed_position_ids); + for (const auto& target : position_ids_target_inputs) { + target.replace_source_output(unsqueezed_position_ids); + } int layer_index = 0; - auto batch_dim = - std::make_shared(position_ids); // it is not always required, so will be disposed if not needed - ov::pass::Manager manager("SDPA to PA"); manager.set_per_pass_validation(false); manager.register_pass(kv_parameters, @@ -127,9 +124,12 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(prev_max_seq_len, batch_dim); + + manager.register_pass(unsqueezed_input_ids, max_context_len, position_ids); manager.register_pass(max_context_len); - manager.register_pass(unsqueezed_position_ids->output(0)); + manager.register_pass(max_context_len); + manager.register_pass(unsqueezed_position_ids); + manager.register_pass(unsqueezed_position_ids); manager.run_passes(model); {