Skip to content

Commit

Permalink
[CPU] [ARM64] jit emitter pipeline fix (#23387)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [ARM64] jit emitter pipeline fix*
  • Loading branch information
eshoguli authored Mar 28, 2024
1 parent 496a5de commit f0ebba0
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 62 deletions.
168 changes: 124 additions & 44 deletions src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "jit_emitter.hpp"
#include <vector>
#include "utils/general_utils.h"
#include "emitters/utils.hpp"

using namespace dnnl::impl::cpu;
using namespace dnnl::impl;
Expand All @@ -13,7 +14,7 @@ namespace ov {
namespace intel_cpu {
namespace aarch64 {

const std::vector<uint32_t> jit_emitter::store_gpr_regs = {
const std::vector<size_t> jit_emitter::store_gpr_regs = {
// Parameter/result registers
0, 1, 2, 3, 4, 5, 6, 7,
// r8: Indirect result location register
Expand All @@ -24,6 +25,13 @@ const std::vector<uint32_t> jit_emitter::store_gpr_regs = {
29, 30
};

static const std::vector<size_t> vec_regs = {
0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31
};

void jit_emitter::emit_code(const std::vector<size_t> &in_idxs,
const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs,
Expand Down Expand Up @@ -98,135 +106,207 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
OPENVINO_THROW("Failed to allocate required number of gpr registers");
}

using namespace Xbyak_aarch64::util;
const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
(in_out_type_ == emitter_in_out_map::vec_to_gpr);
const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
(in_out_type_ == emitter_in_out_map::gpr_to_vec);

// vector registers
for (auto idx : pool_aux_vec_idxs) {
aux_vec_idxs.push_back(static_cast<uint32_t>(idx));
}

for (size_t idx = 0; idx < get_max_vecs_count(); idx++) {
if (aux_vec_idxs.size() >= get_aux_vecs_count()) break;

if (is_vec_input) {
if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
}
if (is_vec_output) {
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;
}

if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;

if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) continue;
if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) continue;

aux_vec_idxs.push_back(idx);
preserved_vec_idxs.push_back(idx);
}
if (aux_vec_idxs.size() < get_aux_vecs_count())
OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of vector registers");

// gpr registers
for (auto idx : pool_aux_gpr_idxs) {
aux_gpr_idxs.push_back(static_cast<uint32_t>(idx));
preserved_gpr_idxs.push_back(static_cast<uint32_t>(idx));
aux_gpr_idxs.push_back(idx);
}

const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30;
for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) {
size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end

if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break;
if ((_idx == Xbyak_aarch64::Operand::X18) ||
(_idx == Xbyak_aarch64::Operand::X23) ||
(_idx == Xbyak_aarch64::Operand::X28)) continue;

if (!is_vec_input) {
if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue;
}
if (!is_vec_output) {
if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue;
}

if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue;
if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue;

aux_gpr_idxs.push_back(_idx);
preserved_gpr_idxs.push_back(_idx);
}
if (aux_gpr_idxs.size() < get_aux_gprs_count())
OV_CPU_JIT_EMITTER_THROW("Failed to allocate required number of general-purpose registers");

if (!entry_map_.empty()) {
// last aux_gpr_idx is for p_table, we can use aux_gpr_idxs from idx 0 for other purpose
p_table = Xbyak_aarch64::XReg(aux_gpr_idxs[aux_gpr_idxs.size() - 1]);
aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1);
}

for (size_t i = 0; i < preserved_gpr_idxs.size(); ++i) {
h->str(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), pre_ptr(h->sp, -16));
}
store_context(preserved_gpr_idxs, preserved_vec_idxs);

if (!entry_map_.empty()) {
load_table_addr();
}
}

void jit_emitter::emitter_postamble() const {
const int size = static_cast<int>(preserved_gpr_idxs.size());
for (int i = (size - 1); i >= 0; --i) {
h->ldr(Xbyak_aarch64::XReg(preserved_gpr_idxs[i]), post_ptr(h->sp, 16));
}
restore_context(preserved_gpr_idxs, preserved_vec_idxs);

preserved_vec_idxs.clear();
preserved_gpr_idxs.clear();

aux_vec_idxs.clear();
aux_gpr_idxs.clear();
}

void jit_emitter::store_context(const std::unordered_set<size_t>& ignore_registers) const {
store_context(store_gpr_regs, vec_regs, ignore_registers);
}

void jit_emitter::store_context(
const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const {
// 1. General-purpose Registers
// 1.1. store pair registers
const auto store_gpr_regs_size = store_gpr_regs.size();
const auto store_gpr_regs_size = gpr_regs.size();
const auto last = store_gpr_regs_size % 2;
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
h->stp(Xbyak_aarch64::XReg(store_gpr_regs[i]),
Xbyak_aarch64::XReg(store_gpr_regs[i + 1]),
pre_ptr(h->sp, -get_gpr_length() * 2));
h->stp(Xbyak_aarch64::XReg(gpr_regs[i]),
Xbyak_aarch64::XReg(gpr_regs[i + 1]),
pre_ptr(h->sp, -get_gpr_length() * 2));
}

// 1.1. store the remaining register
// 1.2. store the remaining register
if (last != 0) {
h->str(Xbyak_aarch64::XReg(store_gpr_regs[store_gpr_regs_size - 1]),
pre_ptr(h->sp, -get_gpr_length() * 2));
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]),
pre_ptr(h->sp, -get_gpr_length()));
}

// 2. SIMD and Floating-Point registers
// 2.1. store pair registers
int prev_reg_idx = -1;
size_t ignore_registers_count = 0;
for (size_t reg_idx = 0; reg_idx < get_asimd_vectors_count(); reg_idx++) {
if (ignore_registers.find(reg_idx) != ignore_registers.end()) {
for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) {
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
ignore_registers_count++;
continue;
}

if (prev_reg_idx == -1) {
prev_reg_idx = static_cast<int>(reg_idx);
continue;
}

h->stp(Xbyak_aarch64::QReg(prev_reg_idx),
Xbyak_aarch64::QReg(reg_idx),
pre_ptr(h->sp, -get_vec_length() * 2));
prev_reg_idx = -1;
}
OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(),
"ignored registers size is not equal actual ignored registers count");

// 2.1. store the remaining register
if (prev_reg_idx != -1) {
h->str(Xbyak_aarch64::QReg(prev_reg_idx),
pre_ptr(h->sp, -get_vec_length()));
if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) {
h->str(Xbyak_aarch64::QReg(prev_reg_idx),
pre_ptr(h->sp, -get_vec_length()));
} else {
ignore_registers_count++;
}
}

OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
"ignored registers size is not equal actual ignored registers count");
}

void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_registers) const {
void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_vec_regs) const {
restore_context(store_gpr_regs, vec_regs, ignore_vec_regs);
}

void jit_emitter::restore_context(
const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const {
// 1. SIMD and Floating-Point registers
// 1.1. restore the remaining register
const auto v_last = (get_asimd_vectors_count() - ignore_registers.size()) % 2;
auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2;
if (v_last != 0) {
const auto reg_idx = get_asimd_vectors_count() - 1;
h->ldr(Xbyak_aarch64::QReg(reg_idx),
post_ptr(h->sp, get_vec_length()));
for (size_t i = 0; i < vec_regs.size(); i++) {
const auto reg_idx = vec_regs.size() - 1 - i;
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
v_last++;
continue;
}

h->ldr(Xbyak_aarch64::QReg(reg_idx),
post_ptr(h->sp, get_vec_length()));
break;
}
}

// 2.2. restore pair registers
// 1.2. restore pair registers
size_t ignore_registers_count = 0;
int prev_reg_idx = -1;
for (size_t i = v_last; i < get_asimd_vectors_count(); i++) {
const auto reg_idx = get_asimd_vectors_count() - 1 - i;
if (ignore_registers.find(reg_idx) != ignore_registers.end()) {
for (size_t i = v_last; i < vec_regs.size(); i++) {
const auto reg_idx = vec_regs.size() - 1 - i;
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
ignore_registers_count++;
continue;
}

if (prev_reg_idx == -1) {
prev_reg_idx = static_cast<int>(reg_idx);
continue;
}

h->ldp(Xbyak_aarch64::QReg(reg_idx),
Xbyak_aarch64::QReg(prev_reg_idx),
post_ptr(h->sp, get_vec_length() * 2));
prev_reg_idx = -1;
}

OPENVINO_ASSERT(ignore_registers_count == ignore_registers.size(),
OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
"ignored registers size is not equal actual ignored registers count");

// 2. General-purpose Registers
// 2.1. restore the remaining register
const auto save_gpr_regs_size = store_gpr_regs.size();
const auto save_gpr_regs_size = gpr_regs.size();
const auto last = save_gpr_regs_size % 2;
if (last != 0) {
h->ldr(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1]),
post_ptr(h->sp, get_gpr_length() * 2));
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]),
post_ptr(h->sp, get_gpr_length()));
}

// 2.2. restore pair registers
for (size_t i = last; i < save_gpr_regs_size; i += 2) {
h->ldp(Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
Xbyak_aarch64::XReg(store_gpr_regs[save_gpr_regs_size - 1 - i]),
h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]),
post_ptr(h->sp, get_gpr_length() * 2));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ class jit_emitter : public ov::snippets::Emitter {
}

private:
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;

// General-purpose Registers
static const std::vector<uint32_t> store_gpr_regs;
static const std::vector<size_t> store_gpr_regs;

size_t table_off(const std::string& key, const size_t key_off_val_shift = 0) const {
// assumption: all table entries sharing the same key also
Expand All @@ -163,6 +164,14 @@ class jit_emitter : public ov::snippets::Emitter {
inline int32_t get_gpr_length() const {
return h->x0.getBit() / 8;
}

void store_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;

void restore_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs = {}) const;
};

} // namespace aarch64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "jit_emitter.hpp"
#include <vector>
#include "utils/general_utils.h"
#include "utils.hpp"

using namespace dnnl::impl::cpu;
using namespace dnnl::impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
void generate() override;

private:
const Xbyak_aarch64::XReg X_TMP_0 = x10;
const Xbyak_aarch64::XReg X_TMP_1 = x11;

XReg reg_post_op_ptrs = X_TMP_0;
XReg reg_post_op_ptrs = x10;
XReg start_to_offsets = reg_post_op_ptrs;

XReg reg_oc_off = x12;
Expand All @@ -126,10 +123,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X10 | ker temporary| R10 | src ptr
// X11 | ker temporary| R11 | src ptr
// X12 | ker temporary (abi_not_param1) | R12 | src ptr
// X13 | [not used] | R13 | src ptr
// X14 | [not used] | R14 | src ptr
// X15 | dst | R15 | temporary
// X16 | [not used: IP1]
// X13 | temporary | R13 | src ptr
// X14 | temporary | R14 | src ptr
// X15 | temporary | R15 | temporary
// X16 | dst
// X17 | [not used: IP0]
// X18 | [not used: Apple: The platforms reserve register x18. Don't use this register.]

Expand All @@ -138,32 +135,35 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X20 | src ptr
// X21 | src ptr
// X22 | src ptr
// X23 | src ptr
// X24 | src ptr
// X23 | kernel used (oneDNN: X_TMP_0)
// X24 | kernel used (oneDNN: X_TMP_1)
// X25 | src ptr
// X26 | temporary
// X27 | temporary
// X28 | kernel used (X_DEFAULT_ADDR)
// X26 | src ptr
// X27 | src ptr
// X28 | kernel used (oneDNN: X_DEFAULT_ADDR)

// X29 | [not used: The Frame Pointer (FP)]
// X30 | [not used: The Link Register (LR)]
// X31 | [not used: The Stack Pointer (SP)]

const XReg reg_work_amount = x9;
const XReg reg_dst = x15;
const XReg reg_dst = x16;

inline XReg get_src_reg(uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported");
}
return XReg(19 + idx);

static const std::vector<uint32_t> src_gprs = { 19, 20, 21, 22, 25, 26, 27 };
return XReg(src_gprs[idx]);
}

inline XReg get_aux_gpr(const uint32_t idx) {
if (idx > 2) {
if (idx > 3) {
OPENVINO_THROW("aux gpr register " + std::to_string(idx) + " is not supported");
}
return XReg(26 + idx);

return XReg(13 + idx);
}

// Vector registers mapping
Expand Down

0 comments on commit f0ebba0

Please sign in to comment.