From 0d504ccf348195e0e286f6743ddbff68cb13218b Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Thu, 22 Aug 2024 09:30:22 +0200 Subject: [PATCH] op definition uses ov::RMS instead of RMSNorm --- .../include/ov_ops/opset_private_tbl.hpp | 1 + .../common_optimizations/rms_fusion.cpp | 5 +- src/plugins/intel_cpu/src/cpu_types.cpp | 4 +- src/plugins/intel_cpu/src/cpu_types.h | 2 +- .../src/nodes/kernels/x64/rms_kernel.cpp | 8 +- .../src/nodes/kernels/x64/rms_kernel.hpp | 1 - src/plugins/intel_cpu/src/nodes/rms_norm.cpp | 75 +++++++------------ src/plugins/intel_cpu/src/nodes/rms_norm.h | 3 +- src/plugins/intel_cpu/src/nodes_factory.cpp | 2 +- .../common/pass/decompose_rms_norm.cpp | 52 +++++++++++++ .../common/pass/decompose_rms_norm.hpp | 19 +++++ .../transformation_pipeline.cpp | 5 +- .../single_layer_tests/classes/rms_norm.cpp | 27 +++---- .../single_layer_tests/classes/rms_norm.hpp | 1 + .../instances/x64/rms_norm.cpp | 12 ++- .../src/base/utils/compare_results.cpp | 4 +- .../src/base/utils/generate_inputs.cpp | 4 +- 17 files changed, 138 insertions(+), 87 deletions(-) create mode 100644 src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.cpp create mode 100644 src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.hpp diff --git a/src/common/transformations/include/ov_ops/opset_private_tbl.hpp b/src/common/transformations/include/ov_ops/opset_private_tbl.hpp index e592b1e40ebc63..30dbfb7519c621 100644 --- a/src/common/transformations/include/ov_ops/opset_private_tbl.hpp +++ b/src/common/transformations/include/ov_ops/opset_private_tbl.hpp @@ -9,3 +9,4 @@ _OPENVINO_OP_REG(AUGRUCell, ov::op::internal) _OPENVINO_OP_REG(AUGRUSequence, ov::op::internal) +_OPENVINO_OP_REG(RMS, ov::op::internal) diff --git a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp index 3aaebaf39234ef..7bd50daea94c88 100644 --- a/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/rms_fusion.cpp @@ -13,6 +13,7 @@ #include "openvino/op/reduce_mean.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/pass/manager.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "ov_ops/rms.hpp" #include "transformations/utils/utils.hpp" @@ -68,7 +69,9 @@ RMSFusion::RMSFusion() { auto mul2 = wrap_type({gamma, mul1}); // compress RMS result - auto comp = wrap_type({mul2}); + auto convert = wrap_type({mul2}); + + auto comp = std::make_shared(OutputVector{mul2, convert}); ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index c3eadf7228433f..8ef12a7e374155 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -247,7 +247,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"EmbeddingBagOffsets", Type::EmbeddingBagOffsets}, {"LLMMLP", Type::LLMMLP}, {"QKVProjection", Type::QKVProjection}, - {"RMSNorm", Type::RMSNorm} + {"RMS", Type::RMS} }; return type_to_name_tbl; } @@ -374,7 +374,7 @@ std::string NameFromType(const Type type) { CASE(CausalMaskPreprocess); CASE(LLMMLP); CASE(QKVProjection); - CASE(RMSNorm); + CASE(RMS); CASE(Unknown); } #undef CASE diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index 0d2c9ba25abd00..d4eb431b62f8a3 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -127,7 +127,7 @@ enum class Type { CausalMaskPreprocess, LLMMLP, QKVProjection, - RMSNorm + RMS }; enum class Algorithm { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp index 9b1995fbc2b535..be6ebfbd80ad8f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp @@ -167,7 +167,7 @@ void jit_rms_kernel::generate() { vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt); // x * rsqrt(mean(x^2)+eps) - if (m_jcp.has_scale && m_jcp.scale_size == 1) { + if (m_jcp.scale_size == 1) { // rsqrt(mean(x^2)+eps) vmovd(xmm_tmp, ptr[reg_scale]); vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp); @@ -181,14 +181,14 @@ void jit_rms_kernel::generate() { { load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false); vmulps(vmm_src, vmm_src, vmm_rsqrt); - if (m_jcp.has_scale && m_jcp.scale_size != 1) { + if (m_jcp.scale_size != 1) { load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false); vmulps(vmm_src, vmm_src, vmm_tmp); } store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size); add(reg_src, vec_size * m_jcp.src_prc.size()); - if (m_jcp.has_scale && m_jcp.scale_size != 1) { + if (m_jcp.scale_size != 1) { add(reg_scale, vec_size * sizeof(float)); } add(reg_dst, vec_size * m_jcp.dst_prc.size()); @@ -199,7 +199,7 @@ void jit_rms_kernel::generate() { if (m_jcp.data_size % vec_size) { load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false); vmulps(vmm_src, vmm_src, vmm_rsqrt); - if (m_jcp.has_scale && m_jcp.scale_size != 1) { + if (m_jcp.scale_size != 1) { load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false); vmulps(vmm_src, vmm_src, vmm_tmp); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp index 3ef5607a705c8f..40379a725905b7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp @@ -19,7 +19,6 @@ struct jit_rms_compile_params { ov::element::Type dst_prc; size_t data_size; float eps; - bool has_scale; size_t scale_size; }; diff --git a/src/plugins/intel_cpu/src/nodes/rms_norm.cpp b/src/plugins/intel_cpu/src/nodes/rms_norm.cpp index bfde8301f7f5c7..0a9febd29873e5 100644 --- a/src/plugins/intel_cpu/src/nodes/rms_norm.cpp +++ b/src/plugins/intel_cpu/src/nodes/rms_norm.cpp @@ -13,8 +13,8 @@ #include "onednn/dnnl.h" #include "openvino/core/parallel.hpp" #include "openvino/util/common_util.hpp" +#include "ov_ops/rms.hpp" #include "shape_inference/custom/rms_norm.hpp" -#include "openvino/op/rms_norm.hpp" #include "openvino/opsets/opset6.hpp" #include "kernels/x64/rms_kernel.hpp" @@ -97,22 +97,19 @@ static void execJitKernel(const std::shared_ptr& ker, con } struct RMSNorm::RMSNormExecutor : public RMSNorm::Executor { - RMSNormExecutor(ov::element::Type precision, size_t data_size, size_t scale_size, float eps, bool has_scale) : m_precision(precision) { + RMSNormExecutor(ov::element::Type precision, size_t data_size, size_t scale_size, float eps) : m_precision(precision) { jit_rms_compile_params jcp; jcp.src_prc = precision; jcp.dst_prc = precision; jcp.data_size = data_size; jcp.scale_size = scale_size; jcp.eps = eps; - jcp.has_scale = has_scale; m_kernel = createJitKernel(jcp); } void execute(const std::vector& inputs, const MemoryPtr output) override { auto src = inputs[0]->getDataAs(); auto dst = output->getDataAs(); - float* scale = nullptr; - if (inputs.size() > 2) - scale = inputs[2]->getDataAs(); + float* scale = inputs[1]->getDataAs(); const auto& src_strides = inputs[0]->getDescWithType()->getStrides(); const auto& dst_strides = output->getDescWithType()->getStrides(); @@ -136,9 +133,8 @@ RMSNorm::RMSNorm(const std::shared_ptr& op, const GraphContext::CPtr c if (!isSupportedOperation(op, errorMessage)) { OPENVINO_THROW("CPU: " + errorMessage); } - const auto rms = std::dynamic_pointer_cast(op); + const auto rms = std::dynamic_pointer_cast(op); m_eps = static_cast(rms->get_epsilon()); - m_has_scale = op->get_input_size() > 2; } void RMSNorm::initSupportedPrimitiveDescriptors() { @@ -151,38 +147,26 @@ void RMSNorm::initSupportedPrimitiveDescriptors() { impl_type = impl_desc_type::jit_avx512; } else if (mayiuse(cpu::x64::avx2)) { impl_type = impl_desc_type::jit_avx2; - } else if (mayiuse(cpu::x64::sse41)) { - impl_type = impl_desc_type::jit_sse42; } else { impl_type = impl_desc_type::ref; } - if (m_has_scale) { - addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::f32}}, - {{LayoutType::ncsp, precision}}, - impl_type); - } else { - addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}}, - {{LayoutType::ncsp, precision}}, - impl_type); - } + addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::f32}}, + {{LayoutType::ncsp, precision}}, + impl_type); } void RMSNorm::createPrimitive() { auto precision = getOriginalInputPrecisionAtPort(0); auto data_dims = getSrcMemoryAtPort(0)->getDescWithType()->getBlockDims(); - auto has_scale = getOriginalInputsNumber() > 2; size_t data_size = data_dims[data_dims.size() - 1]; - size_t scale_size = 0; - if (has_scale) { - scale_size = getSrcMemoryAtPort(2)->getDescWithType()->getBlockDims()[0]; - } + size_t scale_size = shape_size(getSrcMemoryAtPort(1)->getDescWithType()->getBlockDims()); RMSNormKey key = {precision, data_size, scale_size, static_cast(dnnl::impl::float2int(m_eps))}; auto builder = [&](const RMSNormKey& key) -> std::shared_ptr { #ifdef OPENVINO_ARCH_X86_64 - return std::make_shared(precision, data_size, scale_size, m_eps, has_scale); + return std::make_shared(precision, data_size, scale_size, m_eps); #else return nullptr; #endif @@ -209,8 +193,12 @@ void RMSNorm::execute(dnnl::stream strm) { bool RMSNorm::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - const auto rms = std::dynamic_pointer_cast(op); + const auto rms = std::dynamic_pointer_cast(op); if (rms) { + if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { + errorMessage = "RMSNorm needs avx2+."; + return false; + } // check the last dimension of data auto data_pshape = op->input_value(0).get_partial_shape(); if (data_pshape.rank().is_dynamic()) { @@ -226,34 +214,21 @@ bool RMSNorm::isSupportedOperation(const std::shared_ptr& op, st errorMessage = "RMSNorm data rank must be greater than 1."; return false; } - // check axes - auto axes_op = ov::as_type_ptr(op->get_input_node_shared_ptr(1)); - if (!axes_op) { - errorMessage = "RMSNorm axes is expected as Constant."; - return false; - } - // axes should be 1d or scalar in spec - auto axes_vals = axes_op->cast_vector(); - if (axes_vals[0] != -1 && axes_vals[0] != data_rank - 1) { - errorMessage = "RMSNorm axes must be the last dimension."; - return false; - } // check scale - if (op->get_input_size() > 2) { - if (op->get_input_partial_shape(2).rank().get_length() > 1) { - errorMessage = "RMSNorm scale must be 1D or scalar."; - return false; - } - if (op->get_input_partial_shape(2).is_dynamic()) { - errorMessage = "RMSNorm scale shape is not static."; - return false; - } - } - if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { - errorMessage = "RMSNorm needs avx2+."; + if (op->get_input_partial_shape(1).is_dynamic()) { + errorMessage = "RMSNorm scale shape is not static."; return false; } + auto scale_pshape = op->get_input_partial_shape(1); + if (scale_pshape.rank().get_length() > 1) { + for (int64_t i = 0; i < scale_pshape.rank().get_length() - 1; i++) { + if (scale_pshape[i] != 1) { + errorMessage = "RMSNorm scale shape must be [1,..., N]."; + return false; + } + } + } } else { errorMessage = "Only RMSNorm operation is supported"; return false; diff --git a/src/plugins/intel_cpu/src/nodes/rms_norm.h b/src/plugins/intel_cpu/src/nodes/rms_norm.h index a47285d0fd2fc6..0adc320a3c07fb 100644 --- a/src/plugins/intel_cpu/src/nodes/rms_norm.h +++ b/src/plugins/intel_cpu/src/nodes/rms_norm.h @@ -17,7 +17,7 @@ class RMSNorm : public Node { void getSupportedDescriptors() override {} bool created() const override { - return getType() == Type::RMSNorm; + return getType() == Type::RMS; } bool needPrepareParams() const override { return false; @@ -41,7 +41,6 @@ class RMSNorm : public Node { friend struct RMSNormKey; float m_eps = 0.0f; - bool m_has_scale = false; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index 1a430541af0c76..e80447a372e2f1 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -219,7 +219,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") { INTEL_CPU_NODE(QKVProjection, Type::QKVProjection); INTEL_CPU_NODE(MHA, Type::MHA); INTEL_CPU_NODE(PagedAttention, Type::PagedAttention); - INTEL_CPU_NODE(RMSNorm, Type::RMSNorm); + INTEL_CPU_NODE(RMSNorm, Type::RMS); #endif } diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.cpp new file mode 100644 index 00000000000000..6f135139d9a148 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "decompose_rms_norm.hpp" +#include "itt.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/core/rt_info.hpp" +#include "ov_ops/rms.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +namespace ov { +namespace intel_cpu { + +DecomposeRMSNorm::DecomposeRMSNorm() { + MATCHER_SCOPE(DecomposeRMSNorm); + auto pattern_node = ov::pass::pattern::wrap_type(); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto node = std::dynamic_pointer_cast( + pattern_to_output.at(pattern_node).get_node_shared_ptr()); + + if (node == nullptr || transformation_callback(node)) { + return false; + } + auto data = node->get_input_node_shared_ptr(0); + auto data_precision = node->get_input_element_type(0); + auto scale = node->get_input_node_shared_ptr(1); + + auto power_const = ov::opset10::Constant::create(data_precision, {}, std::vector{2.f}); + auto power = std::make_shared(data, power_const); + auto mean_axes = ov::opset10::Constant::create(ov::element::i32, ov::Shape{1}, {-1}); + auto mean = std::make_shared(power, mean_axes, true); + auto eps = ov::opset10::Constant::create(data_precision, {}, {node->get_epsilon()}); + auto add_eps = std::make_shared(mean, eps); + auto sqrt = std::make_shared(add_eps); + auto div_const = ov::opset10::Constant::create(data_precision, {}, {-1}); + auto div = std::make_shared(sqrt, div_const); + auto mul1 = std::make_shared(data, div); + auto mul2 = std::make_shared(scale, mul1); + + ov::replace_node(node, mul2); + return true; + }; + + auto m = std::make_shared(pattern_node, matcher_name); + register_matcher(m, callback); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.hpp new file mode 100644 index 00000000000000..33ad0346dd6c27 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/decompose_rms_norm.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace intel_cpu { + +class DecomposeRMSNorm: public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("DecomposeRMSNorm", "0"); + DecomposeRMSNorm(); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index b106885143bba8..97d2855c8e2e60 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -6,7 +6,6 @@ #include "defs.hpp" // Operations -#include "openvino/op/constant.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset2.hpp" #include "openvino/opsets/opset3.hpp" @@ -130,6 +129,7 @@ #include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp" #include "transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp" #include "transformations/cpu_opset/common/pass/decompose_integer_divide.hpp" +#include "transformations/cpu_opset/common/pass/decompose_rms_norm.hpp" #include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp" #include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp" #include "transformations/cpu_opset/common/pass/ngram_fusion.hpp" @@ -856,12 +856,13 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward); CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm); CPU_SET_CALLBACK_X64(postLPTPassManager, [](const std::shared_ptr& node) -> bool { std::string errorMsg; return node::RMSNorm::isSupportedOperation(node, errorMsg); }, - ov::pass::RMSFusion); + ov::intel_cpu::DecomposeRMSNorm); // markup Rope Input when BF16/F16 inference. if (one_of(inferencePrecision, ov::element::bf16, ov::element::f16)) diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp index 785e3092809403..5a720a0087e2a9 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp @@ -7,7 +7,7 @@ #include "gtest/gtest.h" #include "openvino/core/shape.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/rms_norm.hpp" +#include "ov_ops/rms.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" #include "utils/cpu_test_utils.hpp" #include "openvino/pass/manager.hpp" @@ -78,10 +78,15 @@ void RMSNormLayerCPUTest::generate_inputs(const std::vector& targetIn inputs.insert({param, t}); } }; - // q, k, v, pastkv create_input(function->get_parameters()[0], targetInputStaticShapes[0], 1.0f); - if (targetInputStaticShapes.size() > 1) - create_input(function->get_parameters()[1], targetInputStaticShapes[1], 0.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[1], 0.0f); + for (size_t i = 0; i < targetInputStaticShapes[1].size() - 1; i++) { + if (targetInputStaticShapes[1][i] != 1) { + // decomposed rms expected + m_rms_decomposed = true; + break; + } + } } void RMSNormLayerCPUTest::SetUp() { @@ -102,23 +107,19 @@ void RMSNormLayerCPUTest::SetUp() { selectedType = makeSelectedTypeStr(selectedType, inType); init_input_shapes(inputShapes); ov::ParameterVector inputParams; - // data, axes, scale + // data, scale auto data = std::make_shared(inType, inputDynamicShapes[0]); - auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, std::vector{-1}); inputParams.push_back(data); - std::shared_ptr scale; - if (inputDynamicShapes.size() > 1) { - scale = std::make_shared(inType, inputDynamicShapes[1]); - inputParams.push_back(scale); - } - auto rms = scale ? std::make_shared(data, axes, scale, 0.1f) : - std::make_shared(data, axes, 0.1f); + auto scale = std::make_shared(inType, inputDynamicShapes[1]); + inputParams.push_back(scale); + auto rms = std::make_shared(data, scale, 0.1f); rms->set_friendly_name("rms"); function = makeNgraphFunction(inType, inputParams, rms, "rms"); } TEST_P(RMSNormLayerCPUTest, CompareWithRefs) { run(); + CheckNumberOfNodesWithType(compiledModel, "RMS", m_rms_decomposed ? 0 : 1); } namespace RMSNorm { diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp index da361f35dd7db5..2e64208c53f2b8 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp @@ -25,6 +25,7 @@ class RMSNormLayerCPUTest : public testing::WithParamInterface& targetInputStaticShapes) override; + bool m_rms_decomposed = false; }; } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp index 75f71f41c727f7..a7f4ea7090d5c2 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp @@ -35,13 +35,17 @@ const std::vector> shapes{ {ov::Shape{1}, ov::Shape{1}}} }, }, - // no scale + // decomposition path { // data shape - {ov::test::InputShape{ov::PartialShape{-1, -1, 1094}, - {ov::Shape{1, 8, 1094}, ov::Shape{2, 3, 1094}}} + {ov::test::InputShape{ov::PartialShape{-1, -1, 64}, + {ov::Shape{1, 8, 64}}} }, - } + // scale shape + {ov::test::InputShape{ov::PartialShape{1, 8, 64}, + {ov::Shape{1, 8, 64}}} + }, + }, }; const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16, ElementType::f16), diff --git a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp index 5c5684cab4f6ca..33a1f38b19db76 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp @@ -5,9 +5,9 @@ #include #include "openvino/op/ops.hpp" -#include "openvino/op/rms_norm.hpp" #include "ov_ops/augru_cell.hpp" #include "ov_ops/augru_sequence.hpp" +#include "ov_ops/rms.hpp" #include "shared_test_classes/base/utils/compare_results.hpp" #include @@ -209,8 +209,6 @@ OPENVINO_SUPPRESS_DEPRECATED_START #include "ov_ops/opset_private_tbl.hpp" -_OPENVINO_OP_REG(RMSNorm, ov::op::internal) - #undef _OPENVINO_OP_REG }; OPENVINO_SUPPRESS_DEPRECATED_END diff --git a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp index 13c0674cfc86b0..e09a53c243a3e5 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp @@ -9,9 +9,9 @@ #include "shared_test_classes/base/utils/generate_inputs.hpp" #include "openvino/op/ops.hpp" -#include "openvino/op/rms_norm.hpp" #include "ov_ops/augru_cell.hpp" #include "ov_ops/augru_sequence.hpp" +#include "ov_ops/rms.hpp" #include "common_test_utils/ov_tensor_utils.hpp" #include "common_test_utils/data_utils.hpp" @@ -1022,8 +1022,6 @@ InputsMap getInputMap() { #include "ov_ops/opset_private_tbl.hpp" -_OPENVINO_OP_REG(RMSNorm, ov::op::internal) - #undef _OPENVINO_OP_REG }; return inputsMap;