Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] remove extra convert for fp16 #26755

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,21 @@ void ReduceAdd2bh::generate() {
vmovups(zmm3, ptr[src1 + loop_i * 4 + 16 * 4]);
vaddps(zmm0, zmm0, zmm1);
vaddps(zmm2, zmm2, zmm3);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
if (m_out_f32 && m_to_f16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that m_out_f32 should override m_to_f16 flag? I mean, once a convert to fp32 is fused, it should always store f32 result, right? so if (m_out_f32) should be enough, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that m_out_f32 should override m_to_f16 flag? I mean, once a convert to fp32 is fused, it should always store f32 result, right? so if (m_out_f32) should be enough, right?

@usstq unify this logic. replace m_to_16 with m_output_type to mark possible output precision.

vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
} else {
// convert fp32 to fp16 or bf16
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
}
}
add(loop_i, 32);
Expand All @@ -647,14 +654,20 @@ void ReduceAdd2bh::generate() {
{
vmovups(zmm0, ptr[src0 + loop_i * 4]);
vmovups(zmm2, ptr[src0 + loop_i * 4 + 16 * 4]);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
if (m_out_f32 && m_to_f16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@usstq update as above.

vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
}
}
add(loop_i, 32);
Expand Down
26 changes: 19 additions & 7 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,20 +502,32 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator {

const bool m_do_reduce2;
const bool m_to_f16;
ReduceAdd2bh(bool do_reduce2, bool to_f16) : jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16) {
const bool m_out_f32;
ReduceAdd2bh(bool do_reduce2, bool to_f16, bool out_f32 = false) :
jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16), m_out_f32(out_f32) {
create_kernel();
}

void generate() override;

// add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1)
void call(float * src0, float * src1, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
if (m_out_f32) {
auto* dst = reinterpret_cast<float*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
} else {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
}
}

Expand Down
39 changes: 24 additions & 15 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace node {

#if defined(OPENVINO_ARCH_X86_64)

template<typename T>
template<typename T, typename U>
class LinearKsplit2 {
public:
std::vector<Work> works;
Expand Down Expand Up @@ -119,11 +119,11 @@ class LinearKsplit2 {
DEBUG_LOG(" setup is done. weight @ ", static_cast<void*>(p_weight));
}

void run(uint8_t* pA, int strideA, int M, T* dstC, int strideC,
void run(uint8_t* pA, int strideA, int M, U* dstC, int strideC,
const LLMMLPNode::Config& config,
MatrixDynQuantPerRow& src_dq,
float * w_scale) {
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value);
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value, config.tail_f32);

ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
Expand Down Expand Up @@ -311,7 +311,7 @@ class LinearGateUp {
int m_threads_num = 0;
};

template<typename T>
template<typename T, typename U>
struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
LLMMLP* m_pnode;
const LLMMLPNode::Config m_config;
Expand All @@ -320,7 +320,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
uint8_t* m_scratch_base = nullptr;

LinearGateUp<T> gate_up;
LinearKsplit2<T> down;
LinearKsplit2<T, U> down;
int m_N;
int m_M = 0;

Expand Down Expand Up @@ -438,9 +438,9 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
int M = shape_size(ishape) / ishape[ishape.size() - 1];

auto output = m_pnode->getDstMemoryAtPort(0);
auto* dstC = output->getDataAs<T>();
auto* dstC = output->getDataAs<U>();
const auto& dstStrides = output->getDescWithType<BlockedMemoryDesc>()->getStrides();
int strideC = dstStrides[dstStrides.size() - 2] * sizeof(T);
int strideC = dstStrides[dstStrides.size() - 2] * sizeof(U);

float* p_w_scale_down = nullptr;
if (m_config.down_quantized) {
Expand Down Expand Up @@ -479,21 +479,21 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
}

down.run(p_up_act, stride_up_act, BM, dstC, strideC,
m_config,
m_quant_up_act,
p_w_scale_down);
m_config,
m_quant_up_act,
p_w_scale_down);

m += BM;
pA += BM * strideA_in_bytes;
dstC += BM * strideC / sizeof(T);
dstC += BM * strideC / sizeof(U);
}
}

private:
size_t m_threads_num = 0lu;
};
#else
template<typename T>
template<typename T, typename U>
struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
Executor(LLMMLP* pnode, const LLMMLPNode::Config& config, DnnlScratchPadPtr scrachPad) {}
void execute() {}
Expand Down Expand Up @@ -557,7 +557,8 @@ void LLMMLP::initSupportedPrimitiveDescriptors() {
inPortConfigs.emplace_back(LayoutType::ncsp, weightPrecision, getInputShapeAtPort(3), false, -1); // down

// initialize output port
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
auto outPrecision = m_mlp_config.tail_f32 ? ov::element::f32 : rtPrecision;
outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1);
}
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
}
Expand All @@ -566,9 +567,17 @@ void LLMMLP::createPrimitive() {
auto rtPrecision = getInputPrecisions()[0];
#ifdef OPENVINO_ARCH_X86_64
if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<Executor<ov::bfloat16>>(this, m_mlp_config, context->getScratchPad());
if (m_mlp_config.tail_f32) {
m_executor = std::make_shared<Executor<ov::bfloat16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
m_executor = std::make_shared<Executor<ov::bfloat16, ov::bfloat16>>(this, m_mlp_config, context->getScratchPad());
}
} else if (rtPrecision == ov::element::f16) {
m_executor = std::make_shared<Executor<ov::float16>>(this, m_mlp_config, context->getScratchPad());
if (m_mlp_config.tail_f32) {
m_executor = std::make_shared<Executor<ov::float16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
m_executor = std::make_shared<Executor<ov::float16, ov::float16>>(this, m_mlp_config, context->getScratchPad());
}
}
#endif
if (!m_executor) {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/llm_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LLMMLP : public Node {
virtual ~ExecutorBase() = default;
};
std::shared_ptr<ExecutorBase> m_executor;
template <typename T> struct Executor;
template <typename T, typename U> struct Executor;
LLMMLPNode::Config m_mlp_config;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "fc_convert_fusion.hpp"

#include <utils/general_utils.h>

#include <openvino/core/rt_info.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>

#include "itt.hpp"
#include "transformations/cpu_opset/common/op/fully_connected.hpp"

namespace ov {
namespace intel_cpu {

FcConvertFusion::FcConvertFusion() {
MATCHER_SCOPE(FcConvertFusion);
using namespace ov::pass::pattern;

auto a = any_input();
auto b = any_input();
auto fc = wrap_type<ov::intel_cpu::FullyConnectedNode>({a, b}, consumers_count(1));
auto convert = wrap_type<ov::op::v0::Convert>({fc}, type_matches(ov::element::f32));

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

const auto& m_a = pattern_map.at(a).get_node_shared_ptr();
const auto& m_b = pattern_map.at(b).get_node_shared_ptr();
const auto& m_fc = pattern_map.at(fc).get_node_shared_ptr();
const auto& m_convert = pattern_map.at(convert).get_node_shared_ptr();
auto output_type = m_convert->get_output_element_type(0);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a check here to make sure convert is the only child of fc node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a check here to make sure convert is the only child of fc node.

@usstq update as above.

// auto rank = m_fc->get_output_rank();
auto rank = m_fc->get_output_partial_shape(0).rank();
auto new_fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(m_a, m_b, rank, output_type);

new_fc->set_friendly_name(m_convert->get_friendly_name());
copy_runtime_info(m.get_matched_nodes(), new_fc);
replace_node(m_convert, new_fc);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(convert, matcher_name);
this->register_matcher(m, callback);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <openvino/pass/graph_rewrite.hpp>

namespace ov {
namespace intel_cpu {
class FcConvertFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("FcConvertFusion", "0");
FcConvertFusion();
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "common/pass/convert_to_swish_cpu.hpp"
#include "common/pass/move_fc_reshape_to_weights.hpp"
#include "common/pass/split_fc.hpp"
#include "common/pass/fc_convert_fusion.hpp"
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "common/pass/rnn_sequences_optimization.hpp"
Expand All @@ -33,6 +34,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model) {
manager.set_per_pass_validation(false);
CPU_REGISTER_PASS_COMMON(manager, ConvertMatMulToFC);
CPU_REGISTER_PASS_X64(manager, MoveFCReshapeToWeights);
CPU_REGISTER_PASS_X64(manager, FcConvertFusion);
CPU_REGISTER_PASS_X64(manager, ov::pass::Validate);
CPU_REGISTER_PASS_COMMON(manager, AlignMatMulInputRanks);
CPU_REGISTER_PASS_COMMON(manager, ConvertTileToSeqTiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void LLMMLPNode::validate_and_infer_types() {

auto oshape = ishape;
oshape[oshape.size() - 1] = w_down_shape[0];
set_output_type(0, itype, oshape);
auto otype = m_config.tail_f32 ? ov::element::f32 : itype;
set_output_type(0, otype, oshape);
}

std::shared_ptr<Node> LLMMLPNode::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LLMMLPNode : public ov::op::Op {
int hidden_size;
int up_size;
bool gate_up_combined;
bool tail_f32 = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add ov::element::Type output_type = ov::element::undefined; instead, to be consistent with FullyConnectedNode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add ov::element::Type output_type = ov::element::undefined; instead, to be consistent with FullyConnectedNode

@usstq I see. update the spec

};

// args:
Expand All @@ -33,6 +34,7 @@ class LLMMLPNode : public ov::op::Op {
// 2: up_proj
// 3: down_proj
LLMMLPNode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) {
m_args = args;
validate_and_infer_types();
}

Expand All @@ -46,8 +48,17 @@ class LLMMLPNode : public ov::op::Op {
return m_config;
}

void set_config(const Config& config) {
m_config = config;
}

const OutputVector& get_args() {
return m_args;
}

private:
Config m_config;
OutputVector m_args;
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "mlp_fuse_convert.hpp"

#include <transformations/utils/utils.hpp>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/cpu_opset/x64/op/llm_mlp.hpp"

/*
*/

using namespace ov;
using namespace ov::pass::pattern;

intel_cpu::MLPFuseConvert::MLPFuseConvert() {
MATCHER_SCOPE(MLPFuseConvert);

auto mlp = wrap_type<ov::intel_cpu::LLMMLPNode>();
auto convert = wrap_type<ov::op::v0::Convert>({mlp}, type_matches(ov::element::f32));

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
const auto& m_mlp = pattern_map.at(mlp).get_node_shared_ptr();
const auto& m_cvt = pattern_map.at(convert).get_node_shared_ptr();

auto mlp_node = as_type_ptr<ov::intel_cpu::LLMMLPNode>(m_mlp);
if (!mlp_node) {
return false;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a check here to make sure convert is the only child of mlp node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a check here to make sure convert is the only child of mlp node.

@usstq add has_only_child check

OutputVector args = mlp_node->get_args();
auto cfg = mlp_node->get_config();

cfg.tail_f32 = true;

auto new_mlp = std::make_shared<ov::intel_cpu::LLMMLPNode>(args, cfg);

copy_runtime_info(m_cvt, new_mlp);
ov::replace_node(m_cvt, new_mlp);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(convert, matcher_name);
this->register_matcher(m, callback);
}
Loading
Loading