Skip to content

Commit

Permalink
[CPU] Refactors jitters for nGraph interop (#4255)
Browse files Browse the repository at this point in the history
  • Loading branch information
Marina Kolpakova authored Feb 14, 2021
1 parent d406a5a commit 434e66d
Show file tree
Hide file tree
Showing 17 changed files with 630 additions and 353 deletions.
73 changes: 73 additions & 0 deletions inference-engine/src/mkldnn_plugin/emitters/jit_bf16_emitters.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include "jit_emitter.hpp"

namespace MKLDNNPlugin {

class jit_emu_vcvtneps2bf16 : public jit_emitter {
public:
jit_emu_vcvtneps2bf16(mkldnn::impl::cpu::x64::jit_generator* host, mkldnn::impl::cpu::x64::cpu_isa_t host_isa, const MKLDNNNode* node,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::BF16) : jit_emitter(host, host_isa, node, exec_prc) {
prepare_table();
};

size_t get_inputs_num() const override { return 1; };

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs,
const std::vector<size_t>& pool_vec_idxs, const std::vector<size_t>& pool_gpr_idxs,
const emitter_context *emit_context) const override {
if (host_isa_ == mkldnn::impl::cpu::x64::cpu_isa_t::avx512_common) {
Xbyak::Zmm in = Xbyak::Zmm(in_vec_idxs[0]);
Xbyak::Ymm out = Xbyak::Ymm(out_vec_idxs[0]);
Xbyak::Zmm aux = Xbyak::Zmm(aux_vec_idxs[0]);
Xbyak::Zmm aux1 = Xbyak::Zmm(aux_vec_idxs[1]);

h->uni_vpsrld(aux, in, 16);
h->vpandd(aux, aux, table_val("one"));
h->uni_vmovups(aux1, table_val("even"));
h->uni_vpaddd(aux, aux1, aux);
h->uni_vpaddd(aux, in, aux);
h->vfixupimmps(aux, in, table_val("selector"), 0);
h->vpsrad(aux, aux, 16);
h->vpmovdw(out, aux);
} else {
assert(!"unsupported isa");
}
};


inline int encode_fixup_selector(int input, int output) {
return ((output) << (4 * (input)));
}

void register_table_entries() override {
enum {
fixup_input_code_qnan_ = 0,
fixup_input_code_snan_ = 1,
fixup_input_code_ninf_ = 4,
fixup_input_code_pinf_ = 5,
fixup_output_code_copy_input_ = 1,
fixup_output_code_qnan_input_ = 2,
};
const int selector_int32 =
/* qnan input to qnan output (presenrving input bits 0..21) */
encode_fixup_selector(fixup_input_code_snan_, fixup_output_code_qnan_input_) |
/* snan input to qnan output (presenrving input bits 0..21) */
encode_fixup_selector(fixup_input_code_qnan_, fixup_output_code_qnan_input_) |
/* neg inf input copied to output */
encode_fixup_selector(fixup_input_code_ninf_, fixup_output_code_copy_input_) |
/* pos inf input copied to output */
encode_fixup_selector(fixup_input_code_pinf_, fixup_output_code_copy_input_);
push_arg_entry_of("one", 0x00000001, true);
push_arg_entry_of("even", 0x00007fff, true);
push_arg_entry_of("selector", selector_int32, true);
}

size_t aux_vecs_count() const override { return 2; }
};

} // namespace MKLDNNPlugin {

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "emitter.h"

#include "jit_emitter.hpp"
#include "utils/general_utils.h"

#include <vector>

using namespace mkldnn::impl::cpu;
Expand Down Expand Up @@ -57,7 +55,7 @@ std::set<InferenceEngine::Precision> jit_emitter::get_supported_precisions() {
}

void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
using namespace Xbyak::util;
bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::vec_to_gpr);
bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::gpr_to_vec);
Expand Down Expand Up @@ -148,7 +146,8 @@ void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs, const std
load_table_addr();
}

void jit_emitter::emitter_postamble() {

void jit_emitter::emitter_postamble() const {
using namespace Xbyak::util;

for (size_t i = 0; i < preserved_vec_idxs.size(); ++i)
Expand All @@ -167,9 +166,9 @@ void jit_emitter::emitter_postamble() {
aux_gpr_idxs.clear();
}

void jit_emitter::emit_table() {
void jit_emitter::emit_data() const {
h->align(64);
h->L(l_table);
h->L(*l_table.get());

// Assumption: entries can be inserted with dd, so they should be 4 bytes.
assert(sizeof(table_entry_val_t) == 4);
Expand Down Expand Up @@ -198,18 +197,18 @@ void jit_emitter::prepare_table() {
}
}

void jit_emitter::emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, nullptr);

emitter_postamble();
}

void jit_emitter::emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <ie_common.h>
#include <cpu/x64/jit_generator.hpp>

#include "mkldnn_node.h"

#include <set>
Expand All @@ -25,20 +26,26 @@ struct emitter_context {

class jit_emitter {
public:
jit_emitter(mkldnn::impl::cpu::x64::jit_generator* host, mkldnn::impl::cpu::x64::cpu_isa_t host_isa, const MKLDNNNode* node,
jit_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const MKLDNNNode* node,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec)
: h(host), host_isa_(host_isa), n(node), exec_prc_(exec_prc), in_out_type_(in_out_type) {
: h(host), host_isa_(host_isa), exec_prc_(exec_prc), in_out_type_(in_out_type), l_table (new Xbyak::Label()) {
k_mask = Xbyak::Opmask(1); // FIXME: in general case we need preserve k_mask state as well
}

virtual void emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {});
jit_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec)
: h(host), host_isa_(host_isa), exec_prc_(exec_prc), in_out_type_(in_out_type), l_table (new Xbyak::Label()) {
k_mask = Xbyak::Opmask(1); // FIXME: in general case we need preserve k_mask state as well
}

virtual void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {}) const;
virtual void emit_data() const;

virtual void emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
virtual void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {});
virtual void emit_table();
virtual size_t get_inputs_num() = 0;
virtual size_t get_inputs_num() const = 0;
virtual size_t aux_vecs_count() const;
static std::set<InferenceEngine::Precision> get_supported_precisions();

Expand All @@ -48,17 +55,15 @@ class jit_emitter {
size_t get_max_vecs_count() const;
size_t get_vec_length() const;

const MKLDNNNode* n;
mkldnn::impl::cpu::x64::jit_generator* h;
mkldnn::impl::cpu::x64::cpu_isa_t host_isa_;
InferenceEngine::Precision exec_prc_;

Xbyak::Opmask k_mask;

virtual void prepare_table();
virtual void register_table_entries() {}

void load_table_addr() { h->mov(p_table, l_table); }
void load_table_addr() const { h->mov(p_table, *l_table.get()); }

// we accept only 32bit hexadecimal table values to avoid any rounding
using table_entry_val_t = uint32_t;
Expand All @@ -75,8 +80,8 @@ class jit_emitter {
table_entry_bcast_t bcast;
};

Xbyak::Reg64 p_table;
Xbyak::Label l_table;
mutable Xbyak::Reg64 p_table;
mutable std::shared_ptr<Xbyak::Label> l_table;

enum {
_cmp_eq_oq = mkldnn::impl::cpu::x64::jit_generator::_cmp_eq_oq,
Expand All @@ -89,16 +94,16 @@ class jit_emitter {

virtual void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) = 0;
const emitter_context *emit_context) const = 0;

virtual void emitter_preamble(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs);
virtual void emitter_postamble();
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const;
virtual void emitter_postamble() const;

emitter_in_out_map in_out_type_;

std::vector<size_t> aux_vec_idxs;
std::vector<size_t> aux_gpr_idxs;
mutable std::vector<size_t> aux_vec_idxs;
mutable std::vector<size_t> aux_gpr_idxs;

static constexpr int k_mask_size = 8;

Expand Down Expand Up @@ -126,8 +131,8 @@ class jit_emitter {
}

private:
std::vector<size_t> preserved_vec_idxs;
std::vector<size_t> preserved_gpr_idxs;
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;

void push_vec(const Xbyak::Address &addr, size_t vec_idx) const;
void pop_vec(size_t vec_idx, const Xbyak::Address &addr) const;
Expand Down
Loading

0 comments on commit 434e66d

Please sign in to comment.