Skip to content

Commit

Permalink
[CPU] [ARM] JIT HSwish (openvinotoolkit#24089)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [ARM] jit hswish*

### Tickets:
 - *CVS-136940*
  • Loading branch information
eshoguli authored May 9, 2024
1 parent f61dddf commit aa900f4
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,67 @@ std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(c
return {{element::f32}};
}

/// HARD_SWISH ///
jit_hswish_emitter::jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}

jit_hswish_emitter::jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_hswish_emitter::get_inputs_count() const { return 1; }

size_t jit_hswish_emitter::get_aux_vecs_count() const { return 2; }

size_t jit_hswish_emitter::get_aux_gprs_count() const { return 1; }

void jit_hswish_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_hswish_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);
TReg aux0 = TReg(aux_vec_idxs[0]);
TReg aux1 = TReg(aux_vec_idxs[1]);

// result = (x * min(max(x + 3, 0), 6)) / 6
h->ld1r(aux0.s, table_val2("three"));
h->fadd(aux0.s, src.s, aux0.s);
h->ld1r(aux1.s, table_val2("zero"));
h->fmaxnm(aux0.s, aux0.s, aux1.s);
h->ld1r(aux1.s, table_val2("six"));
h->fminnm(aux0.s, aux0.s, aux1.s);
h->fmul(aux0.s, aux0.s, src.s);
h->ld1r(aux1.s, table_val2("one_sixth"));
h->fmul(dst.s, aux0.s, aux1.s);
}

void jit_hswish_emitter::register_table_entries() {
push_arg_entry_of("zero", 0x00000000, true);
push_arg_entry_of("three", 0x40400000, true);
push_arg_entry_of("six", 0x40c00000, true);
push_arg_entry_of("one_sixth", dnnl::impl::float2int(1.f/6.f), true);
}

std::set<std::vector<element::Type>> jit_hswish_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,32 @@ class jit_exp_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_hswish_emitter : public jit_emitter {
public:
jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_hswish_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_mul_add_emitter : public jit_emitter {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseDivide,
Algorithm::EltwiseEqual,
Algorithm::EltwiseExp,
Algorithm::EltwiseHswish,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter),
OV_CASE(Algorithm::EltwiseExp, ov::intel_cpu::aarch64::jit_exp_emitter),
OV_CASE(Algorithm::EltwiseHswish, ov::intel_cpu::aarch64::jit_hswish_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
Expand Down Expand Up @@ -786,6 +787,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter),
OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter),
OV_CASE(Algorithm::EltwiseHswish, jit_hswish_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
if ((element_type == ov::element::f32) &&
((activation_type == utils::ActivationTypes::Clamp) ||
(activation_type == utils::ActivationTypes::Exp) ||
(activation_type == utils::ActivationTypes::HSwish) ||
(activation_type == utils::ActivationTypes::Relu) ||
(activation_type == utils::ActivationTypes::Sigmoid) ||
(activation_type == utils::ActivationTypes::Swish) ||
Expand Down

0 comments on commit aa900f4

Please sign in to comment.