diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp index dc872b47e80129..5cda942f408ccb 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.cpp @@ -3,6 +3,7 @@ // #include "jit_uni_eltwise_generic.hpp" +#include "ie_ngraph_utils.hpp" namespace ov { namespace intel_cpu { @@ -35,13 +36,7 @@ template void jit_uni_eltwise_generic::generate() { preamble(); - const auto get_precision = []() { - // TODO: debug: not completed - const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32; - //const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP16; - return exec_prc; - }; - const auto exec_prc = get_precision(); + auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_); eltwise_emitter = create_eltwise_emitter(eltwise_data_.front(), exec_prc); for (size_t i = 1; i < eltwise_data_.size(); ++i) { @@ -398,7 +393,6 @@ std::shared_ptr jit_uni_eltwise_generic::create_eltwise_emitte OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter), OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter), OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter), - OV_CASE(Algorithm::EltwisePowerDynamic, ov::intel_cpu::aarch64::jit_power_emitter), OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_emitter), OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter)); @@ -459,6 +453,120 @@ void jit_uni_eltwise_generic::apply_post_ops() { } } +namespace { + +// TODO: copy/paste: refactor +template +struct SupportedPrecisions { + void operator()(std::set> &precisions) { + precisions = T::get_supported_precisions(); + } +}; + +// TODO: copy/paste: refactor +static void set_intersection(const std::set>& precisions1, + const std::set>& precisions2, + std::set>& intersection) { + std::map intersection_types; + + for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) { + for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) { + const auto& it1_precisions = *it1; + // all element types are equal + if (it1_precisions[0] == (*it2)[0]) { + // first precisions size is used + intersection_types.emplace(it1_precisions[0], it1_precisions.size()); + } + } + } + + for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) { + intersection.insert(std::vector(it->second, it->first)); + } +} +} // namespace + +InferenceEngine::Precision eltwise_precision_helper::get_precision( + const size_t inputs_number, + const InferenceEngine::Precision (&src_prc)[MAX_ELTWISE_INPUTS], + const std::vector& eltwise_data) { + Precision exec_prc = Precision::UNSPECIFIED; + + const auto algorithm = eltwise_data.front().algo; + std::set> supported_precision_intersection = get_supported_precisions(algorithm); + + // for element-wise operations all inputs must to have the same precisions + auto has_same_precision = [](const std::vector& precisions) { + return std::all_of(precisions.begin(), precisions.end(), [&precisions](const element::Type precision) { + return precision == precisions[0]; + }); + }; + + // TODO: should we convert all inputs to fp16 for PowerStatic + assert((algorithm == Algorithm::EltwisePowerStatic) || + std::all_of(supported_precision_intersection.begin(), + supported_precision_intersection.end(), + has_same_precision)); + + + for (size_t i = 1; i < eltwise_data.size(); ++i) { + std::set> prcs = get_supported_precisions(eltwise_data[i].algo); + std::set> prcs_intersect = {}; + + OPENVINO_ASSERT((algorithm == Algorithm::EltwisePowerStatic) || + std::all_of(prcs.begin(), prcs.end(), has_same_precision), + "for element-wise nodes all precisions have to be equal"); + + set_intersection(supported_precision_intersection, prcs, prcs_intersect); + + supported_precision_intersection = prcs_intersect; + } + + static const element::Type exec_precisions_priority[] = { + element::f16, + element::f32 + }; + + for (const auto prc : exec_precisions_priority) { + if (std::any_of( + supported_precision_intersection.begin(), + supported_precision_intersection.end(), + [&prc](const std::vector& precisions) { return std::find(precisions.begin(), precisions.end(), prc) != precisions.end(); })) { + exec_prc = InferenceEngine::details::convertPrecision(prc); + break; + } + } + + for (size_t i = 0; i < inputs_number; i++) { + if (src_prc[i] != exec_prc) { + exec_prc = Precision::FP32; + break; + } + } + + if (exec_prc == Precision::UNSPECIFIED) { + IE_THROW() << "Eltwise jitter failed to specify execution precision for Eltwise node"; + } + + return exec_prc; +} + +std::set> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) { + std::set> precisions; + + OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo, + OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter), + OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), + OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), + OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), + OV_CASE(Algorithm::EltwisePowerStatic, jit_power_emitter)); + + if (precisions.empty()) + IE_THROW() << "Unsupported operation type for Eltwise emitter"; + + return precisions; +} + template struct jit_uni_eltwise_generic; } // namespace aarch64 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp index 358c29d78c5c57..98fee8f0387921 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/aarch64/jit_uni_eltwise_generic.hpp @@ -228,6 +228,16 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator { std::vector> post_op_emitters; }; +class eltwise_precision_helper { +public: + static InferenceEngine::Precision get_precision(const size_t inputs_number, + const InferenceEngine::Precision (&src_prc)[MAX_ELTWISE_INPUTS], + const std::vector& eltwise_data); + +private: + static std::set> get_supported_precisions(const Algorithm& algo); +}; + } // namespace aarch64 } // namespace intel_cpu } // namespace ov