diff --git a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp index 007b7f0fefd887..f0f460d51713a5 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp @@ -23,6 +23,11 @@ enum emitter_in_out_map { gpr_to_gpr, }; +// structure for storage of emitter parameters to hash in map +struct emitter_params { + virtual size_t hash() const = 0; +}; + struct emitter_context { virtual ~emitter_context() = default; }; diff --git a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp index 7604c9a82fbc43..0cd6a88fdac937 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp @@ -18,95 +18,116 @@ using namespace Xbyak::util; namespace ov { namespace intel_cpu { +size_t load_emitter_params::hash() const { + size_t seed = 0; + seed = hash_combine(seed, std::string("jit_load_emitter")); + seed = hash_combine(seed, src_prc_.getPrecVal()); + seed = hash_combine(seed, dst_prc_.getPrecVal()); + seed = hash_combine(seed, load_num_); + seed = hash_combine(seed, is_fill_); + seed = hash_combine(seed, fill_value_); + return seed; +} + +size_t store_emitter_params::hash() const { + size_t seed = 0; + seed = hash_combine(seed, std::string("jit_store_emitter")); + seed = hash_combine(seed, src_prc_.getPrecVal()); + seed = hash_combine(seed, dst_prc_.getPrecVal()); + seed = hash_combine(seed, store_num_); + return seed; +} + /// LOAD /// -jit_load_emitter::jit_load_emitter(jit_generator *host, cpu_isa_t host_isa, - Precision exec_prc, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") { +jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + int load_num, Precision src_prc, Precision dst_prc, Precision exec_prc, + bool is_fill, std::string fill_value, emitter_in_out_map in_out_type) +: jit_emitter(host, host_isa, exec_prc, in_out_type), load_num_(load_num), src_prc_(src_prc), dst_prc_(dst_prc), + is_fill_(is_fill), fill_value_(fill_value), name_("unknown") { prepare_table(); - v_len_elt = get_vec_length() / exec_prc.size(); + load_size_ = load_num * src_prc.size(); + v_len_elt_ = get_vec_length() / exec_prc.size(); } size_t jit_load_emitter::get_inputs_num() const { return 1; } -// 0 for temp reg for mask load, 1 for table address size_t jit_load_emitter::aux_gprs_count() const { - return 2; + // 0 for temp reg for mask load in avx512, 1 for table address + if (mayiuse(cpu::x64::avx512_core) && is_fill_) + return 2; + // 0 for temp reg for only mask load in avx512 or for table in sse and avx2 + else if ((mayiuse(cpu::x64::avx512_core) && !one_of(load_num_, 16, 8, 4)) || is_fill_) + return 1; + else + return 0; } void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const { - const auto* load_emitter_context = dynamic_cast(emit_context); - if (load_emitter_context == nullptr) { - IE_THROW() << "Load emitter in " << name << " does not get load emmiter context."; - } - + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const { + const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx2) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else { - IE_THROW() << "Load emitter in " << name << " is performed on unsupported isa(at least x64::sse41)."; + IE_THROW() << "Load emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41)."; } } template -void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, InferenceEngine::Precision src_prc, - const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill, std::string fill_value) const { - bool matched_prc = (dst_prc == src_prc) || (dst_prc == Precision::FP32) || (dst_prc == Precision::I32); +void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const { + bool matched_prc = (dst_prc_ == src_prc_) || (dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::I32); if (!matched_prc) { - IE_THROW() << "Load emitter in " << name << " only support output precision of FP32 or I32 or the same precision as input."; + IE_THROW() << "Load emitter in " << name_ << " only support output precision of FP32 or I32 or the same precision as input."; } - if (load_num > (get_vec_length() / dst_prc.size())) { - IE_THROW() << "Load emitter in " << name << " have unexpected number of elements to load."; + if (load_num_ > (get_vec_length() / dst_prc_.size())) { + IE_THROW() << "Load emitter in " << name_ << " have unexpected number of elements to load."; } using Vmm = typename conditional3::type; // pure load - if (src_prc == dst_prc) { - load_bytes(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value); + if (src_prc_ == dst_prc_) { + load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); } else { // "pure load" + convert. dst_prc must be FP32 or I32. - switch (src_prc) { + switch (src_prc_) { case Precision::FP32: case Precision::I32: - load_bytes(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value); + load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); break; case Precision::I8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, true, load_num * src_prc.size(), is_fill, fill_value); + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, load_size_); break; case Precision::U8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, load_num * src_prc.size(), is_fill, fill_value); + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, load_size_); break; case Precision::I16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, true, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, true, load_size_); break; case Precision::U16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, false, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, false, load_size_); break; case Precision::BF16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, true, false, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, false, load_size_); break; default: - IE_THROW() << "Load emitter in " << name << " has unsupported src precision to load."; + IE_THROW() << "Load emitter in " << name_ << " has unsupported src precision to load."; } } // post convert between I32 and FP32 - if (src_prc != dst_prc) { - switch (dst_prc) { + if (src_prc_ != dst_prc_) { + switch (dst_prc_) { case Precision::FP32: - if ((src_prc != Precision::FP32) && (src_prc != Precision::BF16)) + if ((src_prc_ != Precision::FP32) && (src_prc_ != Precision::BF16)) h->uni_vcvtdq2ps(Vmm(out_vec_idx), Vmm(out_vec_idx)); break; case Precision::I32: - if ((src_prc == Precision::FP32) || (src_prc == Precision::BF16)) { + if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::BF16)) { h->uni_vcvtps2dq(Vmm(out_vec_idx), Vmm(out_vec_idx)); } break; @@ -129,8 +150,7 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, In * */ template -void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size, - bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -141,12 +161,12 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o // Ensure data fits completely inside the Xmm/Ymm/Zmm register if (load_size < 0 || load_size > 64) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte."; // check if proper number bytes fit inside the Xmm/Ymm register if (is_ymm && load_size > 32) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_byte."; if (is_xmm && load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_byte."; auto xmm = Xbyak::Xmm(vmm.getIdx()); auto ymm = Xbyak::Ymm(vmm.getIdx()); @@ -229,7 +249,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o break; case 16: break; default: - IE_THROW() << "Load emitter in " << name<< " has unexpected number of values to load in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte."; } if (has_xmm_block) { @@ -270,8 +290,8 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size / 4); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size / 4); } /** @@ -294,8 +314,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o */ template -void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, - int offset, bool is_signed, int load_size, bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -308,11 +327,11 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak // For Ymm register, load capacity is halved (32 * load_size <= 256) // For Xmm register, load capacity is halved further (32 * load_size <= 128) if (load_size < 0 || load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_bytes_to_dword_extension."; if (is_ymm && load_size > 8) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension."; if (is_xmm && load_size > 4) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension."; // For load_size == 4/8/16, do load/extension in one go switch (load_size) { @@ -365,8 +384,8 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size); } /** @@ -388,8 +407,7 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak * [0.. 32] for ZMM version of the function. i.e. 16 words -> 16 * 32 bit == 512 bit */ template -void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, - int offset, bool is_bf16, bool is_signed, int load_size, bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -402,11 +420,11 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak // For Ymm register, load capacity is halved (16/2(num) * 32 <= 128) // For Xmm register, load capacity is halved again (8/2(num) * 32 <= 128) if (load_size < 0 || load_size > 32) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_words_to_dword_extension."; if (is_ymm && load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_words_to_dword_extension."; if (is_xmm && load_size > 8) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_words_to_dword_extension."; auto xmm = Xbyak::Xmm(vmm.getIdx()); auto ymm = Xbyak::Ymm(vmm.getIdx()); @@ -483,148 +501,152 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size / 2); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size / 2); } template - void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - if (is_xmm || is_ymm) { - uint8 imm = 1; - imm = ~((imm << load_num) - imm); // shift load_num bit - h->uni_vblendps(vmm, vmm, table_val(fill_value), imm); - } else if (is_zmm) { - uint64_t tail_mask = 1; - tail_mask = ~((tail_mask << load_num) - tail_mask); - h->mov(Reg64(aux_gpr_idxs[0]), tail_mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vblendmps(vmm | k_mask, vmm, table_val(fill_value)); - } +void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; + + if (is_xmm || is_ymm) { + uint8 imm = 1; + imm = ~((imm << load_num) - imm); // shift load_num bit + h->uni_vblendps(vmm, vmm, table_val(fill_value), imm); + } else if (is_zmm) { + uint64_t tail_mask = 1; + tail_mask = ~((tail_mask << load_num) - tail_mask); + h->mov(Reg64(aux_gpr_idxs[0]), tail_mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vblendmps(vmm | k_mask, vmm, table_val(fill_value)); } +} void jit_load_emitter::register_table_entries() { - push_arg_entry_of("zero", 0x00000000, true); - push_arg_entry_of("int_one", 0x00000001, true); - push_arg_entry_of("float_one", 0x3f800000, true); - push_arg_entry_of("int32_min", 0xcf000000, true); - push_arg_entry_of("float_min", 0xff7fffff, true); - push_arg_entry_of("int32_max", 0x4effffff, true); - push_arg_entry_of("float_max", 0x7f7fffff, true); + if (is_fill_) { + push_arg_entry_of("zero", 0x00000000, true); + push_arg_entry_of("int_one", 0x00000001, true); + push_arg_entry_of("float_one", 0x3f800000, true); + push_arg_entry_of("int32_min", 0xcf000000, true); + push_arg_entry_of("float_min", 0xff7fffff, true); + push_arg_entry_of("int32_max", 0x4effffff, true); + push_arg_entry_of("float_max", 0x7f7fffff, true); + } } /// STORE /// -jit_store_emitter::jit_store_emitter(jit_generator *host, cpu_isa_t host_isa, - Precision exec_prc, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") { - v_len_elt = get_vec_length() / exec_prc.size(); +jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, int store_num, + Precision src_prc, Precision dst_prc, Precision exec_prc, emitter_in_out_map in_out_type) +: jit_emitter(host, host_isa, exec_prc, in_out_type), store_num_(store_num), src_prc_(src_prc), dst_prc_(dst_prc), name_("unknown") { + v_len_elt_ = get_vec_length() / exec_prc.size(); + store_size_ = store_num * dst_prc.size(); if (!mayiuse(cpu::x64::avx512_core_bf16) && mayiuse(cpu::x64::avx512_core)) { - emu_vcvtneps2bf16.reset(new jit_emu_vcvtneps2bf16(host, host_isa)); + emu_vcvtneps2bf16_.reset(new jit_emu_vcvtneps2bf16(host, host_isa)); } } -// 0 for temp reg for mask store +// 0 for temp reg for mask store for avx512 size_t jit_store_emitter::aux_gprs_count() const { - return 1; + return mayiuse(cpu::x64::avx512_core) && !one_of(store_num_, 16, 8, 4) ? 1 : 0; } -// zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status) size_t jit_store_emitter::aux_vecs_count() const { - return 1; + int count = 0; + // to avoid src vmm pollution after data type conversion + if ((src_prc_.is_float() && !dst_prc_.is_float()) || (!src_prc_.is_float() && dst_prc_.is_float())) + count++; + // zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status) + if ((mayiuse(cpu::x64::avx512_core) && (!one_of(store_num_, 16, 8, 4)) || (one_of(dst_prc_, Precision::U8, Precision::U16)))) + count++; + return count; } size_t jit_store_emitter::get_inputs_num() const { return 1; } void jit_store_emitter::emit_data() const { - if (emu_vcvtneps2bf16) - emu_vcvtneps2bf16->emit_data(); + if (emu_vcvtneps2bf16_) + emu_vcvtneps2bf16_->emit_data(); } void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, const emitter_context *emit_context) const { - const auto* store_emitter_context = dynamic_cast(emit_context); - if (store_emitter_context == nullptr) { - IE_THROW() << "Store emitter in " << name << " does not get store emmiter context."; - } + const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx2) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else { - IE_THROW() << "Store emitter in " << name << " is performed on unsupported isa(at least x64::sse41)."; + IE_THROW() << "Store emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41)."; } } template - void jit_store_emitter::emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc, - const Xbyak::Reg64 ®_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const { - bool matched_prc = (src_prc == dst_prc) || (src_prc == Precision::FP32) || (src_prc == Precision::I32); - if (!matched_prc) { - IE_THROW() << "Store emitter in " << name << " only support input precision of FP32 or I32 or the same precision as output."; - } - if ((src_prc == Precision::FP32) || (src_prc == Precision::I32)) { - if ((isa == cpu::x64::sse41 && store_num > 4) || (isa == cpu::x64::avx2 && store_num > 8) || - (isa == cpu::x64::avx512_core && store_num > 16) || store_num < 0) { - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store."; - } +void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const { + bool matched_prc = (src_prc_ == dst_prc_) || (src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32); + if (!matched_prc) { + IE_THROW() << "Store emitter in " << name_ << " only support input precision of FP32 or I32 or the same precision as output."; + } + if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32)) { + if ((isa == cpu::x64::sse41 && store_num_ > 4) || (isa == cpu::x64::avx2 && store_num_ > 8) || + (isa == cpu::x64::avx512_core && store_num_ > 16) || store_num_ < 0) { + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store."; } + } + using Vmm = typename conditional3::type; - using Vmm = typename conditional3::type; - - if (src_prc != dst_prc) { - switch (src_prc) { - case Precision::FP32: - if ((dst_prc != Precision::FP32) && (dst_prc != Precision::BF16)) { - h->uni_vcvtps2dq(Vmm(in_vec_idx), Vmm(in_vec_idx)); - } - break; - case Precision::I32: - if ((dst_prc == Precision::FP32) || (dst_prc == Precision::BF16)) - h->uni_vcvtdq2ps(Vmm(in_vec_idx), Vmm(in_vec_idx)); - break; - default: - break; - } + int data_idx = in_vec_idx; + if (src_prc_ != dst_prc_) { + switch (src_prc_) { + case Precision::FP32: + if ((dst_prc_ != Precision::FP32) && (dst_prc_ != Precision::BF16)) { + h->uni_vcvtps2dq(Vmm(aux_vec_idxs.back()), Vmm(data_idx)); + data_idx = aux_vec_idxs.back(); + } + break; + case Precision::I32: + if ((dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::BF16)) { + h->uni_vcvtdq2ps(Vmm(aux_vec_idxs.back()), Vmm(data_idx)); + data_idx = aux_vec_idxs.back(); + } + break; + default: + break; } + } - if (src_prc == dst_prc) { - store_bytes(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size()); - } else { - switch (dst_prc) { - case Precision::FP32: - case Precision::I32: - store_bytes(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size()); - break; - case Precision::I8: - store_dword_to_byte_extension(Vmm(in_vec_idx), reg_dst, offset_byte, true, store_num); - break; - case Precision::U8: - store_dword_to_byte_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, store_num); - break; - case Precision::I16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, true, store_num); - break; - case Precision::U16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, false, store_num); - break; - case Precision::BF16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, true, false, store_num); - break; - default: - IE_THROW() << "Store emitter in " << name << " has unsupported dst precision to store."; - } + if (src_prc_ == dst_prc_) { + store_bytes(Vmm(data_idx), reg_dst, offset, store_size_); + } else { + switch (dst_prc_) { + case Precision::FP32: + case Precision::I32: + store_bytes(Vmm(data_idx), reg_dst, offset, store_size_); + break; + case Precision::I8: + store_dword_to_byte_extension(Vmm(data_idx), reg_dst, offset, true, store_num_); + break; + case Precision::U8: + store_dword_to_byte_extension(Vmm(data_idx), reg_dst, offset, false, store_num_); + break; + case Precision::I16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, false, true, store_num_); + break; + case Precision::U16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, false, false, store_num_); + break; + case Precision::BF16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, true, false, store_num_); + break; + default: + IE_THROW() << "Store emitter in " << name_ << " has unsupported dst precision to store."; } } - +} /** * store_bytes is the utility function to facilitate storing of * store_size (0 <= store_size <= 64) many contiguous bytes from the Xmm/Ymm/Zmm @@ -641,130 +663,130 @@ template * */ template - void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - if (store_size < 0 || store_size > 64) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes."; - if (is_ymm && store_size > 32) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_bytes."; - if (is_xmm && store_size > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_bytes."; - - auto xmm = Xbyak::Xmm(vmm.getIdx()); - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto zmm = Xbyak::Zmm(vmm.getIdx()); - - const auto addr = [&](int bytes_offset) { - return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; - }; - - auto store_byte_base = [&]() { - int start_bytes = 0; - int bytes_to_store = store_size; - - if (store_size > 32) { - h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm - start_bytes += 32; - bytes_to_store -= 32; - h->vextractf64x4(ymm, zmm, 1); // load upper bits from zmm into ymm - } +void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; - if (bytes_to_store > 16) { - h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm - start_bytes += 16; - bytes_to_store -= 16; - h->vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm - } + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); - if (bytes_to_store >= 8 && bytes_to_store < 16) - h->uni_vmovq(addr(start_bytes), xmm); - // h->pextrq(addr(start_bytes), xmm, 0); - else if (bytes_to_store == 16) - h->uni_vmovdqu(addr(start_bytes), xmm); - - // 64/32/16/8 with one go - // tail 7 bytes for lower or upper xmm - switch (bytes_to_store) { - case 0: break; - case 1: h->uni_vpextrb(addr(start_bytes), xmm, 0); break; - case 2: h->uni_vpextrw(addr(start_bytes), xmm, 0); break; - case 3: - h->uni_vpextrw(addr(start_bytes), xmm, 0); - h->uni_vpextrb(addr(start_bytes + 2), xmm, 2); - break; - case 4: h->uni_vmovss(addr(start_bytes), xmm); break; - // h->uni_vpextrd(addr(start_bytes), xmm, 0); break; - case 5: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); - break; - case 6: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - break; - case 7: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); - break; - case 8: break; - case 9: h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; - case 10: h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; - case 11: - h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); - h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); - break; - case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; - case 13: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); - break; - case 14: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - break; - case 15: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); - break; - case 16: break; - default: - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes."; - } - }; + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + if (store_size < 0 || store_size > 64) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes."; + if (is_ymm && store_size > 32) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_bytes."; + if (is_xmm && store_size > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_bytes."; + + auto xmm = Xbyak::Xmm(vmm.getIdx()); + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto zmm = Xbyak::Zmm(vmm.getIdx()); + + const auto addr = [&](int bytes_offset) { + return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; + }; + + auto store_byte_base = [&]() { + int start_bytes = 0; + int bytes_to_store = store_size; + + if (store_size > 32) { + h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm + start_bytes += 32; + bytes_to_store -= 32; + h->vextractf64x4(ymm, zmm, 1); // load upper bits from zmm into ymm + } + + if (bytes_to_store > 16) { + h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm + start_bytes += 16; + bytes_to_store -= 16; + h->vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm + } - switch (store_size) { - case 64: - h->uni_vmovdqu(addr(0), zmm); + if (bytes_to_store >= 8 && bytes_to_store < 16) + h->uni_vmovq(addr(start_bytes), xmm); + // h->pextrq(addr(start_bytes), xmm, 0); + else if (bytes_to_store == 16) + h->uni_vmovdqu(addr(start_bytes), xmm); + + // 64/32/16/8 with one go + // tail 7 bytes for lower or upper xmm + switch (bytes_to_store) { + case 0: break; + case 1: h->uni_vpextrb(addr(start_bytes), xmm, 0); break; + case 2: h->uni_vpextrw(addr(start_bytes), xmm, 0); break; + case 3: + h->uni_vpextrw(addr(start_bytes), xmm, 0); + h->uni_vpextrb(addr(start_bytes + 2), xmm, 2); break; - case 32: - h->uni_vmovdqu(addr(0), ymm); + case 4: h->uni_vmovss(addr(start_bytes), xmm); break; + // h->uni_vpextrd(addr(start_bytes), xmm, 0); break; + case 5: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); break; - case 16: - h->uni_vmovdqu(addr(0), xmm); + case 6: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); break; - default: - if (mayiuse(cpu::x64::avx512_core)) { - uint64_t mask = 1; - mask = (mask << store_size) - mask; - h->mov(Reg64(aux_gpr_idxs[0]), mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vmovdqu8(addr(0), zmm | k_mask); - } else { - store_byte_base(); - } + case 7: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); break; + case 8: break; + case 9: h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; + case 10: h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; + case 11: + h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); + h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); + break; + case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; + case 13: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); + break; + case 14: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + break; + case 15: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); + break; + case 16: break; + default: + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes."; } + }; + + switch (store_size) { + case 64: + h->uni_vmovdqu(addr(0), zmm); + break; + case 32: + h->uni_vmovdqu(addr(0), ymm); + break; + case 16: + h->uni_vmovdqu(addr(0), xmm); + break; + default: + if (mayiuse(cpu::x64::avx512_core)) { + uint64_t mask = 1; + mask = (mask << store_size) - mask; + h->mov(Reg64(aux_gpr_idxs[0]), mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vmovdqu8(addr(0), zmm | k_mask); + } else { + store_byte_base(); + } + break; } +} /** * store_dword_to_byte_extension is the utility function to @@ -772,231 +794,231 @@ template * 2. store the packed byte into the memory referenced by ptr[reg + offset] address. */ template - void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int store_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - // At most 8 dwords can fit inside the Ymm register - // At most 4 dwords can fit inside the Xmm register - if (store_num < 0 || store_num > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_byte_extension."; - if (is_ymm && store_num > 8) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_byte_extension."; - if (is_xmm && store_num > 4) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_byte_extension."; - - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto xmm = Xbyak::Xmm(vmm.getIdx()); - - const auto addr = [&](int bytes_offset) { - return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; - }; - - auto store_dword_to_byte_base = [&]() { - // db only available on avx512, need dw+wb to emulate - if (is_signed) - h->uni_vpackssdw(vmm, vmm, vmm); - else - h->uni_vpackusdw(vmm, vmm, vmm); - // gather 2(cross lane) 64 bits into lower vmm to store - // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] - if (is_ymm) { - h->vpermq(ymm, ymm, 0x08); // 00001000 +void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int store_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; + + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); + + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + // At most 8 dwords can fit inside the Ymm register + // At most 4 dwords can fit inside the Xmm register + if (store_num < 0 || store_num > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_byte_extension."; + if (is_ymm && store_num > 8) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_byte_extension."; + if (is_xmm && store_num > 4) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_byte_extension."; + + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto xmm = Xbyak::Xmm(vmm.getIdx()); + + const auto addr = [&](int bytes_offset) { + return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; + }; + + auto store_dword_to_byte_base = [&]() { + // db only available on avx512, need dw+wb to emulate + if (is_signed) + h->uni_vpackssdw(vmm, vmm, vmm); + else + h->uni_vpackusdw(vmm, vmm, vmm); + // gather 2(cross lane) 64 bits into lower vmm to store + // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] + if (is_ymm) { + h->vpermq(ymm, ymm, 0x08); // 00001000 + } + + if (is_signed) + h->uni_vpacksswb(vmm, vmm, vmm); + else + h->uni_vpackuswb(vmm, vmm, vmm); + + store_bytes(vmm, reg, offset, store_num); + }; + + switch (store_num) { + case 16: + // must support avx512F + if (is_signed) { + h->vpmovsdb(addr(0), vmm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(vmm, vmm, zero); + h->vpmovusdb(addr(0), vmm); + } + break; + case 8: + if (mayiuse(cpu::x64::avx512_core)) { // ymm block on avx512F + VL + if (is_signed) { + h->vpmovsdb(addr(0), ymm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(ymm, ymm, zero); + h->vpmovusdb(addr(0), ymm); + } + } else { + store_dword_to_byte_base(); + } + break; + case 4: + if (mayiuse(cpu::x64::avx512_core)) { // xmm block on avx512F + VL + if (is_signed) { + h->vpmovsdb(addr(0), xmm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(xmm, xmm, zero); + h->vpmovusdb(addr(0), xmm); + } + } else { + store_dword_to_byte_base(); + } + break; + default: + if (is_zmm) { // avx512F + unsigned int mask = 1; + mask = (mask << store_num) - mask; + h->mov(Reg32(aux_gpr_idxs[0]), mask); + h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); + if (is_signed) { + h->vpmovsdb(addr(0), vmm | k_mask); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(vmm, vmm, zero); + h->vpmovusdb(addr(0), vmm | k_mask); + } + } else { + store_dword_to_byte_base(); } + break; + } +} - if (is_signed) - h->uni_vpacksswb(vmm, vmm, vmm); - else - h->uni_vpackuswb(vmm, vmm, vmm); +/** +* store_dword_to_word_extension is the utility function to +* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged saturation. +* 2. store the packed words into the memory referenced by ptr[reg + offset] address. +*/ +template +void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 ®, + int offset, bool is_bf16, bool is_signed, int store_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; + + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); + + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + // At most 4 dwords can fit inside the Xmm register + // At most 8 dwords can fit inside the Ymm register + if (store_num < 0 || store_num > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_word_extension."; + if (is_ymm && store_num > 8) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_word_extension."; + if (is_xmm && store_num > 4) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_word_extension."; + + auto xmm = Xbyak::Xmm(vmm.getIdx()); + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto zmm = Xbyak::Zmm(vmm.getIdx()); + + auto store_dword_to_word_base = [&]() { + // direct mov_dw available only on avx512, emulate with pack_dw + permute + pure store + if (is_signed) + h->uni_vpackssdw(vmm, vmm, vmm); + else + h->uni_vpackusdw(vmm, vmm, vmm); + // gather 2/4(cross lane) 64 bits into lower vmm to store + // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] + // [ 128 | 128 ] |--> [ 128 | 128 ] + if (is_ymm) { + h->vpermq(ymm, ymm, 0x08); // 00001000 + } - store_bytes(vmm, reg, offset, store_num); - }; + store_bytes(vmm, reg, offset, store_num * 2); + }; + if (is_bf16) { + if (mayiuse(cpu::x64::avx512_core_bf16)) { + h->vcvtneps2bf16(ymm, zmm); + } else { + emu_vcvtneps2bf16_->emit_code({static_cast(vmm.getIdx())}, {static_cast(ymm.getIdx())}); + } + if (store_num == 16) { + h->vmovdqu16(ptr[reg + offset], ymm); + } else { + store_bytes(ymm, reg, offset, store_num * 2); + } + } else { switch (store_num) { case 16: - // must support avx512F if (is_signed) { - h->vpmovsdb(addr(0), vmm); + h->vpmovsdw(ptr[reg + offset], vmm); // singed int32 saturate to signed int16. } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, vmm, zero); - h->vpmovusdb(addr(0), vmm); + h->uni_vpmaxsd(vmm, zero, vmm); // if singed bit is 1, set value as 0. + h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. } break; case 8: - if (mayiuse(cpu::x64::avx512_core)) { // ymm block on avx512F + VL + if (mayiuse(cpu::x64::avx512_core)) { if (is_signed) { - h->vpmovsdb(addr(0), ymm); + h->vpmovsdw(ptr[reg + offset], ymm); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(ymm, ymm, zero); - h->vpmovusdb(addr(0), ymm); + h->uni_vpmaxsd(ymm, zero, ymm); + h->vpmovusdw(ptr[reg + offset], ymm); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; case 4: - if (mayiuse(cpu::x64::avx512_core)) { // xmm block on avx512F + VL + if (mayiuse(cpu::x64::avx512_core)) { if (is_signed) { - h->vpmovsdb(addr(0), xmm); + h->vpmovsdw(ptr[reg + offset], xmm); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(xmm, xmm, zero); - h->vpmovusdb(addr(0), xmm); + h->uni_vpmaxsd(xmm, zero, xmm); + h->vpmovusdw(ptr[reg + offset], xmm); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; default: - if (is_zmm) { // avx512F + if (is_zmm) { unsigned int mask = 1; mask = (mask << store_num) - mask; h->mov(Reg32(aux_gpr_idxs[0]), mask); h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); if (is_signed) { - h->vpmovsdb(addr(0), vmm | k_mask); + h->vpmovsdw(ptr[reg + offset], vmm | k_mask); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, vmm, zero); - h->vpmovusdb(addr(0), vmm | k_mask); + h->uni_vpmaxsd(vmm, zero, vmm); + h->vpmovusdw(ptr[reg + offset], vmm | k_mask); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; } } - -/** -* store_dword_to_word_extension is the utility function to -* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged saturation. -* 2. store the packed words into the memory referenced by ptr[reg + offset] address. -*/ -template - void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, - bool is_bf16, bool is_signed, int store_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - // At most 4 dwords can fit inside the Xmm register - // At most 8 dwords can fit inside the Ymm register - if (store_num < 0 || store_num > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_word_extension."; - if (is_ymm && store_num > 8) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_word_extension."; - if (is_xmm && store_num > 4) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_word_extension."; - - auto xmm = Xbyak::Xmm(vmm.getIdx()); - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto zmm = Xbyak::Zmm(vmm.getIdx()); - - auto store_dword_to_word_base = [&]() { - // direct mov_dw available only on avx512, emulate with pack_dw + permute + pure store - if (is_signed) - h->uni_vpackssdw(vmm, vmm, vmm); - else - h->uni_vpackusdw(vmm, vmm, vmm); - // gather 2/4(cross lane) 64 bits into lower vmm to store - // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] - // [ 128 | 128 ] |--> [ 128 | 128 ] - if (is_ymm) { - h->vpermq(ymm, ymm, 0x08); // 00001000 - } - - store_bytes(vmm, reg, offset, store_num * 2); - }; - - if (is_bf16) { - if (mayiuse(cpu::x64::avx512_core_bf16)) { - h->vcvtneps2bf16(ymm, zmm); - } else { - emu_vcvtneps2bf16->emit_code({static_cast(vmm.getIdx())}, {static_cast(ymm.getIdx())}); - } - if (store_num == 16) { - h->vmovdqu16(ptr[reg + offset], ymm); - } else { - store_bytes(ymm, reg, offset, store_num * 2); - } - } else { - switch (store_num) { - case 16: - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], vmm); // singed int32 saturate to signed int16. - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, zero, vmm); // if singed bit is 1, set value as 0. - h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. - } - break; - case 8: - if (mayiuse(cpu::x64::avx512_core)) { - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], ymm); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(ymm, zero, ymm); - h->vpmovusdw(ptr[reg + offset], ymm); - } - } else { - store_dword_to_word_base(); - } - break; - case 4: - if (mayiuse(cpu::x64::avx512_core)) { - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], xmm); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(xmm, zero, xmm); - h->vpmovusdw(ptr[reg + offset], xmm); - } - } else { - store_dword_to_word_base(); - } - break; - default: - if (is_zmm) { - unsigned int mask = 1; - mask = (mask << store_num) - mask; - h->mov(Reg32(aux_gpr_idxs[0]), mask); - h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], vmm | k_mask); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, zero, vmm); - h->vpmovusdw(ptr[reg + offset], vmm | k_mask); - } - } else { - store_dword_to_word_base(); - } - break; - } - } - } +} } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp index 2abcf8e0ca6bd4..f4672f9fdfcda0 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp @@ -15,29 +15,25 @@ using namespace InferenceEngine; namespace ov { namespace intel_cpu { -struct load_emitter_context : public emitter_context { - load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8), - offset_byte_(0), is_fill_(false), fill_value_("zero") {} +struct load_emitter_params : public emitter_params { + load_emitter_params(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero"): + src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value) {} - load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"): - src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {} + size_t hash() const override; - int offset_byte_; - int load_num_; Precision src_prc_; Precision dst_prc_; + int load_num_; bool is_fill_; std::string fill_value_; }; -struct store_emitter_context : public emitter_context { - store_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), - store_num_(8), offset_byte_(0) {} +struct store_emitter_params : public emitter_params { + store_emitter_params(Precision src_prc, Precision dst_prc, int store_num): + src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num) {} - store_emitter_context(Precision src_prc, Precision dst_prc, int store_num, int offset_byte = 0) - : src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num), offset_byte_(offset_byte) {} + size_t hash() const override; - int offset_byte_; int store_num_; Precision src_prc_; Precision dst_prc_; @@ -45,10 +41,11 @@ struct store_emitter_context : public emitter_context { class jit_load_emitter : public jit_emitter { public: - jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); + jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, int load_num, Precision src_prc, Precision dst_prc, + Precision exec_prc = Precision::FP32, bool is_fill = false, std::string fill_value = "zero", + emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); /** - * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc. + * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1] * is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values. * fill_value: when load_num can not fully fit in vector register, what values should be filled as default values. * currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max". @@ -66,27 +63,23 @@ class jit_load_emitter : public jit_emitter { * dst_prc */ void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const override; + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const override; size_t get_inputs_num() const override; private: template - void emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, InferenceEngine::Precision src_prc, - const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero") const; + void emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const; template - void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const; template - void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const; template - void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size) const; template void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const; @@ -95,17 +88,23 @@ class jit_load_emitter : public jit_emitter { size_t aux_gprs_count() const override; - std::string name; - int v_len_elt; // 4/8/16 + std::string name_; + int v_len_elt_; // 4/8/16 + int load_num_; + int load_size_; + Precision src_prc_; + Precision dst_prc_; + bool is_fill_; + std::string fill_value_; }; class jit_store_emitter : public jit_emitter { public: - jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); + jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, int size_num, Precision src_prc, Precision dst_prc, + Precision exec_prc = Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); /** - * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data. + * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, where offset_byte is in_idxs[1] * supported src_prc and dst_prc pairs are as below(x indicate for support): * FP32 I32 I16 U16 I8 U8 BF16 --> src_prc * FP32 x x @@ -120,21 +119,20 @@ class jit_store_emitter : public jit_emitter { * note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform */ void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const override; + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const override; size_t get_inputs_num() const override; void emit_data() const override; std::shared_ptr get_emu_vcvtneps2bf16() const { - return emu_vcvtneps2bf16; + return emu_vcvtneps2bf16_; } private: template - void emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc, - const Xbyak::Reg64 ®_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const; + void emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const; template void store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const; @@ -148,9 +146,13 @@ class jit_store_emitter : public jit_emitter { size_t aux_gprs_count() const override; size_t aux_vecs_count() const override; - std::string name; - int v_len_elt; // 4/8/16 - std::shared_ptr emu_vcvtneps2bf16; + std::string name_; + int v_len_elt_; // 4/8/16 + int store_num_; + int store_size_; + Precision src_prc_; + Precision dst_prc_; + std::shared_ptr emu_vcvtneps2bf16_; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/color_convert.cpp b/src/plugins/intel_cpu/src/nodes/color_convert.cpp index b2ef4f31755d04..e31fe58955258f 100644 --- a/src/plugins/intel_cpu/src/nodes/color_convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/color_convert.cpp @@ -422,6 +422,9 @@ class JitConverter : public jit_uni_converter { template void JitConverter::generate() { + init_load(N); + init_store(N); + preamble(); // Get arguments addresses @@ -776,6 +779,10 @@ class JitConverter : public jit_uni_converter { template void JitConverter::generate() { + init_load(N); + init_load(N / 2); + init_store(N); + preamble(); // Get arguments addresses diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.cpp b/src/plugins/intel_cpu/src/nodes/interpolate.cpp index 0c0df6ec67b4a8..afe4b8132064e2 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/interpolate.cpp @@ -58,8 +58,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + init_emitters(); // dummy second reg_tmp_64 as no fill needed load_pool_gpr_idxs = {static_cast(reg_tmp_64.getIdx()), static_cast(reg_tmp_64.getIdx())}; @@ -162,8 +161,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); for (auto& inj : eltwise_injectors) inj->prepare_table(); if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) { @@ -176,6 +174,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Ymm, Xbyak::Zmm>::type; const int vlen = cpu_isa_traits::vlen; + const int step = vlen / sizeof(float); + const int tail_num = jcp_.C % step; + const int scalar = 1; Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_src_aux = r15; @@ -246,8 +247,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Label l_table_constant; Opmask k_mask = Xbyak::Opmask(1); - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unordered_map> emitters; + std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; std::vector load_pool_gpr_idxs; @@ -256,20 +257,57 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi std::vector>> depthwise_injectors; std::vector>> quantization_injectors; - inline void load(const Xbyak::Reg64& reg_src, Vmm& vmm, const int& elt_num, const int& offset = 0) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm.getIdx())}, - std::make_shared(jcp_.src_prc, Precision::FP32, elt_num, offset), - {}, {load_pool_gpr_idxs}); + void init_emitters() { + const std::set steps = { step, tail_num, scalar}; + for (auto step : steps) { + emitters[load_emitter_params(jcp_.src_prc, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, jcp_.src_prc, Precision::FP32)); + if (jcp_.src_prc != Precision::FP32 && step != tail_num) { // to avoid repeated keys and extra load_weights emitter for tail_num case + emitters[load_emitter_params(Precision::FP32, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, Precision::FP32, Precision::FP32)); + } + + emitters[store_emitter_params(Precision::FP32, jcp_.dst_prc, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::FP32, jcp_.dst_prc)); + } } - inline void store(const Vmm& vmm, const Xbyak::Reg64& reg_dst, const int& elt_num, const int& offset = 0) { - store_emitter->emit_code({static_cast(vmm.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.dst_prc, elt_num, offset), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + + void emit_emitters_data() { + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); } - inline void load_weights(const Xbyak::Reg64& reg_weights, Vmm& vmm, const int& elt_num, const int& offset = 0) { - load_emitter->emit_code({static_cast(reg_weights.getIdx())}, {static_cast(vmm.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num, offset), - {}, {load_pool_gpr_idxs}); + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(jcp_.src_prc, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << jcp_.src_prc << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void load_weights(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(Precision::FP32, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << Precision::FP32 << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(Precision::FP32, jcp_.dst_prc, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::FP32 << "." << jcp_.dst_prc << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, + {store_pool_vec_idxs}, {store_pool_gpr_idxs}); } void nn_planar() { @@ -303,7 +341,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // reset index_w, index_w * dataSize done when built to avoid redundent compute mov(reg_index, reg_index_w); - int step = vlen / sizeof(float); Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; @@ -330,7 +367,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(nn_loop_end_label); - step = 1; L(nn_tail_loop_label); { cmp(reg_work_amount, 1); @@ -340,14 +376,14 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_index_offset, dword[reg_index]); add(reg_src_aux, reg_index_offset); - load(reg_src_aux, vmm_val, step); + load(reg_src_aux, vmm_val, scalar); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 1); - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, scalar * jcp_.dst_data_size); + add(reg_index, scalar * jcp_.indices_size); + sub(reg_work_amount, scalar); jmp(nn_tail_loop_label, T_NEAR); } @@ -363,8 +399,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } void nn_blk() { - int step = vlen / sizeof(float); - Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; L(nn_loop_label); @@ -421,8 +455,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cmp(reg_work_amount_out, 1); jl(out_loop_end, T_NEAR); - int step = vlen / sizeof(float); - //inner loop for C Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; @@ -461,7 +493,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(nn_loop_end_label); - int tail_num = jcp_.C % step; if (tail_num != 0) { load(reg_src_aux, vmm_val, tail_num); if (attr_.post_ops_.len() != 0) @@ -519,7 +550,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); int blk = (isa == cpu::x64::sse41) ? (2 * step) : step; int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.dst_data_size) : (blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size); @@ -627,7 +657,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(main_loop_end_label); - int tail_num = jcp_.C % step; if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_num != 0)) { load(reg_src, vmm_valTL, tail_num); load(reg_src_aux, vmm_valTR, tail_num); @@ -669,7 +698,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]); mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); int index_stride = jcp_.OW * jcp_.OH * jcp_.OD * jcp_.indices_size; int weight_stride = jcp_.OW * jcp_.OH * jcp_.OD * sizeof(float); @@ -754,7 +782,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -763,15 +790,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTL, step); + load(reg_src_aux1, vmm_valTL, scalar); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTR, step); + load(reg_src_aux1, vmm_valTR, scalar); - load_weights(reg_src_aux, vmm_weightL, step, 0); - load_weights(reg_src_aux, vmm_weightR, step, weight_stride); + load_weights(reg_src_aux, vmm_weightL, scalar, 0); + load_weights(reg_src_aux, vmm_weightR, scalar, weight_stride); if (jcp_.spatial_dim_size == 1) { linear_onnx_worker_1d(); @@ -780,15 +807,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 2 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBL, step); + load(reg_src_aux1, vmm_valBL, scalar); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 3 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBR, step); + load(reg_src_aux1, vmm_valBR, scalar); - load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride); - load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride); + load_weights(reg_src_aux, vmm_weightT, scalar, 2 * weight_stride); + load_weights(reg_src_aux, vmm_weightB, scalar, 3 * weight_stride); linear_onnx_worker_2d(); } @@ -799,27 +826,27 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 4 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTL, step); + load(reg_src_aux1, vmm_valTL, scalar); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 5 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTR, step); + load(reg_src_aux1, vmm_valTR, scalar); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 6 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBL, step); + load(reg_src_aux1, vmm_valBL, scalar); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 7 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBR, step); + load(reg_src_aux1, vmm_valBR, scalar); linear_onnx_worker_2d(); - load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride); - load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride); + load_weights(reg_src_aux, vmm_weightE, scalar, 5 * weight_stride); + load_weights(reg_src_aux, vmm_weightF, scalar, 4 * weight_stride); uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight @@ -828,12 +855,12 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // process on vmm_val, vmm_val is vmm_valTR, and bc } - store(vmm_valTR, reg_dst, step); + store(vmm_valTR, reg_dst, scalar); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_src_aux, step * sizeof(float)); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, scalar * jcp_.dst_data_size); + add(reg_src_aux, scalar * sizeof(float)); + add(reg_index, scalar * jcp_.indices_size); + sub(reg_work_amount, scalar); jmp(tail_loop_label, T_NEAR); } @@ -876,7 +903,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]); uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]); - int step = vlen / sizeof(float); int blk = (isa == cpu::x64::sse41) ? (2 * step) : step; Xbyak::Label main_loop_label; @@ -940,7 +966,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(main_loop_end_label); // only for by_channel layout for tails. - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -953,15 +978,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, scalar * sizeof(float)); } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar); - int dst_stride = step * jcp_.dst_data_size; - int src_stride = step * jcp_.src_data_size; + int dst_stride = scalar * jcp_.dst_data_size; + int src_stride = scalar * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); - sub(reg_work_amount, step); // work_amount is c + sub(reg_work_amount, scalar); // work_amount is c jmp(tail_loop_label, T_NEAR); } @@ -1020,7 +1045,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]); mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); int grid_len = 4; // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 @@ -1123,7 +1147,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -1182,13 +1205,13 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar); - add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence with dd() - add(reg_tbl_x, step * sizeof(int)); - add(reg_dst, step * jcp_.dst_data_size); + add(reg_tbl_y, scalar * sizeof(int)); // sizeof(int): sequence with dd() + add(reg_tbl_x, scalar * sizeof(int)); + add(reg_dst, scalar * jcp_.dst_data_size); - sub(reg_work_amount, step); + sub(reg_work_amount, scalar); jmp(tail_loop_label, T_NEAR); } diff --git a/src/plugins/intel_cpu/src/nodes/mvn.cpp b/src/plugins/intel_cpu/src/nodes/mvn.cpp index 506dbea8547733..ba72e6b5feb00e 100644 --- a/src/plugins/intel_cpu/src/nodes/mvn.cpp +++ b/src/plugins/intel_cpu/src/nodes/mvn.cpp @@ -110,7 +110,13 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); + tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : + jcp_.C - (jcp_.C / step) * step; + + Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; + load_vector_emitter.reset(new jit_load_emitter(this, isa, step, jcp_.src_prc, dst_prc)); + load_tail_emitter.reset(new jit_load_emitter(this, isa, tail_num, jcp_.src_prc, dst_prc)); + load_tail_with_fill_emitter.reset(new jit_load_emitter(this, isa, tail_num, jcp_.src_prc, dst_prc, Precision::FP32, true)); this->preamble(); mov(reg_src, ptr[reg_params + GET_OFF(src)]); @@ -134,9 +140,6 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } } - tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : - jcp_.C - (jcp_.C / step) * step; - load_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx()), static_cast(reg_load_table.getIdx())}; if (jcp_.planar_layout) { @@ -251,7 +254,9 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k this->postamble(); - load_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_tail_emitter->emit_data(); + load_tail_with_fill_emitter->emit_data(); } private: @@ -286,15 +291,15 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k Xbyak::Opmask k_mask = Xbyak::Opmask(7); - std::unique_ptr load_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_tail_emitter = nullptr; + std::unique_ptr load_tail_with_fill_emitter = nullptr; std::vector load_pool_gpr_idxs; inline void worker_full_size() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, step), - {}, {load_pool_gpr_idxs}); + load_vector_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, + {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { // all with float @@ -313,9 +318,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } inline void worker_tail_blk() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, tail_num), + load_tail_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { @@ -357,10 +360,8 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } inline void worker_tail_planar() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, tail_num, 0, true), - {}, {load_pool_gpr_idxs}); + load_tail_with_fill_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, + {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { if (!isFloatCompatible(jcp_.src_prc)) @@ -435,8 +436,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator } } - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : + jcp_.C - (jcp_.C / step) * step; + + load_vector_emitter.reset(new jit_load_emitter(this, isa, step, jcp_.src_prc, Precision::FP32)); + load_tail_emitter.reset(new jit_load_emitter(this, isa, tail_num, jcp_.src_prc, Precision::FP32)); + store_vector_emitter.reset(new jit_store_emitter(this, isa, step, Precision::FP32, jcp_.dst_prc)); + store_tail_emitter.reset(new jit_store_emitter(this, isa, tail_num, Precision::FP32, jcp_.dst_prc)); this->preamble(); @@ -463,9 +469,6 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator uni_vpxor(vmm_zero, vmm_zero, vmm_zero); - tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : - jcp_.C - (jcp_.C / step) * step; - load_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx()), static_cast(reg_load_table.getIdx())}; store_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx())}; store_pool_vec_idxs = {static_cast(vmm_zero.getIdx())}; @@ -530,8 +533,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_tail_emitter->emit_data(); + store_vector_emitter->emit_data(); + store_tail_emitter->emit_data(); for (auto& inj : eltwise_injectors) inj->prepare_table(); @@ -570,8 +575,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Vmm vmm_d_weights = Vmm(5); Vmm vmm_d_bias = Vmm(6); - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_tail_emitter = nullptr; + std::unique_ptr store_vector_emitter = nullptr; + std::unique_ptr store_tail_emitter = nullptr; std::vector>> eltwise_injectors; std::vector>> depthwise_injectors; @@ -582,9 +589,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator std::vector load_pool_gpr_idxs; inline void worker_mvn(bool is_tail) { - int elt_num = is_tail ? tail_num : step; + const auto& load_emitter = is_tail ? load_tail_emitter : load_vector_emitter; + const auto& store_emitter = is_tail ? store_tail_emitter : store_vector_emitter; + load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, Precision::FP32, elt_num), {}, {load_pool_gpr_idxs}); uni_vsubps(vmm_val, vmm_val, vmm_mean); @@ -594,7 +602,6 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator apply_post_ops(jcp_.dst_prc, jcp_.planar_layout); store_emitter->emit_code({static_cast(vmm_val.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.dst_prc, elt_num), {store_pool_vec_idxs}, {store_pool_gpr_idxs}); } diff --git a/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp b/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp index 632fa5572a6d02..bd2b4464c82c16 100644 --- a/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp +++ b/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp @@ -44,8 +44,9 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + load_vector_emitter.reset(new jit_load_emitter(this, isa, step, Precision::FP32, Precision::FP32)); + load_scalar_emitter.reset(new jit_load_emitter(this, isa, 1, Precision::FP32, Precision::FP32)); + exp_injector.reset(new jit_uni_eltwise_injector_f32(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f)); this->preamble(); @@ -137,8 +138,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_scalar_emitter->emit_data(); prepare_table(); exp_injector->prepare_table(); @@ -147,6 +148,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator private: using Vmm = typename conditional3::type; uint32_t vlen = cpu_isa_traits::vlen; + const int step = vlen / sizeof(float); + const int scalar = 1; Xbyak::Reg64 reg_boxes_coord0 = r8; Xbyak::Reg64 reg_boxes_coord1 = r9; @@ -172,8 +175,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator Xbyak::Reg64 reg_params = abi_param1; - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_scalar_emitter = nullptr; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -205,7 +208,6 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator std::shared_ptr> exp_injector; inline void hard_nms() { - int step = vlen / sizeof(float); Xbyak::Label main_loop_label_hard; Xbyak::Label main_loop_end_label_hard; Xbyak::Label tail_loop_label_hard; @@ -236,21 +238,20 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } L(main_loop_end_label_hard); - step = 1; L(tail_loop_label_hard); { cmp(reg_boxes_num, 1); jl(terminate_label_hard, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, scalar * sizeof(float)); + sub(reg_boxes_coord1, scalar * sizeof(float)); + sub(reg_boxes_coord2, scalar * sizeof(float)); + sub(reg_boxes_coord3, scalar * sizeof(float)); // iou result is in vmm_temp3 - iou(step); + iou(scalar); - sub(reg_boxes_num, step); + sub(reg_boxes_num, scalar); suppressed_by_iou(true); @@ -267,7 +268,6 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator inline void soft_nms() { uni_vbroadcastss(vmm_scale, ptr[reg_scale]); - int step = vlen / sizeof(float); Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; Xbyak::Label tail_loop_label; @@ -327,19 +327,18 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_boxes_num, 1); jl(terminate_label, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, scalar * sizeof(float)); + sub(reg_boxes_coord1, scalar * sizeof(float)); + sub(reg_boxes_coord2, scalar * sizeof(float)); + sub(reg_boxes_coord3, scalar * sizeof(float)); - iou(step); - sub(reg_boxes_num, step); + iou(scalar); + sub(reg_boxes_num, scalar); // soft suppressed by iou_threshold if (jcp.is_soft_suppressed_by_iou) { @@ -427,8 +426,11 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator inline void iou(int ele_num) { auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) { + if (ele_num != scalar && ele_num != step) + IE_THROW() << "NMS JIT implementation supports load emitter with only element count 1 or step! Get: " << ele_num; + + const auto& load_emitter = ele_num == 1 ? load_scalar_emitter : load_vector_emitter; load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_dst.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, ele_num), {}, {load_pool_gpr_idxs}); }; load(reg_boxes_coord0, vmm_boxes_coord0); diff --git a/src/plugins/intel_cpu/src/nodes/roi_align.cpp b/src/plugins/intel_cpu/src/nodes/roi_align.cpp index 4f2bdc6fc694a2..13df00b26855a2 100644 --- a/src/plugins/intel_cpu/src/nodes/roi_align.cpp +++ b/src/plugins/intel_cpu/src/nodes/roi_align.cpp @@ -46,8 +46,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji }; void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + init_emitters(); this->preamble(); @@ -65,8 +64,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); } private: @@ -107,10 +105,9 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji // [1] for reg_dst Xmm xmm_args_pool = Xmm(15); - std::unique_ptr load_emitter = nullptr; - std::vector load_pool_gpr_idxs; + std::unordered_map> emitters; - std::unique_ptr store_emitter = nullptr; + std::vector load_pool_gpr_idxs; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -157,6 +154,82 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji reg64_t reg_params = abi_param1; + void init_emitters() { + const std::set steps = { v_step, x_step, 1}; + for (auto step : steps) { + emitters[load_emitter_params(jcp_.data_prc, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, jcp_.data_prc, Precision::FP32)); + if (jcp_.data_prc != Precision::FP32) { // to avoid repeated keys + emitters[load_emitter_params(Precision::FP32, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, Precision::FP32, Precision::FP32)); + } + if (step != 1) { // scalar load_emitter isn't needed for indexes + emitters[load_emitter_params(Precision::I32, Precision::I32, step).hash()].reset( + new jit_load_emitter(this, isa, step, Precision::I32, Precision::I32)); + } + + emitters[store_emitter_params(Precision::FP32, jcp_.data_prc, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::FP32, jcp_.data_prc)); + if (jcp_.data_prc != Precision::FP32) { // to avoid repeated keys + emitters[store_emitter_params(Precision::FP32, Precision::FP32, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::FP32, Precision::FP32)); + } + } + } + + void emit_emitters_data() { + for (const auto& emitter : emitters) { + emitter.second->emit_data(); + } + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num) { + const auto seed = load_emitter_params(jcp_.data_prc, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << jcp_.data_prc << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void load_buffer(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num) { + const auto seed = load_emitter_params(Precision::FP32, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << Precision::FP32 << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void load_idx(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num) { + const auto seed = load_emitter_params(Precision::I32, Precision::I32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << Precision::I32 << "." << Precision::I32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num) { + const auto seed = store_emitter_params(Precision::FP32, jcp_.data_prc, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::FP32 << "." << jcp_.data_prc << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, + {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + + inline void store_buffer(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num) { + const auto seed = store_emitter_params(Precision::FP32, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::FP32 << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, + {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + void roi_align_cgather() { mov(reg_src_address, ptr[reg_params + GET_OFF(src)]); mov(reg_weights, ptr[reg_params + GET_OFF(weights)]); @@ -180,23 +253,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji imul(reg_src_stride, reg_src_stride, jcp_.data_size); } - auto store = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) { - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.data_prc, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - }; - - auto load_buf = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - auto store_buf = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) { - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - }; - // out loop for samples in bin Xbyak::Label out_loop_label; Xbyak::Label out_loop_end_label; @@ -228,13 +284,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji generate_samples(v_step); // now this sample value across channel reside in vmm_sample // compute with other samples in vmm_buf - load_buf(reg_buf, vmm_buf, v_step); + load_buffer(reg_buf, vmm_buf, v_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, v_step); + store_buffer(vmm_buf, reg_buf, v_step); if ((isa == cpu::x64::sse41) && (jcp_.layout == ROIAlignLayoutType::blk)) { add(reg_src0, x_step * jcp_.data_size); @@ -244,13 +300,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji add(reg_buf, x_step * sizeof(float)); generate_samples(x_step); - load_buf(reg_buf, vmm_buf, x_step); + load_buffer(reg_buf, vmm_buf, x_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, x_step); + store_buffer(vmm_buf, reg_buf, x_step); sub(reg_src0, x_step * jcp_.data_size); sub(reg_src1, x_step * jcp_.data_size); @@ -280,13 +336,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji jl(in_loop_tail_end_label, T_NEAR); generate_samples(tail_step); - load_buf(reg_buf, vmm_buf, tail_step); + load_buffer(reg_buf, vmm_buf, tail_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, tail_step); + store_buffer(vmm_buf, reg_buf, tail_step); int tail_src_stride = tail_step * jcp_.data_size; add(reg_src0, tail_src_stride); @@ -333,7 +389,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji cmp(reg_work_amount, v_step); jl(store_loop_main_end_label, T_NEAR); - load_buf(reg_buf, vmm_buf, v_step); + load_buffer(reg_buf, vmm_buf, v_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -343,7 +399,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji add(reg_buf, x_step * sizeof(float)); add(reg_dst, x_step * jcp_.data_size); - load_buf(reg_buf, vmm_buf, x_step); + load_buffer(reg_buf, vmm_buf, x_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -369,7 +425,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji cmp(reg_work_amount, tail_step); jl(store_loop_tail_end_label, T_NEAR); - load_buf(reg_buf, vmm_buf, tail_step); + load_buffer(reg_buf, vmm_buf, tail_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -402,12 +458,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji } void generate_samples(int num) { - auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(jcp_.data_prc, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - uni_vpxor(vmm_sample, vmm_sample, vmm_sample); load(reg_src0, vmm_src0, num); uni_vfmadd231ps(vmm_sample, vmm_src0, vmm_weights0); @@ -432,12 +482,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji uni_vbroadcastss(vmm_scale, ptr[reg_tmp_64]); } - auto load_idx = [&](Xbyak::Reg64 reg_idx, Vmm vmm_idx, int elt_num) { - load_emitter->emit_code({static_cast(reg_idx.getIdx())}, {static_cast(vmm_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; Xbyak::Label tail_loop_label; diff --git a/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp b/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp index a4899c55949143..e3657bb8fffa83 100644 --- a/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp +++ b/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp @@ -48,8 +48,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi }; void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + load_emitter.reset(new jit_load_emitter(this, isa, step, jpp_.src_prc, Precision::FP32)); + store_emitter.reset(new jit_store_emitter(this, isa, step, Precision::FP32, jpp_.dst_prc)); + store_empty_roi_emitter.reset(new jit_store_emitter(this, isa, step, jpp_.src_prc, jpp_.dst_prc)); this->preamble(); @@ -93,6 +94,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi load_emitter->emit_data(); store_emitter->emit_data(); + store_empty_roi_emitter->emit_data(); } private: @@ -114,6 +116,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi std::vector load_pool_gpr_idxs; std::unique_ptr store_emitter = nullptr; + std::unique_ptr store_empty_roi_emitter = nullptr; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -157,8 +160,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { Vmm vmm_max = get_acc_reg(i); - load_emitter->emit_code({static_cast(reg_input.getIdx())}, {static_cast(vmm_max.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), + load_emitter->emit_code({static_cast(reg_input.getIdx()), static_cast(i * src_c_off)}, {static_cast(vmm_max.getIdx())}, {}, load_pool_gpr_idxs); } @@ -171,9 +173,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Vmm vmm_max = get_acc_reg(i); Vmm vmm_src = get_src_reg(i); - load_emitter->emit_code({static_cast(aux_reg_input1.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), - {}, load_pool_gpr_idxs); + load_emitter->emit_code({static_cast(aux_reg_input1.getIdx()), static_cast(i * src_c_off)}, + {static_cast(vmm_src.getIdx())}, {}, load_pool_gpr_idxs); if (isa == cpu::x64::sse41) { movups(vmm_mask, vmm_max); @@ -206,8 +207,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { Vmm vmm_dst = get_acc_reg(i); - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off), + store_emitter->emit_code({static_cast(vmm_dst.getIdx()), static_cast(i * dst_c_off)}, {static_cast(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs); } } @@ -225,27 +225,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size(); - const auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, src_c_off); mov(aux_reg_input, reg_input); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src00.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src00.getIdx())}, {}, load_pool_gpr_idxs); add(aux_reg_input, reg_xoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src01.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src01.getIdx())}, {}, load_pool_gpr_idxs); add(aux_reg_input, reg_yoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src11.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src11.getIdx())}, {}, load_pool_gpr_idxs); sub(aux_reg_input, reg_xoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src10.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src10.getIdx())}, {}, load_pool_gpr_idxs); uni_vsubps(vmm_src01, vmm_src01, vmm_src00); @@ -259,8 +254,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); - store_emitter->emit_code({static_cast(vmm_src11.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(Precision::FP32, jpp_.dst_prc, step, dst_c_off), + store_emitter->emit_code({static_cast(vmm_src11.getIdx()), static_cast(dst_c_off)}, {static_cast(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs); } } @@ -270,9 +264,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); for (int i = 0; i < c_blocks; i++) { - store_emitter->emit_code({static_cast(vmm_zero.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off), - store_pool_vec_idxs, store_pool_gpr_idxs); + store_empty_roi_emitter->emit_code({static_cast(vmm_zero.getIdx()), static_cast(i * dst_c_off)}, + {static_cast(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs); } } diff --git a/src/plugins/intel_cpu/src/nodes/topk.cpp b/src/plugins/intel_cpu/src/nodes/topk.cpp index 011f42e06fa225..4a44a45dc9b49d 100644 --- a/src/plugins/intel_cpu/src/nodes/topk.cpp +++ b/src/plugins/intel_cpu/src/nodes/topk.cpp @@ -82,8 +82,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + init_emitters(); this->preamble(); @@ -123,8 +122,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); if (!shape_agnostic_alg) prepare_idx_table(); @@ -209,7 +207,6 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato const Xbyak::Opmask k_mask = Xbyak::Opmask(1); const int step = vlen / sizeof(float); const int tail = jcp_.work_amount % step; - const int topk_tail = jcp_.top_k % step; int blk_stride = 0; // stride of channel blocks at the same space coordinate, only used in blocked layout with topk on channel unsigned char cmp_flg; @@ -217,13 +214,101 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label l_table; - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unordered_map> emitters; std::vector store_pool_gpr_idxs; std::vector load_pool_gpr_idxs; std::vector store_pool_vec_idxs; + void init_emitters() { + const std::set steps = { step, tail, 1}; + for (auto step : steps) { + emitters[load_emitter_params(jcp_.precision, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, jcp_.precision, Precision::FP32)); + emitters[load_emitter_params(Precision::I32, Precision::I32, step).hash()].reset( + new jit_load_emitter(this, isa, step, Precision::I32, Precision::I32)); + if (Precision::I32 != jcp_.precision) { // to avoid repeated keys + emitters[load_emitter_params(Precision::I32, Precision::FP32, step).hash()].reset( + new jit_load_emitter(this, isa, step, Precision::I32, Precision::FP32)); + } + + emitters[store_emitter_params(Precision::FP32, jcp_.precision, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::FP32, jcp_.precision)); + emitters[store_emitter_params(Precision::I32, Precision::I32, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::I32, Precision::I32)); + if (Precision::I32 != jcp_.precision) { // to avoid repeated keys + emitters[store_emitter_params(Precision::FP32, Precision::I32, step).hash()].reset( + new jit_store_emitter(this, isa, step, Precision::FP32, Precision::I32)); + } + } + } + + void emit_emitters_data() { + for (const auto& emitter : emitters) { + emitter.second->emit_data(); + } + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(jcp_.precision, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << jcp_.precision << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(Precision::I32, Precision::FP32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << Precision::I32 << "." << Precision::FP32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(Precision::I32, Precision::I32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Load emitter <" << Precision::I32 << "." << Precision::I32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(Precision::FP32, jcp_.precision, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::FP32 << "." << jcp_.precision << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + + inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(Precision::FP32, Precision::I32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::FP32 << "." << Precision::I32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + + inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(Precision::I32, Precision::I32, elt_num).hash(); + if (!emitters[seed]) { + IE_THROW() << "Store emitter <" << Precision::I32 << "." << Precision::I32 << "." << elt_num << "> wasn't inited for Interpolate!"; + } + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + inline void topk_loop() { if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) { if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) { @@ -282,19 +367,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void topk_bitonic(int elt_num) { // src => prc for (int i = 0; i < jcp_.axis_dim; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_prc.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load(reg_src, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store(vmm_tmp, reg_prc, elt_num, i * jcp_.sort_stride * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_table.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * vlen), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_prc_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load_i32(reg_table, vmm_tmp, elt_num, i * vlen); + store_i32(vmm_tmp, reg_prc_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } // sort @@ -305,19 +382,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // prc => dst for (int i = 0; i < jcp_.top_k; i++) { - load_emitter->emit_code({static_cast(reg_prc.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load(reg_prc, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store(vmm_tmp, reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_prc_idx.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_dst_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load_i32(reg_prc_idx, vmm_tmp, elt_num, i * jcp_.sort_stride * sizeof(int)); + store_i32(vmm_tmp, reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } } @@ -437,40 +506,30 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void bitonic_swap_vector(int elt_num, bool cmp_val = true) { bitonic_get_addr(reg_prc, jcp_.data_size, 0); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_val_l.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_aux_idx, vmm_val_l, elt_num); + bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int)); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_aux_idx, vmm_val_r, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), 0); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_idx_l.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int)); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_idx_r.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num); swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val); bitonic_get_addr(reg_prc, jcp_.data_size, 0); - store_emitter->emit_code({static_cast(vmm_val_l.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_l, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int)); - store_emitter->emit_code({static_cast(vmm_val_r.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_r, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), 0); - store_emitter->emit_code({static_cast(vmm_idx_l.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int)); - store_emitter->emit_code({static_cast(vmm_idx_r.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num); } inline void topk_heap_sorting() { @@ -588,25 +647,18 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } get_addr_by_reg_idx(reg_heap_outer_aux, reg_src, reg_i, jcp_.data_size); - load_emitter->emit_code({static_cast(reg_heap_outer_aux.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, s), - {}, {load_pool_gpr_idxs}); + load(reg_heap_outer_aux, vmm_tmp, s); + get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst, reg_i, jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_heap_outer_aux.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, s), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_tmp, reg_heap_outer_aux, s); if (s == step) { table_to_vmm(vmm_tmp, reg_heap_seq_idx, reg_i, 0, sizeof(int)); } else { get_addr_by_reg_idx(reg_heap_outer_aux, reg_heap_seq_idx, reg_i, sizeof(int)); - load_emitter->emit_code({static_cast(reg_heap_outer_aux.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, 1), - {}, {load_pool_gpr_idxs}); + load_i32(reg_heap_outer_aux, vmm_tmp, 1); } get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst_idx, reg_i, sizeof(int)); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_heap_outer_aux.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, s), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_i32(vmm_tmp, reg_heap_outer_aux, s); add(reg_i, s); jmp(topk_init_loop_label, T_NEAR); @@ -1025,19 +1077,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato je(topk_init_loop_end_label, T_NEAR); get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_tmp, elt_num); get_addr_by_reg_idx(reg_tmp, reg_dst, reg_block_sort_stride_byte, reg_i); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_tmp, reg_tmp, elt_num); table_to_vmm(vmm_tmp, reg_bubble_block_idx, reg_i, 0, vlen); get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_block_sort_stride_byte, sizeof(int) / jcp_.data_size, reg_i); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_i32(vmm_tmp, reg_tmp, elt_num); add(reg_i, 1); jmp(topk_init_loop_label, T_NEAR); @@ -1057,9 +1103,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato je(topk_update_loop_end_label, T_NEAR); get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_r, elt_num); table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen); uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r); @@ -1152,9 +1196,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void topk_bubble_inplace(int elt_num) { // load for (int i = 0; i < jcp_.top_k; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(i).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size); uni_vmovdqu(vmm_idx(i), table_val(i)); uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i)); } @@ -1165,9 +1207,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } } for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(jcp_.top_k).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size); uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i)); uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k)); for (int j = jcp_.top_k; j > 0; j--) { @@ -1183,12 +1223,8 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } // store for (int i = 0; i < jcp_.top_k; i++) { - store_emitter->emit_code({static_cast(vmm_val(i).getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_idx(i).getIdx())}, {static_cast(reg_dst_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } } @@ -1211,15 +1247,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato L(topk_load_sort_label); { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(0).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step, 0), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(0), step, 0); uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0)); uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0)); if (isa == cpu::x64::sse41) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step, 4 * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(1), step, 4 * jcp_.data_size); uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); @@ -1235,17 +1267,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato jg(topk_iter_end_label, T_NEAR); get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride); - load_emitter->emit_code({static_cast(reg_aux.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step), - {}, {load_pool_gpr_idxs}); + load(reg_aux, vmm_val(1), step); table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); if (isa == cpu::x64::sse41) { add(reg_aux, 4 * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_aux.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step), - {}, {load_pool_gpr_idxs}); + load(reg_aux, vmm_val(1), step); table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); @@ -1538,16 +1566,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // load l mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_l.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_l, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_idx_l.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_tmp, vmm_idx_l, elt_num); // load r Xbyak::Label topk_load_jmp_label; @@ -1557,16 +1582,14 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato add(reg_tmp_64, reg_block_sort_stride_byte); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_r, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_idx_r.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_tmp, vmm_idx_r, elt_num); + sub(reg_tmp_64, reg_block_sort_stride_byte); } L(topk_load_jmp_label); @@ -1576,16 +1599,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // store l mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - store_emitter->emit_code({static_cast(vmm_val_l.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_l, reg_tmp, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_idx_l.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_l, reg_tmp, elt_num); // store r Xbyak::Label topk_store_jmp_label; @@ -1595,16 +1615,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato add(reg_tmp_64, reg_block_sort_stride_byte); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - store_emitter->emit_code({static_cast(vmm_val_r.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_r, reg_tmp, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_idx_r.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_r, reg_tmp, elt_num); } L(topk_store_jmp_label); } diff --git a/src/plugins/intel_cpu/src/utils/jit_kernel.cpp b/src/plugins/intel_cpu/src/utils/jit_kernel.cpp index f1c99751c1a7d5..520718a55b38ab 100644 --- a/src/plugins/intel_cpu/src/utils/jit_kernel.cpp +++ b/src/plugins/intel_cpu/src/utils/jit_kernel.cpp @@ -212,9 +212,7 @@ const void * consts_table::store(const void *data, size_t size) { } // namespace internal jit_kernel::jit_kernel() - : jit_generator() - , _load_emitter(this, internal::get_current_isa()) - , _store_emitter(this, internal::get_current_isa()) { + : jit_generator() { _free_rmmregs.reserve(16); _free_rmmregs.reserve(16); @@ -297,10 +295,10 @@ void jit_kernel::free(const Zmm & reg) { void jit_kernel::postamble() { jit_generator::postamble(); - if (_is_load_emitter_used) - _load_emitter.emit_data(); - if (_is_store_emitter_used) - _store_emitter.emit_data(); + for (const auto& emitter : _emitters) { + if (emitter.second) + emitter.second->emit_data(); + } } const AddressFrame & jit_kernel::address_frame(size_t size) const { diff --git a/src/plugins/intel_cpu/src/utils/jit_kernel.hpp b/src/plugins/intel_cpu/src/utils/jit_kernel.hpp index ce531d7806c78a..bfb79ee9518d0a 100644 --- a/src/plugins/intel_cpu/src/utils/jit_kernel.hpp +++ b/src/plugins/intel_cpu/src/utils/jit_kernel.hpp @@ -651,6 +651,11 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { const Xbyak::Reg64& src, const Xbyak::Reg64& size); + template + void init_load(size_t length); + template + void init_store(size_t length); + template void load(const variable & dst, const variable & src, size_t length = N); template @@ -697,11 +702,8 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { private: reg_indices _free_x64regs; reg_indices _free_rmmregs; - bool _is_load_emitter_used = false; - bool _is_store_emitter_used = false; - jit_load_emitter _load_emitter; - jit_store_emitter _store_emitter; internal::consts_table _consts; + std::unordered_map> _emitters; }; template @@ -733,6 +735,30 @@ void jit_kernel::copy(const Xbyak::Address& dst, free(p); } +template +void jit_kernel::init_load(size_t length) { + using src_type = typename std::remove_cv< + typename std::remove_pointer::type>::type; + using dst_type = typename std::remove_cv< + typename std::remove_pointer::type>::type; + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + _emitters[load_emitter_params(src_prc, dst_prc, length).hash()].reset( + new jit_load_emitter(this, internal::get_current_isa(), length, src_prc, dst_prc)); +} + +template +void jit_kernel::init_store(size_t length) { + using src_type = typename std::remove_cv< + typename std::remove_pointer::type>::type; + using dst_type = typename std::remove_cv< + typename std::remove_pointer::type>::type; + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + _emitters[store_emitter_params(src_prc, dst_prc, length).hash()].reset( + new jit_store_emitter(this, internal::get_current_isa(), length, src_prc, dst_prc)); +} + template void jit_kernel::load(const variable & dst, const variable & src, size_t length) { static_assert(std::is_same::reg_type, const Xbyak::Reg64>::value, @@ -746,17 +772,18 @@ void jit_kernel::load(const variable & dst, const variable & src, const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); - _load_emitter.emit_code( + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + + const auto key = load_emitter_params(src_prc, dst_prc, length).hash(); + if (!_emitters[key]) { + IE_THROW() << "Load emitter <" << length << "." << src_prc << "." << dst_prc << "> wasn't inited!"; + } + _emitters[key]->emit_code( { static_cast(static_cast(src).getIdx()) }, { static_cast(static_cast(dst).getIdx()) }, - std::make_shared( - internal::type2precision(), - internal::type2precision(), - static_cast(length)), pool_vec_idxs, pool_gpr_idxs); - - _is_load_emitter_used = true; } template @@ -788,17 +815,18 @@ void jit_kernel::store(const variable & dst, const variable & src const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); - _store_emitter.emit_code( + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + + const auto key = store_emitter_params(src_prc, dst_prc, length).hash(); + if (!_emitters[key]) { + IE_THROW() << "Store emitter <" << length << "." << src_prc << "." << dst_prc << "> wasn't inited!"; + } + _emitters[key]->emit_code( { static_cast(static_cast(src).getIdx()) }, { static_cast(static_cast(dst).getIdx()) }, - std::make_shared( - internal::type2precision(), - internal::type2precision(), - static_cast(length)), pool_vec_idxs, pool_gpr_idxs); - - _is_store_emitter_used = true; } template diff --git a/src/tests/unit/cpu/jit_kernel_test.cpp b/src/tests/unit/cpu/jit_kernel_test.cpp index 73ae2c4e5cc5f6..8b9bc54648677a 100644 --- a/src/tests/unit/cpu/jit_kernel_test.cpp +++ b/src/tests/unit/cpu/jit_kernel_test.cpp @@ -159,6 +159,9 @@ struct jit_variable_test_kernel { } void generate() override { + init_load(N); + init_store(N); + preamble(); auto a_ptr = arg(&Params::a); @@ -301,6 +304,9 @@ struct jit_variable_load_store_test_kernel { class kernel_impl : public jit_test_kernel { public: void generate() override { + jit_kernel::init_load(N); + jit_kernel::init_store(N); + jit_kernel::preamble(); auto src_ptr = jit_kernel::arg(&Params::src);