Skip to content

Commit

Permalink
refactoring + comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 25, 2024
1 parent 888f2cb commit 94ec445
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 39 deletions.
127 changes: 113 additions & 14 deletions src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
OPENVINO_THROW("Failed to allocate required number of gpr registers");
}

using namespace Xbyak_aarch64::util;
const bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
(in_out_type_ == emitter_in_out_map::vec_to_gpr);
const bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) ||
(in_out_type_ == emitter_in_out_map::gpr_to_vec);

// vector registers
for (auto idx : pool_aux_vec_idxs) {
aux_vec_idxs.push_back(static_cast<uint32_t>(idx));
Expand All @@ -114,6 +120,13 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
for (size_t idx = 0; idx < get_max_vecs_count(); idx++) {
if (aux_vec_idxs.size() >= get_aux_vecs_count()) break;

if (is_vec_input) {
if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
}
if (is_vec_output) {
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;
}

if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue;
if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue;

Expand All @@ -131,12 +144,21 @@ void jit_emitter::emitter_preamble(const std::vector<size_t>& in_idxs,
aux_gpr_idxs.push_back(idx);
}

const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X28;
const uint32_t end_gpr_idx = Xbyak_aarch64::Operand::X30;
for (size_t gpr_idx = 0; gpr_idx <= end_gpr_idx; ++gpr_idx) {
size_t _idx = end_gpr_idx - gpr_idx; // we allocate from the end

if (aux_gpr_idxs.size() >= get_aux_gprs_count()) break;
if (_idx == Xbyak_aarch64::Operand::X18) continue;
if ((_idx == Xbyak_aarch64::Operand::X18) ||
(_idx == Xbyak_aarch64::Operand::X23) ||
(_idx == Xbyak_aarch64::Operand::X28)) continue;

if (!is_vec_input) {
if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue;
}
if (!is_vec_output) {
if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue;
}

if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue;
if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue;
Expand Down Expand Up @@ -178,16 +200,52 @@ void jit_emitter::store_context(
const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const {
for (const auto i : gpr_regs) {
h->str(Xbyak_aarch64::XReg(i), pre_ptr(h->sp, -get_gpr_length() * 2));
// 1. General-purpose Registers
// 1.1. store pair registers
const auto store_gpr_regs_size = gpr_regs.size();
const auto last = store_gpr_regs_size % 2;
for (size_t i = 0; i < (store_gpr_regs_size - last); i += 2) {
h->stp(Xbyak_aarch64::XReg(gpr_regs[i]),
Xbyak_aarch64::XReg(gpr_regs[i + 1]),
pre_ptr(h->sp, -get_gpr_length() * 2));
}
// 1.2. store the remaining register
if (last != 0) {
h->str(Xbyak_aarch64::XReg(gpr_regs[store_gpr_regs_size - 1]),
pre_ptr(h->sp, -get_gpr_length()));
}

for (const auto i : vec_regs) {
if (ignore_vec_regs.find(i) != ignore_vec_regs.end()) {
// 2. SIMD and Floating-Point registers
// 2.1. store pair registers
int prev_reg_idx = -1;
size_t ignore_registers_count = 0;
for (size_t reg_idx = 0; reg_idx < vec_regs.size(); reg_idx++) {
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
ignore_registers_count++;
continue;
}
h->str(Xbyak_aarch64::QReg(i), pre_ptr(h->sp, -get_vec_length()));
if (prev_reg_idx == -1) {
prev_reg_idx = static_cast<int>(reg_idx);
continue;
}
h->stp(Xbyak_aarch64::QReg(prev_reg_idx),
Xbyak_aarch64::QReg(reg_idx),
pre_ptr(h->sp, -get_vec_length() * 2));
prev_reg_idx = -1;
}

// 2.1. store the remaining register
if (prev_reg_idx != -1) {
if (ignore_vec_regs.find(prev_reg_idx) == ignore_vec_regs.end()) {
h->str(Xbyak_aarch64::QReg(prev_reg_idx),
pre_ptr(h->sp, -get_vec_length()));
} else {
ignore_registers_count++;
}
}

OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
"ignored registers size is not equal actual ignored registers count");
}

void jit_emitter::restore_context(const std::unordered_set<size_t>& ignore_vec_regs) const {
Expand All @@ -198,17 +256,58 @@ void jit_emitter::restore_context(
const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const {
const int vec_regs_size = static_cast<int>(vec_regs.size());
for (int i = (vec_regs_size - 1); i >= 0; --i) {
if (ignore_vec_regs.find(i) != ignore_vec_regs.end()) {
// 1. SIMD and Floating-Point registers
// 1.1. restore the remaining register
auto v_last = (vec_regs.size() - ignore_vec_regs.size()) % 2;
if (v_last != 0) {
for (size_t i = 0; i < vec_regs.size(); i++) {
const auto reg_idx = vec_regs.size() - 1 - i;
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
v_last++;
continue;
}

h->ldr(Xbyak_aarch64::QReg(reg_idx),
post_ptr(h->sp, get_vec_length()));
break;
}
}
// 1.2. restore pair registers
size_t ignore_registers_count = 0;
int prev_reg_idx = -1;
for (size_t i = v_last; i < vec_regs.size(); i++) {
const auto reg_idx = vec_regs.size() - 1 - i;
if (ignore_vec_regs.find(reg_idx) != ignore_vec_regs.end()) {
ignore_registers_count++;
continue;
}
h->ldr(Xbyak_aarch64::QReg(vec_regs[i]), post_ptr(h->sp, get_vec_length()));
if (prev_reg_idx == -1) {
prev_reg_idx = static_cast<int>(reg_idx);
continue;
}
h->ldp(Xbyak_aarch64::QReg(reg_idx),
Xbyak_aarch64::QReg(prev_reg_idx),
post_ptr(h->sp, get_vec_length() * 2));
prev_reg_idx = -1;
}

OPENVINO_ASSERT(ignore_registers_count == ignore_vec_regs.size(),
"ignored registers size is not equal actual ignored registers count");

// 2. General-purpose Registers
// 2.1. restore the remaining register
const auto save_gpr_regs_size = gpr_regs.size();
const auto last = save_gpr_regs_size % 2;
if (last != 0) {
h->ldr(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1]),
post_ptr(h->sp, get_gpr_length()));
}

const int gpr_regs_size = static_cast<int>(gpr_regs.size());
for (int i = (gpr_regs_size - 1); i >= 0; --i) {
h->ldr(Xbyak_aarch64::XReg(gpr_regs[i]), post_ptr(h->sp, get_gpr_length() * 2));
// 2.2. restore pair registers
for (size_t i = last; i < save_gpr_regs_size; i += 2) {
h->ldp(Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - (i + 1)]),
Xbyak_aarch64::XReg(gpr_regs[save_gpr_regs_size - 1 - i]),
post_ptr(h->sp, get_gpr_length() * 2));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
void generate() override;

private:
const Xbyak_aarch64::XReg X_TMP_0 = x10;
const Xbyak_aarch64::XReg X_TMP_1 = x11;

XReg reg_post_op_ptrs = X_TMP_0;
XReg reg_post_op_ptrs = x10;
XReg start_to_offsets = reg_post_op_ptrs;

XReg reg_oc_off = x12;
Expand All @@ -126,10 +123,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X10 | ker temporary| R10 | src ptr
// X11 | ker temporary| R11 | src ptr
// X12 | ker temporary (abi_not_param1) | R12 | src ptr
// X13 | [not used] | R13 | src ptr
// X14 | [not used] | R14 | src ptr
// X15 | dst | R15 | temporary
// X16 | [not used: IP1]
// X13 | temporary | R13 | src ptr
// X14 | temporary | R14 | src ptr
// X15 | temporary | R15 | temporary
// X16 | dst
// X17 | [not used: IP0]
// X18 | [not used: Apple: The platforms reserve register x18. Don't use this register.]

Expand All @@ -138,44 +135,35 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X20 | src ptr
// X21 | src ptr
// X22 | src ptr
// X23 | temporary & kernel used (oneDNN: X_TMP_0)
// X24 | src ptr
// X23 | kernel used (oneDNN: X_TMP_0)
// X24 | kernel used (oneDNN: X_TMP_1)
// X25 | src ptr
// X26 | src ptr
// X27 | temporary
// X28 | temporary & kernel used (oneDNN: X_DEFAULT_ADDR)
// X27 | src ptr
// X28 | kernel used (oneDNN: X_DEFAULT_ADDR)

// X29 | [not used: The Frame Pointer (FP)]
// X30 | [not used: The Link Register (LR)]
// X31 | [not used: The Stack Pointer (SP)]

const XReg reg_work_amount = x9;
const XReg reg_dst = x15;
const XReg reg_dst = x16;

inline XReg get_src_reg(uint32_t idx) {
if (idx > MAX_ELTWISE_INPUTS) {
OPENVINO_THROW("source vector ptr register " + std::to_string(idx) + " is not supported");
}

const uint32_t base = 19;
if ((base + idx) == 23) {
idx++;
}

return XReg(base + idx);
static const std::vector<uint32_t> src_gprs = { 19, 20, 21, 22, 25, 26, 27 };
return XReg(src_gprs[idx]);
}

inline XReg get_aux_gpr(const uint32_t idx) {
if (idx > 3) {
OPENVINO_THROW("aux gpr register " + std::to_string(idx) + " is not supported");
}

if (idx == 0) {
return XReg(23);
}

const uint32_t base = 27;
return XReg(base + idx - 1);
return XReg(13 + idx);
}

// Vector registers mapping
Expand Down

0 comments on commit 94ec445

Please sign in to comment.