Skip to content

Commit

Permalink
[CPU] [ARM64] jit emitters: exp
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 12, 2024
1 parent 62a8746 commit 5115e00
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,136 @@ std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(c
return {{element::f32, element::f32}};
}

/// EXPONENT ///
jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}

jit_exp_emitter::jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_exp_emitter::get_inputs_count() const { return 1; }

size_t jit_exp_emitter::get_aux_vecs_count() const { return 4; }

size_t jit_exp_emitter::get_aux_gprs_count() const { return 1; }

void jit_exp_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_exp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg vmm_src(in_vec_idxs[0]);
const TReg vmm_dst(out_vec_idxs[0]);
const TReg vmm_aux1(aux_vec_idxs[0]);
const TReg vmm_aux2(aux_vec_idxs[1]);
const TReg vmm_aux0(aux_vec_idxs[2]);

const TReg vmm_mask(aux_vec_idxs[3]);

h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_max_f"));
h->fmin(vmm_dst.s, vmm_src.s, vmm_aux0.s);
h->ld1r(vmm_aux0.s, table_val2("exp_ln_flt_min_f"));

// get mask of values lower than log(FLT_MIN) to zero them in the output
h->facgt(vmm_mask.s, vmm_aux0.s, vmm_src.s);

h->fmax(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->mov(vmm_aux1.b16, vmm_dst.b16);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->ld1r(vmm_aux0.s, table_val2("exp_log2ef"));
h->ld1r(vmm_aux2.s, table_val2("half"));
h->fmla(vmm_aux2.s, vmm_dst.s, vmm_aux0.s);

// tmp = floorf(fx)
h->frintm(vmm_aux2.s, vmm_aux2.s);

// keep vmm_src = fx for further computations
h->mov(vmm_dst.b16, vmm_aux2.b16);

// x = x - fx * ln2
h->ld1r(vmm_aux0.s, table_val2("ln2f"));
h->fmls(vmm_aux1.s, vmm_aux2.s, vmm_aux0.s);

// We do not count 2^n here, because n can reach 128 and 2^128 is not
// representable by fp32, so to get around this problem, instead of computing
// 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127
// and 2 are numbers representable in fp32.

// compute 2^(n-1)
h->ld1r(vmm_aux0.s, table_val2("one"));
h->fsub(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
h->fcvtzs(vmm_aux2.s, vmm_dst.s);

h->ld1r(vmm_aux0.s, table_val2("exponent_bias"));
h->add(vmm_aux2.s, vmm_aux2.s, vmm_aux0.s);

h->sqshl(vmm_aux2.s, vmm_aux2.s, 23);

// set zeroes at those points which were < log(FLT_MIN)
h->and_(vmm_aux2.b16, vmm_mask.b16, vmm_aux2.b16);

// compute polynomial
h->ld1r(vmm_aux0.s, table_val2("exp_pol5"));
h->ld1r(vmm_dst.s, table_val2("exp_pol4"));
h->fmla(vmm_dst.s, vmm_aux1.s, vmm_aux0.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol3"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("exp_pol2"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

h->ld1r(vmm_aux0.s, table_val2("exp_pol1"));
h->fmla(vmm_aux0.s, vmm_dst.s, vmm_aux1.s);

h->ld1r(vmm_dst.s, table_val2("one"));
h->fmla(vmm_dst.s, vmm_aux0.s, vmm_aux1.s);

// y = y * 2^n
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux2.s);
h->ld1r(vmm_aux0.s, table_val2("two"));
h->fmul(vmm_dst.s, vmm_dst.s, vmm_aux0.s);
}

void jit_exp_emitter::register_table_entries() {
push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true);
push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("two", 0x40000000, true);
push_arg_entry_of("half", 0x3f000000, true);
push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
push_arg_entry_of("exp_pol1", 0x3f7ffffb, true);
push_arg_entry_of("exp_pol2", 0x3efffee3, true);
push_arg_entry_of("exp_pol3", 0x3e2aad40, true);
push_arg_entry_of("exp_pol4", 0x3d2b9d0d, true);
push_arg_entry_of("exp_pol5", 0x3c07cfce, true);
}

std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ class jit_add_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_exp_emitter : public jit_emitter {
public:
jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_exp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_mul_add_emitter : public jit_emitter {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bool JitEltwiseExecutor::isSupported(
const float gamma) {
const auto is_supported = one_of(algorithm,
Algorithm::EltwiseAdd,
Algorithm::EltwiseExp,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte

OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo,
OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter),
OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
Expand Down Expand Up @@ -654,6 +655,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo,
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// 09 | dst
// 10 | aux
// 11 | aux
// 12-15 | [not used]
// 12 | aux
// 13 | aux
// 14 | aux
// 15 | [not used]
// 16 | src
// 17 | src
// 18 | src
Expand Down Expand Up @@ -213,7 +216,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
}

inline TReg get_aux_vmm(const uint32_t idx) {
if (idx > 2) {
if (idx > 3) {
OPENVINO_THROW("aux vector register " + std::to_string(idx) + " is not supported");
}
return TReg(10 + idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) const {
#if defined(OV_CPU_WITH_ACL)
#if defined(OPENVINO_ARCH_ARM64)
if ((element_type == ov::element::f32) && (activation_type == utils::ActivationTypes::Relu)) {
if ((element_type == ov::element::f32) &&
((activation_type == utils::ActivationTypes::Relu) ||
(activation_type == utils::ActivationTypes::Exp))) {
return "jit";
}

Expand Down

0 comments on commit 5115e00

Please sign in to comment.