Skip to content

Commit

Permalink
[CPU] Eltwise jit
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 24, 2023
1 parent a44ec65 commit ff49ab2
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 21 deletions.
182 changes: 182 additions & 0 deletions src/plugins/intel_cpu/src/emitters/x64/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2243,5 +2243,187 @@ void jit_select_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
h->vblendmps(vmm_dst | k_mask, vmm_src1, vmm_src0);
}
}

/// BITWISE_AND ///
jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_and_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_and_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_and_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
assert(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_and_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vandps(vmm_dst, vmm_src0, vmm_src1);
}

/// BITWISE_NOT ///
jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_bitwise_not_emitter::get_inputs_num() const { return 1; }

std::set<std::vector<element::Type>> jit_bitwise_not_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

size_t jit_bitwise_not_emitter::aux_vecs_count() const { return 1; }

void jit_bitwise_not_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
assert(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_not_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src = Vmm(in_vec_idxs[0]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);
Vmm vmm_aux = Vmm(aux_vec_idxs[0]);

if (isa == x64::sse41) {
h->uni_vmovups(vmm_dst, vmm_src);
h->andnps(vmm_dst, vmm_src);
} else {
h->vmovups(vmm_aux, table_val("all_bits"));
h->vandnps(vmm_dst, vmm_src, vmm_aux);
}
}

void jit_bitwise_not_emitter::register_table_entries() {
push_arg_entry_of("all_bits", 0xFFFFFFFF, true);
}

/// BITWISE_OR ///
jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_or_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_or_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_or_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
assert(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_or_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vorps(vmm_dst, vmm_src0, vmm_src1);
}

/// BITWISE_XOR ///
jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {}

size_t jit_bitwise_xor_emitter::get_inputs_num() const { return 2; }

std::set<std::vector<element::Type>> jit_bitwise_xor_emitter::get_supported_precisions(const std::shared_ptr<ngraph::Node>& node) {
return {
{element::i8, element::i8},
{element::i16, element::i16},
{element::i32, element::i32},
{element::u8, element::u8},
{element::u16, element::u16}
};
}

void jit_bitwise_xor_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
}
else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
assert(!"unsupported isa");
}
}

template <x64::cpu_isa_t isa>
void jit_bitwise_xor_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src0 = Vmm(in_vec_idxs[0]);
Vmm vmm_src1 = Vmm(in_vec_idxs[1]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

h->uni_vxorps(vmm_dst, vmm_src0, vmm_src1);
}

} // namespace intel_cpu
} // namespace ov
71 changes: 71 additions & 0 deletions src/plugins/intel_cpu/src/emitters/x64/jit_eltwise_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,76 @@ class jit_select_emitter : public jit_emitter {
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_bitwise_and_emitter : public jit_emitter {
public:
jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_bitwise_not_emitter : public jit_emitter {
public:
jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
size_t aux_vecs_count() const override;

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
void register_table_entries() override;
};

class jit_bitwise_or_emitter : public jit_emitter {
public:
jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

class jit_bitwise_xor_emitter : public jit_emitter {
public:
jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
};

} // namespace intel_cpu
} // namespace ov
37 changes: 21 additions & 16 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ static void set_intersection(const std::set<std::vector<element::Type>>& precisi
InferenceEngine::Precision eltwise_precision_helper::get_precision(const size_t inputs_number,
const InferenceEngine::Precision(&src_prc)[MAX_ELTWISE_INPUTS],
const std::vector<Eltwise::EltwiseData>& eltwise_data) {
// TODO: refactor
if ((eltwise_data[0].algo == Algorithm::EltwiseBitwiseAnd) ||
(eltwise_data[0].algo == Algorithm::EltwiseBitwiseNot) ||
(eltwise_data[0].algo == Algorithm::EltwiseBitwiseOr) ||
(eltwise_data[0].algo == Algorithm::EltwiseBitwiseXor)) {
return InferenceEngine::Precision::I32;
}

Precision exec_prc = Precision::UNSPECIFIED;

std::set<std::vector<element::Type>> supported_precision_intersection = get_supported_precisions(eltwise_data.front().algo);
Expand Down Expand Up @@ -249,7 +257,11 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter),
OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter),
OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter),
OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter));

if (precisions.empty())
IE_THROW() << "Unsupported operation type for Eltwise emitter";
Expand Down Expand Up @@ -623,7 +635,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter),
OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter),
OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter));
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter),
OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter),
OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter),
OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter));

if (!ctx.emitter)
IE_THROW() << "Unsupported operation type for Eltwise emitter";
Expand Down Expand Up @@ -2083,7 +2099,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
// if dim rank is greater than the maximum possible, we should use the reference execution
bool canUseOptimizedImpl = mayiuse(x64::sse41) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK;
// TODO: Add EltwiseLog algorithm support for JIT implementation
canUseOptimizedImpl &= !(one_of(getAlgorithm(), Algorithm::EltwiseLog) || isBitwise(getAlgorithm()));
canUseOptimizedImpl &= !one_of(getAlgorithm(), Algorithm::EltwiseLog);

bool canUseOptimizedShapeAgnosticImpl = isDynamicNode() && canUseOptimizedImpl;

Expand Down Expand Up @@ -2852,19 +2868,8 @@ bool Eltwise::canFuse(const NodePtr& node) const {
if (!mayiuse(x64::sse41) || getInputShapeAtPort(0).getRank() > MAX_ELTWISE_DIM_RANK)
return false;

// TODO: supported only via reference executor
if (one_of(getAlgorithm(),
Algorithm::EltwiseLog,
Algorithm::EltwiseBitwiseAnd,
Algorithm::EltwiseBitwiseNot,
Algorithm::EltwiseBitwiseOr,
Algorithm::EltwiseBitwiseXor) ||
one_of(node->getAlgorithm(),
Algorithm::EltwiseLog,
Algorithm::EltwiseBitwiseAnd,
Algorithm::EltwiseBitwiseNot,
Algorithm::EltwiseBitwiseOr,
Algorithm::EltwiseBitwiseXor))
// TODO: EltwiseLog is supported only via reference executor
if (getAlgorithm() == Algorithm::EltwiseLog || node->getAlgorithm() == Algorithm::EltwiseLog)
return false;

bool isIntegerNode = isIntegerComputeSupported(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ const auto params_4D_bitwise = ::testing::Combine(
::testing::Values(ov::element::Type_t::undefined),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(ov::AnyMap())),
::testing::Values(CPUSpecificParams({ nhwc, nhwc }, { nhwc }, {}, "ref")),
::testing::Values(CPUSpecificParams({ nhwc, nhwc }, { nhwc }, {}, {})),
::testing::Values(emptyFusingSpec),
::testing::Values(false));

Expand All @@ -267,7 +267,7 @@ const auto params_4D_bitwise_i16 = ::testing::Combine(
::testing::Values(ov::element::Type_t::undefined),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(ov::AnyMap())),
::testing::Values(CPUSpecificParams({ nhwc, nhwc }, { nhwc }, {}, "ref_I32$/")),
::testing::Values(CPUSpecificParams({ nhwc, nhwc }, { nhwc }, {}, "*_I32")),
::testing::Values(emptyFusingSpec),
::testing::Values(false));

Expand All @@ -280,12 +280,12 @@ const auto params_4D_bitwise_NOT = ::testing::Combine(
::testing::ValuesIn({ ngraph::helpers::EltwiseTypes::BITWISE_NOT }),
::testing::ValuesIn({ ngraph::helpers::InputLayerType::NONE }),
::testing::ValuesIn({ ov::test::utils::OpType::VECTOR }),
::testing::ValuesIn({ ov::element::Type_t::i8, ov::element::Type_t::u8, ov::element::Type_t::i32 }),
::testing::ValuesIn({ ov::element::Type_t::i8, /*ov::element::Type_t::u8,*/ ov::element::Type_t::i32}),
::testing::Values(ov::element::Type_t::undefined),
::testing::Values(ov::element::Type_t::undefined),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(ov::AnyMap())),
::testing::Values(CPUSpecificParams({ nhwc }, { nhwc }, {}, "ref")),
::testing::Values(CPUSpecificParams({ nhwc }, { nhwc }, {}, {})),
::testing::Values(emptyFusingSpec),
::testing::Values(false));

Expand All @@ -302,7 +302,7 @@ const auto params_4D_bitwise_NOT_i16 = ::testing::Combine(
::testing::Values(ov::element::Type_t::undefined),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(ov::AnyMap())),
::testing::Values(CPUSpecificParams({ nhwc }, { nhwc }, {}, "ref_I32$/")),
::testing::Values(CPUSpecificParams({ nhwc }, { nhwc }, {}, "*_I32")),
::testing::Values(emptyFusingSpec),
::testing::Values(false));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,12 @@ void CPUTestsBase::updateSelectedType(const std::string& primitiveType, const ov
selectedType = primitiveType;
}

if (selectedType.find("*") != std::string::npos) {
// like as regex
selectedType = primitiveType + "_" + selectedType;
return;
}

if (selectedType.find("$/") != std::string::npos) {
// like as regex
selectedType = selectedType.substr(0, selectedType.find("$/"));
Expand Down

0 comments on commit ff49ab2

Please sign in to comment.