Skip to content

Commit

Permalink
init version
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Aug 21, 2024
1 parent 72516c3 commit 7167a1b
Show file tree
Hide file tree
Showing 15 changed files with 957 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"EmbeddingBagOffsets", Type::EmbeddingBagOffsets},
{"LLMMLP", Type::LLMMLP},
{"QKVProjection", Type::QKVProjection},
{"RMSNorm", Type::RMSNorm}
};
return type_to_name_tbl;
}
Expand Down Expand Up @@ -373,6 +374,7 @@ std::string NameFromType(const Type type) {
CASE(CausalMaskPreprocess);
CASE(LLMMLP);
CASE(QKVProjection);
CASE(RMSNorm);
CASE(Unknown);
}
#undef CASE
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ enum class Type {
CausalMaskPreprocess,
LLMMLP,
QKVProjection,
RMSNorm
};

enum class Algorithm {
Expand Down
241 changes: 241 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "rms_kernel.hpp"

using namespace dnnl::impl::cpu::x64;
using namespace Xbyak;

namespace ov {
namespace intel_cpu {
namespace kernel {

#define GET_OFF(field) offsetof(jit_rms_call_args, field)

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_zmm_to_ymm(
const Xmm &acc, const Xmm &tmp) {
const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Ymm ymm_to_acc(tmp.getIdx());
vextractf64x4(ymm_to_acc, zmm_acc, 1);
vaddps(ymm_acc, ymm_acc, ymm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_xmm(
const Xmm &acc, const Xmm &tmp) {
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_to_acc(tmp.getIdx());
vextractf128(xmm_to_acc, ymm_acc, 1);
vaddps(xmm_acc, xmm_acc, xmm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_xmm_to_scalar(const Xmm &acc,
const Xmm &tmp, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_xmm_);

const Xmm xmm_acc(acc.getIdx());
const Xmm ymm_to_acc(tmp.getIdx());

static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1;
static constexpr uint8_t insertps_configuration[number_of_f32_to_move]
= {0b01001110, 0b10001110, 0b11001110};

for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) {
vinsertps(ymm_to_acc, ymm_to_acc, xmm_acc, insertps_configuration[i]);
vaddss(xmm_acc, xmm_acc, ymm_to_acc);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_ymm_);

const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_tmp(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp2.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_ymm_) {
reduce_ymm_to_xmm(ymm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
} else if (number_of_values_to_reduce > number_of_f32_in_xmm_) {
vextractf128(xmm_acc_upper_half, ymm_acc, 1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp,
number_of_values_to_reduce - number_of_f32_in_xmm_);
vaddss(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) {
reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_vmm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_zmm_);

const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Ymm ymm_acc_upper_half(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp1.getIdx());
const Ymm ymm_tmp(tmp2.getIdx());
const Xmm xmm_tmp1(tmp2.getIdx());
const Xmm xmm_tmp2(tmp3.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_zmm_) {
reduce_zmm_to_ymm(zmm_acc, ymm_tmp);
reduce_ymm_to_xmm(ymm_acc, xmm_tmp1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp1);
} else if (number_of_values_to_reduce > number_of_f32_in_ymm_) {
vextractf64x4(ymm_acc_upper_half, zmm_acc, 1);
reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2);
reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2,
number_of_values_to_reduce - number_of_f32_in_ymm_);
vaddps(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) {
reduce_ymm_to_scalar(
ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::generate() {
this->preamble();
mov(reg_src, ptr[abi_param1 + GET_OFF(src)]);
mov(reg_scale, ptr[abi_param1 + GET_OFF(scale)]);
mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]);
uni_vpxor(vmm_sum0, vmm_sum0, vmm_sum0);
uni_vpxor(vmm_sum1, vmm_sum1, vmm_sum1);
uni_vpxor(vmm_sum2, vmm_sum2, vmm_sum2);
uni_vpxor(vmm_sum3, vmm_sum3, vmm_sum3);
mov(reg_src_org, reg_src);

mov(reg_size, m_jcp.data_size / (vec_size * 4));
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
// sum(x^2)
align(16);
Xbyak::Label loop_4reg;
L(loop_4reg);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1);
vfmadd231ps(vmm_sum1, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2);
vfmadd231ps(vmm_sum2, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3);
vfmadd231ps(vmm_sum3, vmm_src, vmm_src);

add(reg_src, vec_size * m_jcp.src_prc.size() * 4);
dec(reg_size);
jnz(loop_4reg);
}
// 1 ~ 3 vmm
for (size_t i = m_jcp.data_size / (vec_size * 4) * 4; i < m_jcp.data_size / vec_size; i++) {
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
add(reg_src, vec_size * m_jcp.src_prc.size());
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
}
vaddps(vmm_sum0, vmm_sum0, vmm_sum1);
vaddps(vmm_sum2, vmm_sum2, vmm_sum3);
vaddps(vmm_rsqrt, vmm_sum0, vmm_sum2);
reduce_vmm_to_scalar(vmm_rsqrt, vmm_sum0, vmm_sum1, vmm_sum3, vec_size);

// mean(x^2)
mov(reg_tmp.cvt32(), float2int(1.0f / m_jcp.data_size));
vmovd(xmm_tmp, reg_tmp.cvt32());
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// mean(x^2)+eps
mov(reg_tmp.cvt32(), float2int(m_jcp.eps));
vmovd(xmm_tmp, reg_tmp.cvt32());
vaddss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// rsqrt(mean(x^2)+eps)
vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt);

// x * rsqrt(mean(x^2)+eps)
if (m_jcp.has_scale && m_jcp.scale_size == 1) {
// rsqrt(mean(x^2)+eps)
vmovd(xmm_tmp, ptr[reg_scale]);
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
}
vbroadcastss(vmm_rsqrt, xmm_rsqrt);
mov(reg_size, m_jcp.data_size / vec_size);
mov(reg_src, reg_src_org);
align(16);
Xbyak::Label loop_mul;
L(loop_mul);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size);

add(reg_src, vec_size * m_jcp.src_prc.size());
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
add(reg_scale, vec_size * sizeof(float));
}
add(reg_dst, vec_size * m_jcp.dst_prc.size());
dec(reg_size);
jnz(loop_mul);
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, m_jcp.data_size % vec_size);
}

this->postamble();
for (const auto& emitter : emitters) {
if (emitter.second)
emitter.second->emit_data();
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset) {
const auto seed = load_emitter_params(src_prc, ov::element::f32, elt_num, fill, "float_min").hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, ov::element::f32, elt_num, ov::element::f32, fill, "float_min"));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), offset}, {static_cast<size_t>(vmm_dst.getIdx())},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset) {
const auto seed = store_emitter_params(ov::element::f32, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, ov::element::f32, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(vmm_src.getIdx()), offset}, {static_cast<size_t>(reg_dst.getIdx())},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template struct jit_rms_kernel<cpu_isa_t::avx512_core>;
template struct jit_rms_kernel<cpu_isa_t::avx2>;

} // namespace kernel
} // namespace intel_cpu
} // namespace ov
89 changes: 89 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "jit_kernel_base.hpp"

#if defined(OPENVINO_ARCH_X86_64)
#include "emitters/plugin/x64/jit_load_store_emitters.hpp"
#endif

namespace ov {
namespace intel_cpu {
namespace kernel {

struct jit_rms_compile_params {
ov::element::Type src_prc;
ov::element::Type dst_prc;
size_t data_size;
float eps;
bool has_scale;
size_t scale_size;
};

struct jit_rms_call_args {
const uint8_t* src;
const float* scale;
uint8_t* dst;
};

#if defined(OPENVINO_ARCH_X86_64)

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
struct jit_rms_kernel : public JitKernel<jit_rms_compile_params, jit_rms_call_args> {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_rms_kernel)

static constexpr size_t vec_size = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen / sizeof(float);

explicit jit_rms_kernel(const jit_rms_compile_params& jcp) : JitKernel(jit_name(), jcp, isa) {}

private:
using Xmm = Xbyak::Xmm;
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41, Xbyak::Xmm,
isa == dnnl::impl::cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;

void generate() override;
void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset = 0);
void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset = 0);

// from onednn
void reduce_zmm_to_ymm(const Xmm &acc, const Xmm &tmp);
void reduce_ymm_to_xmm(const Xmm &acc, const Xmm &tmp);
void reduce_xmm_to_scalar(const Xmm &acc, const Xmm &tmp, const std::size_t number_of_values_to_reduce = number_of_f32_in_xmm_);
void reduce_ymm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const std::size_t number_of_values_to_reduce = number_of_f32_in_ymm_);
void reduce_vmm_to_scalar(const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce = number_of_f32_in_zmm_);
static constexpr std::size_t number_of_f32_in_xmm_ = 4;
static constexpr std::size_t number_of_f32_in_ymm_ = 8;
static constexpr std::size_t number_of_f32_in_zmm_ = 16;

const Vmm vmm_src = Vmm(0);
const Vmm vmm_sum0 = Vmm(2);
const Vmm vmm_rsqrt = Vmm(2);
const Xmm xmm_rsqrt = Xmm(2);
const Vmm vmm_sum1 = Vmm(3);
const Vmm vmm_tmp = Vmm(3);
const Xmm xmm_tmp = Xmm(3);
const Vmm vmm_sum2 = Vmm(4);
const Vmm vmm_sum3 = Vmm(5);
const Vmm vmm_dst = Vmm(6);
const Xbyak::Reg64 reg_src = r8;
const Xbyak::Reg64 reg_src_org = r13;
const Xbyak::Reg64 reg_scale = r10;
const Xbyak::Reg64 reg_size = r11;
const Xbyak::Reg64 reg_dst = r12;
const Xbyak::Reg64 reg_tmp = rdx;

std::unordered_map<size_t, std::unique_ptr<jit_emitter>> emitters;
const std::vector<size_t> pool_aux_gpr_idxs = { static_cast<size_t>(rax.getIdx()), static_cast<size_t>(r9.getIdx()) };
const std::vector<size_t> pool_aux_vmm_idxs = { 7 };
};

#endif

} // namespace kernel
} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 7167a1b

Please sign in to comment.