Skip to content

Commit

Permalink
Added skipping tests for non-AVX512, fixed Select
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 10, 2023
1 parent ebc1eac commit 85e3a20
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
22 changes: 16 additions & 6 deletions src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2152,8 +2152,12 @@ jit_select_emitter::jit_select_emitter(x64::jit_generator *host, x64::cpu_isa_t
size_t jit_select_emitter::get_inputs_num() const { return 3; }

size_t jit_select_emitter::aux_vecs_count() const {
// mask should be xmm0 on sse41
return host_isa_ == x64::sse41 ? 1 : 0;
if (host_isa_ == x64::avx512_core)
return 0;
else if (host_isa_ == x64::avx2) // tmp vec for mask
return 1;
else // mask should be xmm0 on sse41 + tmp vec for mask
return 2;
}

void jit_select_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
Expand All @@ -2179,15 +2183,21 @@ void jit_select_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

if (isa == x64::sse41) {
Vmm vmm_aux = Vmm(aux_vec_idxs[0]);
if (vmm_aux.getIdx() != vmm_cond.getIdx()) {
h->uni_vmovups(vmm_aux, vmm_cond);
Vmm vmm_mask = Vmm(aux_vec_idxs[0]);
Vmm vmm_zero = Vmm(aux_vec_idxs[1]);
h->uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
h->uni_vcmpps(vmm_cond, vmm_cond, vmm_zero, 0x4);
if (vmm_mask.getIdx() != vmm_cond.getIdx()) {
h->uni_vmovups(vmm_mask, vmm_cond);
}
if (vmm_src1.getIdx() != vmm_dst.getIdx()) {
h->uni_vmovups(vmm_dst, vmm_src1);
}
h->uni_vblendvps(vmm_dst, vmm_dst, vmm_src0, vmm_aux);
h->uni_vblendvps(vmm_dst, vmm_dst, vmm_src0, vmm_mask);
} else if (isa == x64::avx2) {
Vmm vmm_zero = Vmm(aux_vec_idxs[0]);
h->uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
h->uni_vcmpps(vmm_cond, vmm_cond, vmm_zero, 0x4);
h->uni_vblendvps(vmm_dst, vmm_src1, vmm_src0, vmm_cond);
} else {
h->vptestmd(k_mask, vmm_cond, vmm_cond);
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/transformation_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,15 @@ void Transformations::MainSnippets(void) {
snippetsManager.register_pass<SnippetsMarkSkipped>();
snippetsManager.register_pass<ngraph::snippets::pass::SnippetsTokenization>();

if (enableBF16) {
// TODO: Need to add BF16 support for MHA in Snippets
const bool isMHASupported =
!enableBF16 && // TODO: Need to add BF16 support for MHA in Snippets
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // MHA has BRGEMM that is supported only on AVX512 platforms
if (!isMHASupported) {
snippetsManager.get_pass_config()->disable<ngraph::snippets::pass::TokenizeMHASnippets>();
}
if (snippetsMode != Config::SnippetsMode::IgnoreCallback) {
snippetsManager.get_pass_config()->set_callback<ngraph::snippets::pass::TokenizeMHASnippets>(
[](const std::shared_ptr<const ov::Node>& n) -> bool {
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))
return true;
const auto pshape = n->get_output_partial_shape(0);
const auto shape = pshape.get_shape();
const auto parallel_work_amount =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ std::vector<std::string> disabledTestPatterns() {
// tests are useless on such platforms
retVector.emplace_back(R"(.*(BF|bf)16.*)");
retVector.emplace_back(R"(.*bfloat16.*)");
// MatMul in Snippets uses BRGEMM that is supported only on AVX512 platforms
// Disabled Snippets MHA tests as well because MHA pattern contains MatMul
retVector.emplace_back(R"(.*Snippets.*MHA.*)");
retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)");
}
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
//TODO: Issue 92895
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
};
std::vector<element::Type> precisions{element::f32};
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul,
::testing::Combine(
::testing::Combine(
::testing::ValuesIn(input_shapes),
::testing::ValuesIn(precisions),
::testing::Values(1), // MatMu;
Expand Down

0 comments on commit 85e3a20

Please sign in to comment.