Skip to content

Commit

Permalink
resolve review comments; add a new unit tests; update transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Dec 23, 2024
1 parent 6380475 commit 033b0b5
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ class ov::pass::PositionIDsReplacer : public ov::pass::MatcherPass {
};

/**
* @brief Qwen model has a specific feature in the model structure not to use position_ids input,
* this input is detached. The model expects data processing in order.
* @brief Qwen model expects data processing in order, the "position ids" input is detached and
* is not explicitly used in the model. The model uses implicitly defined "position ids" based
* on the past KV cache size.
*
* To use this model in Continuous batching mode, we need to apply position_ids and
* use the corresponding rotary_emb_cos/rotary_emb_sin.
* For this, we replace
* rotary_emb_cos/rotary_emb_sin -> Slice -> Slice
* With
* rotary_emb_cos/rotary_emb_sin -> Slice -> Gather(by position_ids)
* rotary_emb_cos/rotary_emb_sin -> Gather(by position_ids)
* Which enables applying RoPE for each token independently of their order in the input tensor.
*/
class ov::pass::PositionIDsReplacerQwen : public ov::pass::MatcherPass {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#pragma once

#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/matcher_pass.hpp"
Expand All @@ -22,6 +21,8 @@ class TRANSFORMATIONS_API PrevSequenceLengthPattern;

class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim);
OPENVINO_MATCHER_PASS_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::Node>& unsqueezed_input_ids,
const std::shared_ptr<ov::Node>& max_context_len,
const std::shared_ptr<ov::Node>& position_ids);
};
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TRANSFORMATIONS_API TotalSequenceLengthPatternQwen;

class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TotalSequenceLengthPattern", "0");
OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0");
explicit TotalSequenceLengthPattern(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};

Expand All @@ -35,12 +35,12 @@ class ov::pass::TotalSequenceLengthPattern : public ov::pass::MatcherPass {
* CurrentSeqLen: Parameter(name: input_ids) -> ShapeOf -> Gather
*
* Before applying this transformation, we already detected the PrevSeqLen place in the PrevSequenceLengthPattern
* and replaced it with the next subgraph
* and replaced it with the next subgraph:
* PrevSeqLen: Subtract (in: Parameter(name: max_context_len), in: CurrentSeqLen)
*
**/
class ov::pass::TotalSequenceLengthPatternQwen : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TotalSequenceLengthPattern", "0");
OPENVINO_MATCHER_PASS_RTTI("TotalSequenceLengthPattern", "0");
explicit TotalSequenceLengthPatternQwen(const std::shared_ptr<ov::op::v0::Parameter>& max_context_len);
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
Expand Down Expand Up @@ -49,14 +52,27 @@ ov::pass::PositionIDsReplacer::PositionIDsReplacer(const Output<Node>& position_
ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& position_ids) {
MATCHER_SCOPE(PositionIDsReplacerQwen);

auto _const = []() {
return wrap_type<v0::Constant>();
};

// total seq len:
auto p_max_context_len = wrap_type<v0::Parameter>();
auto p_opt_convert = optional<v0::Convert>(p_max_context_len);
auto p_opt_reshape = optional<v1::Reshape>({p_opt_convert, any_input()});

// current seg len
auto p_input_ids = wrap_type<v0::Parameter>();
auto p_unsqueeze = wrap_type<v0::Unsqueeze>({p_input_ids, _const()});
auto p_shape_of = wrap_type<v3::ShapeOf>({p_unsqueeze});
auto p_current_len = wrap_type<v8::Gather>({p_shape_of, _const(), _const()});

auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, any_input(), p_opt_reshape, any_input(), any_input()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, any_input(), any_input(), any_input(), any_input()});
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -71,15 +87,20 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p
auto axis = v0::Constant::create(element::i64, Shape{}, {1});
// in case of PagedAttention (Continuous batching) the rotary_emb_cos/rotary_emb_sin
// are used not in the sequential order, so we need to use position_ids to get the expected values.
auto gather = std::make_shared<v8::Gather>(slice_1, position_ids, axis);
auto gather = std::make_shared<v8::Gather>(slice_1->input_value(0), position_ids, axis);
gather->set_friendly_name(slice_2->get_friendly_name());
gather->validate_and_infer_types();

auto pshape = rotary_emb_sincos->get_output_partial_shape(0);
if (pshape.rank().is_dynamic() || pshape.rank().get_length() != 4) {
return false;
}

// PagedAttention expects the next layout for Q,K,V:
// [batch_size_in_tokens, num_kv_heads * head_size]
// so here we need to reshape the output tensor to move the seq dim (num tokens) to the batch
// num_kv_heads * head_size are already handled in the StateManagementPattern transformation
auto head_size = static_cast<int64_t>(rotary_emb_sincos->get_shape()[3]);
auto head_size = static_cast<int64_t>(pshape[3].get_length());
auto new_shape = v0::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{-1, 1, 1, head_size});
auto reshape = std::make_shared<v1::Reshape>(gather, new_shape, false);
replace_node(slice_2, reshape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

using namespace ov::op;

ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(const std::shared_ptr<ov::Node>& unsqueezed_input_ids,
const std::shared_ptr<ov::Node>& max_context_len,
const std::shared_ptr<ov::Node>& position_ids) {
MATCHER_SCOPE(PrevSequenceLengthPattern);
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
Expand All @@ -40,8 +41,16 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<o
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<ov::Node> replacement;
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
auto cur_seq_len = std::make_shared<v8::Gather>(std::make_shared<v3::ShapeOf>(unsqueezed_input_ids),
v0::Constant::create(element::i64, Shape{}, {1}),
v0::Constant::create(element::i64, Shape{}, {0}));
auto cur_seq_len_i32 = std::make_shared<v0::Convert>(cur_seq_len, element::i32);
auto prev_max_seq_len = std::make_shared<v1::Subtract>(max_context_len, cur_seq_len_i32);
replacement = prev_max_seq_len;
} else {
// it is not always required, so will be disposed if not needed
auto batch_dim = std::make_shared<v3::ShapeOf>(position_ids);

// assumption that any other axis should point to batch dimension, precise reasoning is too complex
// TODO: provide more reliable check
replacement = batch_dim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@
using namespace ov::op;
using namespace ov::pass::pattern;

namespace {

void align_replacement(std::shared_ptr<ov::Node>& replacement,
const ov::PartialShape& required_shape,
ov::element::Type target_type) {
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = ov::op::util::reshapeTo(replacement, ov::Shape(required_shape.rank().get_length(), 1));
}
}

} // namespace

ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
const std::shared_ptr<ov::op::v0::Parameter>& max_context_len) {
MATCHER_SCOPE(TotalSequenceLengthPattern);
Expand Down Expand Up @@ -77,16 +93,8 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(

if (concat_axis_to_compare == gather_idx_to_compare) {
auto target_type = gather->get_output_element_type(0);

if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

auto required_shape = gather->get_output_partial_shape(0);

if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
align_replacement(replacement, required_shape, target_type);
} else {
// TODO: change in the future when we start supporting dynamic shapes here
replacement = ov::util::get_constant_from_source(gather->output(0));
Expand Down Expand Up @@ -128,19 +136,13 @@ ov::pass::TotalSequenceLengthPatternQwen::TotalSequenceLengthPatternQwen(
std::shared_ptr<Node> replacement = max_context_len;

auto target_type = total_seq->get_output_element_type(0);
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}

auto required_shape = total_seq->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
align_replacement(replacement, required_shape, target_type);

replace_node(total_seq, replacement);
return true;
};

auto m = std::make_shared<Matcher>(p_total_seq, matcher_name);
register_matcher(m, callback);
}
}
Loading

0 comments on commit 033b0b5

Please sign in to comment.