diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 7a7a563d488801..c533dfd5847146 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -72,6 +72,136 @@ std::set> jit_add_emitter::get_supported_precisions(c return {{element::f32, element::f32}}; } +/// EXPONENT /// +jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) { + prepare_table(); +} + +jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { + prepare_table(); +} + +size_t jit_exp_emitter::get_inputs_count() const { return 1; } + +size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; } + +size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; } + +void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else { + OPENVINO_THROW("Can't create jit eltwise kernel"); + } +} + +template +void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + if (exec_prc_ != ov::element::f32) { + OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string()); + } + + using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits::TReg; + const TReg vmm_src(in_vec_idxs[0]); + const TReg vmm_dst(out_vec_idxs[0]); + const TReg vmm_aux1(aux_vec_idxs[0]); + const TReg vmm_aux2(aux_vec_idxs[1]); + const TReg vmm_aux0(aux_vec_idxs[2]); + + const TReg vmm_mask(aux_vec_idxs[3]); + + h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f")); + h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s); + h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f")); + + // get mask of values lower than log(FLT_MIN) to zero them in the output + h->facgt(vmm_mask.s, vmm_aux0.s, vmm_src.s); + + h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s); + h->mov(vmm_aux1.b16, vmm_dst.b16); + + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->ld1r(vmm_aux0.s, table_val2("exp_log2ef")); + h->ld1r(vmm_aux2.s, table_val2("half")); + h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s); + + // tmp = floorf(fx) + h->frintm(vmm_aux2.s, vmm_aux2.s); + + // keep vmm_src = fx for further computations + h->mov(vmm_dst.b16, vmm_aux2.b16); + + // x = x - fx * ln2 + h->ld1r(vmm_aux0.s, table_val2("ln2f")); + h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s); + + // We do not count 2^n here, because n can reach 128 and 2^128 is not + // representable by fp32, so to get around this problem, instead of computing + // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 + // and 2 are numbers representable in fp32. + + // compute 2^(n-1) + h->ld1r(vmm_aux0.s, table_val2("one")); + h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s); + h->fcvtzs(vmm_aux2.s, vmm_dst.s); + + h->ld1r(vmm_aux0.s, table_val2("exponent_bias")); + h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s); + + h->sqshl(vmm_aux2.s, vmm_aux2.s, 23); + + // set zeroes at those points which were < log(FLT_MIN) + h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16); + + // compute polynomial + h->ld1r(vmm_aux0.s, table_val2("exp_pol5")); + h->ld1r(vmm_dst.s, table_val2("exp_pol4")); + h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s); + + h->ld1r(vmm_aux0.s, table_val2("exp_pol3")); + h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s); + + h->ld1r(vmm_dst.s, table_val2("exp_pol2")); + h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s); + + h->ld1r(vmm_aux0.s, table_val2("exp_pol1")); + h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s); + + h->ld1r(vmm_dst.s, table_val2("one")); + h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s); + + // y = y * 2^n + h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s); + h->ld1r(vmm_aux0.s, table_val2("two")); + h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s); +} + +void jit_exp_emitter::register_table_entries() { + push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true); + push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true); + push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true); + push_arg_entry_of("one", 0x3f800000, true); + push_arg_entry_of("two", 0x40000000, true); + push_arg_entry_of("half", 0x3f000000, true); + push_arg_entry_of("ln2f", 0x3f317218, true); + push_arg_entry_of("exponent_bias", 0x0000007f, true); + push_arg_entry_of("exp_pol1", 0x3f7ffffb, true); + push_arg_entry_of("exp_pol2", 0x3efffee3, true); + push_arg_entry_of("exp_pol3", 0x3e2aad40, true); + push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true); + push_arg_entry_of("exp_pol5", 0x3c07cfce, true); +} + +std::set> jit_exp_emitter::get_supported_precisions(const std::shared_ptr& node) { + return {{element::f32, element::f32}}; +} + /// MUL_ADD /// jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp index 1c8aa44b357f2d..0f1e3fe5ecb997 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.hpp @@ -31,6 +31,32 @@ class jit_add_emitter : public jit_emitter { void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; }; +class jit_exp_emitter : public jit_emitter { +public: + jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const ov::element::Type exec_prc = ov::element::f32); + + jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, + dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, + const std::shared_ptr& node); + + size_t get_inputs_count() const override; + + size_t get_aux_vecs_count() const override; + + size_t get_aux_gprs_count() const override; + + void register_table_entries() override; + + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + +private: + void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + + template + void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; +}; class jit_mul_add_emitter : public jit_emitter { public: diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp index 24dfeb6f845115..c6bedffde36edc 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp @@ -19,6 +19,7 @@ bool JitEltwiseExecutor::isSupported( const float gamma) { const auto is_supported = one_of(algorithm, Algorithm::EltwiseAdd, + Algorithm::EltwiseExp, Algorithm::EltwiseMultiply, Algorithm::EltwiseMulAdd, Algorithm::EltwisePowerStatic, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp index 3050ec4785c5f3..a97914a19f8621 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp @@ -503,6 +503,7 @@ std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitte OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo, OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter), + OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter), OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter), OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter), OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter), @@ -654,6 +655,7 @@ std::set> eltwise_precision_helper::get_supported_pre OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo, OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter), OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), + OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index a03c4813c4c1ed..f5c2df6eb9c2ae 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -185,7 +185,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // 09 | dst // 10 | aux // 11 | aux - // 12-15 | [not used] + // 12 | aux + // 13 | aux + // 14 | aux + // 15 | [not used] // 16 | src // 17 | src // 18 | src @@ -213,7 +216,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { } inline TReg get_aux_vmm(const uint32_t idx) { - if (idx > 2) { + if (idx > 3) { OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported"); } return TReg(10 + idx); diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp index d81cdb9698b4d2..750ec5de462cf1 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp @@ -139,7 +139,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType const std::vector>>& input_shapes) const { #if defined(OV_CPU_WITH_ACL) #if defined(OPENVINO_ARCH_ARM64) - if ((element_type == ov::element::f32) && (activation_type == utils::ActivationTypes::Relu)) { + if ((element_type == ov::element::f32) && + ((activation_type == utils::ActivationTypes::Relu) || + (activation_type == utils::ActivationTypes::Exp))) { return "jit"; }