-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] remove extra convert for fp16 #26755
base: master
Are you sure you want to change the base?
Changes from 11 commits
36578f0
41eccac
b94ac6d
7796c19
b0caf24
1720dc9
fd99fce
c13dfe1
0455e56
ca8f26b
5901ec9
85f3d9e
4b4d8eb
688d26b
aa4c1bb
10b2492
5e33466
0dba6ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -616,14 +616,21 @@ void ReduceAdd2bh::generate() { | |
vmovups(zmm3, ptr[src1 + loop_i * 4 + 16 * 4]); | ||
vaddps(zmm0, zmm0, zmm1); | ||
vaddps(zmm2, zmm2, zmm3); | ||
if (m_to_f16) { | ||
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); | ||
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); | ||
if (m_out_f32 && m_to_f16) { | ||
vmovups(ptr[dst + loop_i * 4], zmm0); | ||
vmovups(ptr[dst + loop_i * 4 + 64], zmm2); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
} else { | ||
vcvtne2ps2bf16(zmm4, zmm2, zmm0); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
vmovups(ptr[dst + loop_i * 2], zmm4); | ||
} else { | ||
// convert fp32 to fp16 or bf16 | ||
if (m_to_f16) { | ||
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); | ||
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
} else { | ||
vcvtne2ps2bf16(zmm4, zmm2, zmm0); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
vmovups(ptr[dst + loop_i * 2], zmm4); | ||
} | ||
} | ||
} | ||
add(loop_i, 32); | ||
|
@@ -647,14 +654,20 @@ void ReduceAdd2bh::generate() { | |
{ | ||
vmovups(zmm0, ptr[src0 + loop_i * 4]); | ||
vmovups(zmm2, ptr[src0 + loop_i * 4 + 16 * 4]); | ||
if (m_to_f16) { | ||
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); | ||
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); | ||
if (m_out_f32 && m_to_f16) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@usstq update as above. |
||
vmovups(ptr[dst + loop_i * 4], zmm0); | ||
vmovups(ptr[dst + loop_i * 4 + 64], zmm2); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
} else { | ||
vcvtne2ps2bf16(zmm4, zmm2, zmm0); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
vmovups(ptr[dst + loop_i * 2], zmm4); | ||
if (m_to_f16) { | ||
vcvtps2ph(ptr[dst + loop_i * 2], zmm0, 0x4); | ||
vcvtps2ph(ptr[dst + loop_i * 2 + 32], zmm2, 0x4); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
} else { | ||
vcvtne2ps2bf16(zmm4, zmm2, zmm0); | ||
prefetchwt1(ptr[prefetch_dst + loop_i * 2]); | ||
vmovups(ptr[dst + loop_i * 2], zmm4); | ||
} | ||
} | ||
} | ||
add(loop_i, 32); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "fc_convert_fusion.hpp" | ||
|
||
#include <utils/general_utils.h> | ||
|
||
#include <openvino/core/rt_info.hpp> | ||
#include <openvino/pass/pattern/op/wrap_type.hpp> | ||
#include <transformations/utils/utils.hpp> | ||
|
||
#include "itt.hpp" | ||
#include "transformations/cpu_opset/common/op/fully_connected.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
|
||
FcConvertFusion::FcConvertFusion() { | ||
MATCHER_SCOPE(FcConvertFusion); | ||
using namespace ov::pass::pattern; | ||
|
||
auto a = any_input(); | ||
auto b = any_input(); | ||
auto fc = wrap_type<ov::intel_cpu::FullyConnectedNode>({a, b}, consumers_count(1)); | ||
auto convert = wrap_type<ov::op::v0::Convert>({fc}, type_matches(ov::element::f32)); | ||
|
||
ov::matcher_pass_callback callback = [=](Matcher& m) { | ||
const auto& pattern_map = m.get_pattern_value_map(); | ||
|
||
const auto& m_a = pattern_map.at(a).get_node_shared_ptr(); | ||
const auto& m_b = pattern_map.at(b).get_node_shared_ptr(); | ||
const auto& m_fc = pattern_map.at(fc).get_node_shared_ptr(); | ||
const auto& m_convert = pattern_map.at(convert).get_node_shared_ptr(); | ||
auto output_type = m_convert->get_output_element_type(0); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a check here to make sure convert is the only child of fc node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@usstq update as above. |
||
// auto rank = m_fc->get_output_rank(); | ||
auto rank = m_fc->get_output_partial_shape(0).rank(); | ||
auto new_fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(m_a, m_b, rank, output_type); | ||
|
||
new_fc->set_friendly_name(m_convert->get_friendly_name()); | ||
copy_runtime_info(m.get_matched_nodes(), new_fc); | ||
replace_node(m_convert, new_fc); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(convert, matcher_name); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <openvino/pass/graph_rewrite.hpp> | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
class FcConvertFusion : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("FcConvertFusion", "0"); | ||
FcConvertFusion(); | ||
}; | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ class LLMMLPNode : public ov::op::Op { | |
int hidden_size; | ||
int up_size; | ||
bool gate_up_combined; | ||
bool tail_f32 = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@usstq I see. update the spec |
||
}; | ||
|
||
// args: | ||
|
@@ -33,6 +34,7 @@ class LLMMLPNode : public ov::op::Op { | |
// 2: up_proj | ||
// 3: down_proj | ||
LLMMLPNode(const OutputVector& args, const Config& cfg) : Op(args), m_config(cfg) { | ||
m_args = args; | ||
validate_and_infer_types(); | ||
} | ||
|
||
|
@@ -46,8 +48,17 @@ class LLMMLPNode : public ov::op::Op { | |
return m_config; | ||
} | ||
|
||
void set_config(const Config& config) { | ||
m_config = config; | ||
} | ||
|
||
const OutputVector& get_args() { | ||
return m_args; | ||
} | ||
|
||
private: | ||
Config m_config; | ||
OutputVector m_args; | ||
}; | ||
|
||
} // namespace intel_cpu | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "mlp_fuse_convert.hpp" | ||
|
||
#include <transformations/utils/utils.hpp> | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
#include "transformations/cpu_opset/x64/op/llm_mlp.hpp" | ||
|
||
/* | ||
*/ | ||
|
||
using namespace ov; | ||
using namespace ov::pass::pattern; | ||
|
||
intel_cpu::MLPFuseConvert::MLPFuseConvert() { | ||
MATCHER_SCOPE(MLPFuseConvert); | ||
|
||
auto mlp = wrap_type<ov::intel_cpu::LLMMLPNode>(); | ||
auto convert = wrap_type<ov::op::v0::Convert>({mlp}, type_matches(ov::element::f32)); | ||
|
||
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pass::pattern::Matcher& m) { | ||
auto& pattern_map = m.get_pattern_value_map(); | ||
const auto& m_mlp = pattern_map.at(mlp).get_node_shared_ptr(); | ||
const auto& m_cvt = pattern_map.at(convert).get_node_shared_ptr(); | ||
|
||
auto mlp_node = as_type_ptr<ov::intel_cpu::LLMMLPNode>(m_mlp); | ||
if (!mlp_node) { | ||
return false; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add a check here to make sure convert is the only child of mlp node. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@usstq add |
||
OutputVector args = mlp_node->get_args(); | ||
auto cfg = mlp_node->get_config(); | ||
|
||
cfg.tail_f32 = true; | ||
|
||
auto new_mlp = std::make_shared<ov::intel_cpu::LLMMLPNode>(args, cfg); | ||
|
||
copy_runtime_info(m_cvt, new_mlp); | ||
ov::replace_node(m_cvt, new_mlp); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ov::pass::pattern::Matcher>(convert, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that
m_out_f32
should overridem_to_f16
flag? I mean, once a convert to fp32 is fused, it should always store f32 result, right? soif (m_out_f32)
should be enough, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@usstq unify this logic. replace
m_to_16
withm_output_type
to mark possible output precision.