Skip to content

Commit

Permalink
diffetent precisions support #2
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 9, 2023
1 parent 4218ac3 commit d35c293
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 65 deletions.
73 changes: 11 additions & 62 deletions src/plugins/intel_cpu/src/emitters/aarch64/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,7 @@ void jit_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std
}

std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
return {{element::f32, element::f32}};
}

/// MUL_ADD ///
Expand Down Expand Up @@ -147,10 +144,7 @@ void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
}

std::set<std::vector<element::Type>> jit_mul_add_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::f16, element::f16, element::f16},
{element::f32, element::f32, element::f32}
};
return {{element::f32, element::f32, element::f32}};
}

/// MULTIPLY ///
Expand Down Expand Up @@ -203,10 +197,7 @@ void jit_multiply_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, cons
}

std::set<std::vector<element::Type>> jit_multiply_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
return {{element::f32, element::f32}};
}

/// POWER ///
Expand Down Expand Up @@ -249,10 +240,7 @@ void jit_power_emitter::register_table_entries() {
}

std::set<std::vector<element::Type>> jit_power_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::f16, element::f16},
{element::f32, element::f32}
};
return {{element::f32, element::f32}};
}

void jit_power_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
Expand Down Expand Up @@ -286,58 +274,19 @@ void jit_power_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s

if (scale != 1.f) {
auto adr = table_val2("scale");
switch (exec_prc_) {
case Precision::FP16: {
h->ld1r(aux.h, adr);
//h->fmov(aux.h, -1.);
h->fmul(src.h, src.h, aux.h);
break;
}
case Precision::FP32: {
h->ld1r(aux.s, adr);
//h->fmov(aux.s, -1.);
h->fmul(src.s, src.s, aux.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
h->ld1r(aux.s, adr);
//h->fmov(aux.s, -1.);
h->fmul(src.s, src.s, aux.s);
}

if (shift != 0.f) {
auto adr = table_val2("shift");
switch (exec_prc_) {
case Precision::FP16: {
h->ld1r(aux.h, adr);
h->fadd(src.h, src.h, aux.h);
break;
}
case Precision::FP32: {
h->ld1r(aux.s, adr);
h->fadd(src.s, src.s, aux.s);
break;
}
default: {
assert(!"unsupported precision");
}
}
h->ld1r(aux.s, adr);
h->fadd(src.s, src.s, aux.s);
}

if (power == 0.f) {
switch (exec_prc_) {
case Precision::FP16: {
h->fmov(dst.h, 1.);
break;
}
case Precision::FP32: {
h->fmov(dst.s, 1.);
break;
}
default: {
assert(!"unsupported precision");
}
}
h->fmov(dst.s, 1.);
return;
}

Expand Down Expand Up @@ -466,7 +415,7 @@ size_t jit_relu_emitter::get_inputs_count() const { return 1; }
size_t jit_relu_emitter::get_aux_vecs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_relu_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {{element::f16}, {element::f32}};
return {{element::f32}};
}

void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,10 @@ CPUTestsBase::CPUInfo CPUTestsBase::getCPUInfo() const {
std::string CPUTestsBase::getPrimitiveType(const ngraph::helpers::EltwiseTypes& eltwise_type,
const ov::element::Type_t& element_type,
const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) const {
if ((element_type == ov::element::f32) &&
((eltwise_type == ngraph::helpers::EltwiseTypes::ADD) ||
if ((eltwise_type == ngraph::helpers::EltwiseTypes::ADD) ||
(eltwise_type == ngraph::helpers::EltwiseTypes::MULTIPLY) ||
(eltwise_type == ngraph::helpers::EltwiseTypes::SUBTRACT) ||
(eltwise_type == ngraph::helpers::EltwiseTypes::DIVIDE))) {
(eltwise_type == ngraph::helpers::EltwiseTypes::DIVIDE)) {
return "jit";
}
return "acl";
Expand Down

0 comments on commit d35c293

Please sign in to comment.