Skip to content

Commit

Permalink
op definition uses ov::RMS instead of RMSNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Aug 22, 2024
1 parent e700e54 commit 0d504cc
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
_OPENVINO_OP_REG(RMS, ov::op::internal)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "openvino/op/reduce_mean.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/rms.hpp"
#include "transformations/utils/utils.hpp"
Expand Down Expand Up @@ -68,7 +69,9 @@ RMSFusion::RMSFusion() {
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

// compress RMS result
auto comp = wrap_type<ov::op::v0::Convert>({mul2});
auto convert = wrap_type<ov::op::v0::Convert>({mul2});

auto comp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul2, convert});

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"EmbeddingBagOffsets", Type::EmbeddingBagOffsets},
{"LLMMLP", Type::LLMMLP},
{"QKVProjection", Type::QKVProjection},
{"RMSNorm", Type::RMSNorm}
{"RMS", Type::RMS}
};
return type_to_name_tbl;
}
Expand Down Expand Up @@ -374,7 +374,7 @@ std::string NameFromType(const Type type) {
CASE(CausalMaskPreprocess);
CASE(LLMMLP);
CASE(QKVProjection);
CASE(RMSNorm);
CASE(RMS);
CASE(Unknown);
}
#undef CASE
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ enum class Type {
CausalMaskPreprocess,
LLMMLP,
QKVProjection,
RMSNorm
RMS
};

enum class Algorithm {
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void jit_rms_kernel<isa>::generate() {
vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt);

// x * rsqrt(mean(x^2)+eps)
if (m_jcp.has_scale && m_jcp.scale_size == 1) {
if (m_jcp.scale_size == 1) {
// rsqrt(mean(x^2)+eps)
vmovd(xmm_tmp, ptr[reg_scale]);
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
Expand All @@ -181,14 +181,14 @@ void jit_rms_kernel<isa>::generate() {
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size);

add(reg_src, vec_size * m_jcp.src_prc.size());
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
if (m_jcp.scale_size != 1) {
add(reg_scale, vec_size * sizeof(float));
}
add(reg_dst, vec_size * m_jcp.dst_prc.size());
Expand All @@ -199,7 +199,7 @@ void jit_rms_kernel<isa>::generate() {
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.has_scale && m_jcp.scale_size != 1) {
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
Expand Down
1 change: 0 additions & 1 deletion src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ struct jit_rms_compile_params {
ov::element::Type dst_prc;
size_t data_size;
float eps;
bool has_scale;
size_t scale_size;
};

Expand Down
75 changes: 25 additions & 50 deletions src/plugins/intel_cpu/src/nodes/rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#include "onednn/dnnl.h"
#include "openvino/core/parallel.hpp"
#include "openvino/util/common_util.hpp"
#include "ov_ops/rms.hpp"
#include "shape_inference/custom/rms_norm.hpp"
#include "openvino/op/rms_norm.hpp"
#include "openvino/opsets/opset6.hpp"
#include "kernels/x64/rms_kernel.hpp"

Expand Down Expand Up @@ -97,22 +97,19 @@ static void execJitKernel(const std::shared_ptr<kernel::JitKernelBase>& ker, con
}

struct RMSNorm::RMSNormExecutor : public RMSNorm::Executor {
RMSNormExecutor(ov::element::Type precision, size_t data_size, size_t scale_size, float eps, bool has_scale) : m_precision(precision) {
RMSNormExecutor(ov::element::Type precision, size_t data_size, size_t scale_size, float eps) : m_precision(precision) {
jit_rms_compile_params jcp;
jcp.src_prc = precision;
jcp.dst_prc = precision;
jcp.data_size = data_size;
jcp.scale_size = scale_size;
jcp.eps = eps;
jcp.has_scale = has_scale;
m_kernel = createJitKernel(jcp);
}
void execute(const std::vector<MemoryPtr>& inputs, const MemoryPtr output) override {
auto src = inputs[0]->getDataAs<uint8_t>();
auto dst = output->getDataAs<uint8_t>();
float* scale = nullptr;
if (inputs.size() > 2)
scale = inputs[2]->getDataAs<float>();
float* scale = inputs[1]->getDataAs<float>();

const auto& src_strides = inputs[0]->getDescWithType<BlockedMemoryDesc>()->getStrides();
const auto& dst_strides = output->getDescWithType<BlockedMemoryDesc>()->getStrides();
Expand All @@ -136,9 +133,8 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr c
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW("CPU: " + errorMessage);
}
const auto rms = std::dynamic_pointer_cast<const ov::op::internal::RMSNorm>(op);
const auto rms = std::dynamic_pointer_cast<const ov::op::internal::RMS>(op);
m_eps = static_cast<float>(rms->get_epsilon());
m_has_scale = op->get_input_size() > 2;
}

void RMSNorm::initSupportedPrimitiveDescriptors() {
Expand All @@ -151,38 +147,26 @@ void RMSNorm::initSupportedPrimitiveDescriptors() {
impl_type = impl_desc_type::jit_avx512;
} else if (mayiuse(cpu::x64::avx2)) {
impl_type = impl_desc_type::jit_avx2;
} else if (mayiuse(cpu::x64::sse41)) {
impl_type = impl_desc_type::jit_sse42;
} else {
impl_type = impl_desc_type::ref;
}

if (m_has_scale) {
addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::f32}},
{{LayoutType::ncsp, precision}},
impl_type);
} else {
addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::i32}},
{{LayoutType::ncsp, precision}},
impl_type);
}
addSupportedPrimDesc({{LayoutType::ncsp, precision}, {LayoutType::ncsp, ov::element::f32}},
{{LayoutType::ncsp, precision}},
impl_type);
}

void RMSNorm::createPrimitive() {
auto precision = getOriginalInputPrecisionAtPort(0);
auto data_dims = getSrcMemoryAtPort(0)->getDescWithType<BlockedMemoryDesc>()->getBlockDims();
auto has_scale = getOriginalInputsNumber() > 2;
size_t data_size = data_dims[data_dims.size() - 1];
size_t scale_size = 0;
if (has_scale) {
scale_size = getSrcMemoryAtPort(2)->getDescWithType<BlockedMemoryDesc>()->getBlockDims()[0];
}
size_t scale_size = shape_size(getSrcMemoryAtPort(1)->getDescWithType<BlockedMemoryDesc>()->getBlockDims());

RMSNormKey key = {precision, data_size, scale_size, static_cast<size_t>(dnnl::impl::float2int(m_eps))};

auto builder = [&](const RMSNormKey& key) -> std::shared_ptr<RMSNormExecutor> {
#ifdef OPENVINO_ARCH_X86_64
return std::make_shared<RMSNormExecutor>(precision, data_size, scale_size, m_eps, has_scale);
return std::make_shared<RMSNormExecutor>(precision, data_size, scale_size, m_eps);
#else
return nullptr;
#endif
Expand All @@ -209,8 +193,12 @@ void RMSNorm::execute(dnnl::stream strm) {

bool RMSNorm::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
const auto rms = std::dynamic_pointer_cast<const ov::op::internal::RMSNorm>(op);
const auto rms = std::dynamic_pointer_cast<const ov::op::internal::RMS>(op);
if (rms) {
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
errorMessage = "RMSNorm needs avx2+.";
return false;
}
// check the last dimension of data
auto data_pshape = op->input_value(0).get_partial_shape();
if (data_pshape.rank().is_dynamic()) {
Expand All @@ -226,34 +214,21 @@ bool RMSNorm::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, st
errorMessage = "RMSNorm data rank must be greater than 1.";
return false;
}
// check axes
auto axes_op = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(1));
if (!axes_op) {
errorMessage = "RMSNorm axes is expected as Constant.";
return false;
}
// axes should be 1d or scalar in spec
auto axes_vals = axes_op->cast_vector<int>();
if (axes_vals[0] != -1 && axes_vals[0] != data_rank - 1) {
errorMessage = "RMSNorm axes must be the last dimension.";
return false;
}

// check scale
if (op->get_input_size() > 2) {
if (op->get_input_partial_shape(2).rank().get_length() > 1) {
errorMessage = "RMSNorm scale must be 1D or scalar.";
return false;
}
if (op->get_input_partial_shape(2).is_dynamic()) {
errorMessage = "RMSNorm scale shape is not static.";
return false;
}
}
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
errorMessage = "RMSNorm needs avx2+.";
if (op->get_input_partial_shape(1).is_dynamic()) {
errorMessage = "RMSNorm scale shape is not static.";
return false;
}
auto scale_pshape = op->get_input_partial_shape(1);
if (scale_pshape.rank().get_length() > 1) {
for (int64_t i = 0; i < scale_pshape.rank().get_length() - 1; i++) {
if (scale_pshape[i] != 1) {
errorMessage = "RMSNorm scale shape must be [1,..., N].";
return false;
}
}
}
} else {
errorMessage = "Only RMSNorm operation is supported";
return false;
Expand Down
3 changes: 1 addition & 2 deletions src/plugins/intel_cpu/src/nodes/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class RMSNorm : public Node {

void getSupportedDescriptors() override {}
bool created() const override {
return getType() == Type::RMSNorm;
return getType() == Type::RMS;
}
bool needPrepareParams() const override {
return false;
Expand All @@ -41,7 +41,6 @@ class RMSNorm : public Node {
friend struct RMSNormKey;

float m_eps = 0.0f;
bool m_has_scale = false;
};

} // namespace node
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ Node::NodesFactory::NodesFactory() : Factory("NodesFactory") {
INTEL_CPU_NODE(QKVProjection, Type::QKVProjection);
INTEL_CPU_NODE(MHA, Type::MHA);
INTEL_CPU_NODE(PagedAttention, Type::PagedAttention);
INTEL_CPU_NODE(RMSNorm, Type::RMSNorm);
INTEL_CPU_NODE(RMSNorm, Type::RMS);
#endif
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2020-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "decompose_rms_norm.hpp"
#include "itt.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/core/rt_info.hpp"
#include "ov_ops/rms.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

namespace ov {
namespace intel_cpu {

DecomposeRMSNorm::DecomposeRMSNorm() {
MATCHER_SCOPE(DecomposeRMSNorm);
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::internal::RMS>();

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto node = std::dynamic_pointer_cast<ov::op::internal::RMS>(
pattern_to_output.at(pattern_node).get_node_shared_ptr());

if (node == nullptr || transformation_callback(node)) {
return false;
}
auto data = node->get_input_node_shared_ptr(0);
auto data_precision = node->get_input_element_type(0);
auto scale = node->get_input_node_shared_ptr(1);

auto power_const = ov::opset10::Constant::create(data_precision, {}, std::vector<float>{2.f});
auto power = std::make_shared<ov::opset10::Power>(data, power_const);
auto mean_axes = ov::opset10::Constant::create(ov::element::i32, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::opset10::ReduceMean>(power, mean_axes, true);
auto eps = ov::opset10::Constant::create(data_precision, {}, {node->get_epsilon()});
auto add_eps = std::make_shared<ov::opset10::Add>(mean, eps);
auto sqrt = std::make_shared<ov::opset10::Sqrt>(add_eps);
auto div_const = ov::opset10::Constant::create(data_precision, {}, {-1});
auto div = std::make_shared<ov::opset10::Power>(sqrt, div_const);
auto mul1 = std::make_shared<ov::opset10::Multiply>(data, div);
auto mul2 = std::make_shared<ov::opset10::Multiply>(scale, mul1);

ov::replace_node(node, mul2);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, matcher_name);
register_matcher(m, callback);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (C) 2020-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/graph_rewrite.hpp"

namespace ov {
namespace intel_cpu {

class DecomposeRMSNorm: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("DecomposeRMSNorm", "0");
DecomposeRMSNorm();
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "defs.hpp"

// Operations
#include "openvino/op/constant.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset2.hpp"
#include "openvino/opsets/opset3.hpp"
Expand Down Expand Up @@ -130,6 +129,7 @@
#include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp"
#include "transformations/cpu_opset/arm/pass/convert_reduce_no_keep_dims.hpp"
#include "transformations/cpu_opset/common/pass/decompose_integer_divide.hpp"
#include "transformations/cpu_opset/common/pass/decompose_rms_norm.hpp"
#include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
#include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp"
#include "transformations/cpu_opset/common/pass/ngram_fusion.hpp"
Expand Down Expand Up @@ -856,12 +856,13 @@ void Transformations::PostLpt() {
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);
CPU_SET_CALLBACK_X64(postLPTPassManager,
[](const std::shared_ptr<const ov::Node>& node) -> bool {
std::string errorMsg;
return node::RMSNorm::isSupportedOperation(node, errorMsg);
},
ov::pass::RMSFusion);
ov::intel_cpu::DecomposeRMSNorm);

// markup Rope Input when BF16/F16 inference.
if (one_of(inferencePrecision, ov::element::bf16, ov::element::f16))
Expand Down
Loading

0 comments on commit 0d504cc

Please sign in to comment.