Skip to content

Commit

Permalink
[CPU] [ARM] Element-wise fp16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 3, 2023
1 parent 699c5a9 commit 57f18e9
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 30 deletions.
96 changes: 79 additions & 17 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void jit_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const st

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
if ((exec_prc_ != Precision::FP16) && (exec_prc_ != Precision::FP32)) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

Expand All @@ -69,11 +69,26 @@ void jit_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std
TReg src1 = TReg(in_vec_idxs[1]);
TReg dst = TReg(out_vec_idxs[0]);

h->uni_fadd(dst.s, src0.s, src1.s);
switch (exec_prc_) {
case Precision::FP16: {
h->uni_fadd(dst.h, src0.h, src1.h);
break;
}
case Precision::FP32: {
h->uni_fadd(dst.s, src0.s, src1.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
}

std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32}};
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
}

/// MUL_ADD ///
Expand Down Expand Up @@ -103,7 +118,7 @@ void jit_mul_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, cons

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
if ((exec_prc_ != Precision::FP16) && (exec_prc_ != Precision::FP32)) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

Expand All @@ -114,12 +129,28 @@ void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
TReg dst = TReg(out_vec_idxs[0]);

// uni_fmad implementation
h->fmul(dst.s, src0.s, src1.s);
h->fadd(dst.s, dst.s, src2.s);
switch (exec_prc_) {
case Precision::FP16: {
h->fmul(dst.h, src0.h, src1.h);
h->fadd(dst.h, dst.h, src2.h);
break;
}
case Precision::FP32: {
h->fmul(dst.s, src0.s, src1.s);
h->fadd(dst.s, dst.s, src2.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
}

std::set<std::vector<element::Type>> jit_mul_add_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32, element::f32}};
return {
{element::f16, element::f16, element::f16},
{element::f32, element::f32, element::f32}
};
}

/// MULTIPLY ///
Expand Down Expand Up @@ -147,7 +178,7 @@ void jit_multiply_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, con

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_multiply_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
if ((exec_prc_ != Precision::FP16) && (exec_prc_ != Precision::FP32)) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

Expand All @@ -156,11 +187,26 @@ void jit_multiply_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, cons
TReg src1 = TReg(in_vec_idxs[1]);
TReg dst = TReg(out_vec_idxs[0]);

h->uni_fmul(dst.s, src0.s, src1.s);
switch (exec_prc_) {
case Precision::FP16: {
h->uni_fmul(dst.h, src0.h, src1.h);
break;
}
case Precision::FP32: {
h->uni_fmul(dst.s, src0.s, src1.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
}

std::set<std::vector<element::Type>> jit_multiply_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32}};
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
}

/// POWER ///
Expand Down Expand Up @@ -202,7 +248,10 @@ void jit_power_emitter::register_table_entries() {
}

std::set<std::vector<element::Type>> jit_power_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32, element::f32}};
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
}

void jit_power_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
Expand All @@ -222,7 +271,7 @@ float pow_f32(float v1, float v2) {

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
if ((exec_prc_ != Precision::FP16) && (exec_prc_ != Precision::FP32)) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

Expand Down Expand Up @@ -360,7 +409,7 @@ size_t jit_relu_emitter::get_inputs_count() const { return 1; }
size_t jit_relu_emitter::get_aux_vecs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_relu_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f32}};
return {{element::f16}, {element::f32}};
}

void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
Expand All @@ -373,7 +422,7 @@ void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const s

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_relu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != Precision::FP32) {
if ((exec_prc_ != Precision::FP16) && (exec_prc_ != Precision::FP32)) {
IE_THROW() << "unsupported precision: " << exec_prc_;
}

Expand All @@ -384,11 +433,24 @@ void jit_relu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const st
using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

TReg tmp = TReg(aux_vec_idxs[0]);
h->movi(tmp.s, 0);

TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);
h->fmaxnm(dst.s, src.s, tmp.s);

switch (exec_prc_) {
case Precision::FP16: {
h->movi(tmp.h, 0);
h->fmaxnm(dst.h, src.h, tmp.h);
break;
}
case Precision::FP32: {
h->movi(tmp.s, 0);
h->fmaxnm(dst.s, src.s, tmp.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
}

} // namespace aarch64
Expand Down
20 changes: 15 additions & 5 deletions src/plugins/intel_cpu/src/nodes/executors/aarch64/jit_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,32 @@ bool JitEltwiseExecutor::isSupported(
return false;
}

{
const auto check_precisions = [&node](const std::set<InferenceEngine::Precision>& precisions) {
const auto& input_precisions = node->getOriginalInputPrecisions();
if (std::any_of(input_precisions.begin(),
input_precisions.end(),
[](const InferenceEngine::Precision& precision) { return precision != InferenceEngine::Precision::FP32; })) {
[&input_precisions, &precisions](const InferenceEngine::Precision& precision) {
return (input_precisions[0] != precision) || (precisions.find(precision) == precisions.end());
})) {
return false;
}
}

{
const auto& output_precisions = node->getOriginalOutputPrecisions();
if (std::any_of(output_precisions.begin(),
output_precisions.end(),
[](const InferenceEngine::Precision& precision) { return precision != InferenceEngine::Precision::FP32; })) {
[&input_precisions, &precisions](const InferenceEngine::Precision& precision) {
return (input_precisions[0] != precision) || (precisions.find(precision) == precisions.end());
})) {
return false;
}

return true;
};

const std::set<InferenceEngine::Precision> supported_precisions =
std::set<InferenceEngine::Precision>{InferenceEngine::Precision::FP16, InferenceEngine::Precision::FP32};
if (!check_precisions(supported_precisions)) {
return false;
}

if ((algorithm == Algorithm::EltwiseRelu) && ((alpha != 0.f) || (beta != 0.f) || (gamma != 0.f))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ void jit_uni_eltwise_generic<isa>::generate() {
preamble();

const auto get_precision = []() {
// TODO: debug: not completed
const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32;
//const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP16;
return exec_prc;
};
const auto exec_prc = get_precision();
Expand Down Expand Up @@ -253,21 +255,35 @@ void jit_uni_eltwise_generic<isa>::generate() {

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::uni_ldr(const TReg& data,
const XReg& ptr,
const XReg& ptr_reg,
const Precision& src_prc,
const Precision& dst_prc,
const bool broadcast,
const int32_t offset) {
if (src_prc != dst_prc) {
IE_THROW(Unexpected) << "src_prc != dst_prc is not supported";
IE_THROW(Unexpected) << "src_prc (" << src_prc << ") != dst_prc (" << dst_prc << ") is not supported";
}

switch (dst_prc) {
case Precision::FP16: {
if (broadcast) {
// TODO: move custom uni_ld1rw
if (offset == 0) {
ld1r(data.h, ptr(ptr_reg));
} else {
add_imm(X_DEFAULT_ADDR, ptr_reg, offset, X_TMP_0);
ld1r(data.h, ptr(X_DEFAULT_ADDR));
}
} else {
jit_generator::uni_ldr(data, ptr_reg, offset);
}
break;
}
case Precision::FP32: {
if (broadcast) {
jit_generator::uni_ld1rw(data.s, ptr, offset);
jit_generator::uni_ld1rw(data.s, ptr_reg, offset);
} else {
jit_generator::uni_ldr(data, ptr, offset);
jit_generator::uni_ldr(data, ptr_reg, offset);
}
break;
}
Expand All @@ -284,10 +300,11 @@ void jit_uni_eltwise_generic<isa>::uni_ldr(const SReg& data,
const Precision& dst_prc,
const int32_t offset) {
if (src_prc != dst_prc) {
IE_THROW(Unexpected) << "src_prc != dst_prc is not supported";
IE_THROW(Unexpected) << "src_prc (" << src_prc << ") != dst_prc (" << dst_prc << ") is not supported";
}

switch (dst_prc) {
case Precision::FP16:
case Precision::FP32: {
ldr(data, Xbyak_aarch64::ptr(ptr, offset));
break;
Expand All @@ -305,10 +322,11 @@ void jit_uni_eltwise_generic<isa>::uni_str(const XReg& ptr,
const Precision& dst_prc,
const int32_t offset) {
if (src_prc != dst_prc) {
IE_THROW(Unexpected) << "src_prc != dst_prc is not supported";
IE_THROW(Unexpected) << "src_prc (" << src_prc << ") != dst_prc (" << dst_prc << ") is not supported";
}

switch (dst_prc) {
case Precision::FP16:
case Precision::FP32: {
str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, offset));
break;
Expand All @@ -326,10 +344,11 @@ void jit_uni_eltwise_generic<isa>::uni_str(const XReg& ptr,
const Precision& dst_prc,
const int32_t offset) {
if (src_prc != dst_prc) {
IE_THROW(Unexpected) << "uni_str: src_prc != dst_prc is not supported";
IE_THROW(Unexpected) << "src_prc (" << src_prc << ") != dst_prc (" << dst_prc << ") is not supported";
}

switch (dst_prc) {
case Precision::FP16:
case Precision::FP32: {
str(data, Xbyak_aarch64::ptr(ptr, offset));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
// X14 | src ptr | R14 | src ptr
// X15 | src ptr | R15 | temporary
// X16 | src ptr
// X16 | src ptr
// X17 | temporary
// X18 | temporary
// X19-30 | [not used]
Expand Down

0 comments on commit 57f18e9

Please sign in to comment.