Skip to content

Commit

Permalink
Merge branch 'master' into revert
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Dec 24, 2024
2 parents ffc9b28 + 9d78056 commit 0c7febc
Show file tree
Hide file tree
Showing 13 changed files with 845 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API PositionIDsReplacer;
class TRANSFORMATIONS_API PositionIDsReplacerQwen;

} // namespace pass
} // namespace ov
Expand All @@ -24,3 +25,22 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass {
OPENVINO_MATCHER_PASS_RTTI("PositionIDsReplacer");
explicit PositionIDsReplacer(const Output<Node>& position_ids);
};

/**
* @brief Qwen model expects data processing in order, the "position ids" input is detached and
* is not explicitly used in the model. The model uses implicitly defined "position ids" based
* on the past KV cache size.
*
* To use this model in Continuous batching mode, we need to apply position_ids and
* use the corresponding rotary_emb_cos/rotary_emb_sin.
* For this, we replace
* rotary_emb_cos/rotary_emb_sin -> Slice -> Slice
* With
* rotary_emb_cos/rotary_emb_sin -> Gather(by position_ids)
* Which enables applying RoPE for each token independently of their order in the input tensor.
*/
class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("PositionIDsReplacerQwen");
explicit PositionIDsReplacerQwen(const Output<Node>& position_ids);
};
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/matcher_pass.hpp"
Expand All @@ -22,6 +21,8 @@ class TRANSFORMATIONS_API PrevSequenceLengthPattern;

class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern");
explicit PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim);
OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::Node>& unsqueezed_input_ids,
const std::shared_ptr<ov::Node>& max_context_len,
const std::shared_ptr<ov::Node>& position_ids);
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API TotalSequenceLengthPattern;
class TRANSFORMATIONS_API TotalSequenceLengthPatternQwen;

} // namespace pass
} // namespace ov
Expand All @@ -24,3 +25,22 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern");
explicit TotalSequenceLengthPattern(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};

/**
* @brief Qwen model has a specific pattern for TotalSequenceLen place detection.
*
* common pattern: Add (PrevSeqLen, CurrentSeqLen)
*
* The CurrentSeqLen is presented in this form:
* CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather
*
* Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern
* and replaced it with the next subgraph:
* PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen)
*
**/
class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0");
explicit TotalSequenceLengthPatternQwen(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,11 @@ class AttrSetter : public ov::AttributeVisitor {
a->set(m_attr_map[name].as_vector<int64_t>());
} else if (auto a = ov::as_type<ov::AttributeAdapter<ov::element::TypeVector>>(&adapter)) {
a->set(m_attr_map[name].as_T_vector<ov::element::Type>());
} else if (auto a = dynamic_cast<ov::AttributeAdapter<std::shared_ptr<ov::op::util::Variable>>*>(&adapter)) {
ov::op::util::VariableInfo var_info;
var_info.variable_id = m_attr_map[name].as_string();
auto variable = std::make_shared<ov::op::util::Variable>(var_info);
a->set(variable);
} else {
OPENVINO_THROW("unsupported AttributeAdapter for attribute : ", name);
}
Expand Down Expand Up @@ -896,6 +901,7 @@ struct PatternNode {
// scalar constant (treated as wildcard for single-element-constant with any rank)
PatternNode(int v) : node(std::make_shared<ov::op::v0::Constant>(element::from<int>(), Shape({}), v)) {}
PatternNode(float v) : node(std::make_shared<ov::op::v0::Constant>(element::from<float>(), Shape({}), v)) {}
PatternNode(long long v) : node(std::make_shared<ov::op::v0::Constant>(element::from<int64_t>(), Shape({}), v)) {}

PatternNode(std::initializer_list<int> v, values_info vi = nullptr) {
node = ConstVector(std::vector<int>(v), vi);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/core/model.hpp"
#include "openvino/core/node.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/utils/utils.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,106 @@
#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
using namespace ov::pass::pattern;

// TODO: Instead of using the following transformation that matches quite a specific place in a model graph in case when
// position_ids parameter is missing, consider replacing always existing attention_mask parameter with a sub-graph using
// a new slot_mapping parameter.
ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output<Node>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacer);

auto input_ids = pattern::any_input();
auto input_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), input_ids, pattern::any_input()});
auto input_ids = any_input();
auto input_embed = wrap_type<v8::Gather>({any_input(), input_ids, any_input()});

auto position_ids_pattern = pattern::any_input();
auto offset = pattern::wrap_type<v0::Constant>();
auto add_offset = pattern::wrap_type<v1::Add>({position_ids_pattern, offset});
auto convert = pattern::wrap_type<v0::Convert>({add_offset});
auto position_embed = pattern::wrap_type<v8::Gather>({pattern::any_input(), convert, pattern::any_input()});
auto position_ids_pattern = any_input();
auto offset = wrap_type<v0::Constant>();
auto add_offset = wrap_type<v1::Add>({position_ids_pattern, offset});
auto convert = wrap_type<v0::Convert>({add_offset});
auto position_embed = wrap_type<v8::Gather>({any_input(), convert, any_input()});

auto mul = pattern::optional<v0::MatMul>({input_embed, pattern::any_input()});
auto mul = optional<v0::MatMul>({input_embed, any_input()});

auto add = pattern::wrap_type<v1::Add>({mul, position_embed});
auto add = wrap_type<v1::Add>({mul, position_embed});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
replace_node(pattern_map.at(position_ids_pattern).get_node_shared_ptr(), position_ids.get_node_shared_ptr());
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(add, matcher_name);
auto m = std::make_shared<Matcher>(add, matcher_name);
register_matcher(m, callback);
}
}

ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacerQwen);

auto _const = []() {
return wrap_type<v0::Constant>();
};

// total seq len:
auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

// current seg len
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, _const()});
auto p_shape_of = wrap_type<v3::ShapeOf>({p_unsqueeze});
auto p_current_len = wrap_type<v8::Gather>({p_shape_of, _const(), _const()});

auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto max_context_len = pattern_map.at(p_max_context_len).get_node_shared_ptr();
if (max_context_len->get_friendly_name() != "max_context_len") {
return false;
}
auto rotary_emb_sincos = pattern_map.at(p_rotary_emb_sincos).get_node_shared_ptr();
auto slice_1 = pattern_map.at(p_slice_1).get_node_shared_ptr();
auto slice_2 = pattern_map.at(p_slice_2).get_node_shared_ptr();

auto axis = v0::Constant::create(element::i64, Shape{}, {1});
// in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin
// are used not in the sequential order, so we need to use position_ids to get the expected values.
auto gather = std::make_shared<v8::Gather>(slice_1->input_value(0), position_ids, axis);
gather->set_friendly_name(slice_2->get_friendly_name());
gather->validate_and_infer_types();

auto pshape = rotary_emb_sincos->get_output_partial_shape(0);
if (pshape.rank().is_dynamic() || pshape.rank().get_length() != 4) {
return false;
}

// PagedAttention expects the next layout for Q,K,V:
// [batch_size_in_tokens, num_kv_heads * head_size]
// so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch
// num_kv_heads * head_size are already handled in the StateManagementPattern transformation
auto head_size = static_cast<int64_t>(pshape[3].get_length());
auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{-1, 1, 1, head_size});
auto reshape = std::make_shared<v1::Reshape>(gather, new_shape, false);
replace_node(slice_2, reshape);
return true;
};

auto m = std::make_shared<Matcher>(p_slice_2, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

using namespace ov::op;

ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(const std::shared_ptr<ov::Node>& unsqueezed_input_ids,
const std::shared_ptr<ov::Node>& max_context_len,
const std::shared_ptr<ov::Node>& position_ids) {
MATCHER_SCOPE(PrevSequenceLengthPattern);
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
Expand All @@ -40,8 +41,16 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<o
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<ov::Node> replacement;
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
auto cur_seq_len = std::make_shared<v8::Gather>(std::make_shared<v3::ShapeOf>(unsqueezed_input_ids),
v0::Constant::create(element::i64, Shape{}, {1}),
v0::Constant::create(element::i64, Shape{}, {0}));
auto cur_seq_len_i32 = std::make_shared<v0::Convert>(cur_seq_len, element::i32);
auto prev_max_seq_len = std::make_shared<v1::Subtract>(max_context_len, cur_seq_len_i32);
replacement = prev_max_seq_len;
} else {
// it is not always required, so will be disposed if not needed
auto batch_dim = std::make_shared<v3::ShapeOf>(position_ids);

// assumption that any other axis should point to batch dimension, precise reasoning is too complex
// TODO: provide more reliable check
replacement = batch_dim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
parameters_to_remove.push_back(param);
}

pa_transpose->set_friendly_name(sdpa_node->get_friendly_name());
replace_node(m.get_match_root(), pa_transpose);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,49 @@

#include "openvino/cc/pass/itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
using namespace ov::pass::pattern;

namespace {

void align_replacement(std::shared_ptr<ov::Node>& replacement,
const ov::PartialShape& required_shape,
ov::element::Type target_type) {
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = ov::op::util::reshapeTo(replacement, ov::Shape(required_shape.rank().get_length(), 1));
}
}

} // namespace

ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
const std::shared_ptr<ov::op::v0::Parameter>& max_context_len) {
MATCHER_SCOPE(TotalSequenceLengthPattern);

auto kv_past = pattern::wrap_type<v6::ReadValue>({pattern::any_input()});
auto kv_gather = pattern::wrap_type<v8::Gather>({kv_past, pattern::any_input(), pattern::any_input()});
auto kv_current = pattern::any_input();
auto kv_concat = pattern::wrap_type<v0::Concat>({kv_gather, kv_current});
auto kv_shape = pattern::wrap_type<v3::ShapeOf>({kv_concat});
auto gather_idx_label = pattern::wrap_type<v0::Constant>();
auto seq = pattern::wrap_type<v8::Gather>({kv_shape, gather_idx_label, pattern::any_input()});
auto kv_past = wrap_type<v6::ReadValue>({any_input()});
auto kv_gather = wrap_type<v8::Gather>({kv_past, any_input(), any_input()});
auto kv_current = any_input();
auto kv_concat = wrap_type<v0::Concat>({kv_gather, kv_current});
auto kv_shape = wrap_type<v3::ShapeOf>({kv_concat});
auto gather_idx_label = wrap_type<v0::Constant>();
auto seq = wrap_type<v8::Gather>({kv_shape, gather_idx_label, any_input()});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
// TODO: Check that seq has axis that really takes sequence len but not any other dimension --
// use symbolic infra or look at the constant input
const auto& pattern_map = m.get_pattern_value_map();
Expand Down Expand Up @@ -71,16 +93,8 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(

if (concat_axis_to_compare == gather_idx_to_compare) {
auto target_type = gather->get_output_element_type(0);

if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

auto required_shape = gather->get_output_partial_shape(0);

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
align_replacement(replacement, required_shape, target_type);
} else {
// TODO: change in the future when we start supporting dynamic shapes here
replacement = ov::util::get_constant_from_source(gather->output(0));
Expand All @@ -94,6 +108,41 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
auto m = std::make_shared<Matcher>(seq, matcher_name);
register_matcher(m, callback);
}

ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen(
const std::shared_ptr<ov::op::v0::Parameter>& max_context_len) {
MATCHER_SCOPE(TotalSequenceLengthPatternQwen);

auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, any_input()});
auto p_opt_reshape_1 = optional<v1::Reshape>({p_unsqueeze, any_input()});
auto p_opt_convert_1 = optional<v0::Convert>(p_opt_reshape_1);
auto p_kv_shape_current = wrap_type<v3::ShapeOf>({p_opt_convert_1});
auto p_seq_current = wrap_type<v8::Gather>({p_kv_shape_current, any_input(), any_input()});
auto p_opt_convert_2 = optional<v0::Convert>(p_seq_current);

auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_prev_max_seq_len = wrap_type<v1::Subtract>({p_max_context_len, any_input()});
auto p_opt_convert_3 = optional<v0::Convert>(p_prev_max_seq_len);
auto p_opt_reshape_2 = optional<v1::Reshape>({p_opt_convert_3, any_input()});
auto p_total_seq = wrap_type<v1::Add>({p_opt_convert_2, p_opt_reshape_2});

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto total_seq = pattern_map.at(p_total_seq).get_node_shared_ptr();
std::shared_ptr<Node> replacement = max_context_len;

auto target_type = total_seq->get_output_element_type(0);
auto required_shape = total_seq->get_output_partial_shape(0);
align_replacement(replacement, required_shape, target_type);

replace_node(total_seq, replacement);
return true;
};

auto m = std::make_shared<Matcher>(p_total_seq, matcher_name);
register_matcher(m, callback);
}
}
Loading

0 comments on commit 0c7febc

Please sign in to comment.