From 25987d1481c6d92079f43b9c317e918600956538 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 22 Feb 2024 00:51:00 +0000 Subject: [PATCH] min & flags --- .../emitters/plugin/aarch64/jit_eltwise_emitters.cpp | 11 ++++++++++- .../nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp | 5 +++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 2596fb4d51dfbb..f1984497ed0e19 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -127,7 +127,7 @@ jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host, size_t jit_exp_emitter::get_inputs_count() const { return 1; } -size_t jit_exp_emitter::get_aux_vecs_count() const { return 3; } +size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; } size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; } @@ -152,9 +152,15 @@ void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std const TReg vmm_aux2(aux_vec_idxs[1]); const TReg vmm_aux0(aux_vec_idxs[2]); + const TReg vmm_mask(aux_vec_idxs[3]); + h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f")); h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s); h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f")); + + // get mask of values lower than log(FLT_MIN) to zero them in the output + h->facgt(vmm_mask.s, vmm_aux0.s, vmm_src.s); + h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s); h->mov(vmm_aux1.b16, vmm_dst.b16); @@ -189,6 +195,9 @@ void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std h->sqshl(vmm_aux2.s, vmm_aux2.s, 23); + // set zeroes at those points which were < log(FLT_MIN) + h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16); + // compute polynomial h->ld1r(vmm_aux0.s, table_val2("exp_pol5")); h->ld1r(vmm_dst.s, table_val2("exp_pol4")); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index 3e6ba17eee51db..db1faa9b3192b0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -174,7 +174,8 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // 11 | aux // 12 | aux // 13 | aux - // 14-15 | [not used] + // 14 | aux + // 15 | [not used] // 16 | src // 17 | src // 18 | src @@ -202,7 +203,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { } inline TReg get_aux_vmm(const uint32_t idx) { - if (idx > 2) { + if (idx > 3) { OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported"); } return TReg(10 + idx);