Skip to content

Commit

Permalink
Apply review comments #2
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchen-intel committed Jun 17, 2023
1 parent 3e66abd commit 7d63601
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 56 deletions.
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/emitters/x64/cpu_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_
jitters[ngraph::op::v7::Gelu::get_type_info_static()] = CREATE_EMITTER(ov::intel_cpu::jit_gelu_v7_emitter);
jitters[snippets::op::Fill::get_type_info_static()] = CREATE_EMITTER(FillEmitter);

jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_EMITTER(HorizonMaxEmitter);
jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_EMITTER(HorizonSumEmitter);
jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_EMITTER(HorizonEmitter);
jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_EMITTER(HorizonEmitter);

jitters[snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter);
jitters[snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter);
Expand Down
58 changes: 24 additions & 34 deletions src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,13 @@ void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, con
}

HorizonEmitter::HorizonEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n) :
jit_emitter(h, isa, n, Precision::FP32, emitter_in_out_map::vec_to_vec) {}
jit_emitter(h, isa, n, Precision::FP32, emitter_in_out_map::vec_to_vec) {
if (ov::as_type_ptr<const snippets::op::HorizonMax>(n)) {
m_op_type = OpType::max;
} else if (ov::as_type_ptr<const snippets::op::HorizonSum>(n)) {
m_op_type = OpType::sum;
}
}

void HorizonEmitter::emit_impl(const std::vector<size_t>& in,
const std::vector<size_t>& out) const {
Expand Down Expand Up @@ -1365,49 +1371,33 @@ void HorizonEmitter::emit_isa(const std::vector<size_t> &in, const std::vector<s
Zmm dst_zmm = Zmm(out[0]);
Zmm aux_zmm = Zmm(aux_vec_idxs[0]);
h->vshuff32x4(aux_zmm, dst_zmm, dst_zmm, 0x4E);
perform_op(dst_zmm, aux_zmm);
perform_op<Zmm>(dst_zmm, dst_zmm, aux_zmm);
h->vshuff32x4(aux_zmm, dst_zmm, dst_zmm, 0xB1);
perform_op(dst_zmm, aux_zmm);
perform_op<Zmm>(dst_zmm, dst_zmm, aux_zmm);
} else if (isa == dnnl::impl::cpu::x64::avx2) {
Ymm dst_ymm = Ymm(out[0]);
Ymm aux_ymm = Ymm(aux_vec_idxs[0]);
h->vperm2i128(aux_ymm, dst_ymm, dst_ymm, 0x01);
perform_op(dst_ymm, aux_ymm);
perform_op<Ymm>(dst_ymm, dst_ymm, aux_ymm);
}
h->uni_vshufps(aux_vmm, dst_vmm, dst_vmm, 0x4E);
perform_op(dst_vmm, aux_vmm);
perform_op<Xmm>(dst_vmm, dst_vmm, aux_vmm);
h->uni_vshufps(aux_vmm, dst_vmm, dst_vmm, 0xB1);
perform_op(dst_vmm, aux_vmm);
}

void HorizonEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const {
assert(!"Horizontal operation is not implemented.");
}

void HorizonEmitter::perform_op(const Xbyak::Ymm &dst_xmm, const Xbyak::Ymm &src_xmm) const {
assert(!"Horizontal operation is not implemented.");
}

HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n) :
HorizonEmitter(h, isa, n) {}

void HorizonMaxEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const {
h->uni_vmaxps(dst_xmm, dst_xmm, src_xmm);
}

void HorizonMaxEmitter::perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const {
h->uni_vmaxps(dst_ymm, dst_ymm, src_ymm);
}

HorizonSumEmitter::HorizonSumEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n) :
HorizonEmitter(h, isa, n) {}

void HorizonSumEmitter::perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const {
h->uni_vaddps(dst_xmm, dst_xmm, src_xmm);
perform_op<Xmm>(dst_vmm, dst_vmm, aux_vmm);
}

void HorizonSumEmitter::perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const {
h->uni_vaddps(dst_ymm, dst_ymm, src_ymm);
template<typename Vmm>
void HorizonEmitter::perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const {
switch (m_op_type) {
case OpType::max:
h->uni_vmaxps(vmm1, vmm2, vmm3);
break;
case OpType::sum:
h->uni_vaddps(vmm1, vmm2, vmm3);
break;
default:
assert(!"Unsupported horizontal operation.");
}
}

VectorBufferEmitter::VectorBufferEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n) :
Expand Down
24 changes: 4 additions & 20 deletions src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ class HorizonEmitter : public jit_emitter {
}

protected:
size_t aux_gprs_count() const override {return 1;}
size_t aux_vecs_count() const override {return 1;}

private:
Expand All @@ -437,26 +436,11 @@ class HorizonEmitter : public jit_emitter {
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in, const std::vector<size_t> &out) const;

virtual void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const;
virtual void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const;
};

class HorizonMaxEmitter : public HorizonEmitter {
public:
HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

private:
void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const override;
void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const override;
};
template<typename Vmm>
void perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const;

class HorizonSumEmitter : public HorizonEmitter {
public:
HorizonSumEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr<ov::Node>& n);

private:
void perform_op(const Xbyak::Xmm &dst_xmm, const Xbyak::Xmm &src_xmm) const override;
void perform_op(const Xbyak::Ymm &dst_ymm, const Xbyak::Ymm &src_ymm) const override;
enum class OpType { max, sum };
OpType m_op_type = OpType::max;
};

class VectorBufferEmitter : public jit_emitter {
Expand Down

0 comments on commit 7d63601

Please sign in to comment.