From 7d636017edb77b0a2a89ada9fb47aaa287edd89c Mon Sep 17 00:00:00 2001 From: xuchen-intel Date: Sat, 17 Jun 2023 02:54:18 +0800 Subject: [PATCH] Apply review comments #2 --- .../src/emitters/x64/cpu_generator.cpp | 4 +- .../emitters/x64/jit_snippets_emitters.cpp | 58 ++++++++----------- .../emitters/x64/jit_snippets_emitters.hpp | 24 ++------ 3 files changed, 30 insertions(+), 56 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/x64/cpu_generator.cpp index 1244fac99ad2fb..6d776ab57ebd6b 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/cpu_generator.cpp @@ -139,8 +139,8 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ jitters[ngraph::op::v7::Gelu::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_gelu_v7_emitter); jitters[snippets::op::Fill::get_type_info_static()] = CREATE_EMITTER(FillEmitter); - jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_EMITTER(HorizonMaxEmitter); - jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_EMITTER(HorizonSumEmitter); + jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_EMITTER(HorizonEmitter); + jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_EMITTER(HorizonEmitter); jitters[snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter); jitters[snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter); diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index bb6ae5b3bb9ea5..7b7b33cda26504 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -1335,7 +1335,13 @@ void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, con } HorizonEmitter::HorizonEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : - jit_emitter(h, isa, n, Precision::FP32, emitter_in_out_map::vec_to_vec) {} + jit_emitter(h, isa, n, Precision::FP32, emitter_in_out_map::vec_to_vec) { + if (ov::as_type_ptr(n)) { + m_op_type = OpType::max; + } else if (ov::as_type_ptr(n)) { + m_op_type = OpType::sum; + } +} void HorizonEmitter::emit_impl(const std::vector& in, const std::vector& out) const { @@ -1365,49 +1371,33 @@ void HorizonEmitter::emit_isa(const std::vector &in, const std::vectorvshuff32x4(aux_zmm, dst_zmm, dst_zmm, 0x4E); - perform_op(dst_zmm, aux_zmm); + perform_op(dst_zmm, dst_zmm, aux_zmm); h->vshuff32x4(aux_zmm, dst_zmm, dst_zmm, 0xB1); - perform_op(dst_zmm, aux_zmm); + perform_op(dst_zmm, dst_zmm, aux_zmm); } else if (isa == dnnl::impl::cpu::x64::avx2) { Ymm dst_ymm = Ymm(out[0]); Ymm aux_ymm = Ymm(aux_vec_idxs[0]); h->vperm2i128(aux_ymm, dst_ymm, dst_ymm, 0x01); - perform_op(dst_ymm, aux_ymm); + perform_op(dst_ymm, dst_ymm, aux_ymm); } h->uni_vshufps(aux_vmm, dst_vmm, dst_vmm, 0x4E); - perform_op(dst_vmm, aux_vmm); + perform_op(dst_vmm, dst_vmm, aux_vmm); h->uni_vshufps(aux_vmm, dst_vmm, dst_vmm, 0xB1); - perform_op(dst_vmm, aux_vmm); -} - -void HorizonEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const { - assert(!"Horizontal operation is not implemented."); -} - -void HorizonEmitter::perform_op(const Xbyak::Ymm &dst_xmm, const Xbyak::Ymm &src_xmm) const { - assert(!"Horizontal operation is not implemented."); -} - -HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : - HorizonEmitter(h, isa, n) {} - -void HorizonMaxEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const { - h->uni_vmaxps(dst_xmm, dst_xmm, src_xmm); -} - -void HorizonMaxEmitter::perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const { - h->uni_vmaxps(dst_ymm, dst_ymm, src_ymm); -} - -HorizonSumEmitter::HorizonSumEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : - HorizonEmitter(h, isa, n) {} - -void HorizonSumEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const { - h->uni_vaddps(dst_xmm, dst_xmm, src_xmm); + perform_op(dst_vmm, dst_vmm, aux_vmm); } -void HorizonSumEmitter::perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const { - h->uni_vaddps(dst_ymm, dst_ymm, src_ymm); +template +void HorizonEmitter::perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const { + switch (m_op_type) { + case OpType::max: + h->uni_vmaxps(vmm1, vmm2, vmm3); + break; + case OpType::sum: + h->uni_vaddps(vmm1, vmm2, vmm3); + break; + default: + assert(!"Unsupported horizontal operation."); + } } VectorBufferEmitter::VectorBufferEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp index 04290ee5ab1803..a4c3e1f835e4b4 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp @@ -427,7 +427,6 @@ class HorizonEmitter : public jit_emitter { } protected: - size_t aux_gprs_count() const override {return 1;} size_t aux_vecs_count() const override {return 1;} private: @@ -437,26 +436,11 @@ class HorizonEmitter : public jit_emitter { template void emit_isa(const std::vector &in, const std::vector &out) const; - virtual void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const; - virtual void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const; -}; - -class HorizonMaxEmitter : public HorizonEmitter { -public: - HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); - -private: - void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const override; - void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const override; -}; + template + void perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const; -class HorizonSumEmitter : public HorizonEmitter { -public: - HorizonSumEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); - -private: - void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const override; - void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const override; + enum class OpType { max, sum }; + OpType m_op_type = OpType::max; }; class VectorBufferEmitter : public jit_emitter {