diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index cd6021a40a5a8c..c3eadf7228433f 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -247,6 +247,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"EmbeddingBagOffsets", Type::EmbeddingBagOffsets}, {"LLMMLP", Type::LLMMLP}, {"QKVProjection", Type::QKVProjection}, + {"RMSNorm", Type::RMSNorm} }; return type_to_name_tbl; } @@ -373,6 +374,7 @@ std::string NameFromType(const Type type) { CASE(CausalMaskPreprocess); CASE(LLMMLP); CASE(QKVProjection); + CASE(RMSNorm); CASE(Unknown); } #undef CASE diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index 6834225c1f2515..0d2c9ba25abd00 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -127,6 +127,7 @@ enum class Type { CausalMaskPreprocess, LLMMLP, QKVProjection, + RMSNorm }; enum class Algorithm { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp new file mode 100644 index 00000000000000..9b1995fbc2b535 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp @@ -0,0 +1,241 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rms_kernel.hpp" + +using namespace dnnl::impl::cpu::x64; +using namespace Xbyak; + +namespace ov { +namespace intel_cpu { +namespace kernel { + +#define GET_OFF(field) offsetof(jit_rms_call_args, field) + +template +void jit_rms_kernel::reduce_zmm_to_ymm( + const Xmm &acc, const Xmm &tmp) { + const Zmm zmm_acc(acc.getIdx()); + const Ymm ymm_acc(acc.getIdx()); + const Ymm ymm_to_acc(tmp.getIdx()); + vextractf64x4(ymm_to_acc, zmm_acc, 1); + vaddps(ymm_acc, ymm_acc, ymm_to_acc); +} + +template +void jit_rms_kernel::reduce_ymm_to_xmm( + const Xmm &acc, const Xmm &tmp) { + const Ymm ymm_acc(acc.getIdx()); + const Xmm xmm_acc(acc.getIdx()); + const Xmm xmm_to_acc(tmp.getIdx()); + vextractf128(xmm_to_acc, ymm_acc, 1); + vaddps(xmm_acc, xmm_acc, xmm_to_acc); +} + +template +void jit_rms_kernel::reduce_xmm_to_scalar(const Xmm &acc, + const Xmm &tmp, const std::size_t number_of_values_to_reduce) { + assert(number_of_values_to_reduce <= number_of_f32_in_xmm_); + + const Xmm xmm_acc(acc.getIdx()); + const Xmm ymm_to_acc(tmp.getIdx()); + + static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1; + static constexpr uint8_t insertps_configuration[number_of_f32_to_move] + = {0b01001110, 0b10001110, 0b11001110}; + + for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) { + vinsertps(ymm_to_acc, ymm_to_acc, xmm_acc, insertps_configuration[i]); + vaddss(xmm_acc, xmm_acc, ymm_to_acc); + } +} + +template +void jit_rms_kernel::reduce_ymm_to_scalar( + const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, + const std::size_t number_of_values_to_reduce) { + assert(number_of_values_to_reduce <= number_of_f32_in_ymm_); + + const Ymm ymm_acc(acc.getIdx()); + const Xmm xmm_acc(acc.getIdx()); + const Xmm xmm_tmp(tmp1.getIdx()); + const Xmm xmm_acc_upper_half(tmp2.getIdx()); + + if (number_of_values_to_reduce == number_of_f32_in_ymm_) { + reduce_ymm_to_xmm(ymm_acc, xmm_tmp); + reduce_xmm_to_scalar(xmm_acc, xmm_tmp); + } else if (number_of_values_to_reduce > number_of_f32_in_xmm_) { + vextractf128(xmm_acc_upper_half, ymm_acc, 1); + reduce_xmm_to_scalar(xmm_acc, xmm_tmp); + reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp, + number_of_values_to_reduce - number_of_f32_in_xmm_); + vaddss(xmm_acc, xmm_acc, xmm_acc_upper_half); + } else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) { + reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce); + } +} + +template +void jit_rms_kernel::reduce_vmm_to_scalar( + const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, + const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) { + assert(number_of_values_to_reduce <= number_of_f32_in_zmm_); + + const Zmm zmm_acc(acc.getIdx()); + const Ymm ymm_acc(acc.getIdx()); + const Xmm xmm_acc(acc.getIdx()); + const Ymm ymm_acc_upper_half(tmp1.getIdx()); + const Xmm xmm_acc_upper_half(tmp1.getIdx()); + const Ymm ymm_tmp(tmp2.getIdx()); + const Xmm xmm_tmp1(tmp2.getIdx()); + const Xmm xmm_tmp2(tmp3.getIdx()); + + if (number_of_values_to_reduce == number_of_f32_in_zmm_) { + reduce_zmm_to_ymm(zmm_acc, ymm_tmp); + reduce_ymm_to_xmm(ymm_acc, xmm_tmp1); + reduce_xmm_to_scalar(xmm_acc, xmm_tmp1); + } else if (number_of_values_to_reduce > number_of_f32_in_ymm_) { + vextractf64x4(ymm_acc_upper_half, zmm_acc, 1); + reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2); + reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2, + number_of_values_to_reduce - number_of_f32_in_ymm_); + vaddps(xmm_acc, xmm_acc, xmm_acc_upper_half); + } else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) { + reduce_ymm_to_scalar( + ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce); + } +} + +template +void jit_rms_kernel::generate() { + this->preamble(); + mov(reg_src, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_scale, ptr[abi_param1 + GET_OFF(scale)]); + mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]); + uni_vpxor(vmm_sum0, vmm_sum0, vmm_sum0); + uni_vpxor(vmm_sum1, vmm_sum1, vmm_sum1); + uni_vpxor(vmm_sum2, vmm_sum2, vmm_sum2); + uni_vpxor(vmm_sum3, vmm_sum3, vmm_sum3); + mov(reg_src_org, reg_src); + + mov(reg_size, m_jcp.data_size / (vec_size * 4)); + // x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma + // sum(x^2) + align(16); + Xbyak::Label loop_4reg; + L(loop_4reg); + { + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false); + vfmadd231ps(vmm_sum0, vmm_src, vmm_src); + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1); + vfmadd231ps(vmm_sum1, vmm_src, vmm_src); + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2); + vfmadd231ps(vmm_sum2, vmm_src, vmm_src); + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3); + vfmadd231ps(vmm_sum3, vmm_src, vmm_src); + + add(reg_src, vec_size * m_jcp.src_prc.size() * 4); + dec(reg_size); + jnz(loop_4reg); + } + // 1 ~ 3 vmm + for (size_t i = m_jcp.data_size / (vec_size * 4) * 4; i < m_jcp.data_size / vec_size; i++) { + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false); + vfmadd231ps(vmm_sum0, vmm_src, vmm_src); + add(reg_src, vec_size * m_jcp.src_prc.size()); + } + // tail + if (m_jcp.data_size % vec_size) { + load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false); + vfmadd231ps(vmm_sum0, vmm_src, vmm_src); + } + vaddps(vmm_sum0, vmm_sum0, vmm_sum1); + vaddps(vmm_sum2, vmm_sum2, vmm_sum3); + vaddps(vmm_rsqrt, vmm_sum0, vmm_sum2); + reduce_vmm_to_scalar(vmm_rsqrt, vmm_sum0, vmm_sum1, vmm_sum3, vec_size); + + // mean(x^2) + mov(reg_tmp.cvt32(), float2int(1.0f / m_jcp.data_size)); + vmovd(xmm_tmp, reg_tmp.cvt32()); + vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp); + // mean(x^2)+eps + mov(reg_tmp.cvt32(), float2int(m_jcp.eps)); + vmovd(xmm_tmp, reg_tmp.cvt32()); + vaddss(xmm_rsqrt, xmm_rsqrt, xmm_tmp); + // rsqrt(mean(x^2)+eps) + vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt); + + // x * rsqrt(mean(x^2)+eps) + if (m_jcp.has_scale && m_jcp.scale_size == 1) { + // rsqrt(mean(x^2)+eps) + vmovd(xmm_tmp, ptr[reg_scale]); + vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp); + } + vbroadcastss(vmm_rsqrt, xmm_rsqrt); + mov(reg_size, m_jcp.data_size / vec_size); + mov(reg_src, reg_src_org); + align(16); + Xbyak::Label loop_mul; + L(loop_mul); + { + load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false); + vmulps(vmm_src, vmm_src, vmm_rsqrt); + if (m_jcp.has_scale && m_jcp.scale_size != 1) { + load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false); + vmulps(vmm_src, vmm_src, vmm_tmp); + } + store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size); + + add(reg_src, vec_size * m_jcp.src_prc.size()); + if (m_jcp.has_scale && m_jcp.scale_size != 1) { + add(reg_scale, vec_size * sizeof(float)); + } + add(reg_dst, vec_size * m_jcp.dst_prc.size()); + dec(reg_size); + jnz(loop_mul); + } + // tail + if (m_jcp.data_size % vec_size) { + load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false); + vmulps(vmm_src, vmm_src, vmm_rsqrt); + if (m_jcp.has_scale && m_jcp.scale_size != 1) { + load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false); + vmulps(vmm_src, vmm_src, vmm_tmp); + } + store(reg_dst, vmm_src, m_jcp.dst_prc, m_jcp.data_size % vec_size); + } + + this->postamble(); + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); + } +} + +template +void jit_rms_kernel::load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset) { + const auto seed = load_emitter_params(src_prc, ov::element::f32, elt_num, fill, "float_min").hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, ov::element::f32, elt_num, ov::element::f32, fill, "float_min")); + } + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), offset}, {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); +} + +template +void jit_rms_kernel::store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset) { + const auto seed = store_emitter_params(ov::element::f32, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, ov::element::f32, dst_prc, elt_num)); + } + emitters[seed]->emit_code({static_cast(vmm_src.getIdx()), offset}, {static_cast(reg_dst.getIdx())}, + pool_aux_vmm_idxs, pool_aux_gpr_idxs); +} + +template struct jit_rms_kernel; +template struct jit_rms_kernel; + +} // namespace kernel +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp new file mode 100644 index 00000000000000..3ef5607a705c8f --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp @@ -0,0 +1,89 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "jit_kernel_base.hpp" + +#if defined(OPENVINO_ARCH_X86_64) +#include "emitters/plugin/x64/jit_load_store_emitters.hpp" +#endif + +namespace ov { +namespace intel_cpu { +namespace kernel { + +struct jit_rms_compile_params { + ov::element::Type src_prc; + ov::element::Type dst_prc; + size_t data_size; + float eps; + bool has_scale; + size_t scale_size; +}; + +struct jit_rms_call_args { + const uint8_t* src; + const float* scale; + uint8_t* dst; +}; + +#if defined(OPENVINO_ARCH_X86_64) + +template +struct jit_rms_kernel : public JitKernel { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_rms_kernel) + + static constexpr size_t vec_size = dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + + explicit jit_rms_kernel(const jit_rms_compile_params& jcp) : JitKernel(jit_name(), jcp, isa) {} + +private: + using Xmm = Xbyak::Xmm; + using Vmm = typename dnnl::impl::utils::conditional3::type; + + void generate() override; + void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset = 0); + void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset = 0); + + // from onednn + void reduce_zmm_to_ymm(const Xmm &acc, const Xmm &tmp); + void reduce_ymm_to_xmm(const Xmm &acc, const Xmm &tmp); + void reduce_xmm_to_scalar(const Xmm &acc, const Xmm &tmp, const std::size_t number_of_values_to_reduce = number_of_f32_in_xmm_); + void reduce_ymm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, + const std::size_t number_of_values_to_reduce = number_of_f32_in_ymm_); + void reduce_vmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2, + const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce = number_of_f32_in_zmm_); + static constexpr std::size_t number_of_f32_in_xmm_ = 4; + static constexpr std::size_t number_of_f32_in_ymm_ = 8; + static constexpr std::size_t number_of_f32_in_zmm_ = 16; + + const Vmm vmm_src = Vmm(0); + const Vmm vmm_sum0 = Vmm(2); + const Vmm vmm_rsqrt = Vmm(2); + const Xmm xmm_rsqrt = Xmm(2); + const Vmm vmm_sum1 = Vmm(3); + const Vmm vmm_tmp = Vmm(3); + const Xmm xmm_tmp = Xmm(3); + const Vmm vmm_sum2 = Vmm(4); + const Vmm vmm_sum3 = Vmm(5); + const Vmm vmm_dst = Vmm(6); + const Xbyak::Reg64 reg_src = r8; + const Xbyak::Reg64 reg_src_org = r13; + const Xbyak::Reg64 reg_scale = r10; + const Xbyak::Reg64 reg_size = r11; + const Xbyak::Reg64 reg_dst = r12; + const Xbyak::Reg64 reg_tmp = rdx; + + std::unordered_map> emitters; + const std::vector pool_aux_gpr_idxs = { static_cast(rax.getIdx()), static_cast(r9.getIdx()) }; + const std::vector pool_aux_vmm_idxs = { 7 }; +}; + +#endif + +} // namespace kernel +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/rms_norm.cpp b/src/plugins/intel_cpu/src/nodes/rms_norm.cpp new file mode 100644 index 00000000000000..bfde8301f7f5c7 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/rms_norm.cpp @@ -0,0 +1,269 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rms_norm.h" + +#include "common/arbitrary_order_desc_creator.h" +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "dnnl_extension_utils.h" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" +#include "openvino/util/common_util.hpp" +#include "shape_inference/custom/rms_norm.hpp" +#include "openvino/op/rms_norm.hpp" +#include "openvino/opsets/opset6.hpp" +#include "kernels/x64/rms_kernel.hpp" + +#include +#include +#include + +using namespace ov::intel_cpu::kernel; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { +namespace node { + +struct RMSNormKey { + ov::element::Type precision; + size_t data_size; + size_t scale_size; + size_t eps; + size_t hash() const; + bool operator==(const RMSNormKey& rhs) const; +}; + +size_t RMSNormKey::hash() const { + size_t seed = 0; + seed = hash_combine(seed, precision.hash()); + seed = hash_combine(seed, data_size); + seed = hash_combine(seed, scale_size); + seed = hash_combine(seed, eps); + + return seed; +} + +bool RMSNormKey::operator==(const RMSNormKey& rhs) const { + auto retVal = precision == rhs.precision && + data_size == rhs.data_size && + scale_size == rhs.scale_size && + eps == rhs.eps; + + return retVal; +} + +static std::shared_ptr createJitKernel(const jit_rms_compile_params& param) { + std::shared_ptr res; + + MAYBE_UNUSED(param); + +#if defined(OPENVINO_ARCH_X86_64) + + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { + res = std::make_shared>(param); + } else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { + res = std::make_shared>(param); + } + + if (res) + res->create_kernel(); + +#endif // OPENVINO_ARCH_X86_64 + + return res; +} + +static void execJitKernel(const std::shared_ptr& ker, const uint8_t* src, uint8_t* dst, const float* scale) { + MAYBE_UNUSED(ker); + MAYBE_UNUSED(src); + MAYBE_UNUSED(dst); + MAYBE_UNUSED(scale); + +#if defined(OPENVINO_ARCH_X86_64) + + jit_rms_call_args call_args; + call_args.src = src; + call_args.dst = dst; + call_args.scale = scale; + (*ker)(&call_args); + +#endif // OPENVINO_ARCH_X86_64 +} + +struct RMSNorm::RMSNormExecutor : public RMSNorm::Executor { + RMSNormExecutor(ov::element::Type precision, size_t data_size, size_t scale_size, float eps, bool has_scale) : m_precision(precision) { + jit_rms_compile_params jcp; + jcp.src_prc = precision; + jcp.dst_prc = precision; + jcp.data_size = data_size; + jcp.scale_size = scale_size; + jcp.eps = eps; + jcp.has_scale = has_scale; + m_kernel = createJitKernel(jcp); + } + void execute(const std::vector& inputs, const MemoryPtr output) override { + auto src = inputs[0]->getDataAs(); + auto dst = output->getDataAs(); + float* scale = nullptr; + if (inputs.size() > 2) + scale = inputs[2]->getDataAs(); + + const auto& src_strides = inputs[0]->getDescWithType()->getStrides(); + const auto& dst_strides = output->getDescWithType()->getStrides(); + const auto& shape = inputs[0]->getStaticDims(); + const auto src_stride = src_strides[src_strides.size() - 2] * m_precision.size(); + const auto dst_stride = dst_strides[dst_strides.size() - 2] * m_precision.size(); + auto n = shape_size(shape) / shape[shape.size() - 1]; + parallel_for(n, [&] (size_t i) { + execJitKernel(m_kernel, src + i * src_stride, dst + i * dst_stride, scale); + }); + } + +private: + ov::element::Type m_precision; + std::shared_ptr m_kernel; +}; + +RMSNorm::RMSNorm(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, RMSNormShapeInferFactory(op)) { + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + OPENVINO_THROW("CPU: " + errorMessage); + } + const auto rms = std::dynamic_pointer_cast(op); + m_eps = static_cast(rms->get_epsilon()); + m_has_scale = op->get_input_size() > 2; +} + +void RMSNorm::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + auto precision = getOriginalInputPrecisionAtPort(0); + + impl_desc_type impl_type; + if (mayiuse(cpu::x64::avx512_core)) { + impl_type = impl_desc_type::jit_avx512; + } else if (mayiuse(cpu::x64::avx2)) { + impl_type = impl_desc_type::jit_avx2; + } else if (mayiuse(cpu::x64::sse41)) { + impl_type = impl_desc_type::jit_sse42; + } else { + impl_type = impl_desc_type::ref; + } + + if (m_has_scale) { + addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::f32}}, + {{LayoutType::ncsp, precision}}, + impl_type); + } else { + addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}}, + {{LayoutType::ncsp, precision}}, + impl_type); + } +} + +void RMSNorm::createPrimitive() { + auto precision = getOriginalInputPrecisionAtPort(0); + auto data_dims = getSrcMemoryAtPort(0)->getDescWithType()->getBlockDims(); + auto has_scale = getOriginalInputsNumber() > 2; + size_t data_size = data_dims[data_dims.size() - 1]; + size_t scale_size = 0; + if (has_scale) { + scale_size = getSrcMemoryAtPort(2)->getDescWithType()->getBlockDims()[0]; + } + + RMSNormKey key = {precision, data_size, scale_size, static_cast(dnnl::impl::float2int(m_eps))}; + + auto builder = [&](const RMSNormKey& key) -> std::shared_ptr { +#ifdef OPENVINO_ARCH_X86_64 + return std::make_shared(precision, data_size, scale_size, m_eps, has_scale); +#else + return nullptr; +#endif + }; + + auto cache = context->getParamsCache(); + auto result = cache->getOrCreate(key, builder); + if (!result.first) { + OPENVINO_THROW("RMSNorm Executor creation fails with precision " + precision.to_string()); + } + m_executor = result.first; +} + +void RMSNorm::execute(dnnl::stream strm) { + auto orginInputNumber = getOriginalInputsNumber(); + std::vector inputs(orginInputNumber); + + for (size_t i = 0; i < orginInputNumber; i++) { + inputs[i] = getSrcMemoryAtPort(i); + } + + m_executor->execute(inputs, getDstMemoryAtPort(0)); +} + +bool RMSNorm::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto rms = std::dynamic_pointer_cast(op); + if (rms) { + // check the last dimension of data + auto data_pshape = op->input_value(0).get_partial_shape(); + if (data_pshape.rank().is_dynamic()) { + errorMessage = "RMSNorm data rank is not static."; + return false; + } + const auto& data_rank = op->get_input_partial_shape(0).rank().get_length(); + if (data_pshape[data_rank - 1].is_dynamic()) { + errorMessage = "RMSNorm last dimension of data is not static."; + return false; + } + if (data_rank == 1) { + errorMessage = "RMSNorm data rank must be greater than 1."; + return false; + } + // check axes + auto axes_op = ov::as_type_ptr(op->get_input_node_shared_ptr(1)); + if (!axes_op) { + errorMessage = "RMSNorm axes is expected as Constant."; + return false; + } + // axes should be 1d or scalar in spec + auto axes_vals = axes_op->cast_vector(); + if (axes_vals[0] != -1 && axes_vals[0] != data_rank - 1) { + errorMessage = "RMSNorm axes must be the last dimension."; + return false; + } + + // check scale + if (op->get_input_size() > 2) { + if (op->get_input_partial_shape(2).rank().get_length() > 1) { + errorMessage = "RMSNorm scale must be 1D or scalar."; + return false; + } + if (op->get_input_partial_shape(2).is_dynamic()) { + errorMessage = "RMSNorm scale shape is not static."; + return false; + } + } + if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { + errorMessage = "RMSNorm needs avx2+."; + return false; + } + } else { + errorMessage = "Only RMSNorm operation is supported"; + return false; + } + } catch (...) { + return false; + } + return true; +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/rms_norm.h b/src/plugins/intel_cpu/src/nodes/rms_norm.h new file mode 100644 index 00000000000000..a47285d0fd2fc6 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/rms_norm.h @@ -0,0 +1,49 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "node.h" +#include "utils/plain_tensor.hpp" + +namespace ov { +namespace intel_cpu { +namespace node { + +class RMSNorm : public Node { +public: + RMSNorm(const std::shared_ptr& op, const GraphContext::CPtr context); + + void getSupportedDescriptors() override {} + bool created() const override { + return getType() == Type::RMSNorm; + } + bool needPrepareParams() const override { + return false; + } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } + void initSupportedPrimitiveDescriptors() override; + void execute(dnnl::stream strm) override; + void createPrimitive() override; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + +private: + struct Executor { + virtual void execute(const std::vector& inputs, const MemoryPtr output) = 0; + virtual ~Executor() = default; + }; + + std::shared_ptr m_executor; + struct RMSNormExecutor; + friend struct RMSNormKey; + + float m_eps = 0.0f; + bool m_has_scale = false; +}; + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index 9012c37f5ac23b..1a430541af0c76 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -80,6 +80,7 @@ #include "nodes/reorg_yolo.h" #include "nodes/reshape.h" #include "nodes/reverse_sequence.h" +#include "nodes/rms_norm.h" #include "nodes/rnn.h" #include "nodes/roi_align.h" #include "nodes/roi_align_rotated.h" @@ -218,6 +219,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") { INTEL_CPU_NODE(QKVProjection, Type::QKVProjection); INTEL_CPU_NODE(MHA, Type::MHA); INTEL_CPU_NODE(PagedAttention, Type::PagedAttention); + INTEL_CPU_NODE(RMSNorm, Type::RMSNorm); #endif } diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.cpp b/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.cpp new file mode 100644 index 00000000000000..5c1c00208053af --- /dev/null +++ b/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rms_norm.hpp" + +namespace ov { +namespace intel_cpu { +namespace node { + +class RMSNormShapeInfer : public ShapeInferEmptyPads { +public: + RMSNormShapeInfer() {} + + IShapeInfer::Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const auto& dims = input_shapes.front().get(); + return {{dims}, ShapeInferStatus::success}; + } + + port_mask_t get_port_mask() const override { + return EMPTY_PORT_MASK; + } +}; + +class PAShapeInfer : public ShapeInferEmptyPads { +public: + PAShapeInfer() {} + + IShapeInfer::Result infer(const std::vector>& input_shapes, + const std::unordered_map& data_dependency) override { + const auto& query_dims = input_shapes.front().get(); + + return {{query_dims}, ShapeInferStatus::success}; + } + + port_mask_t get_port_mask() const override { + return EMPTY_PORT_MASK; + } +}; + +ShapeInferPtr RMSNormShapeInferFactory::makeShapeInfer() const { + return std::make_shared(); +} + +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.hpp b/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.hpp new file mode 100644 index 00000000000000..b1eacae9ce4778 --- /dev/null +++ b/src/plugins/intel_cpu/src/shape_inference/custom/rms_norm.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "shape_inference/shape_inference_cpu.hpp" + +#pragma once +namespace ov { +namespace intel_cpu { +namespace node { + +class RMSNormShapeInferFactory : public ShapeInferFactory { +public: + RMSNormShapeInferFactory(std::shared_ptr op) : m_op(op) {} + ShapeInferPtr makeShapeInfer() const override; + +private: + std::shared_ptr m_op; +}; +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 98da1be4c74876..b106885143bba8 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -6,6 +6,7 @@ #include "defs.hpp" // Operations +#include "openvino/op/constant.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset2.hpp" #include "openvino/opsets/opset3.hpp" @@ -36,6 +37,7 @@ #include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp" #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" #include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp" +#include "transformations/common_optimizations/rms_fusion.hpp" #include "transformations/control_flow/unroll_tensor_iterator.hpp" #include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp" #include "transformations/op_conversions/convert_avgpool_downgrade.hpp" @@ -157,6 +159,7 @@ #include "nodes/scaled_attn.h" #include "nodes/llm_mlp.h" #include "nodes/qkv_proj.h" +#include "nodes/rms_norm.h" #include "dnnl.hpp" #if defined(OPENVINO_ARCH_ARM64) #include "cpu/aarch64/cpu_isa_traits.hpp" @@ -852,6 +855,14 @@ void Transformations::PostLpt() { } CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward); CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion); + CPU_SET_CALLBACK_X64(postLPTPassManager, + [](const std::shared_ptr& node) -> bool { + std::string errorMsg; + return node::RMSNorm::isSupportedOperation(node, errorMsg); + }, + ov::pass::RMSFusion); + // markup Rope Input when BF16/F16 inference. if (one_of(inferencePrecision, ov::element::bf16, ov::element::f16)) CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::MarkRopeInputsToKeepInMixedPrecision); diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp new file mode 100644 index 00000000000000..33d271a3b81cc6 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rms_norm.hpp" + +#include "gtest/gtest.h" +#include "openvino/core/shape.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/rms_norm.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" +#include "openvino/pass/manager.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { + +std::string RMSNormLayerCPUTest::getTestCaseName(const testing::TestParamInfo& obj) { + CPUSpecificParams cpuParams; + ElementType inType; + std::vector inputShapes; + std::string targetDevice; + std::tie(inType, inputShapes, targetDevice, cpuParams) = obj.param; + + std::ostringstream result; + result << "netPRC=" << inType << "_"; + result << "IS="; + for (const auto& inputShape : inputShapes) { + result << ov::test::utils::partialShape2str({inputShape.first}) << "_"; + } + result << "TS="; + for (const auto& shapes : inputShapes) { + for (const auto& shape : shapes.second) { + result << ov::test::utils::vec2str(shape); + result << "_"; + } + } + result << "trgDev=" << targetDevice; + result << CPUTestsBase::getTestCaseName(cpuParams); + + return result.str(); +} + +template +void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } +} + +void RMSNormLayerCPUTest::generate_inputs(const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this](std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == element::i32) { + ov::Tensor t{ov::element::i32, shape}; + auto size = shape[0]; + auto* p = static_cast(t.data()); + auto start = static_cast(val); + for (size_t i = 0; i < size; i++) { + p[i] = (start + i) % size; + } + inputs.insert({param, t}); + } else if (param->get_element_type() == element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v, pastkv + create_input(function->get_parameters()[0], targetInputStaticShapes[0], 1.0f); + if (targetInputStaticShapes.size() > 1) + create_input(function->get_parameters()[1], targetInputStaticShapes[1], 0.0f); +} + +void RMSNormLayerCPUTest::SetUp() { + ElementType inType; + CPUSpecificParams cpuParams; + std::vector inputShapes; + std::tie(inType, inputShapes, targetDevice, cpuParams) = this->GetParam(); + + std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; + if (selectedType.empty()) { + selectedType = getPrimitiveType(); + } + + rel_threshold = 0.001f; + if (inType == ElementType::bf16) { + rel_threshold = 2e-2f; + } + selectedType = makeSelectedTypeStr(selectedType, inType); + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // data, axes, scale + auto data = std::make_shared(inType, inputDynamicShapes[0]); + auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, std::vector{-1}); + inputParams.push_back(data); + std::shared_ptr scale; + if (inputDynamicShapes.size() > 1) { + scale = std::make_shared(inType, inputDynamicShapes[1]); + inputParams.push_back(scale); + } + auto rms = scale ? std::make_shared(data, axes, scale, 0.1f) : + std::make_shared(data, axes, 0.1f); + rms->set_friendly_name("rms"); + function = makeNgraphFunction(inType, inputParams, rms, "rms"); +} + +TEST_P(RMSNormLayerCPUTest, CompareWithRefs) { + run(); +} + +namespace RMSNorm { + +} // namespace RMSNorm +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp new file mode 100644 index 00000000000000..da361f35dd7db5 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/rms_norm.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { + +typedef std::tuple, // shape + std::string, // targetDevice + CPUSpecificParams> + RMSNormCPUTestParams; + +class RMSNormLayerCPUTest : public testing::WithParamInterface, + virtual public SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); + +protected: + void SetUp() override; + void generate_inputs(const std::vector& targetInputStaticShapes) override; +}; + +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp new file mode 100644 index 00000000000000..68a37a02d0660e --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/rms_norm.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "custom/single_layer_tests/classes/rms_norm.hpp" +#include "utils/cpu_test_utils.hpp" + +using namespace CPUTestUtils; + +namespace ov { +namespace test { +namespace ScaledAttn { +const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"}; + +const std::vector> shapes{ + // normal + { + // data shape + {ov::test::InputShape{ov::PartialShape{-1, -1, 1024 + 16 + 1}, + {ov::Shape{1, 8, 1024 + 16 + 1}, ov::Shape{2, 3, 1024 + 16 + 1}}} + }, + // scale shape + {ov::test::InputShape{ov::PartialShape{1024 + 16 + 1}, + {ov::Shape{1024 + 16 + 1}, ov::Shape{1024 + 16 + 1}}} + }, + }, + // scale is scalar + { + // data shape + {ov::test::InputShape{ov::PartialShape{-1, -1, 1094}, + {ov::Shape{1, 8, 1094}, ov::Shape{2, 3, 1094}}} + }, + // scale shape + {ov::test::InputShape{ov::PartialShape{1}, + {ov::Shape{1}, ov::Shape{1}}} + }, + }, + // no scale + { + // data shape + {ov::test::InputShape{ov::PartialShape{-1, -1, 1094}, + {ov::Shape{1, 8, 1094}, ov::Shape{2, 3, 1094}}} + }, + } +}; + +const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16), + testing::ValuesIn(shapes), + testing::Values(ov::test::utils::DEVICE_CPU), + testing::Values(cpuSpec)); + +INSTANTIATE_TEST_SUITE_P(smoke_RMSNorm_CPU, + RMSNormLayerCPUTest, + params, + RMSNormLayerCPUTest::getTestCaseName); + +} // namespace ScaledAttn +} // namespace test +} // namespace ov diff --git a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp index 6caa069e0d8288..5c5684cab4f6ca 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/compare_results.cpp @@ -5,6 +5,7 @@ #include #include "openvino/op/ops.hpp" +#include "openvino/op/rms_norm.hpp" #include "ov_ops/augru_cell.hpp" #include "ov_ops/augru_sequence.hpp" @@ -207,6 +208,9 @@ OPENVINO_SUPPRESS_DEPRECATED_START #include "openvino/opsets/opset15_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp" + +_OPENVINO_OP_REG(RMSNorm, ov::op::internal) + #undef _OPENVINO_OP_REG }; OPENVINO_SUPPRESS_DEPRECATED_END diff --git a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp index 171266db31a4b9..13c0674cfc86b0 100644 --- a/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp +++ b/src/tests/functional/shared_test_classes/src/base/utils/generate_inputs.cpp @@ -9,6 +9,7 @@ #include "shared_test_classes/base/utils/generate_inputs.hpp" #include "openvino/op/ops.hpp" +#include "openvino/op/rms_norm.hpp" #include "ov_ops/augru_cell.hpp" #include "ov_ops/augru_sequence.hpp" @@ -1020,6 +1021,9 @@ InputsMap getInputMap() { #include "openvino/opsets/opset15_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp" + +_OPENVINO_OP_REG(RMSNorm, ov::op::internal) + #undef _OPENVINO_OP_REG }; return inputsMap;