Skip to content

Commit

Permalink
[GPU] Support rope for glm4v (openvinotoolkit#27545)
Browse files Browse the repository at this point in the history
support rope kernel for glm4v

**Tickets:** CVS-157422

---------

Co-authored-by: Chen Peter <[email protected]>
Co-authored-by: Xiake Sun <[email protected]>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent c542f21 commit edcd4d8
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
auto cat_Concat_505 =
makePattern<opset1::Concat>({flatten_Reshape_501, slice_Slice_443 | var_split_1->output(1)}, {{"axis", -1}});

auto result = cat_Concat_505;
auto result = cat_Concat_505 | flatten_Reshape_501;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down Expand Up @@ -577,6 +577,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
config.slice_stop = static_cast<size_t>(config.slice_start + validator["total_size_k"]);
}

if (ov::is_type<opset1::Reshape>(root)) {
if (config.rotary_ndims != config.head_size)
return false;
}

new_args.push_back(pattern_map.at(qkv_linear));
new_args.push_back(pattern_map.at(cos_sin_cache));
new_args.push_back(pattern_map.at(cos_sin_cache));
Expand All @@ -585,9 +590,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s

auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(flatten_Reshape_501).get_node_shared_ptr(),
pattern_map.at(cat_Concat_505).get_node_shared_ptr()},
new_node);
ov::copy_runtime_info({root->get_input_node_shared_ptr(0), root}, new_node);
ov::replace_node(old_node, new_node);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -854,4 +854,112 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_2d_rope) {
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin_cache, position_ids});
}
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGML_nano_2d_rope) {
disable_rt_info_check();
const int batch = 2;
const int seq_len = 7;
const int num_heads = 16;
const int ndims = 128;
const int rotary_ndims = 128;
const int max_pos_length = 2048;
{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{batch, seq_len, 3072});
auto cos_sin_cache =
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::PartialShape{max_pos_length, (rotary_ndims / 2), 2});
auto position_ids = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::PartialShape{batch, seq_len});

auto __module_transformer_index_67_Gather =
makeOP<ov::opset8::Gather>({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}});

auto ListUnpack_321 = makeOP<ov::opset1::VariadicSplit>({input, -1, {2048, 512, 512}});
auto view_Reshape = makeOP<ov::opset1::Reshape>({ListUnpack_321->output(0), {0, 0, num_heads, ndims}},
{{"special_zero", true}});

auto permute_Transpose = makeOP<ov::opset1::Transpose>({view_Reshape, {0, 2, 1, 3}}, {});

auto slice_Slice_357 =
makeOP<ov::opset1::StridedSlice>({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});

auto aten_view_Reshape_1 =
makeOP<ov::opset1::Reshape>({ListUnpack_321->output(1), {0, 0, 2, ndims}}, {{"special_zero", true}});
auto aten_transpose_1 = makeOP<ov::opset8::Transpose>({aten_view_Reshape_1, {0, 2, 1, 3}});
auto shape_of_105249 = makeOP<ov::opset8::ShapeOf>({aten_transpose_1}, {{"output_type", "i32"}});
auto gather_105252 = makeOP<ov::opset8::Gather>({shape_of_105249, {2}, {0}}, {{"batch_dims", 0}});
auto scatter_update_63441 = makeOP<ov::opset8::ScatterUpdate>({{0, 0}, {1}, gather_105252, {0}});
// connected to cos_sin_cache
auto slice_Slice_369 = makeOP<ov::opset1::StridedSlice>(
{__module_transformer_index_67_Gather, {0, 0}, scatter_update_63441, {1, 1}},
{{"begin_mask", {1, 0}},
{"end_mask", {1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto list_construct_concat_1 =
makeOP<ov::opset1::Concat>({{-1}, {1}, gather_105252, {rotary_ndims / 2}, {2}}, {{"axis", 0}});

auto reshape_Reshape_373 =
makeOP<ov::opset1::Reshape>({slice_Slice_357, {0, 16, 0, 64, 2}}, {{"special_zero", true}});
auto select_Gather_384 =
makeOP<ov::opset8::Gather>({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); // x_even
auto select_Gather_381 =
makeOP<ov::opset8::Gather>({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); // x_odd
auto view_Reshape_380 =
makeOP<ov::opset1::Reshape>({slice_Slice_369, list_construct_concat_1}, {{"special_zero", false}});
auto select_Gather_385 = makeOP<ov::opset8::Gather>({view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); // cos_tab
auto select_Gather_382 = makeOP<ov::opset8::Gather>({view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); // sin_tab

auto mul_Multiply_386 = makeOP<ov::opset1::Multiply>({select_Gather_381, select_Gather_382},
{{"auto_broadcast", "numpy"}}); // x_odd_sin
auto mul_Multiply_383 = makeOP<ov::opset1::Multiply>({select_Gather_384, select_Gather_385},
{{"auto_broadcast", "numpy"}}); // x_even_cos
auto Multiply_101315 =
makeOP<ov::opset1::Multiply>({mul_Multiply_386, -1.000000f}, {{"auto_broadcast", "numpy"}});
auto sub_Subtract_389 =
makeOP<ov::opset1::Add>({mul_Multiply_383, Multiply_101315}, {{"auto_broadcast", "numpy"}});

auto mul_Multiply_391 = makeOP<ov::opset1::Multiply>({select_Gather_381, select_Gather_385},
{{"auto_broadcast", "numpy"}}); // x_odd_cos
auto mul_Multiply_393 = makeOP<ov::opset1::Multiply>({select_Gather_384, select_Gather_382},
{{"auto_broadcast", "numpy"}}); // x_even_sin
auto add_Add_396 = makeOP<ov::opset1::Add>({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}});

auto Unsqueeze_62716 = makeOP<ov::opset1::Unsqueeze>({sub_Subtract_389, -1}, {});
auto Unsqueeze_62717 = makeOP<ov::opset1::Unsqueeze>({add_Add_396, -1}, {});

auto stack_401 = makeOP<ov::opset1::Concat>({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}});
auto flatten_Reshape_421 =
makeOP<ov::opset1::Reshape>({stack_401, {0, num_heads, 0, rotary_ndims}}, {{"special_zero", true}});
model = std::make_shared<ov::Model>(ov::NodeVector{flatten_Reshape_421},
ov::ParameterVector{input, cos_sin_cache, position_ids});
}
manager.register_pass<ov::pass::RoPEFusion>(true);
{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{batch, seq_len, 3072});
auto cos_sin_cache =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{max_pos_length, (rotary_ndims / 2), 2});
auto position_ids = std::make_shared<ov::opset1::Parameter>(ov::element::i32, ov::PartialShape{batch, seq_len});
auto gather_cos_sin = makeOP<ov::opset8::Gather>({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}});
auto rope = makeOP<ov::op::internal::RoPE>({input, gather_cos_sin, gather_cos_sin},
{{"config.slice_start", 0},
{"config.slice_stop", 2048},
{"config.input_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", true},
{"config.support_2d_rope", true},
{"config.is_qwen", false},
{"config.head_cnt", num_heads},
{"config.head_size", ndims},
{"config.gather_position_arg_id", 0}});
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin_cache, position_ids});
}
}

0 comments on commit edcd4d8

Please sign in to comment.