Skip to content

Commit

Permalink
[CPU] Relu implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 5, 2023
1 parent 6f868d6 commit 30630b4
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void jit_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const st
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
IE_THROW() << "unsupported precision";
IE_THROW() << "unsupported precision: " << exec_prc_;
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
Expand Down Expand Up @@ -104,7 +104,7 @@ void jit_mul_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, cons
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
IE_THROW() << "unsupported precision";
IE_THROW() << "unsupported precision: " << exec_prc_;
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
Expand Down Expand Up @@ -148,7 +148,7 @@ void jit_multiply_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, con
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_multiply_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
IE_THROW() << "unsupported precision";
IE_THROW() << "unsupported precision: " << exec_prc_;
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
Expand Down Expand Up @@ -199,7 +199,7 @@ void jit_power_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
IE_THROW() << "unsupported precision";
IE_THROW() << "unsupported precision: " << exec_prc_;
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
Expand Down Expand Up @@ -232,6 +232,57 @@ void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s
}
}

/// RELU ///
jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node), alpha) {
}

jit_relu_emitter::jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const Precision exec_prc,
const float alpha)
: jit_emitter(host, host_isa, exec_prc, alpha) {
}

size_t jit_relu_emitter::get_inputs_count() const { return 1; }

size_t jit_relu_emitter::get_aux_vecs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_relu_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32}};
}

void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
IE_THROW() << "Can't create jit eltwise kernel";
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_relu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

if (alpha != 0.f) {
IE_THROW() << "not zero alpha is not supported";
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

TReg tmp = TReg(aux_vec_idxs[0]);
h->movi(tmp.s, 0);

TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);
h->fmaxnm(dst.s, src.s, tmp.s);
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,31 @@ class jit_power_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_relu_emitter : public jit_emitter {
public:
jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32,
const float alpha = 0.f);

jit_relu_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
const float alpha = 0.f);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
7 changes: 7 additions & 0 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ bool is_supported(const Node* node) {
}
}

if (node->getAlgorithm() == Algorithm::EltwiseRelu) {
const auto eltwise = dynamic_cast<const Eltwise*>(node);
if ((eltwise == nullptr) || (eltwise->getAlpha() != 0.f) || (eltwise->getBeta() != 0.f) || (eltwise->getGamma() != 0.f)) {
return false;
}
}

if ((node->getAlgorithm() != Algorithm::EltwisePowerDynamic) &&
(node->getAlgorithm() != Algorithm::EltwisePowerStatic)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ bool JitEltwiseExecutor::isSupported(const Algorithm& algorithm) {
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerDynamic,
Algorithm::EltwisePowerStatic);
Algorithm::EltwisePowerStatic,
Algorithm::EltwiseRelu);
if (!is_supported) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerDynamic, ov::intel_cpu::aarch64::jit_power_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_emitter));
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_emitter),
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter));

if (!ctx.emitter)
IE_THROW() << "Unsupported operation type '" << algToString(data.algo) << "' for Eltwise emitter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,18 @@ CPUTestsBase::CPUInfo CPUTestsBase::getCPUInfo() const {
}

#if defined(OPENVINO_ARCH_ARM64)
namespace {
bool is_static(const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) {
return std::all_of(input_shapes.begin(),
input_shapes.end(),
[](const std::pair<ov::PartialShape, std::vector<ov::Shape>>& shape) { return shape.first.is_static(); });
}
} // namespace

std::string CPUTestsBase::getPrimitiveType(const ngraph::helpers::EltwiseTypes& eltwise_type,
const ov::element::Type_t& element_type,
const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) const {
if (element_type == ov::element::f32) {
const auto is_static = [](const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) {
return std::all_of(input_shapes.begin(),
input_shapes.end(),
[](const std::pair<ov::PartialShape, std::vector<ov::Shape>>& shape) { return shape.first.is_static(); });
};

if (is_static(input_shapes) &&
((eltwise_type == ngraph::helpers::EltwiseTypes::ADD) ||
(eltwise_type == ngraph::helpers::EltwiseTypes::MULTIPLY) ||
Expand All @@ -283,6 +285,12 @@ std::string CPUTestsBase::getPrimitiveType(const ngraph::helpers::EltwiseTypes&
std::string CPUTestsBase::getPrimitiveType(const ngraph::helpers::ActivationTypes& activation_type,
const ov::element::Type_t& element_type,
const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) const {
if (element_type == ov::element::f32) {
if (is_static(input_shapes) && (activation_type == ngraph::helpers::ActivationTypes::Relu)) {
return "jit";
}
}

if (activation_type == ngraph::helpers::ActivationTypes::Mish) {
// operation is decomposed and executed by different kernels
return "";
Expand Down

0 comments on commit 30630b4

Please sign in to comment.