Skip to content

Commit

Permalink
[CPU] [ARM] Element-wise fp16 support #2
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 3, 2023
1 parent 57f18e9 commit 4db8bb4
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "jit_uni_eltwise_generic.hpp"
#include "ie_ngraph_utils.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -35,13 +36,7 @@ template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::generate() {
preamble();

const auto get_precision = []() {
// TODO: debug: not completed
const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32;
//const InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP16;
return exec_prc;
};
const auto exec_prc = get_precision();
auto const exec_prc = eltwise_precision_helper::get_precision(jep_.inputs_number, jep_.src_prc, eltwise_data_);

eltwise_emitter = create_eltwise_emitter(eltwise_data_.front(), exec_prc);
for (size_t i = 1; i < eltwise_data_.size(); ++i) {
Expand Down Expand Up @@ -398,7 +393,6 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerDynamic, ov::intel_cpu::aarch64::jit_power_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_emitter),
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter));

Expand Down Expand Up @@ -459,6 +453,120 @@ void jit_uni_eltwise_generic<isa>::apply_post_ops() {
}
}

namespace {

// TODO: copy/paste: refactor
template<typename T>
struct SupportedPrecisions {
void operator()(std::set<std::vector<element::Type>> &precisions) {
precisions = T::get_supported_precisions();
}
};

// TODO: copy/paste: refactor
static void set_intersection(const std::set<std::vector<element::Type>>& precisions1,
const std::set<std::vector<element::Type>>& precisions2,
std::set<std::vector<element::Type>>& intersection) {
std::map<element::Type, size_t> intersection_types;

for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) {
for (auto it2 = precisions2.begin(); it2 != precisions2.end(); ++it2) {
const auto& it1_precisions = *it1;
// all element types are equal
if (it1_precisions[0] == (*it2)[0]) {
// first precisions size is used
intersection_types.emplace(it1_precisions[0], it1_precisions.size());
}
}
}

for (auto it = intersection_types.begin(); it != intersection_types.end(); ++it) {
intersection.insert(std::vector<element::Type>(it->second, it->first));
}
}
} // namespace

InferenceEngine::Precision eltwise_precision_helper::get_precision(
const size_t inputs_number,
const InferenceEngine::Precision (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<ov::intel_cpu::aarch64::EltwiseData>& eltwise_data) {
Precision exec_prc = Precision::UNSPECIFIED;

const auto algorithm = eltwise_data.front().algo;
std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(algorithm);

// for element-wise operations all inputs must to have the same precisions
auto has_same_precision = [](const std::vector<element::Type>& precisions) {
return std::all_of(precisions.begin(), precisions.end(), [&precisions](const element::Type precision) {
return precision == precisions[0];
});
};

// TODO: should we convert all inputs to fp16 for PowerStatic
assert((algorithm == Algorithm::EltwisePowerStatic) ||
std::all_of(supported_precision_intersection.begin(),
supported_precision_intersection.end(),
has_same_precision));


for (size_t i = 1; i < eltwise_data.size(); ++i) {
std::set<std::vector<element::Type>> prcs = get_supported_precisions(eltwise_data[i].algo);
std::set<std::vector<element::Type>> prcs_intersect = {};

OPENVINO_ASSERT((algorithm == Algorithm::EltwisePowerStatic) ||
std::all_of(prcs.begin(), prcs.end(), has_same_precision),
"for element-wise nodes all precisions have to be equal");

set_intersection(supported_precision_intersection, prcs, prcs_intersect);

supported_precision_intersection = prcs_intersect;
}

static const element::Type exec_precisions_priority[] = {
element::f16,
element::f32
};

for (const auto prc : exec_precisions_priority) {
if (std::any_of(
supported_precision_intersection.begin(),
supported_precision_intersection.end(),
[&prc](const std::vector<element::Type>& precisions) { return std::find(precisions.begin(), precisions.end(), prc) != precisions.end(); })) {
exec_prc = InferenceEngine::details::convertPrecision(prc);
break;
}
}

for (size_t i = 0; i < inputs_number; i++) {
if (src_prc[i] != exec_prc) {
exec_prc = Precision::FP32;
break;
}
}

if (exec_prc == Precision::UNSPECIFIED) {
IE_THROW() << "Eltwise jitter failed to specify execution precision for Eltwise node";
}

return exec_prc;
}

std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) {
std::set<std::vector<element::Type>> precisions;

OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo,
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_emitter));

if (precisions.empty())
IE_THROW() << "Unsupported operation type for Eltwise emitter";

return precisions;
}

template struct jit_uni_eltwise_generic<cpu_isa_t::asimd>;

} // namespace aarch64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, jit_generator {
std::vector<std::shared_ptr<jit_emitter>> post_op_emitters;
};

class eltwise_precision_helper {
public:
static InferenceEngine::Precision get_precision(const size_t inputs_number,
const InferenceEngine::Precision (&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<ov::intel_cpu::aarch64::EltwiseData>& eltwise_data);

private:
static std::set<std::vector<element::Type>> get_supported_precisions(const Algorithm& algo);
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov

0 comments on commit 4db8bb4

Please sign in to comment.