From 033b0b536ec65ff80786b72c77457b68781109b1 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Mon, 23 Dec 2024 23:22:59 +0400 Subject: [PATCH] resolve review comments; add a new unit tests; update transformations --- .../position_ids_replacer.hpp | 8 +- .../prev_sequence_length_pattern.hpp | 7 +- .../total_sequence_length_pattern.hpp | 6 +- .../position_ids_replacer.cpp | 29 +- .../prev_sequence_length_pattern.cpp | 13 +- .../total_sequence_length_pattern.cpp | 36 +- .../sdpa_to_paged_attention_test.cpp | 397 +++++++++++------- .../openvino/pass/sdpa_to_paged_attention.hpp | 2 +- src/core/src/pass/sdpa_to_paged_attention.cpp | 15 +- 9 files changed, 310 insertions(+), 203 deletions(-) diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp index e341116cdf847a..d2921e837ddbe3 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp @@ -27,15 +27,17 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass { }; /** - * @brief Qwen model has a specific feature in the model structure not to use position_ids input, - * this input is detached. The model expects data processing in order. + * @brief Qwen model expects data processing in order, the "position ids" input is detached and + * is not explicitly used in the model. The model uses implicitly defined "position ids" based + * on the past KV cache size. * * To use this model in Continuous batching mode, we need to apply position_ids and * use the corresponding rotary_emb_cos/rotary_emb_sin. * For this, we replace * rotary_emb_cos/rotary_emb_sin -> Slice -> Slice * With - * rotary_emb_cos/rotary_emb_sin -> Slice -> Gather(by position_ids) + * rotary_emb_cos/rotary_emb_sin -> Gather(by position_ids) + * Which enables applying RoPE for each token independently of their order in the input tensor. */ class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass { public: diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp index fd4e22c69262ae..055ea727c2736e 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp @@ -4,7 +4,6 @@ #pragma once -#include "openvino/cc/pass/itt.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/subtract.hpp" #include "openvino/pass/matcher_pass.hpp" @@ -22,6 +21,8 @@ class TRANSFORMATIONS_API PrevSequenceLengthPattern; class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("PrevSequenceLengthPattern", "0"); - explicit PrevSequenceLengthPattern(std::shared_ptr prev_max_seq_len, std::shared_ptr batch_dim); + OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern", "0"); + explicit PrevSequenceLengthPattern(const std::shared_ptr& unsqueezed_input_ids, + const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids); }; \ No newline at end of file diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp index 0de0aef1df1105..b5cd5ec88e369f 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp @@ -22,7 +22,7 @@ class TRANSFORMATIONS_API TotalSequenceLengthPatternQwen; class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("TotalSequenceLengthPattern", "0"); + OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0"); explicit TotalSequenceLengthPattern(const std::shared_ptr& max_context_len); }; @@ -35,12 +35,12 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass { * CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather * * Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern - * and replaced it with the next subgraph + * and replaced it with the next subgraph: * PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen) * **/ class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("TotalSequenceLengthPattern", "0"); + OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0"); explicit TotalSequenceLengthPatternQwen(const std::shared_ptr& max_context_len); }; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index 04702d0d102cee..1cc9be37606950 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -7,9 +7,12 @@ #include "openvino/cc/pass/itt.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" #include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" #include "openvino/op/squeeze.hpp" +#include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" @@ -49,14 +52,27 @@ ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output& position_ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& position_ids) { MATCHER_SCOPE(PositionIDsReplacerQwen); + auto _const = []() { + return wrap_type(); + }; + + // total seq len: auto p_max_context_len = wrap_type(); auto p_opt_convert = optional(p_max_context_len); auto p_opt_reshape = optional({p_opt_convert, any_input()}); + // current seg len + auto p_input_ids = wrap_type(); + auto p_unsqueeze = wrap_type({p_input_ids, _const()}); + auto p_shape_of = wrap_type({p_unsqueeze}); + auto p_current_len = wrap_type({p_shape_of, _const(), _const()}); + auto p_rotary_emb_sincos = wrap_type(); + auto p_neg_const = wrap_type(); + auto p_neg_mul = wrap_type({p_current_len, p_neg_const}); // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] - auto p_slice_1 = wrap_type({p_rotary_emb_sincos, any_input(), p_opt_reshape, any_input(), any_input()}); - auto p_slice_2 = wrap_type({p_slice_1, any_input(), any_input(), any_input(), any_input()}); + auto p_slice_1 = wrap_type({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()}); + auto p_slice_2 = wrap_type({p_slice_1, p_neg_mul, _const(), _const(), _const()}); ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -71,15 +87,20 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& p auto axis = v0::Constant::create(element::i64, Shape{}, {1}); // in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin // are used not in the sequential order, so we need to use position_ids to get the expected values. - auto gather = std::make_shared(slice_1, position_ids, axis); + auto gather = std::make_shared(slice_1->input_value(0), position_ids, axis); gather->set_friendly_name(slice_2->get_friendly_name()); gather->validate_and_infer_types(); + auto pshape = rotary_emb_sincos->get_output_partial_shape(0); + if (pshape.rank().is_dynamic() || pshape.rank().get_length() != 4) { + return false; + } + // PagedAttention expects the next layout for Q,K,V: // [batch_size_in_tokens, num_kv_heads * head_size] // so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch // num_kv_heads * head_size are already handled in the StateManagementPattern transformation - auto head_size = static_cast(rotary_emb_sincos->get_shape()[3]); + auto head_size = static_cast(pshape[3].get_length()); auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector{-1, 1, 1, head_size}); auto reshape = std::make_shared(gather, new_shape, false); replace_node(slice_2, reshape); diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp index 36d9d88975b2e0..55d7af822c3857 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.cpp @@ -14,8 +14,9 @@ using namespace ov::op; -ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr prev_max_seq_len, - std::shared_ptr batch_dim) { +ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(const std::shared_ptr& unsqueezed_input_ids, + const std::shared_ptr& max_context_len, + const std::shared_ptr& position_ids) { MATCHER_SCOPE(PrevSequenceLengthPattern); // The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in // kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim. @@ -40,8 +41,16 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptrget_output_element_type(0); std::shared_ptr replacement; if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) { + auto cur_seq_len = std::make_shared(std::make_shared(unsqueezed_input_ids), + v0::Constant::create(element::i64, Shape{}, {1}), + v0::Constant::create(element::i64, Shape{}, {0})); + auto cur_seq_len_i32 = std::make_shared(cur_seq_len, element::i32); + auto prev_max_seq_len = std::make_shared(max_context_len, cur_seq_len_i32); replacement = prev_max_seq_len; } else { + // it is not always required, so will be disposed if not needed + auto batch_dim = std::make_shared(position_ids); + // assumption that any other axis should point to batch dimension, precise reasoning is too complex // TODO: provide more reliable check replacement = batch_dim; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp index c14bcca8bce2ae..cbf9426a0c82c5 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp @@ -20,6 +20,22 @@ using namespace ov::op; using namespace ov::pass::pattern; +namespace { + +void align_replacement(std::shared_ptr& replacement, + const ov::PartialShape& required_shape, + ov::element::Type target_type) { + if (replacement->get_output_element_type(0) != target_type) { + replacement = std::make_shared(replacement, target_type); + } + + if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { + replacement = ov::op::util::reshapeTo(replacement, ov::Shape(required_shape.rank().get_length(), 1)); + } +} + +} // namespace + ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( const std::shared_ptr& max_context_len) { MATCHER_SCOPE(TotalSequenceLengthPattern); @@ -77,16 +93,8 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( if (concat_axis_to_compare == gather_idx_to_compare) { auto target_type = gather->get_output_element_type(0); - - if (replacement->get_output_element_type(0) != target_type) { - replacement = std::make_shared(replacement, target_type); - } - auto required_shape = gather->get_output_partial_shape(0); - - if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { - replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1)); - } + align_replacement(replacement, required_shape, target_type); } else { // TODO: change in the future when we start supporting dynamic shapes here replacement = ov::util::get_constant_from_source(gather->output(0)); @@ -128,14 +136,8 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen( std::shared_ptr replacement = max_context_len; auto target_type = total_seq->get_output_element_type(0); - if (replacement->get_output_element_type(0) != target_type) { - replacement = std::make_shared(replacement, target_type); - } - auto required_shape = total_seq->get_output_partial_shape(0); - if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { - replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1)); - } + align_replacement(replacement, required_shape, target_type); replace_node(total_seq, replacement); return true; @@ -143,4 +145,4 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen( auto m = std::make_shared(p_total_seq, matcher_name); register_matcher(m, callback); -} \ No newline at end of file +} diff --git a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp index e4c9e488708e18..840309993c939a 100644 --- a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp +++ b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp @@ -29,6 +29,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" #include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" #include "transformations/utils/gen_pattern.hpp" #include "transformations/utils/print_model.hpp" @@ -71,19 +72,23 @@ ov::ParameterVector nodes_to_params(const ov::NodeVector& node_vec) { enum QKV : int { Q = 0, K = 1, V = 2 }; vector MOCK_VALUE = {1}; +// original weights = 151936, attention_weights = 12288 +#define WEIGHTS 1024 +#define ATTENTION_WEIGHTS 512 + class Qwen7bChatSDPA { public: static std::shared_ptr gen_embeddings(const std::shared_ptr& input_ids) { auto view_reshape = makeOP({input_ids, {-1, 0}}, {special_zero_true}); auto input_ids_i64 = makeOP({view_reshape}, {dest_type_i64}); - auto weights = makeConst(element::u8, {151936, 4096}, MOCK_VALUE); + auto weights = makeConst(element::u8, {WEIGHTS, 4096}, MOCK_VALUE); auto weights_fp16 = makeOP({weights}, {dest_type_f16}); - auto zero_point = makeConst(element::u8, {151936, 1}, MOCK_VALUE); + auto zero_point = makeConst(element::u8, {WEIGHTS, 1}, MOCK_VALUE); auto zero_point_fp16 = makeOP({zero_point}, {dest_type_f16}); auto zero_point_subtract = makeOP({weights_fp16, zero_point_fp16}, {numpy_broadcast}); - auto scale = makeConst(element::f16, {151936, 1}, MOCK_VALUE); + auto scale = makeConst(element::f16, {WEIGHTS, 1}, MOCK_VALUE); auto mul_scale = makeOP({zero_point_subtract, scale}, {numpy_broadcast}); auto fq_weights = makeOP({mul_scale}, {dest_type_f32}); @@ -91,35 +96,39 @@ class Qwen7bChatSDPA { } static std::shared_ptr gen_attention_weights() { - auto weights = makeConst(element::u8, {12288, 4096}, MOCK_VALUE); - auto Convert_375820 = makeOP({weights}, {dest_type_f16}); - auto attn_c_attn_weight_zero_point = makeConst(element::u8, {12288, 1}, MOCK_VALUE); - auto Convert_375823 = makeOP({attn_c_attn_weight_zero_point}, {dest_type_f16}); - auto attn_c_attn_weight_zero_point_subtract = - makeOP({Convert_375820, Convert_375823}, {numpy_broadcast}); - auto attn_c_attn_weight_scale = makeConst(element::f16, {12288, 1}, MOCK_VALUE); - auto attn_c_attn_weight_fq_weights_1 = - makeOP({attn_c_attn_weight_zero_point_subtract, attn_c_attn_weight_scale}, {numpy_broadcast}); - return makeOP({attn_c_attn_weight_fq_weights_1}, {dest_type_f32}); + auto weights = makeConst(element::u8, {ATTENTION_WEIGHTS, 4096}, MOCK_VALUE); + auto weights_f16 = makeOP({weights}, {dest_type_f16}); + + auto zero_points = makeConst(element::u8, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto zero_points_f16 = makeOP({zero_points}, {dest_type_f16}); + auto subtract = makeOP({weights_f16, zero_points_f16}, {numpy_broadcast}); + + auto scale = makeConst(element::f16, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto mul = makeOP({subtract, scale}, {numpy_broadcast}); + return makeOP({mul}, {dest_type_f32}); } static std::shared_ptr gen_qkv_proj(const std::shared_ptr& embeddings) { - auto Constant_244726 = single_val(/*rank*/ 3, /*val*/ 1); - auto Constant_244724 = single_val(/*rank*/ 3, /*val*/ 2); - auto pow_Power = makeOP({embeddings, Constant_244724}, {numpy_broadcast}); - auto mean_ReduceMean = makeOP({pow_Power, {-1}}, {{"keep_dims", true}}); - auto Constant_244725 = single_val(/*rank*/ 3, /*val*/ 1); - auto add_Add = makeOP({mean_ReduceMean, Constant_244725}, {numpy_broadcast}); - auto rsqrt_Sqrt = makeOP({add_Add}); - auto rsqrt_Divide = makeOP({Constant_244726, rsqrt_Sqrt}, {numpy_broadcast, {"m_pythondiv", true}}); - auto mul_Multiply_0 = makeOP({embeddings, rsqrt_Divide}, {numpy_broadcast}); - auto Constant_244727 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); - auto mul_Multiply_1 = makeOP({mul_Multiply_0, Constant_244727}, {numpy_broadcast}); + auto _const_0 = single_val(/*rank*/ 3, /*val*/ 2); + auto pow = makeOP({embeddings, _const_0}, {numpy_broadcast}); + auto mean = makeOP({pow, {-1}}, {{"keep_dims", true}}); + + auto _const_1 = single_val(/*rank*/ 3, /*val*/ 1); + auto add = makeOP({mean, _const_1}, {numpy_broadcast}); + auto sqrt = makeOP({add}); + + auto _const_2 = single_val(/*rank*/ 3, /*val*/ 1); + auto div = makeOP({_const_2, sqrt}, {numpy_broadcast, {"m_pythondiv", true}}); + auto mul_0 = makeOP({embeddings, div}, {numpy_broadcast}); + + auto _const_3 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); + auto mul_1 = makeOP({mul_0, _const_3}, {numpy_broadcast}); auto attention_weights = gen_attention_weights(); - auto linear_MatMul = - makeOP({mul_Multiply_1, attention_weights}, {{"transpose_a", false}, {"transpose_b", true}}); - auto Constant_244728 = makeConst(element::f32, {1, 1, 12288}, MOCK_VALUE); - auto linear_add = makeOP({linear_MatMul, Constant_244728}, {numpy_broadcast}); + auto linear_matmul = + makeOP({mul_1, attention_weights}, {{"transpose_a", false}, {"transpose_b", true}}); + + auto _const_4 = makeConst(element::f32, {1, 1, ATTENTION_WEIGHTS}, MOCK_VALUE); + auto linear_add = makeOP({linear_matmul, _const_4}, {numpy_broadcast}); return makeOP({linear_add, 2, {4096, 4096, -1}}); } @@ -154,20 +163,23 @@ class Qwen7bChatSDPA { static std::shared_ptr gen_rope(QKV idx, const std::shared_ptr& qkv_proj, const std::shared_ptr& head_size, - const std::shared_ptr& sliced_sin_by_current, - const std::shared_ptr& sliced_cos_by_current) { + const std::shared_ptr& sliced_sin, + const std::shared_ptr& sliced_cos) { auto current_k = makeOP({qkv_proj->output(idx), {0, 0, 32, 128}}, {special_zero_true}); auto sliced_k = makeOP({current_k, {0}, head_size, {1}, {3}}); - auto mul_Multiply_2 = makeOP({sliced_k, sliced_cos_by_current}, {numpy_broadcast}); - auto reshape_Reshape_1 = makeOP({sliced_k, {0, 0, 32, 2, 64}}, {special_zero_true}); - auto ListUnpack_Split_1 = makeOP({reshape_Reshape_1, -2}, {{"num_splits", 2}}); - auto ListUnpack_Squeeze_2 = makeOP({ListUnpack_Split_1->output(1), -2}); - auto Constant_244730 = single_val(/*rank*/ 4, /*val*/ 1); - auto neg_Multiply_3 = makeOP({ListUnpack_Squeeze_2, Constant_244730}, {numpy_broadcast}); - auto ListUnpack_Squeeze_1 = makeOP({ListUnpack_Split_1->output(0), -2}); - auto cat_Concat_2 = makeOP({neg_Multiply_3, ListUnpack_Squeeze_1}, {{"axis", -1}}); - auto mul_Multiply_3 = makeOP({cat_Concat_2, sliced_sin_by_current}, {numpy_broadcast}); - return makeOP({mul_Multiply_2, mul_Multiply_3}, {numpy_broadcast}); + auto mul_1 = makeOP({sliced_k, sliced_cos}, {numpy_broadcast}); + + auto reshape = makeOP({sliced_k, {0, 0, 32, 2, 64}}, {special_zero_true}); + auto split_1 = makeOP({reshape, -2}, {{"num_splits", 2}}); + auto list_unpack_1 = makeOP({split_1->output(1), -2}); + + auto _const = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_2 = makeOP({list_unpack_1, _const}, {numpy_broadcast}); + auto list_unpack_2 = makeOP({split_1->output(0), -2}); + auto concat = makeOP({mul_2, list_unpack_2}, {{"axis", -1}}); + + auto mul_3 = makeOP({concat, sliced_sin}, {numpy_broadcast}); + return makeOP({mul_1, mul_3}, {numpy_broadcast}); } static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& total_seq_len, @@ -175,8 +187,8 @@ class Qwen7bChatSDPA { std::shared_ptr& head_size) { auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); auto sliced_sin_by_total = makeOP({sin, {0}, total_seq_len, {1}, {1}}); - auto rotery_emb_sin_shape = makeOP({sliced_sin_by_total}, {{"output_type", "i64"}}); - head_size = makeOP({rotery_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}}); + auto rotary_emb_sin_shape = makeOP({sliced_sin_by_total}, {{"output_type", "i64"}}); + head_size = makeOP({rotary_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}}); return makeOP({sliced_sin_by_total, neg_mul, {LLONG_MAX}, {1}, {1}}); } @@ -205,10 +217,10 @@ class Qwen7bChatSDPA { static std::shared_ptr gen_Q(const std::shared_ptr& past_seq_len_2, const std::shared_ptr& total_seq_len_2, const std::shared_ptr& rope_Q) { - auto slice_Slice_10 = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); - auto slice_Slice_13 = makeOP({slice_Slice_10, past_seq_len_2, total_seq_len_2, {1}, {1}}); - auto mul_Multiply_4 = makeOP({rope_Q, slice_Slice_13}, {numpy_broadcast}); - return makeOP({mul_Multiply_4, {0, 2, 1, 3}}); + auto _const = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); + auto slice = makeOP({_const, past_seq_len_2, total_seq_len_2, {1}, {1}}); + auto mul = makeOP({rope_Q, slice}, {numpy_broadcast}); + return makeOP({mul, {0, 2, 1, 3}}); } static std::shared_ptr gen_total_seq_len_2(const std::shared_ptr& past_k_len, @@ -228,132 +240,146 @@ class Qwen7bChatSDPA { static std::shared_ptr gen_attention_mask(const std::shared_ptr& Q_in, const std::shared_ptr& attention_mask_in, const std::shared_ptr& total_seq_len) { - auto slice_Slice_17 = makeConst(element::boolean, {1, 1, 8192, 8192}, MOCK_VALUE); + auto _const = makeConst(element::boolean, {1, 1, 8192, 8192}, MOCK_VALUE); auto shape_of_q = makeOP({Q_in}, {{"output_type", "i64"}}); - auto Gather_255227 = makeOP({shape_of_q, {2}, 0}, {{"batch_dims", 0}}); - auto sub_Subtract_1 = makeOP({total_seq_len, Gather_255227}, {numpy_broadcast}); - auto Concat_238310 = makeOP({sub_Subtract_1, {0ll}}, {{"axis", 0}}); - auto Concat_238311 = makeOP({total_seq_len, {2}}, {{"mode", "numpy"}}); - auto slice_Slice_19 = makeOP({slice_Slice_17, Concat_238310, Concat_238311, {1, 1}, {2, 3}}); - auto bitwise_not_BitwiseNot = makeOP({slice_Slice_19}); - auto Constant_244732 = single_val(/*rank*/ 4, /*val*/ 1); - auto view_Reshape_3 = makeOP({attention_mask_in, {0, 0}}, {special_zero_true}); - auto unsqueeze_Unsqueeze = makeOP({view_Reshape_3, 1}); - auto unsqueeze_Unsqueeze_1 = makeOP({unsqueeze_Unsqueeze, 2}); - auto to_Convert = makeOP({unsqueeze_Unsqueeze_1}, {dest_type_f32}); - auto Constant_244731 = single_val(/*rank*/ 4, /*val*/ 1); - auto rsub_Multiply = makeOP({to_Convert, Constant_244731}, {numpy_broadcast}); - auto rsub_Subtract = makeOP({Constant_244732, rsub_Multiply}, {numpy_broadcast}); - auto Constant_244733 = single_val(/*rank*/ 4, /*val*/ 1); - auto mul_Multiply_5 = makeOP({rsub_Subtract, Constant_244733}, {numpy_broadcast}); - auto ListConstruct_5 = makeOP({{1ll}, {1ll}, Gather_255227, {1ll}}, {{"axis", 0}}); - auto expand_Broadcast = makeOP({mul_Multiply_5, ListConstruct_5}, {{"mode", "bidirectional"}}); - return makeOP({bitwise_not_BitwiseNot, -FLT_MAX, expand_Broadcast}, {numpy_broadcast}); + auto gather = makeOP({shape_of_q, {2}, 0}, {{"batch_dims", 0}}); + auto sub_1 = makeOP({total_seq_len, gather}, {numpy_broadcast}); + auto concat = makeOP({sub_1, {0ll}}, {{"axis", 0}}); + auto broadcast = makeOP({total_seq_len, {2}}, {{"mode", "numpy"}}); + auto slice = makeOP({_const, concat, broadcast, {1, 1}, {2, 3}}); + auto bitwise_not = makeOP({slice}); + + auto _const_1 = single_val(/*rank*/ 4, /*val*/ 1); + auto view_reshape = makeOP({attention_mask_in, {0, 0}}, {special_zero_true}); + auto unsqueeze_0 = makeOP({view_reshape, 1}); + auto unsqueeze_1 = makeOP({unsqueeze_0, 2}); + auto convert_0 = makeOP({unsqueeze_1}, {dest_type_f32}); + + auto _const_2 = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_1 = makeOP({convert_0, _const_2}, {numpy_broadcast}); + auto sub_2 = makeOP({_const_1, mul_1}, {numpy_broadcast}); + + auto _const_3 = single_val(/*rank*/ 4, /*val*/ 1); + auto mul_2 = makeOP({sub_2, _const_3}, {numpy_broadcast}); + auto list_construct = makeOP({{1ll}, {1ll}, gather, {1ll}}, {{"axis", 0}}); + auto expand_broadcast = makeOP({mul_2, list_construct}, {{"mode", "bidirectional"}}); + return makeOP({bitwise_not, -FLT_MAX, expand_broadcast}, {numpy_broadcast}); } }; class Qwen7bChatPA { public: static std::shared_ptr gen_embeddings(const std::shared_ptr& input_ids) { - auto Constant_241 = makeConst(element::u8, {151936, 4096}, MOCK_VALUE); - auto Convert_242 = makeOP({Constant_241}, {dest_type_f16}); - auto Constant_243 = makeConst(element::u8, {151936, 1}, MOCK_VALUE); - auto Convert_244 = makeOP({Constant_243}, {dest_type_f16}); - auto Subtract_245 = makeOP({Convert_242, Convert_244}, {numpy_broadcast}); - auto Constant_246 = makeConst(element::f16, {151936, 1}, MOCK_VALUE); - auto Multiply_247 = makeOP({Subtract_245, Constant_246}, {numpy_broadcast}); - auto Convert_248 = makeOP({Multiply_247}, {dest_type_f32}); - auto Reshape_239 = makeOP({input_ids, {-1, 0}}, {special_zero_true}); - auto Convert_240 = makeOP({Reshape_239}, {dest_type_i64}); - return makeOP({Convert_248, Convert_240, 0}, {{"batch_dims", 0}}); + auto weights = makeConst(element::u8, {WEIGHTS, 4096}, MOCK_VALUE); + auto weights_fp16 = makeOP({weights}, {dest_type_f16}); + + auto zero_point = makeConst(element::u8, {WEIGHTS, 1}, MOCK_VALUE); + auto zero_point_fp16 = makeOP({zero_point}, {dest_type_f16}); + auto sub = makeOP({weights_fp16, zero_point_fp16}, {numpy_broadcast}); + + auto scale = makeConst(element::f16, {WEIGHTS, 1}, MOCK_VALUE); + auto mul = makeOP({sub, scale}, {numpy_broadcast}); + auto mul_fp32 = makeOP({mul}, {dest_type_f32}); + + auto reshape_view = makeOP({input_ids, {-1, 0}}, {special_zero_true}); + auto reshape_view_i64 = makeOP({reshape_view}, {dest_type_i64}); + return makeOP({mul_fp32, reshape_view_i64, 0}, {{"batch_dims", 0}}); } static std::shared_ptr gen_qkv_proj(const std::shared_ptr& embeddings) { - auto Constant_236 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); - auto Constant_237 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); - auto Power_251 = makeOP({embeddings, Constant_237}, {numpy_broadcast}); - auto ReduceMean_253 = makeOP({Power_251, {-1}}, {{"keep_dims", true}}); - auto Constant_254 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); - auto Add_255 = makeOP({ReduceMean_253, Constant_254}, {numpy_broadcast}); - auto Sqrt_256 = makeOP({Add_255}); - auto Divide_257 = makeOP({Constant_236, Sqrt_256}, {numpy_broadcast, {"m_pythondiv", true}}); - auto Multiply_258 = makeOP({embeddings, Divide_257}, {numpy_broadcast}); - auto Constant_259 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); - auto Multiply_260 = makeOP({Multiply_258, Constant_259}, {numpy_broadcast}); - auto Constant_261 = makeConst(element::u8, {12288, 4096}, MOCK_VALUE); - auto Convert_262 = makeOP({Constant_261}, {dest_type_f16}); - auto Constant_263 = makeConst(element::u8, {12288, 1}, MOCK_VALUE); - auto Convert_264 = makeOP({Constant_263}, {dest_type_f16}); - auto Subtract_265 = makeOP({Convert_262, Convert_264}, {numpy_broadcast}); - auto Constant_266 = makeConst(element::f16, {12288, 1}, MOCK_VALUE); - auto Multiply_267 = makeOP({Subtract_265, Constant_266}, {numpy_broadcast}); - auto Convert_268 = makeOP({Multiply_267}, {dest_type_f32}); - auto MatMul_269 = - makeOP({Multiply_260, Convert_268}, {{"transpose_a", false}, {"transpose_b", true}}); - auto Constant_270 = makeConst(element::f32, {1, 1, 12288}, MOCK_VALUE); - auto Add_271 = makeOP({MatMul_269, Constant_270}, {numpy_broadcast}); - - return makeOP({Add_271, 2, {4096, 4096, -1}}); + auto _const_0 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto pow = makeOP({embeddings, _const_0}, {numpy_broadcast}); + auto mean = makeOP({pow, {-1}}, {{"keep_dims", true}}); + auto _const_1 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto add_0 = makeOP({mean, _const_1}, {numpy_broadcast}); + + auto sqrt = makeOP({add_0}); + auto _const_2 = makeConst(element::f32, {1, 1, 1}, MOCK_VALUE); + auto div = makeOP({_const_2, sqrt}, {numpy_broadcast, {"m_pythondiv", true}}); + auto mul_0 = makeOP({embeddings, div}, {numpy_broadcast}); + + auto _const_3 = makeConst(element::f32, {1, 1, 4096}, MOCK_VALUE); + auto mul_1 = makeOP({mul_0, _const_3}, {numpy_broadcast}); + + auto _const_4 = makeConst(element::u8, {ATTENTION_WEIGHTS, 4096}, MOCK_VALUE); + auto convert_0 = makeOP({_const_4}, {dest_type_f16}); + + auto _const_5 = makeConst(element::u8, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto convert_1 = makeOP({_const_5}, {dest_type_f16}); + auto sub = makeOP({convert_0, convert_1}, {numpy_broadcast}); + + auto _const_6 = makeConst(element::f16, {ATTENTION_WEIGHTS, 1}, MOCK_VALUE); + auto mul_2 = makeOP({sub, _const_6}, {numpy_broadcast}); + auto convert_2 = makeOP({mul_2}, {dest_type_f32}); + auto matmul = makeOP({mul_1, convert_2}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Constant_270 = makeConst(element::f32, {1, 1, ATTENTION_WEIGHTS}, MOCK_VALUE); + auto add_1 = makeOP({matmul, Constant_270}, {numpy_broadcast}); + + return makeOP({add_1, 2, {4096, 4096, -1}}); } static std::shared_ptr gen_rope(QKV idx, const std::shared_ptr& qkv_proj, const std::shared_ptr& head_size, - const std::shared_ptr& sliced_sin_by_current, - const std::shared_ptr& sliced_cos_by_current) { - auto Reshape_276 = makeOP({qkv_proj->output(idx), {0, 0, 32, 128}}, {special_zero_true}); - auto Slice_437 = makeOP({Reshape_276, {0}, head_size, {1}, {3}}); - auto Multiply_440 = makeOP({Slice_437, sliced_cos_by_current}, {numpy_broadcast}); - auto Reshape_442 = makeOP({Slice_437, {0, 0, 32, 2, 64}}, {special_zero_true}); - auto Split_444 = makeOP({Reshape_442, -2}, {{"num_splits", 2}}); - auto Squeeze_446 = makeOP({Split_444->output(1), -2}); - auto Constant_447 = makeConst(element::f32, {1, 1, 1, 1}, {1.000000f}); - auto Multiply_448 = makeOP({Squeeze_446, Constant_447}, {numpy_broadcast}); - auto Squeeze_450 = makeOP({Split_444->output(0), -2}); - auto Concat_451 = makeOP({Multiply_448, Squeeze_450}, {{"axis", -1}}); - auto Multiply_461 = makeOP({Concat_451, sliced_sin_by_current}, {numpy_broadcast}); - return makeOP({Multiply_440, Multiply_461}, {numpy_broadcast}); + const std::shared_ptr& sin, + const std::shared_ptr& cos) { + auto Q_or_K = makeOP({qkv_proj->output(idx), {0, 0, 32, 128}}, {special_zero_true}); + auto sliced = makeOP({Q_or_K, {0}, head_size, {1}, {3}}); + auto mul_0 = makeOP({sliced, sin}, {numpy_broadcast}); + + auto reshape = makeOP({sliced, {0, 0, 32, 2, 64}}, {special_zero_true}); + auto split = makeOP({reshape, -2}, {{"num_splits", 2}}); + auto squeeze_0 = makeOP({split->output(1), -2}); + auto _const_0 = makeConst(element::f32, {1, 1, 1, 1}, {1.000000f}); + auto mul_1 = makeOP({squeeze_0, _const_0}, {numpy_broadcast}); + + auto squeeze_1 = makeOP({split->output(0), -2}); + auto concat = makeOP({mul_1, squeeze_1}, {{"axis", -1}}); + auto mul_2 = makeOP({concat, cos}, {numpy_broadcast}); + return makeOP({mul_0, mul_2}, {numpy_broadcast}); } static std::shared_ptr gen_rope_emb_sin(const std::shared_ptr& max_context_len, const std::shared_ptr& position_ids, std::shared_ptr& head_size) { - auto Constant_277 = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); - auto Slice_293 = makeOP({Constant_277, {0}, max_context_len, {1}, {1}}); - auto slice_sin = makeOP({Slice_293, position_ids, 1}, {{"batch_dims", 0}}); - auto ShapeOf_430 = makeOP({Slice_293}, {{"output_type", "i64"}}); - head_size = makeOP({ShapeOf_430, {3}, 0}, {{"batch_dims", 0}}); + auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto slice_sin = makeOP({sin, position_ids, 1}, {{"batch_dims", 0}}); + + auto slice = makeOP({sin, {0}, max_context_len, {1}, {1}}); + auto shape_of = makeOP({slice}, {{"output_type", "i64"}}); + head_size = makeOP({shape_of, {3}, 0}, {{"batch_dims", 0}}); + return makeOP({slice_sin, {-1, 1, 1, 128}}, {{"special_zero", false}}); } static std::shared_ptr gen_rope_emb_cos(const std::shared_ptr& max_context_len, const std::shared_ptr& position_ids) { - auto Constant_452 = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); - auto Slice_456 = makeOP({Constant_452, {0}, max_context_len, {1}, {1}}); - auto Slice_460 = makeOP({Slice_456, position_ids, 1}, {{"batch_dims", 0}}); - return makeOP({Slice_460, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE); + auto slice = makeOP({cos, position_ids, 1}, {{"batch_dims", 0}}); + return makeOP({slice, {-1, 1, 1, 128}}, {{"special_zero", false}}); } static std::shared_ptr align_pa_layout(const std::shared_ptr& pa, const std::shared_ptr& head_size) { - auto Concat_1257 = makeOP({{0ll}, {1ll}, {-1ll}, head_size}, {{"axis", 0}}); - auto Reshape_1258 = makeOP({pa->output(0), Concat_1257}, {special_zero_true}); - return makeOP({Reshape_1258, {0, 2, 1, 3}}); + auto shape = makeOP({{0ll}, {1ll}, {-1ll}, head_size}, {{"axis", 0}}); + auto reshaped = makeOP({pa->output(0), shape}, {special_zero_true}); + return makeOP({reshaped, {0, 2, 1, 3}}); } static std::shared_ptr gen_current_len(const std::shared_ptr& rope_K) { - auto ShapeOf_484 = makeOP({rope_K}, {{"output_type", "i32"}}); - return makeOP({ShapeOf_484, {1}, 0ll}, {{"batch_dims", 0}}); + auto shape_of = makeOP({rope_K}, {{"output_type", "i32"}}); + return makeOP({shape_of, {1}, 0ll}, {{"batch_dims", 0}}); } static std::shared_ptr gen_past_len(const std::shared_ptr& input_ids, const std::shared_ptr& max_context_len) { - auto ShapeOf_897 = makeOP({input_ids}, {{"output_type", "i64"}}); - auto Gather_898 = makeOP({ShapeOf_897, 1ll, 0ll}, {{"batch_dims", 0}}); - auto Convert_899 = makeOP({Gather_898}, {{"destination_type", "i32"}}); - auto past_len = makeOP({max_context_len, Convert_899}, {numpy_broadcast}); - auto Convert_1000 = makeOP({past_len}, {{"destination_type", "i32"}}); - return makeOP({Convert_1000, {1}}, {special_zero_true}); + auto shape_of = makeOP({input_ids}, {{"output_type", "i64"}}); + auto cur_len = makeOP({shape_of, 1ll, 0ll}, {{"batch_dims", 0}}); + auto cur_len_i32 = makeOP({cur_len}, {{"destination_type", "i32"}}); + + auto past_len = makeOP({max_context_len, cur_len_i32}, {numpy_broadcast}); + auto past_len_i32 = makeOP({past_len}, {{"destination_type", "i32"}}); + return makeOP({past_len_i32, {1}}, {special_zero_true}); } static std::shared_ptr gen_total_len(const std::shared_ptr& cur_len, @@ -362,35 +388,36 @@ class Qwen7bChatPA { } static std::shared_ptr gen_V(const std::shared_ptr& qkv_proj, std::shared_ptr& head_size) { - auto Reshape_641 = makeOP({qkv_proj->output(2), {0, 0, 32, 128}}, {special_zero_true}); - auto Gather_1231 = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); - auto Transpose_1232 = makeOP({Reshape_641, Gather_1231}); + auto current_V = makeOP({qkv_proj->output(2), {0, 0, 32, 128}}, {special_zero_true}); + auto gather = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); + auto transpose = makeOP({current_V, gather}); - auto ShapeOf_1250 = makeOP({Transpose_1232}, {{"output_type", "i64"}}); - auto Gather_1251 = makeOP({ShapeOf_1250, -1ll, 0ll}, {{"batch_dims", 0}}); - head_size = makeOP({Gather_1251, 0}); + auto shape_of = makeOP({transpose}, {{"output_type", "i64"}}); + auto gather_2 = makeOP({shape_of, -1ll, 0ll}, {{"batch_dims", 0}}); + head_size = makeOP({gather_2, 0}); - return makeOP({Transpose_1232, {0, -1}}, {special_zero_true}); + return makeOP({transpose, {0, -1}}, {special_zero_true}); } static std::shared_ptr gen_K(const std::shared_ptr& rope_K) { - auto Gather_1227 = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); - auto Transpose_1228 = makeOP({rope_K, Gather_1227}); - return makeOP({Transpose_1228, {0, -1}}, {special_zero_true}); + auto gather = makeOP({{0, 2, 1, 3}, {0, 2, 1, 3}, 0ll}, {{"batch_dims", 0}}); + auto transpose = makeOP({rope_K, gather}); + return makeOP({transpose, {0, -1}}, {special_zero_true}); } static std::shared_ptr gen_Q(const std::shared_ptr& total_seq_len, const std::shared_ptr& rope_Q) { - auto Constant_463 = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); - auto ShapeOf_489 = makeOP({rope_Q}, {{"output_type", "i32"}}); - auto Gather_492 = makeOP({ShapeOf_489, {1}, 0ll}, {{"batch_dims", 0}}); - auto past_seq_len_2 = makeOP({total_seq_len, Gather_492}, {numpy_broadcast}); - auto Slice_496 = makeOP({Constant_463, past_seq_len_2, total_seq_len, {1}, {1}}); - auto Multiply_631 = makeOP({rope_Q, Slice_496}, {numpy_broadcast}); - auto Transpose_633 = makeOP({Multiply_631, {0, 2, 1, 3}}); + auto _const_1 = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE); + auto shape_of = makeOP({rope_Q}, {{"output_type", "i32"}}); + auto current_seq_len = makeOP({shape_of, {1}, 0ll}, {{"batch_dims", 0}}); + auto past_seq_len = makeOP({total_seq_len, current_seq_len}, {numpy_broadcast}); + + auto slice = makeOP({_const_1, past_seq_len, total_seq_len, {1}, {1}}); + auto mul = makeOP({rope_Q, slice}, {numpy_broadcast}); + auto transpose_1 = makeOP({mul, {0, 2, 1, 3}}); - auto Transpose_1223 = makeOP({Transpose_633, {0, 2, 1, 3}}); - return makeOP({Transpose_1223, {0, -1}}, {special_zero_true}); + auto transpose_2 = makeOP({transpose_1, {0, 2, 1, 3}}); + return makeOP({transpose_2, {0, -1}}, {special_zero_true}); } }; @@ -475,8 +502,8 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { position_ids}); // Inputs pre-processing: - auto Convert_1001 = makeOP({max_context_len}, {dest_type_i64}); - auto max_context_len_aligned = makeOP({Convert_1001, {1}}, {special_zero_true}); + auto max_context_len_i64 = makeOP({max_context_len}, {dest_type_i64}); + auto max_context_len_aligned = makeOP({max_context_len_i64, {1}}, {special_zero_true}); auto input_ids_aligned = makeOP({input_ids, 1}); auto position_ids_aligned = makeOP({position_ids, 1}); @@ -499,7 +526,6 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { auto total_seq_len = Qwen7bChatPA::gen_total_len(current_seq_len, past_seq_len); // Q, K, V: - shared_ptr head_size_2; auto Q = Qwen7bChatPA::gen_Q(total_seq_len, rope_Q); auto K = Qwen7bChatPA::gen_K(rope_K); @@ -535,3 +561,58 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) { comparator.disable(FunctionsComparator::PRECISIONS); disable_rt_info_check(); } + +TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) { + { + // Inputs to SDPA transformer: + auto beam_idx = makeOP({}, {{"shape", PartialShape{DYN}}, el_type_i64}); + auto input_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, el_type_i64}); + ParameterVector params = nodes_to_params({input_ids, beam_idx}); + + // K cache + auto k_cache = Qwen7bChatSDPA::gen_cache(input_ids, beam_idx, "K_cache"); + + // Current/past/total Seq lengths calculation: + auto current_len = Qwen7bChatSDPA::gen_current_len(input_ids); + auto past_len = Qwen7bChatSDPA::gen_past_len(k_cache); + auto total_len = Qwen7bChatSDPA::gen_total_len(current_len, past_len); + auto result = std::make_shared(total_len); + + // Expected that these Nodes to be created inside SDPAToPagedAttention + auto new_input_ids = std::make_shared(element::i64, PartialShape{DYN}); + auto axis = v0::Constant::create(element::i32, Shape{}, {1}); + auto aligned_input_ids = std::make_shared(new_input_ids, axis); + + input_ids->output(0).replace(aligned_input_ids); + auto max_context_len = std::make_shared(element::i32, PartialShape{}); + max_context_len->output(0).set_names({"max_context_len"}); + auto position_ids = std::make_shared(element::i64, PartialShape{DYN}); + position_ids->output(0).set_names({"position_ids"}); + + params.push_back(max_context_len); + params.push_back(new_input_ids); + + // Model and Transformations: + model = std::make_shared(ResultVector{result}, params); + manager.register_pass(aligned_input_ids, max_context_len, position_ids); + manager.register_pass(max_context_len); + } + + { + // Inputs to PA transformer: + auto max_context_len = makeOP({}, {{"shape", PartialShape{}}, el_type_i32}); + auto params = nodes_to_params({max_context_len}); + + // Inputs pre-processing: + auto max_context_len_i64 = makeOP({max_context_len}, {dest_type_i64}); + auto max_context_len_aligned = makeOP({max_context_len_i64, {1}}, {special_zero_true}); + + auto result = std::make_shared(max_context_len_aligned); + model_ref = std::make_shared(ResultVector{result}, params); + } + // TODO: align precisions, check the copying of "fuse_names" attr in SDPAToPagedAttention + // checking the graph structure and names, other checks are temporarily disabled: + comparator.disable(FunctionsComparator::PRECISIONS); + disable_result_friendly_names_check(); + disable_rt_info_check(); +} diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp index 74aeacb0719cee..d52e78dbd6a489 100644 --- a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -19,7 +19,7 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass { public: OPENVINO_MODEL_PASS_RTTI("SDPAToPagedAttention"); - SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); + explicit SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); bool run_on_model(const std::shared_ptr& model) override; private: diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index 2f697102afc972..e6fc744bb5ef4f 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -88,12 +88,6 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(std::make_shared(unsqueezed_input_ids), - v0::Constant::create(element::i64, Shape{}, {1}), - v0::Constant::create(element::i64, Shape{}, {0})); - auto prev_max_seq_len = - std::make_shared(max_context_len, std::make_shared(cur_seq_len, element::i32)); - ParameterVector kv_parameters; ParameterVector parameters_to_remove; ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model @@ -118,9 +112,6 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(position_ids); // it is not always required, so will be disposed if not needed - ov::pass::Manager manager("SDPA to PA"); manager.set_per_pass_validation(false); manager.register_pass(kv_parameters, @@ -134,11 +125,11 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(prev_max_seq_len, batch_dim); + manager.register_pass(unsqueezed_input_ids, max_context_len, position_ids); manager.register_pass(max_context_len); manager.register_pass(max_context_len); - manager.register_pass(unsqueezed_position_ids->output(0)); - manager.register_pass(unsqueezed_position_ids->output(0)); + manager.register_pass(unsqueezed_position_ids); + manager.register_pass(unsqueezed_position_ids); manager.run_passes(model); {