Skip to content

Commit

Permalink
Merge pull request #18 from xczhai/xc/debug_f16_convert
Browse files Browse the repository at this point in the history
Xc/debug f16 convert
  • Loading branch information
xczhai authored Nov 20, 2024
2 parents 0455e56 + c13dfe1 commit ca8f26b
Show file tree
Hide file tree
Showing 17 changed files with 1,491 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,11 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
if (m_keep_precision_sensitive_in_fp32 && has_fp16_compression) {
pass::Manager manager(get_pass_config(), "KeepPrecisionSensitiveInFP32");
// Mark subgraphs with disable_fp16_compression to keep them in FP32
// manager.register_pass<pass::Serialize>("opt1_1.xml", "");
manager.register_pass<pass::MarkSugraphsToKeepInMixedPrecision>();
// manager.register_pass<pass::Serialize>("opt1_2.xml", "");
manager.register_pass<pass::AlignMixedFP32FP16Types>();
// manager.register_pass<pass::Serialize>("opt1_3.xml", "");
manager.run_passes(f);
}

Expand Down Expand Up @@ -491,8 +494,11 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
// to remove extra converts
if (m_keep_precision_sensitive_in_fp32) {
pass::Manager manager(get_pass_config(), "KeepPrecisionSensitiveInFP32:RemoveConverts");
// manager.register_pass<pass::Serialize>("opt1_4.xml", "");
manager.register_pass<pass::EnableDecompressionConvertConstantFolding>();
// manager.register_pass<pass::Serialize>("opt1_5.xml", "");
manager.register_pass<pass::ConstantFolding>();
// manager.register_pass<pass::Serialize>("opt1_6.xml", "");
manager.run_passes(f);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
for (const auto& input : node->inputs()) {
const auto& incoming_output = input.get_source_output();
const auto& incoming_node = incoming_output.get_node_shared_ptr();
const auto& node_name = incoming_node->get_friendly_name();

if (fp16_compression_is_disabled(incoming_node))
continue; // we are in the middle
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "utils/node_dumper.h"
#include "utils/verbose.h"
#include "utils/precision_support.h"
#include "utils/linux_perf.hpp"

#include <oneapi/dnnl/dnnl.hpp>
#include "common/primitive_desc_iface.hpp"
Expand Down Expand Up @@ -1374,6 +1375,7 @@ void Graph::InferDynamic(SyncInferRequest* request, int numaId, UpdateStrategy&&

for (; inferCounter < stopIndx; ++inferCounter) {
auto& node = m_executableGraphNodes[inferCounter];
auto prof = LinuxPerf::Profile(node->getTypeStr());

ExecuteNodeWithCatch(node, request, numaId);
}
Expand All @@ -1394,6 +1396,7 @@ static int GetNumaNodeId(const GraphContext::CPtr& context) {
void Graph::Infer(SyncInferRequest* request) {
DEBUG_LOG("Infer graph: ", GetName(), ". Status: ", static_cast<int>(status));
const int numaId = GetNumaNodeId(m_context);
auto prof = LinuxPerf::Profile("Graph::Infer");

if (!m_pMemoryControl) {
OPENVINO_THROW("Memory control unit is not initilized in graph: ", GetName());
Expand Down
39 changes: 26 additions & 13 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,21 @@ void ReduceAdd2bh::generate() {
vmovups(zmm3, ptr[src1 + loop_i * 4 + 16 * 4]);
vaddps(zmm0, zmm0, zmm1);
vaddps(zmm2, zmm2, zmm3);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
if (m_out_f32 && m_to_f16) {
vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
} else {
// convert fp32 to fp16 or bf16
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
}
}
add(loop_i, 32);
Expand All @@ -647,14 +654,20 @@ void ReduceAdd2bh::generate() {
{
vmovups(zmm0, ptr[src0 + loop_i * 4]);
vmovups(zmm2, ptr[src0 + loop_i * 4 + 16 * 4]);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
if (m_out_f32 && m_to_f16) {
vmovups(ptr[dst + loop_i * 4], zmm0);
vmovups(ptr[dst + loop_i * 4 + 64], zmm2);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
if (m_to_f16) {
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4);
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
} else {
vcvtne2ps2bf16(zmm4, zmm2, zmm0);
prefetchwt1(ptr[prefetch_dst + loop_i * 2]);
vmovups(ptr[dst + loop_i * 2], zmm4);
}
}
}
add(loop_i, 32);
Expand Down
26 changes: 19 additions & 7 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,20 +502,32 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator {

const bool m_do_reduce2;
const bool m_to_f16;
ReduceAdd2bh(bool do_reduce2, bool to_f16) : jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16) {
const bool m_out_f32;
ReduceAdd2bh(bool do_reduce2, bool to_f16, bool out_f32 = false) :
jit_generator(jit_name()), m_do_reduce2(do_reduce2), m_to_f16(to_f16), m_out_f32(out_f32) {
create_kernel();
}

void generate() override;

// add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1)
void call(float * src0, float * src1, size_t src_stride, void * pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
if (m_out_f32) {
auto* dst = reinterpret_cast<float*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
} else {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
}
}
}

Expand Down
42 changes: 27 additions & 15 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace node {

#if defined(OPENVINO_ARCH_X86_64)

template<typename T>
template<typename T, typename U>
class LinearKsplit2 {
public:
std::vector<Work> works;
Expand Down Expand Up @@ -119,11 +119,12 @@ class LinearKsplit2 {
DEBUG_LOG(" setup is done. weight @ ", static_cast<void*>(p_weight));
}

void run(uint8_t* pA, int strideA, int M, T* dstC, int strideC,
void run(uint8_t* pA, int strideA, int M, U* dstC, int strideC,
const LLMMLPNode::Config& config,
MatrixDynQuantPerRow& src_dq,
float * w_scale) {
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value);
// static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value);
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value, config.tail_f32);

ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
Expand Down Expand Up @@ -311,7 +312,7 @@ class LinearGateUp {
int m_threads_num = 0;
};

template<typename T>
template<typename T, typename U>
struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
LLMMLP* m_pnode;
const LLMMLPNode::Config m_config;
Expand All @@ -320,7 +321,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
uint8_t* m_scratch_base = nullptr;

LinearGateUp<T> gate_up;
LinearKsplit2<T> down;
LinearKsplit2<T, U> down;
int m_N;
int m_M = 0;

Expand Down Expand Up @@ -438,9 +439,11 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
int M = shape_size(ishape) / ishape[ishape.size() - 1];

auto output = m_pnode->getDstMemoryAtPort(0);
auto* dstC = output->getDataAs<T>();
auto out_prec = output->getPrecision();
// need to cast to target precision
auto* dstC = output->getDataAs<U>();
const auto& dstStrides = output->getDescWithType<BlockedMemoryDesc>()->getStrides();
int strideC = dstStrides[dstStrides.size() - 2] * sizeof(T);
int strideC = dstStrides[dstStrides.size() - 2] * sizeof(U);

float* p_w_scale_down = nullptr;
if (m_config.down_quantized) {
Expand Down Expand Up @@ -479,21 +482,21 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
}

down.run(p_up_act, stride_up_act, BM, dstC, strideC,
m_config,
m_quant_up_act,
p_w_scale_down);
m_config,
m_quant_up_act,
p_w_scale_down);

m += BM;
pA += BM * strideA_in_bytes;
dstC += BM * strideC / sizeof(T);
dstC += BM * strideC / sizeof(U);
}
}

private:
size_t m_threads_num = 0lu;
};
#else
template<typename T>
template<typename T, typename U>
struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
Executor(LLMMLP* pnode, const LLMMLPNode::Config& config, DnnlScratchPadPtr scrachPad) {}
void execute() {}
Expand Down Expand Up @@ -557,7 +560,8 @@ void LLMMLP::initSupportedPrimitiveDescriptors() {
inPortConfigs.emplace_back(LayoutType::ncsp, weightPrecision, getInputShapeAtPort(3), false, -1); // down

// initialize output port
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
auto outPrecision = m_mlp_config.tail_f32 ? ov::element::f32 : rtPrecision;
outPortConfigs.emplace_back(LayoutType::ncsp, outPrecision, getOutputShapeAtPort(0), false, -1);
}
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
}
Expand All @@ -566,9 +570,17 @@ void LLMMLP::createPrimitive() {
auto rtPrecision = getInputPrecisions()[0];
#ifdef OPENVINO_ARCH_X86_64
if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<Executor<ov::bfloat16>>(this, m_mlp_config, context->getScratchPad());
if (m_mlp_config.tail_f32) {
m_executor = std::make_shared<Executor<ov::bfloat16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
m_executor = std::make_shared<Executor<ov::bfloat16, ov::bfloat16>>(this, m_mlp_config, context->getScratchPad());
}
} else if (rtPrecision == ov::element::f16) {
m_executor = std::make_shared<Executor<ov::float16>>(this, m_mlp_config, context->getScratchPad());
if (m_mlp_config.tail_f32) {
m_executor = std::make_shared<Executor<ov::float16, float>>(this, m_mlp_config, context->getScratchPad());
} else {
m_executor = std::make_shared<Executor<ov::float16, ov::float16>>(this, m_mlp_config, context->getScratchPad());
}
}
#endif
if (!m_executor) {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/llm_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LLMMLP : public Node {
virtual ~ExecutorBase() = default;
};
std::shared_ptr<ExecutorBase> m_executor;
template <typename T> struct Executor;
template <typename T, typename U> struct Executor;
LLMMLPNode::Config m_mlp_config;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "fc_convert_fusion.hpp"

#include <utils/general_utils.h>

#include <openvino/core/rt_info.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>

#include "itt.hpp"
#include "transformations/cpu_opset/common/op/fully_connected.hpp"

namespace ov {
namespace intel_cpu {

FcConvertFusion::FcConvertFusion() {
MATCHER_SCOPE(FcConvertFusion);
using namespace ov::pass::pattern;

auto a = any_input();
auto b = any_input();
auto fc = wrap_type<ov::intel_cpu::FullyConnectedNode>({a, b}, consumers_count(1));
auto convert = wrap_type<ov::op::v0::Convert>({fc}, type_matches(ov::element::f32));

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

const auto& m_a = pattern_map.at(a).get_node_shared_ptr();
const auto& m_b = pattern_map.at(b).get_node_shared_ptr();
const auto& m_fc = pattern_map.at(fc).get_node_shared_ptr();

const auto& m_convert = pattern_map.at(convert).get_node_shared_ptr();
auto output_type = m_convert->get_output_element_type(0);

// auto rank = m_fc->get_output_rank();
auto rank = m_fc->get_output_partial_shape(0).rank();
auto new_fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(m_a, m_b, rank, output_type);

new_fc->set_friendly_name(m_convert->get_friendly_name());
copy_runtime_info(m.get_matched_nodes(), new_fc);
replace_node(m_convert, new_fc);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(convert, matcher_name);
this->register_matcher(m, callback);
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <openvino/pass/graph_rewrite.hpp>

namespace ov {
namespace intel_cpu {
class FcConvertFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("FcConvertFusion", "0");
FcConvertFusion();
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "common/pass/convert_to_swish_cpu.hpp"
#include "common/pass/move_fc_reshape_to_weights.hpp"
#include "common/pass/split_fc.hpp"
#include "common/pass/fc_convert_fusion.hpp"
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "common/pass/rnn_sequences_optimization.hpp"
Expand All @@ -33,6 +34,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model) {
manager.set_per_pass_validation(false);
CPU_REGISTER_PASS_COMMON(manager, ConvertMatMulToFC);
CPU_REGISTER_PASS_X64(manager, MoveFCReshapeToWeights);
CPU_REGISTER_PASS_X64(manager, FcConvertFusion);
CPU_REGISTER_PASS_X64(manager, ov::pass::Validate);
CPU_REGISTER_PASS_COMMON(manager, AlignMatMulInputRanks);
CPU_REGISTER_PASS_COMMON(manager, ConvertTileToSeqTiles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void LLMMLPNode::validate_and_infer_types() {

auto oshape = ishape;
oshape[oshape.size() - 1] = w_down_shape[0];
set_output_type(0, itype, oshape);
auto otype = m_config.tail_f32 ? ov::element::f32 : itype;
set_output_type(0, otype, oshape);
}

std::shared_ptr<Node> LLMMLPNode::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand Down
Loading

0 comments on commit ca8f26b

Please sign in to comment.