diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp index d973dbd164b08d..e341116cdf847a 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/position_ids_replacer.hpp @@ -26,6 +26,17 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass { explicit PositionIDsReplacer(const Output& position_ids); }; +/** + * @brief Qwen model has a specific feature in the model structure not to use position_ids input, + * this input is detached. The model expects data processing in order. + * + * To use this model in Continuous batching mode, we need to apply position_ids and + * use the corresponding rotary_emb_cos/rotary_emb_sin. + * For this, we replace + * rotary_emb_cos/rotary_emb_sin -> Slice -> Slice + * With + * rotary_emb_cos/rotary_emb_sin -> Slice -> Gather(by position_ids) + */ class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass { public: OPENVINO_RTTI("PositionIDsReplacerQwen", "0"); diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp index 709c76ae423fb3..0de0aef1df1105 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp @@ -26,6 +26,19 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass { explicit TotalSequenceLengthPattern(const std::shared_ptr& max_context_len); }; +/** + * @brief Qwen model has a specific pattern for TotalSequenceLen place detection. + * + * common pattern: Add (PrevSeqLen, CurrentSeqLen) + * + * The CurrentSeqLen is presented in this form: + * CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather + * + * Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern + * and replaced it with the next subgraph + * PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen) + * + **/ class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass { public: OPENVINO_RTTI("TotalSequenceLengthPattern", "0"); diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp index 3eb954d5daaac9..04702d0d102cee 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/position_ids_replacer.cpp @@ -15,6 +15,7 @@ #include "transformations/utils/utils.hpp" using namespace ov::op; +using namespace ov::pass::pattern; // TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when // position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using @@ -22,65 +23,69 @@ using namespace ov::op; ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output& position_ids) { MATCHER_SCOPE(PositionIDsReplacer); - auto input_ids = pattern::any_input(); - auto input_embed = pattern::wrap_type({pattern::any_input(), input_ids, pattern::any_input()}); + auto input_ids = any_input(); + auto input_embed = wrap_type({any_input(), input_ids, any_input()}); - auto position_ids_pattern = pattern::any_input(); - auto offset = pattern::wrap_type(); - auto add_offset = pattern::wrap_type({position_ids_pattern, offset}); - auto convert = pattern::wrap_type({add_offset}); - auto position_embed = pattern::wrap_type({pattern::any_input(), convert, pattern::any_input()}); + auto position_ids_pattern = any_input(); + auto offset = wrap_type(); + auto add_offset = wrap_type({position_ids_pattern, offset}); + auto convert = wrap_type({add_offset}); + auto position_embed = wrap_type({any_input(), convert, any_input()}); - auto mul = pattern::optional({input_embed, pattern::any_input()}); + auto mul = optional({input_embed, any_input()}); - auto add = pattern::wrap_type({mul, position_embed}); + auto add = wrap_type({mul, position_embed}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids.get_node_shared_ptr()); return true; }; - auto m = std::make_shared(add, matcher_name); + auto m = std::make_shared(add, matcher_name); register_matcher(m, callback); } ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output& position_ids) { MATCHER_SCOPE(PositionIDsReplacerQwen); - auto max_context_len_pattern = pattern::wrap_type(); - auto optional_convert = pattern::optional(max_context_len_pattern); - auto optional_reshape = pattern::optional({optional_convert, pattern::any_input()}); + auto p_max_context_len = wrap_type(); + auto p_opt_convert = optional(p_max_context_len); + auto p_opt_reshape = optional({p_opt_convert, any_input()}); - auto slice_1_pattern = pattern::wrap_type( - {pattern::any_input(), pattern::any_input(), optional_reshape, pattern::any_input(), pattern::any_input()}); - auto slice_2_pattern = pattern::wrap_type( - {slice_1_pattern, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()}); + auto p_rotary_emb_sincos = wrap_type(); + // the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128] + auto p_slice_1 = wrap_type({p_rotary_emb_sincos, any_input(), p_opt_reshape, any_input(), any_input()}); + auto p_slice_2 = wrap_type({p_slice_1, any_input(), any_input(), any_input(), any_input()}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto max_context_len = pattern_map.at(max_context_len_pattern).get_node_shared_ptr(); + auto max_context_len = pattern_map.at(p_max_context_len).get_node_shared_ptr(); if (max_context_len->get_friendly_name() != "max_context_len") { return false; } - - auto slice_1 = pattern_map.at(slice_1_pattern).get_node_shared_ptr(); - auto slice_2 = pattern_map.at(slice_2_pattern).get_node_shared_ptr(); - - auto gather = - std::make_shared(slice_1, position_ids, v0::Constant::create(element::i64, Shape{}, {1})); + auto rotary_emb_sincos = pattern_map.at(p_rotary_emb_sincos).get_node_shared_ptr(); + auto slice_1 = pattern_map.at(p_slice_1).get_node_shared_ptr(); + auto slice_2 = pattern_map.at(p_slice_2).get_node_shared_ptr(); + + auto axis = v0::Constant::create(element::i64, Shape{}, {1}); + // in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin + // are used not in the sequential order, so we need to use position_ids to get the expected values. + auto gather = std::make_shared(slice_1, position_ids, axis); gather->set_friendly_name(slice_2->get_friendly_name()); - auto axis = std::make_shared(element::i64, Shape{1}, 2); - auto squeeze = std::make_shared(gather, axis); + gather->validate_and_infer_types(); - auto reshape_shape = v0::Constant::create(element::i64, Shape{4}, {-1, 1, 1, 128}); - auto reshape = std::make_shared(squeeze, reshape_shape, false); + // PagedAttention expects the next layout for Q,K,V: + // [batch_size_in_tokens, num_kv_heads * head_size] + // so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch + // num_kv_heads * head_size are already handled in the StateManagementPattern transformation + auto head_size = static_cast(rotary_emb_sincos->get_shape()[3]); + auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector{-1, 1, 1, head_size}); + auto reshape = std::make_shared(gather, new_shape, false); replace_node(slice_2, reshape); - - gather->validate_and_infer_types(); return true; }; - auto m = std::make_shared(slice_2_pattern, matcher_name); + auto m = std::make_shared(p_slice_2, matcher_name); register_matcher(m, callback); } diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp index e0f6ee41249f07..eb485dc30d4292 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/total_sequence_length_pattern.cpp @@ -12,6 +12,7 @@ #include "openvino/op/reshape.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/or.hpp" @@ -19,20 +20,21 @@ #include "transformations/utils/utils.hpp" using namespace ov::op; +using namespace ov::pass::pattern; ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( const std::shared_ptr& max_context_len) { MATCHER_SCOPE(TotalSequenceLengthPattern); - auto kv_past = pattern::wrap_type({pattern::any_input()}); - auto kv_gather = pattern::wrap_type({kv_past, pattern::any_input(), pattern::any_input()}); - auto kv_current = pattern::any_input(); - auto kv_concat = pattern::wrap_type({kv_gather, kv_current}); - auto kv_shape = pattern::wrap_type({kv_concat}); - auto gather_idx_label = pattern::wrap_type(); - auto seq = pattern::wrap_type({kv_shape, gather_idx_label, pattern::any_input()}); + auto kv_past = wrap_type({any_input()}); + auto kv_gather = wrap_type({kv_past, any_input(), any_input()}); + auto kv_current = any_input(); + auto kv_concat = wrap_type({kv_gather, kv_current}); + auto kv_shape = wrap_type({kv_concat}); + auto gather_idx_label = wrap_type(); + auto seq = wrap_type({kv_shape, gather_idx_label, any_input()}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- // use symbolic infra or look at the constant input const auto& pattern_map = m.get_pattern_value_map(); @@ -100,7 +102,7 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern( return true; }; - auto m = std::make_shared(seq, matcher_name); + auto m = std::make_shared(seq, matcher_name); register_matcher(m, callback); } @@ -108,34 +110,32 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen( const std::shared_ptr& max_context_len) { MATCHER_SCOPE(TotalSequenceLengthPatternQwen); - auto prev_max_seq_len = pattern::wrap_type(); - auto opt_convert_1 = pattern::optional(prev_max_seq_len); - auto opt_reshape_1 = pattern::optional({opt_convert_1, pattern::any_input()}); + auto p_input_ids = wrap_type(); + auto p_unsqueeze = wrap_type({p_input_ids, any_input()}); + auto p_opt_reshape_2 = optional({p_unsqueeze, any_input()}); + auto p_opt_convert_2 = optional(p_opt_reshape_2); + auto p_kv_shape_current = wrap_type({p_opt_convert_2}); + auto p_seq_current = wrap_type({p_kv_shape_current, any_input(), any_input()}); - auto input_ids = pattern::wrap_type(); - auto unsqueeze = pattern::wrap_type({input_ids, pattern::any_input()}); - auto opt_reshape_2 = pattern::optional({unsqueeze, pattern::any_input()}); - auto opt_convert_2 = pattern::optional(opt_reshape_2); - auto kv_shape_current = pattern::wrap_type({opt_convert_2}); - auto seq_current = pattern::wrap_type({kv_shape_current, pattern::any_input(), pattern::any_input()}); + auto p_max_context_len = wrap_type(); + auto p_prev_max_seq_len = wrap_type({max_context_len, any_input()}); + auto p_opt_convert_1 = optional(p_prev_max_seq_len); + auto opt_reshape_1 = optional({p_opt_convert_1, p_seq_current}); - auto pattern_total_seq = pattern::wrap_type({seq_current, opt_reshape_1}); + auto p_total_seq = wrap_type({p_seq_current, opt_reshape_1}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { - // TODO: Check that seq has axis that really takes sequence len but not any other dimension -- - // use symbolic infra or look at the constant input + ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto total_seq = pattern_map.at(pattern_total_seq).get_node_shared_ptr(); + auto total_seq = pattern_map.at(p_total_seq).get_node_shared_ptr(); std::shared_ptr replacement = max_context_len; - auto target_type = total_seq->get_output_element_type(0); + auto target_type = total_seq->get_output_element_type(0); if (replacement->get_output_element_type(0) != target_type) { replacement = std::make_shared(replacement, target_type); } auto required_shape = total_seq->get_output_partial_shape(0); - if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1)); } @@ -144,6 +144,6 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen( return true; }; - auto m = std::make_shared(pattern_total_seq, matcher_name); + auto m = std::make_shared(p_total_seq, matcher_name); register_matcher(m, callback); } \ No newline at end of file