Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Add RMSNorm jit implementation #26147

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
_OPENVINO_OP_REG(RMS, ov::op::internal)
luo-cheng2021 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace pass {
class RMSFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RMSFusion", "0");
RMSFusion();
RMSFusion(bool force_tail_convert = true);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static std::function<bool(ov::Output<ov::Node>)> constant_value(const float targ
};
}

RMSFusion::RMSFusion() {
RMSFusion::RMSFusion(bool force_tail_convert) {
using namespace ov::pass::pattern;

// Detect RMS decomposition pattern
Expand Down Expand Up @@ -67,8 +67,11 @@ RMSFusion::RMSFusion() {
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

// compress RMS result
auto comp = wrap_type<ov::op::v0::Convert>({mul2});
std::shared_ptr<ov::Node> comp = mul2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a requirement for this PR, but still worth to mention. ConvertPrecision(FP32->FP16) pass keeps normalization subgraph in higher precision to maintain the accuracy, which results in additinal Convert op (FP32->Fp16) in the end of pattern. Ideally even for CPU (with fp16 infer prec) we will need to fuse such Conversion into the RMS for better performance. Basically two possible solutions:

  1. Match two different patterns (with and w/o Convert) in boumds of RMSFusion transformation
  2. Fuse RMS+Convert in separate transformation (ideally it should be generic pass suitable for any parent op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

if (force_tail_convert) {
// compress RMS result
comp = wrap_type<ov::op::v0::Convert>({mul2});
}
usstq marked this conversation as resolved.
Show resolved Hide resolved

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,36 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) {
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}

// no convert at the end of the subgraph
TEST_F(TransformationTestsF, RMSNormFusionTest6) {
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto power_const = ov::opset10::Constant::create(ov::element::f32, {}, {2.f});
auto power = std::make_shared<ov::opset10::Power>(input, power_const);
auto mean_axes = ov::opset10::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::opset10::ReduceMean>(power, mean_axes, true);
auto eps = ov::opset10::Constant::create(ov::element::f32, {}, {1e-5f});
auto add_eps = std::make_shared<ov::opset10::Add>(mean, eps);
auto sqrt = std::make_shared<ov::opset10::Sqrt>(add_eps);
auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {-1});
auto div = std::make_shared<ov::opset10::Power>(sqrt, div_const);
auto mul1 = std::make_shared<ov::opset10::Multiply>(input, div);
auto gamma = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto mul2 = std::make_shared<ov::opset10::Multiply>(gamma, mul1);

model = std::make_shared<ov::Model>(ov::NodeVector{mul2}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>(false);
}
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto rms_const = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-5f);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"EmbeddingBagOffsets", Type::EmbeddingBagOffsets},
{"LLMMLP", Type::LLMMLP},
{"QKVProjection", Type::QKVProjection},
{"RMS", Type::RMS}
};
return type_to_name_tbl;
}
Expand Down Expand Up @@ -373,6 +374,7 @@ std::string NameFromType(const Type type) {
CASE(CausalMaskPreprocess);
CASE(LLMMLP);
CASE(QKVProjection);
CASE(RMS);
CASE(Unknown);
}
#undef CASE
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ enum class Type {
CausalMaskPreprocess,
LLMMLP,
QKVProjection,
RMS
};

enum class Algorithm {
Expand Down
241 changes: 241 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "rms_kernel.hpp"

using namespace dnnl::impl::cpu::x64;
using namespace Xbyak;

namespace ov {
namespace intel_cpu {
namespace kernel {

#define GET_OFF(field) offsetof(jit_rms_call_args, field)

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_zmm_to_ymm(
const Xmm &acc, const Xmm &tmp) {
const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Ymm ymm_to_acc(tmp.getIdx());
vextractf64x4(ymm_to_acc, zmm_acc, 1);
vaddps(ymm_acc, ymm_acc, ymm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_xmm(
const Xmm &acc, const Xmm &tmp) {
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_to_acc(tmp.getIdx());
vextractf128(xmm_to_acc, ymm_acc, 1);
vaddps(xmm_acc, xmm_acc, xmm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_xmm_to_scalar(const Xmm &acc,
const Xmm &tmp, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_xmm_);

const Xmm xmm_acc(acc.getIdx());
const Xmm ymm_to_acc(tmp.getIdx());

static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1;
static constexpr uint8_t insertps_configuration[number_of_f32_to_move]
= {0b01001110, 0b10001110, 0b11001110};

for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) {
vinsertps(ymm_to_acc, ymm_to_acc, xmm_acc, insertps_configuration[i]);
vaddss(xmm_acc, xmm_acc, ymm_to_acc);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_ymm_);

const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_tmp(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp2.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_ymm_) {
reduce_ymm_to_xmm(ymm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
} else if (number_of_values_to_reduce > number_of_f32_in_xmm_) {
vextractf128(xmm_acc_upper_half, ymm_acc, 1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp,
number_of_values_to_reduce - number_of_f32_in_xmm_);
vaddss(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) {
reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_vmm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_zmm_);

const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Ymm ymm_acc_upper_half(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp1.getIdx());
const Ymm ymm_tmp(tmp2.getIdx());
const Xmm xmm_tmp1(tmp2.getIdx());
const Xmm xmm_tmp2(tmp3.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_zmm_) {
reduce_zmm_to_ymm(zmm_acc, ymm_tmp);
reduce_ymm_to_xmm(ymm_acc, xmm_tmp1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp1);
} else if (number_of_values_to_reduce > number_of_f32_in_ymm_) {
vextractf64x4(ymm_acc_upper_half, zmm_acc, 1);
reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2);
reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2,
number_of_values_to_reduce - number_of_f32_in_ymm_);
vaddps(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) {
reduce_ymm_to_scalar(
ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::generate() {
this->preamble();
mov(reg_src, ptr[abi_param1 + GET_OFF(src)]);
mov(reg_scale, ptr[abi_param1 + GET_OFF(scale)]);
mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]);
uni_vpxor(vmm_sum0, vmm_sum0, vmm_sum0);
uni_vpxor(vmm_sum1, vmm_sum1, vmm_sum1);
uni_vpxor(vmm_sum2, vmm_sum2, vmm_sum2);
uni_vpxor(vmm_sum3, vmm_sum3, vmm_sum3);
mov(reg_src_org, reg_src);

mov(reg_size, m_jcp.data_size / (vec_size * 4));
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
// sum(x^2)
align(16);
Xbyak::Label loop_4reg;
L(loop_4reg);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1);
vfmadd231ps(vmm_sum1, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2);
vfmadd231ps(vmm_sum2, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3);
vfmadd231ps(vmm_sum3, vmm_src, vmm_src);

add(reg_src, vec_size * m_jcp.src_prc.size() * 4);
dec(reg_size);
jnz(loop_4reg);
}
// 1 ~ 3 vmm
for (size_t i = m_jcp.data_size / (vec_size * 4) * 4; i < m_jcp.data_size / vec_size; i++) {
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
add(reg_src, vec_size * m_jcp.src_prc.size());
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
}
vaddps(vmm_sum0, vmm_sum0, vmm_sum1);
vaddps(vmm_sum2, vmm_sum2, vmm_sum3);
vaddps(vmm_rsqrt, vmm_sum0, vmm_sum2);
reduce_vmm_to_scalar(vmm_rsqrt, vmm_sum0, vmm_sum1, vmm_sum3, vec_size);

// mean(x^2)
mov(reg_tmp.cvt32(), float2int(1.0f / m_jcp.data_size));
vmovd(xmm_tmp, reg_tmp.cvt32());
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// mean(x^2)+eps
mov(reg_tmp.cvt32(), float2int(m_jcp.eps));
vmovd(xmm_tmp, reg_tmp.cvt32());
vaddss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// rsqrt(mean(x^2)+eps)
vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt);

// x * rsqrt(mean(x^2)+eps)
if (m_jcp.scale_size == 1) {
// rsqrt(mean(x^2)+eps)
vmovd(xmm_tmp, ptr[reg_scale]);
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
}
vbroadcastss(vmm_rsqrt, xmm_rsqrt);
mov(reg_size, m_jcp.data_size / vec_size);
mov(reg_src, reg_src_org);
align(16);
Xbyak::Label loop_mul;
L(loop_mul);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size);

add(reg_src, vec_size * m_jcp.src_prc.size());
if (m_jcp.scale_size != 1) {
add(reg_scale, vec_size * sizeof(float));
}
add(reg_dst, vec_size * m_jcp.dst_prc.size());
dec(reg_size);
jnz(loop_mul);
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, m_jcp.data_size % vec_size);
}

this->postamble();
for (const auto& emitter : emitters) {
if (emitter.second)
emitter.second->emit_data();
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset) {
const auto seed = load_emitter_params(src_prc, ov::element::f32, elt_num, fill, "float_min").hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, ov::element::f32, elt_num, ov::element::f32, fill, "float_min"));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), offset}, {static_cast<size_t>(vmm_dst.getIdx())},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset) {
const auto seed = store_emitter_params(ov::element::f32, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, ov::element::f32, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(vmm_src.getIdx()), offset}, {static_cast<size_t>(reg_dst.getIdx())},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template struct jit_rms_kernel<cpu_isa_t::avx512_core>;
template struct jit_rms_kernel<cpu_isa_t::avx2>;

} // namespace kernel
} // namespace intel_cpu
} // namespace ov
Loading
Loading