diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 143603f0415373..f002e0043a8744 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -549,7 +549,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto cat_Concat_505 = makePattern({flatten_Reshape_501, slice_Slice_443 | var_split_1->output(1)}, {{"axis", -1}}); - auto result = cat_Concat_505; + auto result = cat_Concat_505 | flatten_Reshape_501; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -577,6 +577,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s config.slice_stop = static_cast(config.slice_start + validator["total_size_k"]); } + if (ov::is_type(root)) { + if (config.rotary_ndims != config.head_size) + return false; + } + new_args.push_back(pattern_map.at(qkv_linear)); new_args.push_back(pattern_map.at(cos_sin_cache)); new_args.push_back(pattern_map.at(cos_sin_cache)); @@ -585,9 +590,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto new_node = std::make_shared(new_args, config); new_node->set_friendly_name(old_node->get_friendly_name()); - ov::copy_runtime_info({pattern_map.at(flatten_Reshape_501).get_node_shared_ptr(), - pattern_map.at(cat_Concat_505).get_node_shared_ptr()}, - new_node); + ov::copy_runtime_info({root->get_input_node_shared_ptr(0), root}, new_node); ov::replace_node(old_node, new_node); return true; }; diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 6eb0add525c815..ea928de5c01702 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -854,4 +854,112 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_2d_rope) { model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin_cache, position_ids}); } +} + +TEST_F(TransformationTestsF, ConvertToROPE_chatGML_nano_2d_rope) { + disable_rt_info_check(); + const int batch = 2; + const int seq_len = 7; + const int num_heads = 16; + const int ndims = 128; + const int rotary_ndims = 128; + const int max_pos_length = 2048; + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{batch, seq_len, 3072}); + auto cos_sin_cache = + std::make_shared(ov::element::f32, + ov::PartialShape{max_pos_length, (rotary_ndims / 2), 2}); + auto position_ids = std::make_shared(ov::element::i32, ov::PartialShape{batch, seq_len}); + + auto __module_transformer_index_67_Gather = + makeOP({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}}); + + auto ListUnpack_321 = makeOP({input, -1, {2048, 512, 512}}); + auto view_Reshape = makeOP({ListUnpack_321->output(0), {0, 0, num_heads, ndims}}, + {{"special_zero", true}}); + + auto permute_Transpose = makeOP({view_Reshape, {0, 2, 1, 3}}, {}); + + auto slice_Slice_357 = + makeOP({permute_Transpose, {0, 0, 0, 0}, {0, 0, 0, rotary_ndims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + auto aten_view_Reshape_1 = + makeOP({ListUnpack_321->output(1), {0, 0, 2, ndims}}, {{"special_zero", true}}); + auto aten_transpose_1 = makeOP({aten_view_Reshape_1, {0, 2, 1, 3}}); + auto shape_of_105249 = makeOP({aten_transpose_1}, {{"output_type", "i32"}}); + auto gather_105252 = makeOP({shape_of_105249, {2}, {0}}, {{"batch_dims", 0}}); + auto scatter_update_63441 = makeOP({{0, 0}, {1}, gather_105252, {0}}); + // connected to cos_sin_cache + auto slice_Slice_369 = makeOP( + {__module_transformer_index_67_Gather, {0, 0}, scatter_update_63441, {1, 1}}, + {{"begin_mask", {1, 0}}, + {"end_mask", {1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto list_construct_concat_1 = + makeOP({{-1}, {1}, gather_105252, {rotary_ndims / 2}, {2}}, {{"axis", 0}}); + + auto reshape_Reshape_373 = + makeOP({slice_Slice_357, {0, 16, 0, 64, 2}}, {{"special_zero", true}}); + auto select_Gather_384 = + makeOP({reshape_Reshape_373, 0, -1}, {{"batch_dims", 0}}); // x_even + auto select_Gather_381 = + makeOP({reshape_Reshape_373, 1, -1}, {{"batch_dims", 0}}); // x_odd + auto view_Reshape_380 = + makeOP({slice_Slice_369, list_construct_concat_1}, {{"special_zero", false}}); + auto select_Gather_385 = makeOP({view_Reshape_380, 0, -1}, {{"batch_dims", 0}}); // cos_tab + auto select_Gather_382 = makeOP({view_Reshape_380, 1, -1}, {{"batch_dims", 0}}); // sin_tab + + auto mul_Multiply_386 = makeOP({select_Gather_381, select_Gather_382}, + {{"auto_broadcast", "numpy"}}); // x_odd_sin + auto mul_Multiply_383 = makeOP({select_Gather_384, select_Gather_385}, + {{"auto_broadcast", "numpy"}}); // x_even_cos + auto Multiply_101315 = + makeOP({mul_Multiply_386, -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto sub_Subtract_389 = + makeOP({mul_Multiply_383, Multiply_101315}, {{"auto_broadcast", "numpy"}}); + + auto mul_Multiply_391 = makeOP({select_Gather_381, select_Gather_385}, + {{"auto_broadcast", "numpy"}}); // x_odd_cos + auto mul_Multiply_393 = makeOP({select_Gather_384, select_Gather_382}, + {{"auto_broadcast", "numpy"}}); // x_even_sin + auto add_Add_396 = makeOP({mul_Multiply_391, mul_Multiply_393}, {{"auto_broadcast", "numpy"}}); + + auto Unsqueeze_62716 = makeOP({sub_Subtract_389, -1}, {}); + auto Unsqueeze_62717 = makeOP({add_Add_396, -1}, {}); + + auto stack_401 = makeOP({Unsqueeze_62716, Unsqueeze_62717}, {{"axis", -1}}); + auto flatten_Reshape_421 = + makeOP({stack_401, {0, num_heads, 0, rotary_ndims}}, {{"special_zero", true}}); + model = std::make_shared(ov::NodeVector{flatten_Reshape_421}, + ov::ParameterVector{input, cos_sin_cache, position_ids}); + } + manager.register_pass(true); + { + auto input = std::make_shared(ov::element::f32, ov::Shape{batch, seq_len, 3072}); + auto cos_sin_cache = + std::make_shared(ov::element::f32, ov::Shape{max_pos_length, (rotary_ndims / 2), 2}); + auto position_ids = std::make_shared(ov::element::i32, ov::PartialShape{batch, seq_len}); + auto gather_cos_sin = makeOP({cos_sin_cache, position_ids, 0}, {{"batch_dims", 0}}); + auto rope = makeOP({input, gather_cos_sin, gather_cos_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 2048}, + {"config.input_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", true}, + {"config.support_2d_rope", true}, + {"config.is_qwen", false}, + {"config.head_cnt", num_heads}, + {"config.head_size", ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin_cache, position_ids}); + } } \ No newline at end of file