Skip to content

Commit

Permalink
[CPU] Enable bf16/fp16 inference precision support for platforms with…
Browse files Browse the repository at this point in the history
… avx2_vnni_2 ISA (#20486)
  • Loading branch information
liubo-intel authored Jan 8, 2024
1 parent 93727d1 commit f0d8269
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 44 deletions.
9 changes: 5 additions & 4 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "openvino/runtime/internal_properties.hpp"
#include "openvino/runtime/properties.hpp"
#include "utils/debug_capabilities.h"
#include "utils/precision_support.h"

#include <algorithm>
#include <map>
Expand Down Expand Up @@ -219,7 +220,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
". Expected only true/false");
}
if (enable) {
if (mayiuse(avx512_core)) {
if (hasHardwareSupport(ov::element::bf16)) {
inferencePrecision = ov::element::bf16;
} else {
OPENVINO_THROW("Platform doesn't support BF16 format");
Expand All @@ -234,12 +235,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
auto const prec = val.as<ov::element::Type>();
inferencePrecisionSetExplicitly = true;
if (prec == ov::element::bf16) {
if (mayiuse(avx512_core)) {
if (hasHardwareSupport(ov::element::bf16)) {
inferencePrecision = ov::element::bf16;
}
} else if (prec == ov::element::f16) {
#if defined(OPENVINO_ARCH_X86_64)
if (mayiuse(avx512_core_fp16) || mayiuse(avx512_core_amx_fp16)) {
if (hasHardwareSupport(ov::element::f16)) {
inferencePrecision = ov::element::f16;
}
#elif defined(OV_CPU_ARM_ENABLE_FP16)
Expand Down Expand Up @@ -398,4 +399,4 @@ void Config::updateProperties() {
}

} // namespace intel_cpu
} // namespace ov
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
public:
jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::bf16) : jit_emitter(host, host_isa, exec_prc) {
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16))
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) &&
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2))
prepare_table();
}

Expand Down Expand Up @@ -55,6 +56,9 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
h->vfixupimmps(aux, in, table_val("selector"), 0);
h->vpsrad(aux, aux, 16);
h->vpmovdw(out, aux);
} else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::cpu_isa_t::avx2_vnni_2)) {
Xmm out = Xmm(out_vec_idxs[0]);
h->vcvtneps2bf16(out, in, PreferredEncoding::VexEncoding);
} else { // round_to_nearest_even emulation
Vmm aux = Vmm(aux_vec_idxs[0]);
Xmm out = Xmm(out_vec_idxs[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
bool is_f16 = (prc == ov::element::f16);
bool is_signed = prc.is_signed();

if (is_f16 && !mayiuse(cpu::x64::avx512_core_fp16))
OPENVINO_THROW("Load emitter in ", name_, " only support fp16 on platform with avx512_core_fp16.");
if (is_f16 && !mayiuse(cpu::x64::avx2))
OPENVINO_THROW("Load emitter in ", name_, " only support fp16 on platform with avx2 or above.");

// Ensure extended double words fit inside Zmm (32/2(num) * 32 <= 512)
// For Ymm register, load capacity is halved (16/2(num) * 32 <= 128)
Expand Down Expand Up @@ -1188,20 +1188,34 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 &reg,
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else if (is_f16) {
if (!mayiuse(cpu::x64::avx512_core_fp16))
OPENVINO_THROW("Store emitter in ", name_, " only support fp16 on platform with avx512_core_fp16.");
// to avoid src vmm pollution
if (src_prc_ == ov::element::f32) {
// since avx512, zmm(fp32) => ymm(fp16)
ymm = Ymm(aux_vec_idxs[0]);
} // in I32 case, zmm&ymm is already in aux reg

h->vcvtps2ph(ymm, zmm, 0x4);
if (store_num == 16) {
h->vmovdqu16(ptr[reg + offset], ymm);
if (mayiuse(cpu::x64::avx512_core)) {
// to avoid src vmm pollution
if (src_prc_ == ov::element::f32) {
// since avx512, zmm(fp32) => ymm(fp16)
ymm = Ymm(aux_vec_idxs[0]);
} // in I32 case, zmm&ymm is already in aux reg

h->vcvtps2ph(ymm, zmm, 0x4);
if (store_num == 16) {
h->vmovdqu16(ptr[reg + offset], ymm);
} else {
data_idx = static_cast<int>(ymm.getIdx());
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else if (mayiuse(cpu::x64::avx2)) {
// to avoid src vmm pollution
if (src_prc_ == ov::element::f32) {
xmm = Xmm(aux_vec_idxs[0]);
}
h->vcvtps2ph(xmm, ymm, 0x4);
if (store_num == 8) {
h->uni_vmovdqu(ptr[reg + offset], xmm);
} else {
data_idx = static_cast<int>(xmm.getIdx());
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else {
data_idx = static_cast<int>(ymm.getIdx());
store_bytes<Vmm>(reg, offset, store_num * 2);
OPENVINO_THROW("Store emitter in ", name_, " only support fp16 on platform with avx512_core or avx2.");
}
} else {
switch (store_num) {
Expand Down
21 changes: 16 additions & 5 deletions src/plugins/intel_cpu/src/nodes/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ ov::element::Type Convolution::fusedEltwisePrecision(const NodePtr& fusingNode)
}

const std::vector<impl_desc_type>& Convolution::getDefaultImplPriority() {
static const std::vector<impl_desc_type> priorities = {
static std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown,
impl_desc_type::dw_acl,
impl_desc_type::winograd_acl,
Expand All @@ -349,6 +349,8 @@ const std::vector<impl_desc_type>& Convolution::getDefaultImplPriority() {
impl_desc_type::jit_avx512_dw,
impl_desc_type::jit_avx512_1x1,
impl_desc_type::jit_avx512,
impl_desc_type::brgconv_avx2_1x1,
impl_desc_type::brgconv_avx2,
impl_desc_type::jit_avx2_dw,
impl_desc_type::jit_avx2_1x1,
impl_desc_type::jit_avx2,
Expand All @@ -369,11 +371,19 @@ const std::vector<impl_desc_type>& Convolution::getDefaultImplPriority() {
impl_desc_type::ref,
};

priorities.erase(std::remove_if(priorities.begin(),
priorities.end(),
[](impl_desc_type type) {
return !isBrgConvAvailable() && (type & impl_desc_type::brgconv);
}),
priorities.end());

return priorities;
}

const bool Convolution::isBrgConvAvailable() {
static const bool isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
static const bool isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2);
return isBrgConvAvailable;
}

Expand Down Expand Up @@ -1634,12 +1644,13 @@ void Convolution::initializeInputZeroPoints(const uint8_t* inputZpData, const si
if (inputZpData[j] != inputZpData[0])
inputZeroPointType = zpType::PerChannel;
}
// Only enable per-tensor zero point on avx512-amx and avx512-core-vnni.
// Only enable per-tensor zero point on avx512-amx and avx512-core-vnni, avx2_vnni_2.
// If zero point is pertensor, both legacy zp and stock zp
// would be passed into conv node. The conv node would determine how to create
// post-ops attribute and prioritize to choose final onednn kernel.
if (inputZeroPointType == zpType::PerTensor &&
(impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_vnni)))
if (inputZeroPointType == zpType::PerTensor && (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx) ||
impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_vnni) ||
impl::cpu::x64::mayiuse(impl::cpu::x64::avx2_vnni_2)))
inputZeroPoints.push_back(static_cast<int32_t>(inputZpData[0]));
else
inputZeroPointType = zpType::PerChannel;
Expand Down
30 changes: 22 additions & 8 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
this, p->entry_[i], vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias));
}

if (mayiuse(avx512_core))
if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2))
uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa));

const auto &jep = jep_;
Expand Down Expand Up @@ -771,11 +771,19 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
uni_vmovss(xmm_src, op);
break;
case ov::element::bf16:
uni_vpinsrw(xmm_src, xmm_src, op, 0);
uni_vpslld(xmm_src, xmm_src, 16);
if (isa == x64::avx2_vnni_2) {
vbcstnebf162ps(xmm_src, op);
} else {
uni_vpinsrw(xmm_src, xmm_src, op, 0);
uni_vpslld(xmm_src, xmm_src, 16);
}
break;
case ov::element::f16:
vcvtph2ps(xmm_src, op);
if (isa == x64::avx2_vnni_2) {
vbcstnesh2ps(xmm_src, op);
} else {
vcvtph2ps(xmm_src, op);
}
break;
case ov::element::i16:
uni_vpinsrw(xmm_src, xmm_src, op, 0);
Expand Down Expand Up @@ -839,8 +847,15 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
uni_vmovups(op, vmm_dst);
break;
case ov::element::bf16:
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
if (isa == x64::avx512_core) {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
} else {
uni_vcvtneps2bf16->emit_code({static_cast<size_t>(vmm_dst.getIdx())},
{static_cast<size_t>(xmm_dst.getIdx())});
uni_vmovdqu(op, xmm_dst);
}
break;
case ov::element::f16:
vcvtps2ph(op, vmm_dst, 0x4);
Expand Down Expand Up @@ -2184,8 +2199,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
if (!fusedWith.empty()) {
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
}

if (!mayiuse(avx512_core)) {
if (!hasHardwareSupport(ov::element::bf16)) {
bool hasBF16 = false;
for (auto &inPrc : inputPrecisions)
if (inPrc == ov::element::bf16)
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ void FullyConnected::getSupportedDescriptors() {
if (one_of(outputDataType , memory::data_type::u8, memory::data_type::s8)) {
outputDataType = memory::data_type::bf16;
}
// TODO: Ticket CVS-122347 - support WeightsDecompression with bf16 inputDataType on avx2_vnni_2
if (useWeightsDecompressionImpl && !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
inputDataType = outputDataType = memory::data_type::f32;
}
} else if (inputDataType == memory::data_type::f16) {
#if defined(OV_CPU_WITH_ACL)
// acl fc does not support precisions conversion
Expand Down
3 changes: 1 addition & 2 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2024,9 +2024,8 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
inputPrecision = ov::element::f32;
}

if ((inputPrecision == ov::element::bf16) && !mayiuse(avx512_core)) {
if (!hasHardwareSupport(inputPrecision))
inputPrecision = ov::element::f32;
}

// support input with rank<=3 only with float precision and planar layout.
// Jit for avx2(gather is available) and ref for no-avx2 machine.
Expand Down
6 changes: 2 additions & 4 deletions src/plugins/intel_cpu/src/nodes/mvn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1829,10 +1829,8 @@ void MVN::initSupportedPrimitiveDescriptors() {

ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(0);
ov::element::Type outputPrecision = getOriginalOutputPrecisionAtPort(0);
if (!mayiuse(avx512_core)) {
if (outputPrecision == ov::element::bf16)
outputPrecision = ov::element::f32;
}
if (!hasHardwareSupport(outputPrecision))
outputPrecision = ov::element::f32;

if (!fusedWith.empty()) {
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
};

// @todo should we always convert to f32 regardless of hardware support, as it is done for f16?
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
if (!hasHardwareSupport(ov::element::bf16))
map.insert({ov::element::bf16, ov::element::f32});
#if defined(OV_CPU_ARM_ENABLE_FP16)
if (inferencePrecision != ov::element::f16)
Expand Down Expand Up @@ -518,9 +518,10 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
using namespace ov::pass::low_precision;
CPU_LPT_SCOPE(LowPrecisionTransformations_Part4);
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "LowPrecisionTransformations");
//Only enable conv/group conv signed input on AMX platform.
// Only enable conv/group conv signed input on AMX and avx2_vnni_2 platform.
std::vector<ov::element::Type> input0LowPrecisionList;
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2)) {
input0LowPrecisionList = {ov::element::u8, ov::element::i8};
} else {
input0LowPrecisionList = {ov::element::u8};
Expand Down
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/utils/precision_support.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ bool hasHardwareSupport(const ov::element::Type& precision) {
switch (precision) {
case ov::element::f16: {
#if defined(OPENVINO_ARCH_X86_64)
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_fp16))
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_fp16) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2))
return true;
return false;
#elif defined(OV_CPU_ARM_ENABLE_FP16)
Expand All @@ -25,7 +26,8 @@ bool hasHardwareSupport(const ov::element::Type& precision) {
}
case ov::element::bf16: {
#if defined(OPENVINO_ARCH_X86_64)
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2))
return true;
return false;
#else
Expand Down

0 comments on commit f0d8269

Please sign in to comment.