Skip to content

Commit

Permalink
comments: refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Jan 28, 2024
1 parent 2b03e32 commit c550713
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ ov::element::Type get_arithmetic_binary_exec_precision(const std::shared_ptr<ov:
/// ADD ///
jit_add_emitter::jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node), alpha) {
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_add_emitter::jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc,
const float alpha) : jit_emitter(host, host_isa, exec_prc, alpha) {
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_add_emitter::get_inputs_count() const { return 2; }
Expand Down Expand Up @@ -78,16 +76,14 @@ std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(c
/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node), alpha) {
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc,
const float alpha)
: jit_emitter(host, host_isa, exec_prc, alpha) {
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_mul_add_emitter::get_inputs_count() const { return 3; }
Expand Down Expand Up @@ -125,15 +121,13 @@ std::set<std::vector<element::Type>> jit_mul_add_emitter::get_supported_precisio
/// MULTIPLY ///
jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node), alpha) {}
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {}

jit_multiply_emitter::jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc,
const float alpha)
: jit_emitter(host, host_isa, exec_prc, alpha) {}
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_multiply_emitter::get_inputs_count() const { return 2; }

Expand Down Expand Up @@ -196,7 +190,7 @@ jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit

size_t jit_power_static_emitter::get_inputs_count() const { return 1; }

size_t jit_power_static_emitter::get_aux_vecs_count() const { return 2; }
size_t jit_power_static_emitter::get_aux_vecs_count() const { return 1; }

size_t jit_power_static_emitter::get_aux_gprs_count() const { return 2; }

Expand Down Expand Up @@ -225,14 +219,19 @@ void jit_power_static_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs,
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg aux = TReg(aux_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);

if (power == 0.f) {
h->fmov(dst.s, 1.);
return;
}

bool get_from_dst = false;
const auto src = [&in_vec_idxs, &out_vec_idxs, &get_from_dst]() -> TReg {
return get_from_dst ? TReg(out_vec_idxs[0]) : TReg(in_vec_idxs[0]);
};

TReg aux = TReg(aux_vec_idxs[0]);
if (scale != 1.f) {
auto adr = table_val2("scale");
h->ld1r(aux.s, adr);
Expand All @@ -247,11 +246,6 @@ void jit_power_static_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs,
get_from_dst = true;
}

if (power == 0.f) {
h->fmov(dst.s, 1.);
return;
}

if (power == 1.f) {
if (!get_from_dst && (in_vec_idxs[0] != dst.getIdx())) {
h->uni_orr(dst, src(), src());
Expand Down Expand Up @@ -283,36 +277,33 @@ void jit_power_static_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs,
Xbyak_aarch64::SReg s1(1);

for (auto i = 0; i < 4; i++) {
store_context();
//store_context();
h->mov(s0, src().s[i]);
h->ldr(s1, table_val("power"));

const int32_t qreg_len = 16;
h->str(Xbyak_aarch64::QReg(dst.getIdx()), pre_ptr(h->sp, -qreg_len * 2));
h->str(Xbyak_aarch64::QReg(dst.getIdx()), pre_ptr(h->sp, -16));
h->blr(func_reg);
h->ldr(Xbyak_aarch64::QReg(dst.getIdx()), post_ptr(h->sp, qreg_len * 2));
h->ldr(Xbyak_aarch64::QReg(dst.getIdx()), post_ptr(h->sp, 16));

Xbyak_aarch64::WReg w0(0);
h->fmov(w0, s0);
h->mov(dst.s[i], w0);
restore_context();
//restore_context();
}
}
}

/// RELU ///
jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node), alpha) {
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc,
const float alpha)
: jit_emitter(host, host_isa, exec_prc, alpha) {
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_relu_emitter::get_inputs_count() const { return 1; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ class jit_add_emitter : public jit_emitter {
public:
jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f);
const ov::element::Type exec_prc = ov::element::f32);

jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha = 0.f);
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

Expand All @@ -38,13 +36,11 @@ class jit_mul_add_emitter : public jit_emitter {
public:
jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f);
ov::element::Type exec_prc = ov::element::f32);

jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha = 0.f);
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
Expand All @@ -61,13 +57,11 @@ class jit_multiply_emitter : public jit_emitter {
public:
jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f);
ov::element::Type exec_prc = ov::element::f32);

jit_multiply_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha = 0.f);
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
Expand Down Expand Up @@ -118,13 +112,11 @@ class jit_relu_emitter : public jit_emitter {
public:
jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f);
const ov::element::Type exec_prc = ov::element::f32);

jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha = 0.f);
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ class jit_emitter : public ov::snippets::Emitter {
jit_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f,
//const float alpha = 0.f,
emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) :
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc),
alpha(alpha), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
alpha(0.f), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
}

jit_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ngraph::Node>& n,
ov::element::Type exec_prc = ov::element::f32,
const float alpha = 0.f,
//const float alpha = 0.f,
emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) :
Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc),
alpha(alpha), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
alpha(0.f), in_out_type_(in_out_type), p_table(0), l_table (new Xbyak_aarch64::Label()) {
}

void emit_code(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#pragma once

#include "cpu_types.h"
#include "../executor.hpp"
#include "../eltwise.hpp"
#include <node.h>
#include "nodes/executors/eltwise.hpp"
#include "node.h"

namespace ov {
namespace intel_cpu {
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/executors/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ namespace intel_cpu {
enum class ExecutorType {
Undefined,
Common,
x64,
jit_x64,
Dnnl,
Acl,
Mlas,
Jit
jit_aarch64
};

class ExecutorContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const std::vector<TransposeExecutorDesc>& getTransposeExecutorsList() {
OV_CPU_INSTANCE_MLAS_ARM64(ExecutorType::Mlas, std::make_shared<MlasTransposeExecutorBuilder>())
OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared<RefOptimizedTransposeExecutorBuilder>())
OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared<ACLTransposeExecutorBuilder>())
OV_CPU_INSTANCE_X64(ExecutorType::x64, std::make_shared<JitTransposeExecutorBuilder>())
OV_CPU_INSTANCE_X64(ExecutorType::jit_x64, std::make_shared<JitTransposeExecutorBuilder>())
OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared<RefTransposeExecutorBuilder>())
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,28 +253,28 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
const ov::element::Type& src_prc,
const ov::element::Type& dst_prc,
const bool broadcast,
const int32_t offset) {
const int32_t ptr_offset) {
switch (src_prc) {
case ov::element::f16: {
if (broadcast) {
if (offset == 0) {
if (ptr_offset == 0) {
ld1r(data.h, ptr(ptr_reg));
} else {
add_imm(X_DEFAULT_ADDR, ptr_reg, offset, X_TMP_0);
add_imm(X_DEFAULT_ADDR, ptr_reg, ptr_offset, X_TMP_0);
ld1r(data.h, ptr(X_DEFAULT_ADDR));
}
} else {
ldr(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr_reg, offset));
ldr(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr_reg, ptr_offset));
}
break;
}
case ov::element::f32:
case ov::element::i32:
case ov::element::u32: {
if (broadcast) {
jit_generator::uni_ld1rw(data.s, ptr_reg, offset);
jit_generator::uni_ld1rw(data.s, ptr_reg, ptr_offset);
} else {
jit_generator::uni_ldr(data, ptr_reg, offset);
jit_generator::uni_ldr(data, ptr_reg, ptr_offset);
}
break;
}
Expand Down Expand Up @@ -314,16 +314,16 @@ void jit_uni_eltwise_generic<isa>::load_scalar(const SReg& data,
const XReg& ptr,
const ov::element::Type& src_prc,
const ov::element::Type& dst_prc,
const int32_t offset) {
const int32_t ptr_offset) {
switch (src_prc) {
case ov::element::f16: {
ldr(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, offset));
ldr(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::f32:
case ov::element::i32:
case ov::element::u32: {
ldr(data, Xbyak_aarch64::ptr(ptr, offset));
ldr(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
Expand Down Expand Up @@ -362,7 +362,7 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
const TReg& data,
const ov::element::Type& src_prc,
const ov::element::Type& dst_prc,
const int32_t offset) {
const int32_t ptr_offset) {
if (src_prc != dst_prc) {
switch (src_prc) {
case ov::element::f32: {
Expand Down Expand Up @@ -393,13 +393,13 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,

switch (dst_prc) {
case ov::element::f16: {
str(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, offset));
str(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::f32:
case ov::element::i32:
case ov::element::u32: {
str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, offset));
str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
Expand All @@ -413,7 +413,7 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
const SReg& data,
const ov::element::Type& src_prc,
const ov::element::Type& dst_prc,
const int32_t offset) {
const int32_t ptr_offset) {
if (src_prc != dst_prc) {
switch (src_prc) {
case ov::element::f32: {
Expand Down Expand Up @@ -444,13 +444,13 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,

switch (dst_prc) {
case ov::element::f16: {
str(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, offset));
str(Xbyak_aarch64::HReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i32:
case ov::element::u32:
case ov::element::f32: {
str(data, Xbyak_aarch64::ptr(ptr, offset));
str(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
Expand All @@ -470,7 +470,7 @@ struct EltwiseEmitterContext {
template<typename T>
struct EltwiseEmitter {
void operator()(EltwiseEmitterContext& ctx) {
ctx.emitter = std::make_shared<T>(ctx.host, ctx.host_isa, ctx.exec_prc, ctx.opData.alpha);
ctx.emitter = std::make_shared<T>(ctx.host, ctx.host_isa, ctx.exec_prc);
}
};

Expand Down Expand Up @@ -525,7 +525,7 @@ void jit_uni_eltwise_generic<isa>::compute_eltwise_op() {
out_idxs.push_back(vmm_dst.getIdx());

std::vector<size_t> gpr_idxs;
for (size_t i = 0; i < eltwise_emitter->get_aux_vecs_count(); i++) {
for (size_t i = 0; i < eltwise_emitter->get_aux_gprs_count(); i++) {
gpr_idxs.push_back(get_aux_gpr(i).getIdx());
}

Expand Down
Loading

0 comments on commit c550713

Please sign in to comment.