Skip to content

Commit

Permalink
refactor mlp op spec; refactor ReduceAdd2bh kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
xczhai committed Nov 25, 2024
1 parent 4b4d8eb commit db91b52
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 56 deletions.
41 changes: 20 additions & 21 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,21 +616,20 @@ void ReduceAdd2bh::generate() {
vmovups(zmm3, ptr[src1 + loop_i * 4 + 16 * 4]);
vaddps(zmm0, zmm0, zmm1);
vaddps(zmm2, zmm2, zmm3);
if (m_out_f32 && m_to_f16) {
if (m_output_type == ov::element::f32) {
vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else if (m_output_type == ov::element::f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else if (m_output_type == ov::element::bf16) {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
} else {
// convert fp32 to fp16 or bf16
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
OPENVINO_THROW("ReduceAdd2hb cannot be generated with precision " + m_output_type.to_string());
}
}
add(loop_i, 32);
Expand All @@ -654,20 +653,20 @@ void ReduceAdd2bh::generate() {
{
vmovups(zmm0, ptr[src0 + loop_i * 4]);
vmovups(zmm2, ptr[src0 + loop_i * 4 + 16 * 4]);
if (m_out_f32 && m_to_f16) {
if (m_output_type == ov::element::f32) {
vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else if (m_output_type == ov::element::f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else if (m_output_type == ov::element::bf16) {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
} else {
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
OPENVINO_THROW("ReduceAdd2hb cannot be generated with precision " + m_output_type.to_string());
}
}
add(loop_i, 32);
Expand Down
36 changes: 21 additions & 15 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,44 +501,50 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(ReduceAdd2bh)

const bool m_do_reduce2;
const bool m_to_f16;
const bool m_out_f32;
ReduceAdd2bh(bool do_reduce2, bool to_f16, bool out_f32 = false) :
jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16), m_out_f32(out_f32) {
const ov::element::Type m_output_type;
ReduceAdd2bh(bool do_reduce2, const ov::element::Type output_type = ov::element::undefined) :
jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_output_type(output_type) {
create_kernel();
}

void generate() override;

// add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1)
void call(float * src0, float * src1, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
if (m_out_f32) {
auto* dst = reinterpret_cast<float*>(pf16_dst);
void call(float * src0, float * src1, size_t src_stride, void * out_dst, size_t dst_stride, int num_rows, int num_cols) {
if (m_output_type == ov::element::f32) {
auto* dst = reinterpret_cast<float*>(out_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
} else {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
} else if (one_of(m_output_type, ov::element::bf16, ov::element::f16)) {
// one_of(m_output_type, ov::element::f16, ov::element::bf16)
auto* dst = reinterpret_cast<int16_t*>(out_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
} else {
OPENVINO_THROW("ReduceAdd2bh call with precision " + m_output_type.to_string());
}
}

// convert tensor to bf16: ConvertFP32toBF16(src0)
void call(float * src0, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, dst, prefetch_dst, num_cols);
if (one_of(m_output_type, ov::element::bf16, ov::element::f16)) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, dst, prefetch_dst, num_cols);
}
} else {
OPENVINO_THROW("ReduceAdd2bh call with precision " + m_output_type.to_string());
}
}
};
Expand Down
24 changes: 14 additions & 10 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ class LinearKsplit2 {
void run(uint8_t* pA, int strideA, int M, U* dstC, int strideC,
const LLMMLPNode::Config& config,
MatrixDynQuantPerRow& src_dq,
float * w_scale) {
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value, config.tail_f32);
float * w_scale,
ov::element::Type output_type) {
static ReduceAdd2bh jit_reduce2cvt(true, output_type);

ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
Expand Down Expand Up @@ -438,6 +439,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
int M = shape_size(ishape) / ishape[ishape.size() - 1];

auto output = m_pnode->getDstMemoryAtPort(0);
auto outPrecision = output->getPrecision();
auto* dstC = output->getDataAs<U>();
const auto& dstStrides = output->getDescWithType<BlockedMemoryDesc>()->getStrides();
int strideC = dstStrides[dstStrides.size() - 2] * sizeof(U);
Expand Down Expand Up @@ -481,7 +483,8 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
down.run(p_up_act, stride_up_act, BM, dstC, strideC,
m_config,
m_quant_up_act,
p_w_scale_down);
p_w_scale_down,
outPrecision);

m += BM;
pA += BM * strideA_in_bytes;
Expand Down Expand Up @@ -519,6 +522,7 @@ void LLMMLP::initSupportedPrimitiveDescriptors() {
std::vector<PortConfigurator> outPortConfigs;

auto rtPrecision = getOriginalInputPrecisionAtPort(0);
auto outPrecision = getOriginalOutputPrecisionAtPort(0);

if (rtPrecision == ov::element::f32) {
// fallback to supported precision if possible
Expand Down Expand Up @@ -546,7 +550,7 @@ void LLMMLP::initSupportedPrimitiveDescriptors() {
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(6), false, -1); // down_weight scales per OC

// initialize output port
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1);
} else {
auto weightPrecision = ov::element::f16;

Expand All @@ -557,31 +561,31 @@ void LLMMLP::initSupportedPrimitiveDescriptors() {
inPortConfigs.emplace_back(LayoutType::ncsp, weightPrecision, getInputShapeAtPort(3), false, -1); // down

// initialize output port
auto outPrecision = m_mlp_config.tail_f32 ? ov::element::f32 : rtPrecision;
outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1);
}
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
}

void LLMMLP::createPrimitive() {
auto rtPrecision = getInputPrecisions()[0];
auto outPrecision = getOutputPrecisions()[0];
#ifdef OPENVINO_ARCH_X86_64
if (rtPrecision == ov::element::bf16) {
if (m_mlp_config.tail_f32) {
if (outPrecision == ov::element::f32) {
m_executor = std::make_shared<Executor<ov::bfloat16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
} else if (outPrecision == ov::element::bf16) {
m_executor = std::make_shared<Executor<ov::bfloat16, ov::bfloat16>>(this, m_mlp_config, context->getScratchPad());
}
} else if (rtPrecision == ov::element::f16) {
if (m_mlp_config.tail_f32) {
if (outPrecision == ov::element::f32) {
m_executor = std::make_shared<Executor<ov::float16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
} else if (outPrecision == ov::element::f16) {
m_executor = std::make_shared<Executor<ov::float16, ov::float16>>(this, m_mlp_config, context->getScratchPad());
}
}
#endif
if (!m_executor) {
OPENVINO_THROW("LLMMLP Executor creation fails with precision " + rtPrecision.to_string());
OPENVINO_THROW("LLMMLP Executor creation fails with runtime precision " + rtPrecision.to_string() + ", output precision " + outPrecision.to_string());
}
}

Expand Down
12 changes: 11 additions & 1 deletion src/plugins/intel_cpu/src/nodes/qkv_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,21 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase {

WeightBuffer wbuffer;

ov::element::Type output_type = ov::element::undefined;

Executor(QKVProjection * pnode, DnnlScratchPadPtr scrachPad) : m_node(pnode), m_scrachPad(scrachPad) {
PlainTensor w0(pnode->getSrcMemoryAtPort(1));
PlainTensor w1(pnode->getSrcMemoryAtPort(2));
PlainTensor w2(pnode->getSrcMemoryAtPort(3));

if (std::is_same<T, ov::float16>::value) {
output_type = ov::element::f16;
} else if (std::is_same<T, ov::bfloat16>::value) {
output_type = ov::element::bf16;
} else {
OPENVINO_THROW("QKVProjection Executor creation fails with output precision " + std::string(typeid(T).name()));
}

// in quantized mode, weights are already quantized in per-OC mode into INT8
// and activations will be dynamically per-token quantized and using AMX-INT8 to get the result
bool quantized_int8 = m_node->m_config.quantized;
Expand Down Expand Up @@ -187,7 +197,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase {
}

void execute() override {
static ReduceAdd2bh jit_cvt(false, std::is_same<T, ov::float16>::value);
static ReduceAdd2bh jit_cvt(false, output_type);

auto input = m_node->getSrcMemoryAtPort(0);
const auto& ishape = input->getStaticDims();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void LLMMLPNode::validate_and_infer_types() {

auto oshape = ishape;
oshape[oshape.size() - 1] = w_down_shape[0];
auto otype = m_config.tail_f32 ? ov::element::f32 : itype;
auto otype = m_output_type == ov::element::undefined ? itype : m_output_type;
set_output_type(0, otype, oshape);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LLMMLPNode : public ov::op::Op {
// 1: gate_proj
// 2: up_proj
// 3: down_proj
LLMMLPNode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) {
LLMMLPNode(const OutputVector& args, const Config& cfg, const ov::element::Type output_type = ov::element::undefined) : Op(args), m_config(cfg), m_output_type(output_type) {
m_args = args;
validate_and_infer_types();
}
Expand All @@ -56,9 +56,14 @@ class LLMMLPNode : public ov::op::Op {
return m_args;
}

ov::element::Type get_output_type() const {
return m_output_type;
}

private:
Config m_config;
OutputVector m_args;
ov::element::Type m_output_type;
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/cpu_opset/x64/op/llm_mlp.hpp"

/*
*/

using namespace ov;
using namespace ov::pass::pattern;

Expand All @@ -41,11 +38,9 @@ intel_cpu::MLPFuseConvert::MLPFuseConvert() {
}

OutputVector args = mlp_node->get_args();
auto cfg = mlp_node->get_config();

cfg.tail_f32 = true;
const auto cfg = mlp_node->get_config();

auto new_mlp = std::make_shared<ov::intel_cpu::LLMMLPNode>(args, cfg);
auto new_mlp = std::make_shared<ov::intel_cpu::LLMMLPNode>(args, cfg, ov::element::f32);

copy_runtime_info(m_cvt, new_mlp);
ov::replace_node(m_cvt, new_mlp);
Expand Down

0 comments on commit db91b52

Please sign in to comment.