Skip to content

Commit

Permalink
an overloaded emit() and size_t index
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Feb 1, 2021
1 parent 17195f4 commit 04665a6
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 232 deletions.
43 changes: 28 additions & 15 deletions inference-engine/src/mkldnn_plugin/nodes/common/emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ size_t jit_emitter::get_vec_length() const {
one_of(host_isa_, cpu::x64::avx2) ? 32 : 16;
}

void jit_emitter::push_vec(const Xbyak::Address &addr, int vec_idx) const {
void jit_emitter::push_vec(const Xbyak::Address &addr, size_t vec_idx) const {
if (host_isa_ == cpu::x64::sse41) {
h->uni_vmovups(addr, Xmm(vec_idx));
} else if (host_isa_ == cpu::x64::avx2) {
Expand All @@ -33,7 +33,7 @@ void jit_emitter::push_vec(const Xbyak::Address &addr, int vec_idx) const {
}
}

void jit_emitter::pop_vec(int vec_idx, const Xbyak::Address &addr) const {
void jit_emitter::pop_vec(size_t vec_idx, const Xbyak::Address &addr) const {
if (host_isa_ == cpu::x64::sse41) {
h->uni_vmovups(Xmm(vec_idx), addr);
} else if (host_isa_ == cpu::x64::avx2) {
Expand All @@ -56,8 +56,8 @@ std::set<InferenceEngine::Precision> jit_emitter::get_supported_precisions() {
return {InferenceEngine::Precision::FP32};
}

void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs) {
void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
using namespace Xbyak::util;
bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::vec_to_gpr);
bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::gpr_to_vec);
Expand All @@ -67,7 +67,7 @@ void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::v

// For sse41 mask register has to be Xmm(0)
if (host_isa_ == cpu::x64::sse41 && aux_vecs_count() > 0) {
int idx = 0;
size_t idx = 0;
if (is_vec_input)
assert(std::find(in_idxs.begin(), in_idxs.end(), idx) == in_idxs.end());
if (is_vec_output)
Expand All @@ -88,7 +88,7 @@ void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::v
}
}

for (int idx = 0; idx < get_max_vecs_count(); idx++) {
for (size_t idx = 0; idx < get_max_vecs_count(); idx++) {
if (aux_vec_idxs.size() >= aux_vecs_count()) break;

if (is_vec_input) {
Expand All @@ -109,8 +109,8 @@ void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::v
for (auto idx : pool_gpr_idxs)
aux_gpr_idxs.push_back(idx);

for (int gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) {
int _idx = Operand::R15 - gpr_idx; // we allocate from the end
for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) {
size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end

if (aux_gpr_idxs.size() >= aux_gprs_count()) break;
if (_idx == Operand::RSP) continue;
Expand All @@ -134,13 +134,13 @@ void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::v
aux_gpr_idxs.erase(aux_gpr_idxs.end() - 1);
}

for (int i = 0; i < preserved_gpr_idxs.size(); ++i)
for (size_t i = 0; i < preserved_gpr_idxs.size(); ++i)
h->push(Reg64(preserved_gpr_idxs[i]));

if (preserved_vec_idxs.size())
h->sub(h->rsp, preserved_vec_idxs.size() * get_vec_length());

for (int i = 0; i < preserved_vec_idxs.size(); ++i) {
for (size_t i = 0; i < preserved_vec_idxs.size(); ++i) {
push_vec(h->ptr[h->rsp + i * get_vec_length()], preserved_vec_idxs[i]);
}

Expand All @@ -151,7 +151,7 @@ void jit_emitter::emitter_preamble(const std::vector<int> &in_idxs, const std::v
void jit_emitter::emitter_postamble() {
using namespace Xbyak::util;

for (int i = 0; i < preserved_vec_idxs.size(); ++i)
for (size_t i = 0; i < preserved_vec_idxs.size(); ++i)
pop_vec(preserved_vec_idxs[i], h->ptr[h->rsp + i * get_vec_length()]);

if (preserved_vec_idxs.size())
Expand Down Expand Up @@ -198,12 +198,25 @@ void jit_emitter::prepare_table() {
}
}

void jit_emitter::emit(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
const std::shared_ptr<const emitter_context> &emit_context) {
void jit_emitter::emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

emitter_postamble();
}

void jit_emitter::emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);

if (emit_context) {
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
} else {
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
}

emitter_postamble();
}
Expand Down
32 changes: 19 additions & 13 deletions inference-engine/src/mkldnn_plugin/nodes/common/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ class jit_emitter {
k_mask = Xbyak::Opmask(1); // FIXME: in general case we need preserve k_mask state as well
}

virtual void emit(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs = {}, const std::vector<int> &pool_gpr_idxs = {},
const std::shared_ptr<const emitter_context> &emit_context = nullptr);
virtual void emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {});

virtual void emit(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {});
virtual void emit_table();
virtual size_t get_inputs_num() = 0;
virtual size_t aux_vecs_count() const;
Expand Down Expand Up @@ -84,18 +87,21 @@ class jit_emitter {
_cmp_gt_os = mkldnn::impl::cpu::x64::jit_generator::_cmp_nle_us,
};

virtual void emit_impl(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
virtual void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {}

virtual void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) {}

virtual void emitter_preamble(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs);
virtual void emitter_preamble(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs);
virtual void emitter_postamble();

emitter_in_out_map in_out_type_;

std::vector<int> aux_vec_idxs;
std::vector<int> aux_gpr_idxs;
std::vector<size_t> aux_vec_idxs;
std::vector<size_t> aux_gpr_idxs;

static constexpr int k_mask_size = 8;

Expand Down Expand Up @@ -123,11 +129,11 @@ class jit_emitter {
}

private:
std::vector<int> preserved_vec_idxs;
std::vector<int> preserved_gpr_idxs;
std::vector<size_t> preserved_vec_idxs;
std::vector<size_t> preserved_gpr_idxs;

void push_vec(const Xbyak::Address &addr, int vec_idx) const;
void pop_vec(int vec_idx, const Xbyak::Address &addr) const;
void push_vec(const Xbyak::Address &addr, size_t vec_idx) const;
void pop_vec(size_t vec_idx, const Xbyak::Address &addr) const;

size_t table_off(std::string& key, size_t key_off_val_shift = 0) const {
// assumption: all table entries sharing the same key also
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ size_t jit_load_emitter::aux_gprs_count() const {
return 2;
}

void jit_load_emitter::emit_impl(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
void jit_load_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) {
const auto* load_emitter_context = dynamic_cast<const MKLDNNPlugin::load_emitter_context*>(emit_context);
if (load_emitter_context == nullptr) {
THROW_IE_EXCEPTION << "Load emitter in " << n->getName() << " does not get load emmiter context.";
}

if (host_isa_ == cpu::x64::sse41) {
emit_isa<cpu::x64::sse41>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, out_idxs[0],
emit_isa<cpu::x64::sse41>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
} else if (host_isa_ == cpu::x64::avx2) {
emit_isa<cpu::x64::avx2>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, out_idxs[0],
emit_isa<cpu::x64::avx2>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
} else if (host_isa_ == cpu::x64::avx512_common) {
emit_isa<cpu::x64::avx512_common>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, out_idxs[0],
emit_isa<cpu::x64::avx512_common>(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast<int>(out_idxs[0]),
load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_);
} else {
THROW_IE_EXCEPTION << "Load emitter in " << n->getName() << " is performed on unsupported isa(at least x64::sse41).";
Expand Down Expand Up @@ -510,21 +510,21 @@ size_t jit_store_emitter::aux_vecs_count() const {

size_t jit_store_emitter::get_inputs_num() { return 1; }

void jit_store_emitter::emit_impl(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) {
const auto* store_emitter_context = dynamic_cast<const MKLDNNPlugin::store_emitter_context*>(emit_context);
if (store_emitter_context == nullptr) {
THROW_IE_EXCEPTION << "Store emitter in " << n->getName() << " does not get store emmiter context.";
}
if (host_isa_ == cpu::x64::sse41) {
emit_isa<cpu::x64::sse41>(in_idxs[0], store_emitter_context->src_prc_, Reg64(out_idxs[0]),
emit_isa<cpu::x64::sse41>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
} else if (host_isa_ == cpu::x64::avx2) {
emit_isa<cpu::x64::avx2>(in_idxs[0], store_emitter_context->src_prc_, Reg64(out_idxs[0]),
emit_isa<cpu::x64::avx2>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
} else if (host_isa_ == cpu::x64::avx512_common) {
emit_isa<cpu::x64::avx512_common>(in_idxs[0], store_emitter_context->src_prc_, Reg64(out_idxs[0]),
emit_isa<cpu::x64::avx512_common>(static_cast<int>(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]),
store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_);
} else {
THROW_IE_EXCEPTION << "Store emitter in " << n->getName() << " is performed on unsupported isa(at least x64::sse41).";
Expand Down Expand Up @@ -829,7 +829,7 @@ template <typename Vmm>
if (mayiuse(cpu::x64::avx512_core_bf16)) {
h->vcvtneps2bf16(ymm, zmm);
} else {
emu_vcvtneps2bf16->emit({vmm.getIdx()}, {ymm.getIdx()});
emu_vcvtneps2bf16->emit({static_cast<size_t>(vmm.getIdx())}, {static_cast<size_t>(ymm.getIdx())});
}
if (store_num == 16) {
h->vmovdqu16(ptr[reg + offset], ymm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class jit_load_emitter : public jit_emitter {
* \|/
* dst_prc
*/
void emit_impl(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) override;

size_t get_inputs_num() override;
Expand Down Expand Up @@ -117,8 +117,8 @@ class jit_store_emitter : public jit_emitter {
* dst_prc
* note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform
*/
void emit_impl(const std::vector<int> &in_idxs, const std::vector<int> &out_idxs,
const std::vector<int> &pool_vec_idxs, const std::vector<int> &pool_gpr_idxs,
void emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) override;

size_t get_inputs_num() override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge
if (mayiuse(avx512_core_bf16))
vcvtneps2bf16(ymm_dst, vmm_dst);
else
emu_vcvtneps2bf16->emit({vmm_dst.getIdx()}, {ymm_dst.getIdx()});
emu_vcvtneps2bf16->emit({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(ymm_dst.getIdx())});
vmovdqu16(op, ymm_dst);
break;
default:
Expand Down
Loading

0 comments on commit 04665a6

Please sign in to comment.