Skip to content

Commit

Permalink
Update TotalSequenceLength and PositionIds replacer patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Dec 17, 2024
1 parent 879ae7a commit 4764822
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass {
explicit PositionIDsReplacer(const Output<Node>& position_ids);
};

/**
* @brief Qwen model has a specific feature in the model structure not to use position_ids input,
* this input is detached. The model expects data processing in order.
*
* To use this model in Continuous batching mode, we need to apply position_ids and
* use the corresponding rotary_emb_cos/rotary_emb_sin.
* For this, we replace
* rotary_emb_cos/rotary_emb_sin -> Slice -> Slice
* With
* rotary_emb_cos/rotary_emb_sin -> Slice -> Gather(by position_ids)
*/
class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PositionIDsReplacerQwen", "0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
explicit TotalSequenceLengthPattern(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};

/**
* @brief Qwen model has a specific pattern for TotalSequenceLen place detection.
*
* common pattern: Add (PrevSeqLen, CurrentSeqLen)
*
* The CurrentSeqLen is presented in this form:
* CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather
*
* Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern
* and replaced it with the next subgraph
* PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen)
*
**/
class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TotalSequenceLengthPattern", "0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,72 +15,77 @@
#include "transformations/utils/utils.hpp"

using namespace ov::op;
using namespace ov::pass::pattern;

// TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when
// position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using
// a new slot_mapping parameter.
ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output<Node>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacer);

auto input_ids = pattern::any_input();
auto input_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), input_ids, pattern::any_input()});
auto input_ids = any_input();
auto input_embed = wrap_type<v8::Gather>({any_input(), input_ids, any_input()});

auto position_ids_pattern = pattern::any_input();
auto offset = pattern::wrap_type<v0::Constant>();
auto add_offset = pattern::wrap_type<v1::Add>({position_ids_pattern, offset});
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});
auto position_ids_pattern = any_input();
auto offset = wrap_type<v0::Constant>();
auto add_offset = wrap_type<v1::Add>({position_ids_pattern, offset});
auto convert = wrap_type<v0::Convert>({add_offset});
auto position_embed = wrap_type<v8::Gather>({any_input(), convert, any_input()});

auto mul = pattern::optional<v0::MatMul>({input_embed, pattern::any_input()});
auto mul = optional<v0::MatMul>({input_embed, any_input()});

auto add = pattern::wrap_type<v1::Add>({mul, position_embed});
auto add = wrap_type<v1::Add>({mul, position_embed});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids.get_node_shared_ptr());
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(add, matcher_name);
auto m = std::make_shared<Matcher>(add, matcher_name);
register_matcher(m, callback);
}

ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacerQwen);

auto max_context_len_pattern = pattern::wrap_type<v0::Parameter>();
auto optional_convert = pattern::optional<v0::Convert>(max_context_len_pattern);
auto optional_reshape = pattern::optional<v1::Reshape>({optional_convert, pattern::any_input()});
auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

auto slice_1_pattern = pattern::wrap_type<v8::Slice>(
{pattern::any_input(), pattern::any_input(), optional_reshape, pattern::any_input(), pattern::any_input()});
auto slice_2_pattern = pattern::wrap_type<v8::Slice>(
{slice_1_pattern, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()});
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, any_input(), p_opt_reshape, any_input(), any_input()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, any_input(), any_input(), any_input(), any_input()});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto max_context_len = pattern_map.at(max_context_len_pattern).get_node_shared_ptr();
auto max_context_len = pattern_map.at(p_max_context_len).get_node_shared_ptr();
if (max_context_len->get_friendly_name() != "max_context_len") {
return false;
}

auto slice_1 = pattern_map.at(slice_1_pattern).get_node_shared_ptr();
auto slice_2 = pattern_map.at(slice_2_pattern).get_node_shared_ptr();

auto gather =
std::make_shared<v8::Gather>(slice_1, position_ids, v0::Constant::create(element::i64, Shape{}, {1}));
auto rotary_emb_sincos = pattern_map.at(p_rotary_emb_sincos).get_node_shared_ptr();
auto slice_1 = pattern_map.at(p_slice_1).get_node_shared_ptr();
auto slice_2 = pattern_map.at(p_slice_2).get_node_shared_ptr();

auto axis = v0::Constant::create(element::i64, Shape{}, {1});
// in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin
// are used not in the sequential order, so we need to use position_ids to get the expected values.
auto gather = std::make_shared<v8::Gather>(slice_1, position_ids, axis);
gather->set_friendly_name(slice_2->get_friendly_name());
auto axis = std::make_shared<v0::Constant>(element::i64, Shape{1}, 2);
auto squeeze = std::make_shared<v0::Squeeze>(gather, axis);
gather->validate_and_infer_types();

auto reshape_shape = v0::Constant::create(element::i64, Shape{4}, {-1, 1, 1, 128});
auto reshape = std::make_shared<v1::Reshape>(squeeze, reshape_shape, false);
// PagedAttention expects the next layout for Q,K,V:
// [batch_size_in_tokens, num_kv_heads * head_size]
// so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch
// num_kv_heads * head_size are already handled in the StateManagementPattern transformation
auto head_size = static_cast<int64_t>(rotary_emb_sincos->get_shape()[3]);
auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{-1, 1, 1, head_size});
auto reshape = std::make_shared<v1::Reshape>(gather, new_shape, false);
replace_node(slice_2, reshape);

gather->validate_and_infer_types();
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(slice_2_pattern, matcher_name);
auto m = std::make_shared<Matcher>(p_slice_2, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,29 @@
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
using namespace ov::pass::pattern;

ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
const std::shared_ptr<ov::op::v0::Parameter>& max_context_len) {
MATCHER_SCOPE(TotalSequenceLengthPattern);

auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
auto kv_current = pattern::any_input();
auto kv_concat = pattern::wrap_type<v0::Concat>({kv_gather, kv_current});
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_concat});
auto gather_idx_label = pattern::wrap_type<v0::Constant>();
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, gather_idx_label, pattern::any_input()});
auto kv_past = wrap_type<v6::ReadValue>({any_input()});
auto kv_gather = wrap_type<v8::Gather>({kv_past, any_input(), any_input()});
auto kv_current = any_input();
auto kv_concat = wrap_type<v0::Concat>({kv_gather, kv_current});
auto kv_shape = wrap_type<v3::ShapeOf>({kv_concat});
auto gather_idx_label = wrap_type<v0::Constant>();
auto seq = wrap_type<v8::Gather>({kv_shape, gather_idx_label, any_input()});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension --
// use symbolic infra or look at the constant input
const auto& pattern_map = m.get_pattern_value_map();
Expand Down Expand Up @@ -100,42 +102,40 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
auto m = std::make_shared<Matcher>(seq, matcher_name);
register_matcher(m, callback);
}

ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen(
const std::shared_ptr<ov::op::v0::Parameter>& max_context_len) {
MATCHER_SCOPE(TotalSequenceLengthPatternQwen);

auto prev_max_seq_len = pattern::wrap_type<v0::Parameter>();
auto opt_convert_1 = pattern::optional<v0::Convert>(prev_max_seq_len);
auto opt_reshape_1 = pattern::optional<v1::Reshape>({opt_convert_1, pattern::any_input()});
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, any_input()});
auto p_opt_reshape_2 = optional<v1::Reshape>({p_unsqueeze, any_input()});
auto p_opt_convert_2 = optional<v0::Convert>(p_opt_reshape_2);
auto p_kv_shape_current = wrap_type<v3::ShapeOf>({p_opt_convert_2});
auto p_seq_current = wrap_type<v8::Gather>({p_kv_shape_current, any_input(), any_input()});

auto input_ids = pattern::wrap_type<v0::Parameter>();
auto unsqueeze = pattern::wrap_type<v0::Unsqueeze>({input_ids, pattern::any_input()});
auto opt_reshape_2 = pattern::optional<v1::Reshape>({unsqueeze, pattern::any_input()});
auto opt_convert_2 = pattern::optional<v0::Convert>(opt_reshape_2);
auto kv_shape_current = pattern::wrap_type<v3::ShapeOf>({opt_convert_2});
auto seq_current = pattern::wrap_type<v8::Gather>({kv_shape_current, pattern::any_input(), pattern::any_input()});
auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_prev_max_seq_len = wrap_type<v1::Subtract>({max_context_len, any_input()});
auto p_opt_convert_1 = optional<v0::Convert>(p_prev_max_seq_len);
auto opt_reshape_1 = optional<v1::Reshape>({p_opt_convert_1, p_seq_current});

auto pattern_total_seq = pattern::wrap_type<v1::Add>({seq_current, opt_reshape_1});
auto p_total_seq = wrap_type<v1::Add>({p_seq_current, opt_reshape_1});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension --
// use symbolic infra or look at the constant input
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto total_seq = pattern_map.at(pattern_total_seq).get_node_shared_ptr();
auto total_seq = pattern_map.at(p_total_seq).get_node_shared_ptr();

std::shared_ptr<Node> replacement = max_context_len;
auto target_type = total_seq->get_output_element_type(0);

auto target_type = total_seq->get_output_element_type(0);
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

auto required_shape = total_seq->get_output_partial_shape(0);

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
Expand All @@ -144,6 +144,6 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen(
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_total_seq, matcher_name);
auto m = std::make_shared<Matcher>(p_total_seq, matcher_name);
register_matcher(m, callback);
}

0 comments on commit 4764822

Please sign in to comment.