Skip to content

Commit

Permalink
fix the test: apply changes from the current master
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Dec 20, 2024
1 parent 415a16a commit 6380475
Showing 1 changed file with 4 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ class Qwen7bChatPA {
}

static std::shared_ptr<Node> gen_Q(const std::shared_ptr<Node>& total_seq_len,
const std::shared_ptr<Node>& rope_Q,
std::shared_ptr<Node>& scale) {
const std::shared_ptr<Node>& rope_Q) {
auto Constant_463 = makeConst(element::f32, {1, 32767, 1, 1}, MOCK_VALUE);
auto ShapeOf_489 = makeOP<opset3::ShapeOf>({rope_Q}, {{"output_type", "i32"}});
auto Gather_492 = makeOP<v8::Gather>({ShapeOf_489, {1}, 0ll}, {{"batch_dims", 0}});
Expand All @@ -390,12 +389,6 @@ class Qwen7bChatPA {
auto Multiply_631 = makeOP<v1::Multiply>({rope_Q, Slice_496}, {numpy_broadcast});
auto Transpose_633 = makeOP<v1::Transpose>({Multiply_631, {0, 2, 1, 3}});

auto ShapeOf_1238 = makeOP<opset3::ShapeOf>({Transpose_633}, {{"output_type", "i64"}});
auto Gather_1241 = makeOP<v8::Gather>({ShapeOf_1238, -1ll, 0ll}, {{"batch_dims", 0}});
auto Convert_1242 = makeOP<v0::Convert>({Gather_1241}, {dest_type_f32});
auto Sqrt_1243 = makeOP<v0::Sqrt>({Convert_1242});
scale = makeOP<v1::Divide>({1.000000f, Sqrt_1243}, {numpy_broadcast, {"m_pythondiv", true}});

auto Transpose_1223 = makeOP<v1::Transpose>({Transpose_633, {0, 2, 1, 3}});
return makeOP<v1::Reshape>({Transpose_1223, {0, -1}}, {special_zero_true});
}
Expand Down Expand Up @@ -506,15 +499,16 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
auto total_seq_len = Qwen7bChatPA::gen_total_len(current_seq_len, past_seq_len);

// Q, K, V:
shared_ptr<Node> scale;

shared_ptr<Node> head_size_2;
auto Q = Qwen7bChatPA::gen_Q(total_seq_len, rope_Q, scale);
auto Q = Qwen7bChatPA::gen_Q(total_seq_len, rope_Q);
auto K = Qwen7bChatPA::gen_K(rope_K);
auto V = Qwen7bChatPA::gen_V(qkv_proj, head_size_2);

// Additional PA arguments:
auto sliding_window = std::make_shared<v0::Constant>(element::i32, Shape{}, 0);
auto alibi_slopes = std::make_shared<v0::Constant>(element::f32, Shape{0});
auto scale = std::make_shared<v0::Constant>(element::f32, Shape{}, MOCK_VALUE);

// PagedAttention:
auto pa = std::make_shared<op::PagedAttentionExtension>(OutputVector{Q,
Expand Down

0 comments on commit 6380475

Please sign in to comment.