diff --git a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp index 7bd50daea94c88..e84d66ad1bcc45 100644 --- a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp @@ -68,11 +68,6 @@ RMSFusion::RMSFusion() { auto gamma = wrap_type(type_matches(element::f32)); auto mul2 = wrap_type({gamma, mul1}); - // compress RMS result - auto convert = wrap_type({mul2}); - - auto comp = std::make_shared(OutputVector{mul2, convert}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto node = m.get_match_root(); @@ -80,6 +75,18 @@ RMSFusion::RMSFusion() { return false; } + auto find_opt_convert = [&](const ov::Output& out) -> std::shared_ptr { + auto present_to = out.get_target_inputs(); + if (present_to.size() == 1) { + auto convert_raw = dynamic_cast(present_to.begin()->get_node()); + if (convert_raw) + return convert_raw->shared_from_this(); + return nullptr; + } + // if multiple children, skip finding convert even there is one. + return nullptr; + }; + auto x_output = pattern_map.at(x); auto const_eps_node = @@ -101,16 +108,22 @@ RMSFusion::RMSFusion() { return false; } - auto output_type = m.get_match_root()->get_output_element_type(0); + auto root = m.get_match_root(); + // compress RMS result + auto convert = find_opt_convert(root); + if (convert) + root = convert; + + auto output_type = root->get_output_element_type(0); auto rms = std::make_shared(x_output, gamma_node, eps_value, output_type); - rms->set_friendly_name(m.get_match_root()->get_friendly_name()); + rms->set_friendly_name(root->get_friendly_name()); ov::copy_runtime_info(m.get_matched_nodes(), rms); - ov::replace_node(m.get_match_root(), rms); + ov::replace_node(root, rms); return true; }; - auto m = std::make_shared(comp, "RMSFusion"); + auto m = std::make_shared(mul2, "RMSFusion"); this->register_matcher(m, callback); } diff --git a/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp b/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp index 579f5f56114dcf..888dc872a55b1d 100644 --- a/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp +++ b/src/common/transformations/tests/common_optimizations/rms_norm_decomposition_test.cpp @@ -159,3 +159,36 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) { model_ref = std::make_shared(ov::NodeVector{rms}, ov::ParameterVector{input}); } } + +// no convert at the end of the subgraph +TEST_F(TransformationTestsF, RMSNormFusionTest6) { + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 6}); + auto power_const = ov::opset10::Constant::create(ov::element::f32, {}, {2.f}); + auto power = std::make_shared(input, power_const); + auto mean_axes = ov::opset10::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto mean = std::make_shared(power, mean_axes, true); + auto eps = ov::opset10::Constant::create(ov::element::f32, {}, {1e-5f}); + auto add_eps = std::make_shared(mean, eps); + auto sqrt = std::make_shared(add_eps); + auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {-1}); + auto div = std::make_shared(sqrt, div_const); + auto mul1 = std::make_shared(input, div); + auto gamma = ov::opset10::Constant::create(ov::element::f32, + ov::Shape{6}, + {0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f}); + auto mul2 = std::make_shared(gamma, mul1); + + model = std::make_shared(ov::NodeVector{mul2}, ov::ParameterVector{input}); + manager.register_pass(); + } + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, 6}); + auto rms_const = ov::opset10::Constant::create(ov::element::f32, + ov::Shape{6}, + {0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f}); + auto rms = std::make_shared(input, rms_const, 1e-5f); + + model_ref = std::make_shared(ov::NodeVector{rms}, ov::ParameterVector{input}); + } +} diff --git a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp index 33a1f38b19db76..d438778eba8393 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp @@ -208,7 +208,6 @@ OPENVINO_SUPPRESS_DEPRECATED_START #include "openvino/opsets/opset15_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp" - #undef _OPENVINO_OP_REG }; OPENVINO_SUPPRESS_DEPRECATED_END diff --git a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp index e09a53c243a3e5..ae963375fc7f5d 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp @@ -1021,7 +1021,6 @@ InputsMap getInputMap() { #include "openvino/opsets/opset15_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp" - #undef _OPENVINO_OP_REG }; return inputsMap;