Skip to content

Commit

Permalink
[CPU] Eltwise jit
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 26, 2023
1 parent affa0b8 commit 984fd9f
Show file tree
Hide file tree
Showing 7 changed files with 485 additions and 110 deletions.
186 changes: 186 additions & 0 deletions src/plugins/intel_cpu/src/emitters/x64/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2243,5 +2243,191 @@ void jit_select_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
h->vblendmps(vmm_dst | k_mask, vmm_src1, vmm_src0);
}
}

/// BITWISE_AND ///
jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_and_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_and_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_and_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_ASSERT(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_and_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vandps(vmm_dst, vmm_src0, vmm_src1);
}

/// BITWISE_NOT ///
jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_bitwise_not_emitter::get_inputs_num() const { return 1; }

std::set<std::vector<element::Type>> jit_bitwise_not_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

size_t jit_bitwise_not_emitter::aux_vecs_count() const { return 1; }

void jit_bitwise_not_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_ASSERT(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_not_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src = Vmm(in_vec_idxs[0]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);
Vmm vmm_aux = Vmm(aux_vec_idxs[0]);

if (isa == x64::sse41) {
std::cout << "jit_bitwise_not_emitter::emit_isa: x64::sse41" << std::endl;
if (vmm_dst.getIdx() != vmm_src.getIdx()) {
h->uni_vmovups(vmm_dst, vmm_src);
}
h->andnps(vmm_dst, table_val("all_bits"));
} else if ((host_isa_ == x64::avx2) || (host_isa_ == x64::avx512_core)) {
std::cout << "jit_bitwise_not_emitter::emit_isa: x64::avx2 or x64::avx512_core" << std::endl;
h->vandnps(vmm_dst, vmm_src, table_val("all_bits"));
} else {
OPENVINO_ASSERT(!"unsupported isa");
}
}

void jit_bitwise_not_emitter::register_table_entries() {
push_arg_entry_of("all_bits", 0xFFFFFFFF, true);
}

/// BITWISE_OR ///
jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_or_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_or_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_or_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_ASSERT(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_or_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vorps(vmm_dst, vmm_src0, vmm_src1);
}

/// BITWISE_XOR ///
jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_xor_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_xor_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_xor_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_ASSERT(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_xor_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vxorps(vmm_dst, vmm_src0, vmm_src1);
}

} // namespace intel_cpu
} // namespace ov
71 changes: 71 additions & 0 deletions src/plugins/intel_cpu/src/emitters/x64/jit_eltwise_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,76 @@ class jit_select_emitter : public jit_emitter {
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_bitwise_and_emitter : public jit_emitter {
public:
jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_bitwise_not_emitter : public jit_emitter {
public:
jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
size_t aux_vecs_count() const override;

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
void register_table_entries() override;
};

class jit_bitwise_or_emitter : public jit_emitter {
public:
jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_bitwise_xor_emitter : public jit_emitter {
public:
jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

} // namespace intel_cpu
} // namespace ov
41 changes: 23 additions & 18 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ static void set_intersection(const std::set<std::vector<element::Type>>& precisi
InferenceEngine::Precision eltwise_precision_helper::get_precision(const size_t inputs_number,
const InferenceEngine::Precision(&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<Eltwise::EltwiseData>& eltwise_data) {
if (one_of(eltwise_data[0].algo,
Algorithm::EltwiseBitwiseAnd,
Algorithm::EltwiseBitwiseNot,
Algorithm::EltwiseBitwiseOr,
Algorithm::EltwiseBitwiseXor)) {
return InferenceEngine::Precision::I32;
}

Precision exec_prc = Precision::UNSPECIFIED;

std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(eltwise_data.front().algo);
Expand Down Expand Up @@ -249,7 +257,11 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter),
OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter),
OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter),
OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter));

if (precisions.empty())
IE_THROW() << "Unsupported operation type for Eltwise emitter";
Expand Down Expand Up @@ -623,7 +635,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter),
OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter),
OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter),
OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter));

if (!ctx.emitter)
IE_THROW() << "Unsupported operation type for Eltwise emitter";
Expand Down Expand Up @@ -1792,7 +1808,7 @@ class EltwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
break;
case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break;
case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break;
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
default: OPENVINO_THROW("Unsupported operation type for Eltwise executor");
}
}
});
Expand Down Expand Up @@ -1849,7 +1865,7 @@ class BitwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
*dst_ptr_f = src_f[0] ^ src_f[1];
break;
}
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
default: OPENVINO_THROW("Unsupported operation type for Eltwise executor");
}
}
});
Expand Down Expand Up @@ -2076,7 +2092,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
// if dim rank is greater than the maximum possible, we should use the reference execution
bool canUseOptimizedImpl = mayiuse(x64::sse41) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK;
// TODO: Add EltwiseLog algorithm support for JIT implementation
canUseOptimizedImpl &= !(one_of(getAlgorithm(), Algorithm::EltwiseLog) || isBitwise(getAlgorithm()));
canUseOptimizedImpl &= !one_of(getAlgorithm(), Algorithm::EltwiseLog);

bool canUseOptimizedShapeAgnosticImpl = isDynamicNode() && canUseOptimizedImpl;

Expand Down Expand Up @@ -2845,19 +2861,8 @@ bool Eltwise::canFuse(const NodePtr& node) const {
if (!mayiuse(x64::sse41) || getInputShapeAtPort(0).getRank() > MAX_ELTWISE_DIM_RANK)
return false;

// TODO: supported only via reference executor
if (one_of(getAlgorithm(),
Algorithm::EltwiseLog,
Algorithm::EltwiseBitwiseAnd,
Algorithm::EltwiseBitwiseNot,
Algorithm::EltwiseBitwiseOr,
Algorithm::EltwiseBitwiseXor) ||
one_of(node->getAlgorithm(),
Algorithm::EltwiseLog,
Algorithm::EltwiseBitwiseAnd,
Algorithm::EltwiseBitwiseNot,
Algorithm::EltwiseBitwiseOr,
Algorithm::EltwiseBitwiseXor))
// TODO: EltwiseLog is supported only via reference executor
if (getAlgorithm() == Algorithm::EltwiseLog || node->getAlgorithm() == Algorithm::EltwiseLog)
return false;

bool isIntegerNode = isIntegerComputeSupported(this);
Expand Down
Loading

0 comments on commit 984fd9f

Please sign in to comment.