From 2b3ad6566fdad38e815e8966e49f8f3b4fc38680 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 13 Mar 2024 23:27:02 +0000 Subject: [PATCH 1/2] [CPU] [ARM64] jit select --- .../plugin/aarch64/jit_eltwise_emitters.cpp | 54 +++++++++++++++++++ .../plugin/aarch64/jit_eltwise_emitters.hpp | 28 ++++++++++ .../nodes/executors/aarch64/jit_eltwise.cpp | 1 + .../aarch64/jit_uni_eltwise_generic.cpp | 2 + 4 files changed, 85 insertions(+) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 5fbe24b4f2b637..cf91f7643c05a3 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -495,6 +495,60 @@ void jit_relu_emitter::emit_isa(const std::vector &in_vec_idxs, const st h->fmaxnm(dst.s, src.s, tmp.s); } +/// SELECT /// +jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { + prepare_table(); +} +jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { + prepare_table(); +} + +size_t jit_select_emitter::get_inputs_count() const { return 3; } + +size_t jit_select_emitter::get_aux_vecs_count() const { return 1; } + +size_t jit_select_emitter::get_aux_gprs_count() const { return 1; } + +std::set> jit_select_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + +void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { + if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else { + OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel"); + } +} + +template +void jit_select_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string()); + + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + const TReg src1 = TReg(in_vec_idxs[0]); + const TReg src2 = TReg(in_vec_idxs[1]); + const TReg src3 = TReg(in_vec_idxs[2]); + const TReg dst = TReg(out_vec_idxs[0]); + const TReg aux = TReg(aux_vec_idxs[0]); + + h->ld1r(aux.s, table_val2("one")); + h->facge(aux.s, src1.s, aux.s); + + h->bsl(aux.b16, src2.b16, src3.b16); + h->mov(dst.b16, aux.b16); +} + +void jit_select_emitter::register_table_entries() { + push_arg_entry_of("one", 0x3f800000, true); +} + /// SUBTRACT /// jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp index 80e1c2ed7e9c42..dff1dd9edf08af 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp @@ -209,6 +209,34 @@ class jit_relu_emitter : public jit_emitter { void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; }; +class jit_select_emitter : public jit_emitter { +public: + jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); + + jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& n); + + size_t get_inputs_count() const override; + + size_t get_aux_vecs_count() const override; + + size_t get_aux_gprs_count() const override; + + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); + +private: + void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + + template + void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + + void register_table_entries() override; +}; + class jit_subtract_emitter : public jit_emitter { public: jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp index 5c4bd9fa8ecafb..d86a6770e8534d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp @@ -26,6 +26,7 @@ bool JitEltwiseExecutor::isSupported( Algorithm::EltwisePowerStatic, Algorithm::EltwisePrelu, Algorithm::EltwiseRelu, + Algorithm::EltwiseSelect, Algorithm::EltwiseSubtract); if (!is_supported) { return false; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp index 340f4632ef5eb5..ddde571506d825 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp @@ -617,6 +617,7 @@ std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitte OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter), OV_CASE(Algorithm::EltwisePrelu, ov::intel_cpu::aarch64::jit_prelu_emitter), OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter), + OV_CASE(Algorithm::EltwiseSelect, ov::intel_cpu::aarch64::jit_select_emitter), OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter)); if (!ctx.emitter) @@ -770,6 +771,7 @@ std::set> eltwise_precision_helper::get_supported_pre OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), + OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter)); if (precisions.empty()) From bdc27eec3fa0ec0fb0fc6ba7b5e27bda6dbd9f93 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 21 Mar 2024 13:15:23 +0000 Subject: [PATCH 2/2] review comments --- .../plugin/aarch64/jit_eltwise_emitters.cpp | 14 +++----------- .../plugin/aarch64/jit_eltwise_emitters.hpp | 4 ---- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index cf91f7643c05a3..41f798565c4e4c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -500,23 +500,19 @@ jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator * dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const std::shared_ptr& node) : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) { - prepare_table(); } jit_select_emitter::jit_select_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { - prepare_table(); } size_t jit_select_emitter::get_inputs_count() const { return 3; } size_t jit_select_emitter::get_aux_vecs_count() const { return 1; } -size_t jit_select_emitter::get_aux_gprs_count() const { return 1; } - std::set> jit_select_emitter::get_supported_precisions(const std::shared_ptr& node) { - return {{element::f32, element::f32}}; + return {{element::f32, element::f32, element::f32}}; } void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { @@ -538,17 +534,13 @@ void jit_select_emitter::emit_isa(const std::vector &in_vec_idxs, const const TReg dst = TReg(out_vec_idxs[0]); const TReg aux = TReg(aux_vec_idxs[0]); - h->ld1r(aux.s, table_val2("one")); - h->facge(aux.s, src1.s, aux.s); + h->eor(aux.b16, aux.b16, aux.b16); + h->fcmgt(aux.s, src1.s, aux.s); h->bsl(aux.b16, src2.b16, src3.b16); h->mov(dst.b16, aux.b16); } -void jit_select_emitter::register_table_entries() { - push_arg_entry_of("one", 0x3f800000, true); -} - /// SUBTRACT /// jit_subtract_emitter::jit_subtract_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp index dff1dd9edf08af..e0ff36c2657730 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp @@ -223,8 +223,6 @@ class jit_select_emitter : public jit_emitter { size_t get_aux_vecs_count() const override; - size_t get_aux_gprs_count() const override; - static std::set> get_supported_precisions( const std::shared_ptr& node = nullptr); @@ -233,8 +231,6 @@ class jit_select_emitter : public jit_emitter { template void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; - - void register_table_entries() override; }; class jit_subtract_emitter : public jit_emitter {