From f0ebba0438392fc141bf73b3eb126e905b7f2313 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 28 Mar 2024 05:53:15 +0000 Subject: [PATCH] [CPU] [ARM64] jit emitter pipeline fix (#23387) ### Details: - *[CPU] [ARM64] jit emitter pipeline fix* --- .../emitters/plugin/aarch64/jit_emitter.cpp | 168 +++++++++++++----- .../emitters/plugin/aarch64/jit_emitter.hpp | 11 +- .../src/emitters/plugin/x64/jit_emitter.cpp | 1 + .../aarch64/jit_uni_eltwise_generic.hpp | 34 ++-- 4 files changed, 152 insertions(+), 62 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp index 65aac61ba853e3..f180bc6a2c39d4 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp @@ -5,6 +5,7 @@ #include "jit_emitter.hpp" #include #include "utils/general_utils.h" +#include "emitters/utils.hpp" using namespace dnnl::impl::cpu; using namespace dnnl::impl; @@ -13,7 +14,7 @@ namespace ov { namespace intel_cpu { namespace aarch64 { -const std::vector jit_emitter::store_gpr_regs = { +const std::vector jit_emitter::store_gpr_regs = { // Parameter/result registers 0, 1, 2, 3, 4, 5, 6, 7, // r8: Indirect result location register @@ -24,6 +25,13 @@ const std::vector jit_emitter::store_gpr_regs = { 29, 30 }; +static const std::vector vec_regs = { + 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31 +}; + void jit_emitter::emit_code(const std::vector &in_idxs, const std::vector &out_idxs, const std::vector &pool_vec_idxs, @@ -98,14 +106,68 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, OPENVINO_THROW("Failed to allocate required number of gpr registers"); } + using namespace Xbyak_aarch64::util; + const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || + (in_out_type_ == emitter_in_out_map::vec_to_gpr); + const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || + (in_out_type_ == emitter_in_out_map::gpr_to_vec); + + // vector registers for (auto idx : pool_aux_vec_idxs) { aux_vec_idxs.push_back(static_cast(idx)); } + for (size_t idx = 0; idx < get_max_vecs_count(); idx++) { + if (aux_vec_idxs.size() >= get_aux_vecs_count()) break; + + if (is_vec_input) { + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; + } + if (is_vec_output) { + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + } + + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + + if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) continue; + if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) continue; + + aux_vec_idxs.push_back(idx); + preserved_vec_idxs.push_back(idx); + } + if (aux_vec_idxs.size() < get_aux_vecs_count()) + OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of vector registers"); + + // gpr registers for (auto idx : pool_aux_gpr_idxs) { - aux_gpr_idxs.push_back(static_cast(idx)); - preserved_gpr_idxs.push_back(static_cast(idx)); + aux_gpr_idxs.push_back(idx); + } + + const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30; + for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) { + size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end + + if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break; + if ((_idx == Xbyak_aarch64::Operand::X18) || + (_idx == Xbyak_aarch64::Operand::X23) || + (_idx == Xbyak_aarch64::Operand::X28)) continue; + + if (!is_vec_input) { + if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue; + } + if (!is_vec_output) { + if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue; + } + + if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue; + if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue; + + aux_gpr_idxs.push_back(_idx); + preserved_gpr_idxs.push_back(_idx); } + if (aux_gpr_idxs.size() < get_aux_gprs_count()) + OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of general-purpose registers"); if (!entry_map_.empty()) { // last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose @@ -113,9 +175,7 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1); } - for (size_t i = 0; i < preserved_gpr_idxs.size(); ++i) { - h->str(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), pre_ptr(h->sp, -16)); - } + store_context(preserved_gpr_idxs, preserved_vec_idxs); if (!entry_map_.empty()) { load_table_addr(); @@ -123,10 +183,9 @@ void jit_emitter::emitter_preamble(const std::vector& in_idxs, } void jit_emitter::emitter_postamble() const { - const int size = static_cast(preserved_gpr_idxs.size()); - for (int i = (size - 1); i >= 0; --i) { - h->ldr(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), post_ptr(h->sp, 16)); - } + restore_context(preserved_gpr_idxs, preserved_vec_idxs); + + preserved_vec_idxs.clear(); preserved_gpr_idxs.clear(); aux_vec_idxs.clear(); @@ -134,99 +193,120 @@ void jit_emitter::emitter_postamble() const { } void jit_emitter::store_context(const std::unordered_set& ignore_registers) const { + store_context(store_gpr_regs, vec_regs, ignore_registers); +} + +void jit_emitter::store_context( + const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs) const { // 1. General-purpose Registers // 1.1. store pair registers - const auto store_gpr_regs_size = store_gpr_regs.size(); + const auto store_gpr_regs_size = gpr_regs.size(); const auto last = store_gpr_regs_size % 2; for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) { - h->stp(Xbyak_aarch64::XReg(store_gpr_regs[i]), - Xbyak_aarch64::XReg(store_gpr_regs[i + 1]), - pre_ptr(h->sp, -get_gpr_length() * 2)); + h->stp(Xbyak_aarch64::XReg(gpr_regs[i]), + Xbyak_aarch64::XReg(gpr_regs[i + 1]), + pre_ptr(h->sp, -get_gpr_length() * 2)); } - - // 1.1. store the remaining register + // 1.2. store the remaining register if (last != 0) { - h->str(Xbyak_aarch64::XReg(store_gpr_regs[store_gpr_regs_size - 1]), - pre_ptr(h->sp, -get_gpr_length() * 2)); + h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]), + pre_ptr(h->sp, -get_gpr_length())); } // 2. SIMD and Floating-Point registers // 2.1. store pair registers int prev_reg_idx = -1; size_t ignore_registers_count = 0; - for (size_t reg_idx = 0; reg_idx < get_asimd_vectors_count(); reg_idx++) { - if (ignore_registers.find(reg_idx) != ignore_registers.end()) { + for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) { + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { ignore_registers_count++; continue; } - if (prev_reg_idx == -1) { prev_reg_idx = static_cast(reg_idx); continue; } - h->stp(Xbyak_aarch64::QReg(prev_reg_idx), Xbyak_aarch64::QReg(reg_idx), pre_ptr(h->sp, -get_vec_length() * 2)); prev_reg_idx = -1; } - OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(), - "ignored registers size is not equal actual ignored registers count"); // 2.1. store the remaining register if (prev_reg_idx != -1) { - h->str(Xbyak_aarch64::QReg(prev_reg_idx), - pre_ptr(h->sp, -get_vec_length())); + if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) { + h->str(Xbyak_aarch64::QReg(prev_reg_idx), + pre_ptr(h->sp, -get_vec_length())); + } else { + ignore_registers_count++; + } } + + OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(), + "ignored registers size is not equal actual ignored registers count"); } -void jit_emitter::restore_context(const std::unordered_set& ignore_registers) const { +void jit_emitter::restore_context(const std::unordered_set& ignore_vec_regs) const { + restore_context(store_gpr_regs, vec_regs, ignore_vec_regs); +} + +void jit_emitter::restore_context( + const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs) const { // 1. SIMD and Floating-Point registers // 1.1. restore the remaining register - const auto v_last = (get_asimd_vectors_count() - ignore_registers.size()) % 2; + auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2; if (v_last != 0) { - const auto reg_idx = get_asimd_vectors_count() - 1; - h->ldr(Xbyak_aarch64::QReg(reg_idx), - post_ptr(h->sp, get_vec_length())); + for (size_t i = 0; i < vec_regs.size(); i++) { + const auto reg_idx = vec_regs.size() - 1 - i; + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { + v_last++; + continue; + } + + h->ldr(Xbyak_aarch64::QReg(reg_idx), + post_ptr(h->sp, get_vec_length())); + break; + } } - - // 2.2. restore pair registers + // 1.2. restore pair registers size_t ignore_registers_count = 0; int prev_reg_idx = -1; - for (size_t i = v_last; i < get_asimd_vectors_count(); i++) { - const auto reg_idx = get_asimd_vectors_count() - 1 - i; - if (ignore_registers.find(reg_idx) != ignore_registers.end()) { + for (size_t i = v_last; i < vec_regs.size(); i++) { + const auto reg_idx = vec_regs.size() - 1 - i; + if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) { ignore_registers_count++; continue; } - if (prev_reg_idx == -1) { prev_reg_idx = static_cast(reg_idx); continue; } - h->ldp(Xbyak_aarch64::QReg(reg_idx), Xbyak_aarch64::QReg(prev_reg_idx), post_ptr(h->sp, get_vec_length() * 2)); prev_reg_idx = -1; } - OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(), + OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(), "ignored registers size is not equal actual ignored registers count"); // 2. General-purpose Registers // 2.1. restore the remaining register - const auto save_gpr_regs_size = store_gpr_regs.size(); + const auto save_gpr_regs_size = gpr_regs.size(); const auto last = save_gpr_regs_size % 2; if (last != 0) { - h->ldr(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1]), - post_ptr(h->sp, get_gpr_length() * 2)); + h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]), + post_ptr(h->sp, get_gpr_length())); } // 2.2. restore pair registers for (size_t i = last; i < save_gpr_regs_size; i += 2) { - h->ldp(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - (i + 1)]), - Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - i]), + h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]), + Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]), post_ptr(h->sp, get_gpr_length() * 2)); } } diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp index 0a77d29ef8f2f6..3f247cdee9f6b7 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp @@ -141,10 +141,11 @@ class jit_emitter : public ov::snippets::Emitter { } private: + mutable std::vector preserved_vec_idxs; mutable std::vector preserved_gpr_idxs; // General-purpose Registers - static const std::vector store_gpr_regs; + static const std::vector store_gpr_regs; size_t table_off(const std::string& key, const size_t key_off_val_shift = 0) const { // assumption: all table entries sharing the same key also @@ -163,6 +164,14 @@ class jit_emitter : public ov::snippets::Emitter { inline int32_t get_gpr_length() const { return h->x0.getBit() / 8; } + + void store_context(const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs = {}) const; + + void restore_context(const std::vector& gpr_regs, + const std::vector& vec_regs, + const std::unordered_set& ignore_vec_regs = {}) const; }; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp index 498d03fb6c0af2..7aea76e914bd2d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp @@ -5,6 +5,7 @@ #include "jit_emitter.hpp" #include #include "utils/general_utils.h" +#include "utils.hpp" using namespace dnnl::impl::cpu; using namespace dnnl::impl; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index 71eee36187c967..89469ba603a402 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -97,10 +97,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { void generate() override; private: - const Xbyak_aarch64::XReg X_TMP_0 = x10; - const Xbyak_aarch64::XReg X_TMP_1 = x11; - - XReg reg_post_op_ptrs = X_TMP_0; + XReg reg_post_op_ptrs = x10; XReg start_to_offsets = reg_post_op_ptrs; XReg reg_oc_off = x12; @@ -126,10 +123,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // X10 | ker temporary| R10 | src ptr // X11 | ker temporary| R11 | src ptr // X12 | ker temporary (abi_not_param1) | R12 | src ptr - // X13 | [not used] | R13 | src ptr - // X14 | [not used] | R14 | src ptr - // X15 | dst | R15 | temporary - // X16 | [not used: IP1] + // X13 | temporary | R13 | src ptr + // X14 | temporary | R14 | src ptr + // X15 | temporary | R15 | temporary + // X16 | dst // X17 | [not used: IP0] // X18 | [not used: Apple: The platforms reserve register x18. Don't use this register.] @@ -138,32 +135,35 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { // X20 | src ptr // X21 | src ptr // X22 | src ptr - // X23 | src ptr - // X24 | src ptr + // X23 | kernel used (oneDNN: X_TMP_0) + // X24 | kernel used (oneDNN: X_TMP_1) // X25 | src ptr - // X26 | temporary - // X27 | temporary - // X28 | kernel used (X_DEFAULT_ADDR) + // X26 | src ptr + // X27 | src ptr + // X28 | kernel used (oneDNN: X_DEFAULT_ADDR) // X29 | [not used: The Frame Pointer (FP)] // X30 | [not used: The Link Register (LR)] // X31 | [not used: The Stack Pointer (SP)] const XReg reg_work_amount = x9; - const XReg reg_dst = x15; + const XReg reg_dst = x16; inline XReg get_src_reg(uint32_t idx) { if (idx > MAX_ELTWISE_INPUTS) { OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported"); } - return XReg(19 + idx); + + static const std::vector src_gprs = { 19, 20, 21, 22, 25, 26, 27 }; + return XReg(src_gprs[idx]); } inline XReg get_aux_gpr(const uint32_t idx) { - if (idx > 2) { + if (idx > 3) { OPENVINO_THROW("aux gpr register " + std::to_string(idx) + " is not supported"); } - return XReg(26 + idx); + + return XReg(13 + idx); } // Vector registers mapping