diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp index a7babe7db2fc23..9f4f1d03fcea13 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp @@ -616,21 +616,20 @@ void ReduceAdd2bh::generate() { vmovups(zmm3, ptr[src1 + loop_i * 4 + 16 * 4]); vaddps(zmm0, zmm0, zmm1); vaddps(zmm2, zmm2, zmm3); - if (m_out_f32 && m_to_f16) { + if (m_output_type == ov::element::f32) { vmovups(ptr[dst + loop_i * 4], zmm0); vmovups(ptr[dst + loop_i * 4 + 64], zmm2); prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + } else if (m_output_type == ov::element::f16) { + vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); + vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); + prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + } else if (m_output_type == ov::element::bf16) { + vcvtne2ps2bf16(zmm4, zmm2, zmm0); + prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + vmovups(ptr[dst + loop_i * 2], zmm4); } else { - // convert fp32 to fp16 or bf16 - if (m_to_f16) { - vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); - vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); - prefetchwt1(ptr[prefetch_dst + loop_i * 2]); - } else { - vcvtne2ps2bf16(zmm4, zmm2, zmm0); - prefetchwt1(ptr[prefetch_dst + loop_i * 2]); - vmovups(ptr[dst + loop_i * 2], zmm4); - } + OPENVINO_THROW("ReduceAdd2hb cannot be generated with precision " + m_output_type.to_string()); } } add(loop_i, 32); @@ -654,20 +653,20 @@ void ReduceAdd2bh::generate() { { vmovups(zmm0, ptr[src0 + loop_i * 4]); vmovups(zmm2, ptr[src0 + loop_i * 4 + 16 * 4]); - if (m_out_f32 && m_to_f16) { + if (m_output_type == ov::element::f32) { vmovups(ptr[dst + loop_i * 4], zmm0); vmovups(ptr[dst + loop_i * 4 + 64], zmm2); prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + } else if (m_output_type == ov::element::f16) { + vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); + vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); + prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + } else if (m_output_type == ov::element::bf16) { + vcvtne2ps2bf16(zmm4, zmm2, zmm0); + prefetchwt1(ptr[prefetch_dst + loop_i * 2]); + vmovups(ptr[dst + loop_i * 2], zmm4); } else { - if (m_to_f16) { - vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); - vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); - prefetchwt1(ptr[prefetch_dst + loop_i * 2]); - } else { - vcvtne2ps2bf16(zmm4, zmm2, zmm0); - prefetchwt1(ptr[prefetch_dst + loop_i * 2]); - vmovups(ptr[dst + loop_i * 2], zmm4); - } + OPENVINO_THROW("ReduceAdd2hb cannot be generated with precision " + m_output_type.to_string()); } } add(loop_i, 32); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp index 8f7b5b604ce3d3..a2b919ba6c8de5 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp @@ -501,44 +501,50 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(ReduceAdd2bh) const bool m_do_reduce2; - const bool m_to_f16; - const bool m_out_f32; - ReduceAdd2bh(bool do_reduce2, bool to_f16, bool out_f32 = false) : - jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16), m_out_f32(out_f32) { + const ov::element::Type m_output_type; + ReduceAdd2bh(bool do_reduce2, const ov::element::Type output_type = ov::element::undefined) : + jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_output_type(output_type) { create_kernel(); } void generate() override; // add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1) - void call(float * src0, float * src1, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - if (m_out_f32) { - auto* dst = reinterpret_cast(pf16_dst); + void call(float * src0, float * src1, size_t src_stride, void * out_dst, size_t dst_stride, int num_rows, int num_cols) { + if (m_output_type == ov::element::f32) { + auto* dst = reinterpret_cast(out_dst); for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); (*this)(src0, src1, dst, prefetch_dst, num_cols); } - } else { - auto* dst = reinterpret_cast(pf16_dst); + } else if (one_of(m_output_type, ov::element::bf16, ov::element::f16)) { + // one_of(m_output_type, ov::element::f16, ov::element::bf16) + auto* dst = reinterpret_cast(out_dst); for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); (*this)(src0, src1, dst, prefetch_dst, num_cols); } + } else { + OPENVINO_THROW("ReduceAdd2bh call with precision " + m_output_type.to_string()); } } // convert tensor to bf16: ConvertFP32toBF16(src0) void call(float * src0, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - auto* dst = reinterpret_cast(pf16_dst); - for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) { - // the prefetch distance is increased to ensure by the time store happens - // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); - (*this)(src0, dst, prefetch_dst, num_cols); + if (one_of(m_output_type, ov::element::bf16, ov::element::f16)) { + auto* dst = reinterpret_cast(pf16_dst); + for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) { + // the prefetch distance is increased to ensure by the time store happens + // prefetch has done and no HW prefetcher is triggered + auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); + (*this)(src0, dst, prefetch_dst, num_cols); + } + } else { + OPENVINO_THROW("ReduceAdd2bh call with precision " + m_output_type.to_string()); } } }; diff --git a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp index 6c9a214be32236..0184e51ee61f85 100644 --- a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp +++ b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp @@ -122,8 +122,9 @@ class LinearKsplit2 { void run(uint8_t* pA, int strideA, int M, U* dstC, int strideC, const LLMMLPNode::Config& config, MatrixDynQuantPerRow& src_dq, - float * w_scale) { - static ReduceAdd2bh jit_reduce2cvt(true, std::is_same::value, config.tail_f32); + float * w_scale, + ov::element::Type output_type) { + static ReduceAdd2bh jit_reduce2cvt(true, output_type); ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; @@ -438,6 +439,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase { int M = shape_size(ishape) / ishape[ishape.size() - 1]; auto output = m_pnode->getDstMemoryAtPort(0); + auto outPrecision = output->getPrecision(); auto* dstC = output->getDataAs(); const auto& dstStrides = output->getDescWithType()->getStrides(); int strideC = dstStrides[dstStrides.size() - 2] * sizeof(U); @@ -481,7 +483,8 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase { down.run(p_up_act, stride_up_act, BM, dstC, strideC, m_config, m_quant_up_act, - p_w_scale_down); + p_w_scale_down, + outPrecision); m += BM; pA += BM * strideA_in_bytes; @@ -519,6 +522,7 @@ void LLMMLP::initSupportedPrimitiveDescriptors() { std::vector outPortConfigs; auto rtPrecision = getOriginalInputPrecisionAtPort(0); + auto outPrecision = getOriginalOutputPrecisionAtPort(0); if (rtPrecision == ov::element::f32) { // fallback to supported precision if possible @@ -546,7 +550,7 @@ void LLMMLP::initSupportedPrimitiveDescriptors() { inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(6), false, -1); // down_weight scales per OC // initialize output port - outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1); + outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1); } else { auto weightPrecision = ov::element::f16; @@ -557,7 +561,6 @@ void LLMMLP::initSupportedPrimitiveDescriptors() { inPortConfigs.emplace_back(LayoutType::ncsp, weightPrecision, getInputShapeAtPort(3), false, -1); // down // initialize output port - auto outPrecision = m_mlp_config.tail_f32 ? ov::element::f32 : rtPrecision; outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1); } addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); @@ -565,23 +568,24 @@ void LLMMLP::initSupportedPrimitiveDescriptors() { void LLMMLP::createPrimitive() { auto rtPrecision = getInputPrecisions()[0]; + auto outPrecision = getOutputPrecisions()[0]; #ifdef OPENVINO_ARCH_X86_64 if (rtPrecision == ov::element::bf16) { - if (m_mlp_config.tail_f32) { + if (outPrecision == ov::element::f32) { m_executor = std::make_shared>(this, m_mlp_config, context->getScratchPad()); - } else { + } else if (outPrecision == ov::element::bf16) { m_executor = std::make_shared>(this, m_mlp_config, context->getScratchPad()); } } else if (rtPrecision == ov::element::f16) { - if (m_mlp_config.tail_f32) { + if (outPrecision == ov::element::f32) { m_executor = std::make_shared>(this, m_mlp_config, context->getScratchPad()); - } else { + } else if (outPrecision == ov::element::f16) { m_executor = std::make_shared>(this, m_mlp_config, context->getScratchPad()); } } #endif if (!m_executor) { - OPENVINO_THROW("LLMMLP Executor creation fails with precision " + rtPrecision.to_string()); + OPENVINO_THROW("LLMMLP Executor creation fails with runtime precision " + rtPrecision.to_string() + ", output precision " + outPrecision.to_string()); } } diff --git a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp index 00c8b6f9b17c0b..c1e3d9ee1335c2 100644 --- a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp @@ -66,11 +66,21 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { WeightBuffer wbuffer; + ov::element::Type output_type = ov::element::undefined; + Executor(QKVProjection * pnode, DnnlScratchPadPtr scrachPad) : m_node(pnode), m_scrachPad(scrachPad) { PlainTensor w0(pnode->getSrcMemoryAtPort(1)); PlainTensor w1(pnode->getSrcMemoryAtPort(2)); PlainTensor w2(pnode->getSrcMemoryAtPort(3)); + if (std::is_same::value) { + output_type = ov::element::f16; + } else if (std::is_same::value) { + output_type = ov::element::bf16; + } else { + OPENVINO_THROW("QKVProjection Executor creation fails with output precision " + std::string(typeid(T).name())); + } + // in quantized mode, weights are already quantized in per-OC mode into INT8 // and activations will be dynamically per-token quantized and using AMX-INT8 to get the result bool quantized_int8 = m_node->m_config.quantized; @@ -187,7 +197,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { } void execute() override { - static ReduceAdd2bh jit_cvt(false, std::is_same::value); + static ReduceAdd2bh jit_cvt(false, output_type); auto input = m_node->getSrcMemoryAtPort(0); const auto& ishape = input->getStaticDims(); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp index a5f5bf7bc67183..1fa191ad7af97b 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.cpp @@ -56,7 +56,7 @@ void LLMMLPNode::validate_and_infer_types() { auto oshape = ishape; oshape[oshape.size() - 1] = w_down_shape[0]; - auto otype = m_config.tail_f32 ? ov::element::f32 : itype; + auto otype = m_output_type == ov::element::undefined ? itype : m_output_type; set_output_type(0, otype, oshape); } diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp index a7d7566902defc..4f8cb8f66d1ae4 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/op/llm_mlp.hpp @@ -33,7 +33,7 @@ class LLMMLPNode : public ov::op::Op { // 1: gate_proj // 2: up_proj // 3: down_proj - LLMMLPNode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { + LLMMLPNode(const OutputVector& args, const Config& cfg, const ov::element::Type output_type = ov::element::undefined) : Op(args), m_config(cfg), m_output_type(output_type) { m_args = args; validate_and_infer_types(); } @@ -56,9 +56,14 @@ class LLMMLPNode : public ov::op::Op { return m_args; } + ov::element::Type get_output_type() const { + return m_output_type; + } + private: Config m_config; OutputVector m_args; + ov::element::Type m_output_type; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fuse_convert.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fuse_convert.cpp index 53c11ef4945e04..741cdb788e3f21 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fuse_convert.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/x64/pass/mlp_fuse_convert.cpp @@ -12,9 +12,6 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/cpu_opset/x64/op/llm_mlp.hpp" -/* - */ - using namespace ov; using namespace ov::pass::pattern; @@ -41,11 +38,9 @@ intel_cpu::MLPFuseConvert::MLPFuseConvert() { } OutputVector args = mlp_node->get_args(); - auto cfg = mlp_node->get_config(); - - cfg.tail_f32 = true; + const auto cfg = mlp_node->get_config(); - auto new_mlp = std::make_shared(args, cfg); + auto new_mlp = std::make_shared(args, cfg, ov::element::f32); copy_runtime_info(m_cvt, new_mlp); ov::replace_node(m_cvt, new_mlp);