Skip to content

Commit

Permalink
loose fp16 support limitation of jit_load_store_emitters to avx512_co…
Browse files Browse the repository at this point in the history
…re and avx2
  • Loading branch information
liubo-intel committed Dec 26, 2023
1 parent 74d0202 commit f45f9f8
Showing 1 changed file with 65 additions and 73 deletions.
138 changes: 65 additions & 73 deletions src/plugins/intel_cpu/src/emitters/x64/jit_load_store_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,8 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
bool is_f16 = (prc == ov::element::f16);
bool is_signed = prc.is_signed();

if (is_f16 && !mayiuse(cpu::x64::avx512_core_fp16) && !mayiuse(cpu::x64::avx2_vnni_2))
OPENVINO_THROW("Load emitter in ",
name_,
" only support fp16 on platform with avx512_core_fp16 or avx2_vnni_2.");
if (is_f16 && !mayiuse(cpu::x64::avx512_core) && !mayiuse(cpu::x64::avx2))
OPENVINO_THROW("Load emitter in ", name_, " only support fp16 on platform with avx512_core or avx2.");

// Ensure extended double words fit inside Zmm (32/2(num) * 32 <= 512)
// For Ymm register, load capacity is halved (16/2(num) * 32 <= 128)
Expand All @@ -504,87 +502,82 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak
// For load_size == 32/16/8, do load/extension in one go
// including xmm/ymm tail block for ymm/zmm, so explicite xmm/ymm/zmm
switch (load_size) {
case 32: {
// needed here?
if (!is_zmm)
IE_THROW() << "Load emitter in " << name_
<< " has unexpected number of values(32) to load to non-zmm in load_words_to_dword_extension.";
if (is_bf16) {
h->uni_vpmovzxwd(zmm, ptr[reg + offset]);
h->uni_vpslld(zmm, zmm, 16);
} else if (is_f16) {
h->vcvtph2ps(zmm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(zmm, ptr[reg + offset]);
else
case 32: {
if (is_bf16) {
h->uni_vpmovzxwd(zmm, ptr[reg + offset]);
h->uni_vpslld(zmm, zmm, 16);
} else if (is_f16) {
h->vcvtph2ps(zmm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(zmm, ptr[reg + offset]);
else
h->uni_vpmovzxwd(zmm, ptr[reg + offset]);
}
break;
}
break;
}
case 16: {
if (is_bf16) {
h->uni_vpmovzxwd(ymm, ptr[reg + offset]);
h->uni_vpslld(ymm, ymm, 16);

} else if (is_f16) {
h->vcvtph2ps(ymm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(ymm, ptr[reg + offset]);
else
h->uni_vpmovzxwd(ymm, ptr[reg + offset]);
}
break;
}
case 8: {
if (is_bf16) {
h->uni_vpmovzxwd(xmm, ptr[reg + offset]);
h->uni_vpslld(xmm, xmm, 16);
} else if (is_f16) {
h->vcvtph2ps(xmm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(xmm, ptr[reg + offset]);
else
h->uni_vpmovzxwd(xmm, ptr[reg + offset]);
}
break;
}
default: {
if (is_zmm && load_size > threshold_for_mask_emu_load) {
unsigned int mask = 1;
mask = (mask << (load_size / 2)) - mask;
h->mov(Reg32(aux_gpr_idxs[0]), mask);
h->kmovw(k_mask, Reg32(aux_gpr_idxs[0]));
case 16: {
if (is_bf16) {
h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]);
h->uni_vpslld(vmm, vmm, 16);
h->uni_vpmovzxwd(ymm, ptr[reg + offset]);
h->uni_vpslld(ymm, ymm, 16);
} else if (is_f16) {
h->vcvtph2ps(vmm | k_mask | T_z, ptr[reg + offset]);
h->vcvtph2ps(ymm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(vmm | k_mask | T_z, ptr[reg + offset]);
h->uni_vpmovsxwd(ymm, ptr[reg + offset]);
else
h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]);
h->uni_vpmovzxwd(ymm, ptr[reg + offset]);
}
} else {
// xmm or ymm version
load_bytes(xmm, reg, offset, load_size);
break;
}
case 8: {
if (is_bf16) {
h->uni_vpmovzxwd(vmm, xmm);
h->uni_vpslld(vmm, vmm, 16);
h->uni_vpmovzxwd(xmm, ptr[reg + offset]);
h->uni_vpslld(xmm, xmm, 16);
} else if (is_f16) {
h->vcvtph2ps(ymm, xmm);
h->vcvtph2ps(xmm, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(vmm, xmm);
h->uni_vpmovsxwd(xmm, ptr[reg + offset]);
else
h->uni_vpmovzxwd(xmm, ptr[reg + offset]);
}
break;
}
default: {
if (is_zmm && load_size > threshold_for_mask_emu_load) {
unsigned int mask = 1;
mask = (mask << (load_size / 2)) - mask;
h->mov(Reg32(aux_gpr_idxs[0]), mask);
h->kmovw(k_mask, Reg32(aux_gpr_idxs[0]));
if (is_bf16) {
h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]);
h->uni_vpslld(vmm, vmm, 16);
} else if (is_f16) {
h->vcvtph2ps(vmm | k_mask | T_z, ptr[reg + offset]);
} else {
if (is_signed)
h->uni_vpmovsxwd(vmm | k_mask | T_z, ptr[reg + offset]);
else
h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]);
}
} else {
// xmm or ymm version
load_bytes(xmm, reg, offset, load_size);
if (is_bf16) {
h->uni_vpmovzxwd(vmm, xmm);
h->uni_vpslld(vmm, vmm, 16);
} else if (is_f16) {
h->vcvtph2ps(ymm, xmm);
} else {
if (is_signed)
h->uni_vpmovsxwd(vmm, xmm);
else
h->uni_vpmovzxwd(vmm, xmm);
}
}
break;
}
break;
}
}
}

Expand Down Expand Up @@ -1197,7 +1190,7 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 &reg,
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else if (is_f16) {
if (mayiuse(cpu::x64::avx512_core_fp16)) {
if (mayiuse(cpu::x64::avx512_core)) {
// to avoid src vmm pollution
if (src_prc_ == ov::element::f32) {
// since avx512, zmm(fp32) => ymm(fp16)
Expand All @@ -1211,7 +1204,7 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 &reg,
data_idx = static_cast<int>(ymm.getIdx());
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else if (mayiuse(cpu::x64::avx2_vnni_2)) {
} else if (mayiuse(cpu::x64::avx2)) {
// to avoid src vmm pollution
if (src_prc_ == ov::element::f32) {
xmm = Xmm(aux_vec_idxs[0]);
Expand All @@ -1224,8 +1217,7 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 &reg,
store_bytes<Vmm>(reg, offset, store_num * 2);
}
} else {
IE_THROW() << "Store emitter in " << name_
<< " only support fp16 on platform with avx512_core_fp16 or avx2_vnni_2.";
IE_THROW() << "Store emitter in " << name_ << " only support fp16 on platform with avx512_core or avx2.";
}
} else {
switch (store_num) {
Expand Down

0 comments on commit f45f9f8

Please sign in to comment.