From 94ec445c04b608143a6f73e0a57e018e4a584820 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 25 Mar 2024 03:15:20 +0000 Subject: [PATCH] refactoring + comments --- .../emitters/plugin/aarch64/jit_emitter.cpp | 127 ++++++++++++++++-- .../aarch64/jit_uni_eltwise_generic.hpp | 38 ++---- 2 files changed, 126 insertions(+), 39 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp index bb6f93f97fedec..f180bc6a2c39d4 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp @@ -106,6 +106,12 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, OPENVINO_THROW("Failed to allocate required number of gpr registers"); } + using namespace Xbyak_aarch64::util; + const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || + (in_out_type_ == emitter_in_out_map::vec_to_gpr); + const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || + (in_out_type_ == emitter_in_out_map::gpr_to_vec); + // vector registers for (auto idx : pool_aux_vec_idxs) { aux_vec_idxs.push_back(static_cast(idx)); @@ -114,6 +120,13 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, for (size_t idx = 0; idx < get_max_vecs_count(); idx++) { if (aux_vec_idxs.size() >= get_aux_vecs_count()) break; + if (is_vec_input) { + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; + } + if (is_vec_output) { + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + } + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; @@ -131,12 +144,21 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, aux_gpr_idxs.push_back(idx); } - const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X28; + const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30; for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) { size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break; - if (_idx == Xbyak_aarch64::Operand::X18) continue; + if ((_idx == Xbyak_aarch64::Operand::X18) || + (_idx == Xbyak_aarch64::Operand::X23) || + (_idx == Xbyak_aarch64::Operand::X28)) continue; + + if (!is_vec_input) { + if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue; + } + if (!is_vec_output) { + if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue; + } if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue; if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue; @@ -178,16 +200,52 @@ void jit_emitter::store_context( const std::vector& gpr_regs, const std::vector& vec_regs, const std::unordered_set& ignore_vec_regs) const { - for (const auto i : gpr_regs) { - h->str(Xbyak_aarch64::XReg(i), pre_ptr(h->sp, -get_gpr_length() * 2)); + // 1. General-purpose Registers + // 1.1. store pair registers + const auto store_gpr_regs_size = gpr_regs.size(); + const auto last = store_gpr_regs_size % 2; + for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) { + h->stp(Xbyak_aarch64::XReg(gpr_regs[i]), + Xbyak_aarch64::XReg(gpr_regs[i + 1]), + pre_ptr(h->sp, -get_gpr_length() * 2)); + } + // 1.2. store the remaining register + if (last != 0) { + h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]), + pre_ptr(h->sp, -get_gpr_length())); } - for (const auto i : vec_regs) { - if (ignore_vec_regs.find(i) != ignore_vec_regs.end()) { + // 2. SIMD and Floating-Point registers + // 2.1. store pair registers + int prev_reg_idx = -1; + size_t ignore_registers_count = 0; + for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) { + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { + ignore_registers_count++; continue; } - h->str(Xbyak_aarch64::QReg(i), pre_ptr(h->sp, -get_vec_length())); + if (prev_reg_idx == -1) { + prev_reg_idx = static_cast(reg_idx); + continue; + } + h->stp(Xbyak_aarch64::QReg(prev_reg_idx), + Xbyak_aarch64::QReg(reg_idx), + pre_ptr(h->sp, -get_vec_length() * 2)); + prev_reg_idx = -1; } + + // 2.1. store the remaining register + if (prev_reg_idx != -1) { + if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) { + h->str(Xbyak_aarch64::QReg(prev_reg_idx), + pre_ptr(h->sp, -get_vec_length())); + } else { + ignore_registers_count++; + } + } + + OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(), + "ignored registers size is not equal actual ignored registers count"); } void jit_emitter::restore_context(const std::unordered_set& ignore_vec_regs) const { @@ -198,17 +256,58 @@ void jit_emitter::restore_context( const std::vector& gpr_regs, const std::vector& vec_regs, const std::unordered_set& ignore_vec_regs) const { - const int vec_regs_size = static_cast(vec_regs.size()); - for (int i = (vec_regs_size - 1); i >= 0; --i) { - if (ignore_vec_regs.find(i) != ignore_vec_regs.end()) { + // 1. SIMD and Floating-Point registers + // 1.1. restore the remaining register + auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2; + if (v_last != 0) { + for (size_t i = 0; i < vec_regs.size(); i++) { + const auto reg_idx = vec_regs.size() - 1 - i; + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { + v_last++; + continue; + } + + h->ldr(Xbyak_aarch64::QReg(reg_idx), + post_ptr(h->sp, get_vec_length())); + break; + } + } + // 1.2. restore pair registers + size_t ignore_registers_count = 0; + int prev_reg_idx = -1; + for (size_t i = v_last; i < vec_regs.size(); i++) { + const auto reg_idx = vec_regs.size() - 1 - i; + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { + ignore_registers_count++; continue; } - h->ldr(Xbyak_aarch64::QReg(vec_regs[i]), post_ptr(h->sp, get_vec_length())); + if (prev_reg_idx == -1) { + prev_reg_idx = static_cast(reg_idx); + continue; + } + h->ldp(Xbyak_aarch64::QReg(reg_idx), + Xbyak_aarch64::QReg(prev_reg_idx), + post_ptr(h->sp, get_vec_length() * 2)); + prev_reg_idx = -1; + } + + OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(), + "ignored registers size is not equal actual ignored registers count"); + + // 2. General-purpose Registers + // 2.1. restore the remaining register + const auto save_gpr_regs_size = gpr_regs.size(); + const auto last = save_gpr_regs_size % 2; + if (last != 0) { + h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]), + post_ptr(h->sp, get_gpr_length())); } - const int gpr_regs_size = static_cast(gpr_regs.size()); - for (int i = (gpr_regs_size - 1); i >= 0; --i) { - h->ldr(Xbyak_aarch64::XReg(gpr_regs[i]), post_ptr(h->sp, get_gpr_length() * 2)); + // 2.2. restore pair registers + for (size_t i = last; i < save_gpr_regs_size; i += 2) { + h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]), + Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]), + post_ptr(h->sp, get_gpr_length() * 2)); } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index a03c4813c4c1ed..c92f0e77e313c3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -97,10 +97,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { void generate() override; private: - const Xbyak_aarch64::XReg X_TMP_0 = x10; - const Xbyak_aarch64::XReg X_TMP_1 = x11; - - XReg reg_post_op_ptrs = X_TMP_0; + XReg reg_post_op_ptrs = x10; XReg start_to_offsets = reg_post_op_ptrs; XReg reg_oc_off = x12; @@ -126,10 +123,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // X10 | ker temporary| R10 | src ptr // X11 | ker temporary| R11 | src ptr // X12 | ker temporary (abi_not_param1) | R12 | src ptr - // X13 | [not used] | R13 | src ptr - // X14 | [not used] | R14 | src ptr - // X15 | dst | R15 | temporary - // X16 | [not used: IP1] + // X13 | temporary | R13 | src ptr + // X14 | temporary | R14 | src ptr + // X15 | temporary | R15 | temporary + // X16 | dst // X17 | [not used: IP0] // X18 | [not used: Apple: The platforms reserve register x18. Don't use this register.] @@ -138,31 +135,27 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // X20 | src ptr // X21 | src ptr // X22 | src ptr - // X23 | temporary & kernel used (oneDNN: X_TMP_0) - // X24 | src ptr + // X23 | kernel used (oneDNN: X_TMP_0) + // X24 | kernel used (oneDNN: X_TMP_1) // X25 | src ptr // X26 | src ptr - // X27 | temporary - // X28 | temporary & kernel used (oneDNN: X_DEFAULT_ADDR) + // X27 | src ptr + // X28 | kernel used (oneDNN: X_DEFAULT_ADDR) // X29 | [not used: The Frame Pointer (FP)] // X30 | [not used: The Link Register (LR)] // X31 | [not used: The Stack Pointer (SP)] const XReg reg_work_amount = x9; - const XReg reg_dst = x15; + const XReg reg_dst = x16; inline XReg get_src_reg(uint32_t idx) { if (idx > MAX_ELTWISE_INPUTS) { OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported"); } - const uint32_t base = 19; - if ((base + idx) == 23) { - idx++; - } - - return XReg(base + idx); + static const std::vector src_gprs = { 19, 20, 21, 22, 25, 26, 27 }; + return XReg(src_gprs[idx]); } inline XReg get_aux_gpr(const uint32_t idx) { @@ -170,12 +163,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { OPENVINO_THROW("aux gpr register " + std::to_string(idx) + " is not supported"); } - if (idx == 0) { - return XReg(23); - } - - const uint32_t base = 27; - return XReg(base + idx - 1); + return XReg(13 + idx); } // Vector registers mapping