diff --git a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp index 007b7f0fefd887..f0f460d51713a5 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_emitter.hpp @@ -23,6 +23,11 @@ enum emitter_in_out_map { gpr_to_gpr, }; +// structure for storage of emitter parameters to hash in map +struct emitter_params { + virtual size_t hash() const = 0; +}; + struct emitter_context { virtual ~emitter_context() = default; }; diff --git a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp index 7604c9a82fbc43..da1589aa4497d4 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.cpp @@ -18,95 +18,125 @@ using namespace Xbyak::util; namespace ov { namespace intel_cpu { +size_t load_emitter_params::hash() const { + size_t seed = 0; + seed = hash_combine(seed, std::string("jit_load_emitter")); + seed = hash_combine(seed, src_prc_.getPrecVal()); + seed = hash_combine(seed, dst_prc_.getPrecVal()); + seed = hash_combine(seed, load_num_); + seed = hash_combine(seed, is_fill_); + seed = hash_combine(seed, fill_value_); + return seed; +} + +size_t store_emitter_params::hash() const { + size_t seed = 0; + seed = hash_combine(seed, std::string("jit_store_emitter")); + seed = hash_combine(seed, src_prc_.getPrecVal()); + seed = hash_combine(seed, dst_prc_.getPrecVal()); + seed = hash_combine(seed, store_num_); + return seed; +} + +static int get_aux_regs_for_avx512_mask(const size_t byte_size, const bool is_fill = false) { + if (mayiuse(cpu::x64::avx512_core)) { + if (!one_of(byte_size, 64, 32, 16) || is_fill) { + return 1; + } + } + return 0; +} + /// LOAD /// -jit_load_emitter::jit_load_emitter(jit_generator *host, cpu_isa_t host_isa, - Precision exec_prc, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") { +jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + Precision src_prc, Precision dst_prc, int load_num, Precision exec_prc, + bool is_fill, std::string fill_value, emitter_in_out_map in_out_type) +: jit_emitter(host, host_isa, exec_prc, in_out_type), load_num_(load_num), src_prc_(src_prc), dst_prc_(dst_prc), + is_fill_(is_fill), fill_value_(fill_value), name_("unknown") { prepare_table(); - v_len_elt = get_vec_length() / exec_prc.size(); + load_size_ = load_num * src_prc.size(); + v_len_elt_ = get_vec_length() / exec_prc.size(); } size_t jit_load_emitter::get_inputs_num() const { return 1; } -// 0 for temp reg for mask load, 1 for table address size_t jit_load_emitter::aux_gprs_count() const { - return 2; + // 0 for temp reg for mask load in avx512 if needed + int count = get_aux_regs_for_avx512_mask(load_num_ * dst_prc_.size(), is_fill_); + + // 1 for table address + if (is_fill_) + count++; + + return count; } void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const { - const auto* load_emitter_context = dynamic_cast(emit_context); - if (load_emitter_context == nullptr) { - IE_THROW() << "Load emitter in " << name << " does not get load emmiter context."; - } - + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const { + const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx2) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(Reg64(in_idxs[0]), load_emitter_context->offset_byte_, load_emitter_context->src_prc_, static_cast(out_idxs[0]), - load_emitter_context->dst_prc_, load_emitter_context->load_num_, load_emitter_context->is_fill_, load_emitter_context->fill_value_); + emit_isa(Reg64(in_idxs[0]), static_cast(out_idxs[0]), offset); } else { - IE_THROW() << "Load emitter in " << name << " is performed on unsupported isa(at least x64::sse41)."; + IE_THROW() << "Load emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41)."; } } template -void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, InferenceEngine::Precision src_prc, - const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill, std::string fill_value) const { - bool matched_prc = (dst_prc == src_prc) || (dst_prc == Precision::FP32) || (dst_prc == Precision::I32); +void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const { + bool matched_prc = (dst_prc_ == src_prc_) || (dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::I32); if (!matched_prc) { - IE_THROW() << "Load emitter in " << name << " only support output precision of FP32 or I32 or the same precision as input."; + IE_THROW() << "Load emitter in " << name_ << " only support output precision of FP32 or I32 or the same precision as input."; } - if (load_num > (get_vec_length() / dst_prc.size())) { - IE_THROW() << "Load emitter in " << name << " have unexpected number of elements to load."; + if (load_num_ > (get_vec_length() / dst_prc_.size())) { + IE_THROW() << "Load emitter in " << name_ << " have unexpected number of elements to load."; } using Vmm = typename conditional3::type; // pure load - if (src_prc == dst_prc) { - load_bytes(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value); + if (src_prc_ == dst_prc_) { + load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); } else { // "pure load" + convert. dst_prc must be FP32 or I32. - switch (src_prc) { + switch (src_prc_) { case Precision::FP32: case Precision::I32: - load_bytes(Vmm(out_vec_idx), reg_src, offset_byte, load_num * src_prc.size(), is_fill, fill_value); + load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); break; case Precision::I8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, true, load_num * src_prc.size(), is_fill, fill_value); + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, load_size_); break; case Precision::U8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, load_num * src_prc.size(), is_fill, fill_value); + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, load_size_); break; case Precision::I16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, true, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, true, load_size_); break; case Precision::U16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, false, false, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, false, load_size_); break; case Precision::BF16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset_byte, true, false, load_num * src_prc.size(), is_fill, fill_value); + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, false, load_size_); break; default: - IE_THROW() << "Load emitter in " << name << " has unsupported src precision to load."; + IE_THROW() << "Load emitter in " << name_ << " has unsupported src precision to load."; } } // post convert between I32 and FP32 - if (src_prc != dst_prc) { - switch (dst_prc) { + if (src_prc_ != dst_prc_) { + switch (dst_prc_) { case Precision::FP32: - if ((src_prc != Precision::FP32) && (src_prc != Precision::BF16)) + if ((src_prc_ != Precision::FP32) && (src_prc_ != Precision::BF16)) h->uni_vcvtdq2ps(Vmm(out_vec_idx), Vmm(out_vec_idx)); break; case Precision::I32: - if ((src_prc == Precision::FP32) || (src_prc == Precision::BF16)) { + if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::BF16)) { h->uni_vcvtps2dq(Vmm(out_vec_idx), Vmm(out_vec_idx)); } break; @@ -129,8 +159,7 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, In * */ template -void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size, - bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -141,12 +170,12 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o // Ensure data fits completely inside the Xmm/Ymm/Zmm register if (load_size < 0 || load_size > 64) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte."; // check if proper number bytes fit inside the Xmm/Ymm register if (is_ymm && load_size > 32) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_byte."; if (is_xmm && load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_byte."; auto xmm = Xbyak::Xmm(vmm.getIdx()); auto ymm = Xbyak::Ymm(vmm.getIdx()); @@ -229,7 +258,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o break; case 16: break; default: - IE_THROW() << "Load emitter in " << name<< " has unexpected number of values to load in load_byte."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_byte."; } if (has_xmm_block) { @@ -270,8 +299,8 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size / 4); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size / 4); } /** @@ -294,8 +323,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o */ template -void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, - int offset, bool is_signed, int load_size, bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -308,11 +336,11 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak // For Ymm register, load capacity is halved (32 * load_size <= 256) // For Xmm register, load capacity is halved further (32 * load_size <= 128) if (load_size < 0 || load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_bytes_to_dword_extension."; if (is_ymm && load_size > 8) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_bytes_to_dword_extension."; if (is_xmm && load_size > 4) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_bytes_to_dword_extension."; // For load_size == 4/8/16, do load/extension in one go switch (load_size) { @@ -365,8 +393,8 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size); } /** @@ -388,8 +416,7 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak * [0.. 32] for ZMM version of the function. i.e. 16 words -> 16 * 32 bit == 512 bit */ template -void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, - int offset, bool is_bf16, bool is_signed, int load_size, bool is_fill, std::string fill_value) const { +void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -402,11 +429,11 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak // For Ymm register, load capacity is halved (16/2(num) * 32 <= 128) // For Xmm register, load capacity is halved again (8/2(num) * 32 <= 128) if (load_size < 0 || load_size > 32) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load in load_words_to_dword_extension."; if (is_ymm && load_size > 16) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to ymm in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to ymm in load_words_to_dword_extension."; if (is_xmm && load_size > 8) - IE_THROW() << "Load emitter in " << name << " has unexpected number of values to load to xmm in load_words_to_dword_extension."; + IE_THROW() << "Load emitter in " << name_ << " has unexpected number of values to load to xmm in load_words_to_dword_extension."; auto xmm = Xbyak::Xmm(vmm.getIdx()); auto ymm = Xbyak::Ymm(vmm.getIdx()); @@ -483,148 +510,157 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak } } - if (is_fill) - fill_with_default(vmm, fill_value, load_size / 2); + if (is_fill_) + fill_with_default(vmm, fill_value_, load_size / 2); } template - void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - if (is_xmm || is_ymm) { - uint8 imm = 1; - imm = ~((imm << load_num) - imm); // shift load_num bit - h->uni_vblendps(vmm, vmm, table_val(fill_value), imm); - } else if (is_zmm) { - uint64_t tail_mask = 1; - tail_mask = ~((tail_mask << load_num) - tail_mask); - h->mov(Reg64(aux_gpr_idxs[0]), tail_mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vblendmps(vmm | k_mask, vmm, table_val(fill_value)); - } +void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; + + if (is_xmm || is_ymm) { + uint8 imm = 1; + imm = ~((imm << load_num) - imm); // shift load_num bit + h->uni_vblendps(vmm, vmm, table_val(fill_value), imm); + } else if (is_zmm) { + uint64_t tail_mask = 1; + tail_mask = ~((tail_mask << load_num) - tail_mask); + h->mov(Reg64(aux_gpr_idxs[0]), tail_mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vblendmps(vmm | k_mask, vmm, table_val(fill_value)); } +} void jit_load_emitter::register_table_entries() { - push_arg_entry_of("zero", 0x00000000, true); - push_arg_entry_of("int_one", 0x00000001, true); - push_arg_entry_of("float_one", 0x3f800000, true); - push_arg_entry_of("int32_min", 0xcf000000, true); - push_arg_entry_of("float_min", 0xff7fffff, true); - push_arg_entry_of("int32_max", 0x4effffff, true); - push_arg_entry_of("float_max", 0x7f7fffff, true); + if (is_fill_) { + push_arg_entry_of("zero", 0x00000000, true); + push_arg_entry_of("int_one", 0x00000001, true); + push_arg_entry_of("float_one", 0x3f800000, true); + push_arg_entry_of("int32_min", 0xcf000000, true); + push_arg_entry_of("float_min", 0xff7fffff, true); + push_arg_entry_of("int32_max", 0x4effffff, true); + push_arg_entry_of("float_max", 0x7f7fffff, true); + } } /// STORE /// -jit_store_emitter::jit_store_emitter(jit_generator *host, cpu_isa_t host_isa, - Precision exec_prc, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name("unknown") { - v_len_elt = get_vec_length() / exec_prc.size(); +jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + Precision src_prc, Precision dst_prc, int store_num, Precision exec_prc, emitter_in_out_map in_out_type) +: jit_emitter(host, host_isa, exec_prc, in_out_type), store_num_(store_num), src_prc_(src_prc), dst_prc_(dst_prc), name_("unknown") { + v_len_elt_ = get_vec_length() / exec_prc.size(); + store_size_ = store_num * dst_prc.size(); if (!mayiuse(cpu::x64::avx512_core_bf16) && mayiuse(cpu::x64::avx512_core)) { - emu_vcvtneps2bf16.reset(new jit_emu_vcvtneps2bf16(host, host_isa)); + emu_vcvtneps2bf16_.reset(new jit_emu_vcvtneps2bf16(host, host_isa)); } } -// 0 for temp reg for mask store +// 0 for temp reg for mask store for avx512 size_t jit_store_emitter::aux_gprs_count() const { - return 1; + return get_aux_regs_for_avx512_mask(store_num_ * src_prc_.size()); } -// zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status) size_t jit_store_emitter::aux_vecs_count() const { - return 1; + int count = 0; + + // to avoid src vmm pollution after data type conversion + if ((src_prc_.is_float() && !dst_prc_.is_float()) || + (!src_prc_.is_float() && dst_prc_.is_float()) || + (src_prc_ == Precision::FP32 && dst_prc_ == Precision::BF16)) + count++; + + // zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status) + if (mayiuse(cpu::x64::avx512_core) && one_of(dst_prc_, Precision::U8, Precision::U16)) + count++; + + return count; } size_t jit_store_emitter::get_inputs_num() const { return 1; } void jit_store_emitter::emit_data() const { - if (emu_vcvtneps2bf16) - emu_vcvtneps2bf16->emit_data(); + if (emu_vcvtneps2bf16_) + emu_vcvtneps2bf16_->emit_data(); } void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, const emitter_context *emit_context) const { - const auto* store_emitter_context = dynamic_cast(emit_context); - if (store_emitter_context == nullptr) { - IE_THROW() << "Store emitter in " << name << " does not get store emmiter context."; - } + const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx2) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(static_cast(in_idxs[0]), store_emitter_context->src_prc_, Reg64(out_idxs[0]), - store_emitter_context->offset_byte_, store_emitter_context->dst_prc_, store_emitter_context->store_num_); + emit_isa(static_cast(in_idxs[0]), Reg64(out_idxs[0]), offset); } else { - IE_THROW() << "Store emitter in " << name << " is performed on unsupported isa(at least x64::sse41)."; + IE_THROW() << "Store emitter in " << name_ << " is performed on unsupported isa(at least x64::sse41)."; } } template - void jit_store_emitter::emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc, - const Xbyak::Reg64 ®_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const { - bool matched_prc = (src_prc == dst_prc) || (src_prc == Precision::FP32) || (src_prc == Precision::I32); - if (!matched_prc) { - IE_THROW() << "Store emitter in " << name << " only support input precision of FP32 or I32 or the same precision as output."; - } - if ((src_prc == Precision::FP32) || (src_prc == Precision::I32)) { - if ((isa == cpu::x64::sse41 && store_num > 4) || (isa == cpu::x64::avx2 && store_num > 8) || - (isa == cpu::x64::avx512_core && store_num > 16) || store_num < 0) { - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store."; - } +void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const { + bool matched_prc = (src_prc_ == dst_prc_) || (src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32); + if (!matched_prc) { + IE_THROW() << "Store emitter in " << name_ << " only support input precision of FP32 or I32 or the same precision as output."; + } + if ((src_prc_ == Precision::FP32) || (src_prc_ == Precision::I32)) { + if ((isa == cpu::x64::sse41 && store_num_ > 4) || (isa == cpu::x64::avx2 && store_num_ > 8) || + (isa == cpu::x64::avx512_core && store_num_ > 16) || store_num_ < 0) { + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store."; } + } + using Vmm = typename conditional3::type; - using Vmm = typename conditional3::type; - - if (src_prc != dst_prc) { - switch (src_prc) { - case Precision::FP32: - if ((dst_prc != Precision::FP32) && (dst_prc != Precision::BF16)) { - h->uni_vcvtps2dq(Vmm(in_vec_idx), Vmm(in_vec_idx)); - } - break; - case Precision::I32: - if ((dst_prc == Precision::FP32) || (dst_prc == Precision::BF16)) - h->uni_vcvtdq2ps(Vmm(in_vec_idx), Vmm(in_vec_idx)); - break; - default: - break; - } + int data_idx = in_vec_idx; + if (src_prc_ != dst_prc_) { + switch (src_prc_) { + case Precision::FP32: + if ((dst_prc_ != Precision::FP32) && (dst_prc_ != Precision::BF16)) { + h->uni_vcvtps2dq(Vmm(aux_vec_idxs.back()), Vmm(data_idx)); + data_idx = aux_vec_idxs.back(); + } + break; + case Precision::I32: + if ((dst_prc_ == Precision::FP32) || (dst_prc_ == Precision::BF16)) { + h->uni_vcvtdq2ps(Vmm(aux_vec_idxs.back()), Vmm(data_idx)); + data_idx = aux_vec_idxs.back(); + } + break; + default: + break; } + } - if (src_prc == dst_prc) { - store_bytes(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size()); - } else { - switch (dst_prc) { - case Precision::FP32: - case Precision::I32: - store_bytes(Vmm(in_vec_idx), reg_dst, offset_byte, store_num * dst_prc.size()); - break; - case Precision::I8: - store_dword_to_byte_extension(Vmm(in_vec_idx), reg_dst, offset_byte, true, store_num); - break; - case Precision::U8: - store_dword_to_byte_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, store_num); - break; - case Precision::I16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, true, store_num); - break; - case Precision::U16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, false, false, store_num); - break; - case Precision::BF16: - store_dword_to_word_extension(Vmm(in_vec_idx), reg_dst, offset_byte, true, false, store_num); - break; - default: - IE_THROW() << "Store emitter in " << name << " has unsupported dst precision to store."; - } + if (src_prc_ == dst_prc_) { + store_bytes(Vmm(data_idx), reg_dst, offset, store_size_); + } else { + switch (dst_prc_) { + case Precision::FP32: + case Precision::I32: + store_bytes(Vmm(data_idx), reg_dst, offset, store_size_); + break; + case Precision::I8: + store_dword_to_byte_extension(Vmm(data_idx), reg_dst, offset, true, store_num_); + break; + case Precision::U8: + store_dword_to_byte_extension(Vmm(data_idx), reg_dst, offset, false, store_num_); + break; + case Precision::I16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, false, true, store_num_); + break; + case Precision::U16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, false, false, store_num_); + break; + case Precision::BF16: + store_dword_to_word_extension(Vmm(data_idx), reg_dst, offset, true, false, store_num_); + break; + default: + IE_THROW() << "Store emitter in " << name_ << " has unsupported dst precision to store."; } } - +} /** * store_bytes is the utility function to facilitate storing of * store_size (0 <= store_size <= 64) many contiguous bytes from the Xmm/Ymm/Zmm @@ -641,130 +677,130 @@ template * */ template - void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - if (store_size < 0 || store_size > 64) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes."; - if (is_ymm && store_size > 32) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_bytes."; - if (is_xmm && store_size > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_bytes."; - - auto xmm = Xbyak::Xmm(vmm.getIdx()); - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto zmm = Xbyak::Zmm(vmm.getIdx()); - - const auto addr = [&](int bytes_offset) { - return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; - }; - - auto store_byte_base = [&]() { - int start_bytes = 0; - int bytes_to_store = store_size; - - if (store_size > 32) { - h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm - start_bytes += 32; - bytes_to_store -= 32; - h->vextractf64x4(ymm, zmm, 1); // load upper bits from zmm into ymm - } +void jit_store_emitter::store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; - if (bytes_to_store > 16) { - h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm - start_bytes += 16; - bytes_to_store -= 16; - h->vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm - } + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); - if (bytes_to_store >= 8 && bytes_to_store < 16) - h->uni_vmovq(addr(start_bytes), xmm); - // h->pextrq(addr(start_bytes), xmm, 0); - else if (bytes_to_store == 16) - h->uni_vmovdqu(addr(start_bytes), xmm); - - // 64/32/16/8 with one go - // tail 7 bytes for lower or upper xmm - switch (bytes_to_store) { - case 0: break; - case 1: h->uni_vpextrb(addr(start_bytes), xmm, 0); break; - case 2: h->uni_vpextrw(addr(start_bytes), xmm, 0); break; - case 3: - h->uni_vpextrw(addr(start_bytes), xmm, 0); - h->uni_vpextrb(addr(start_bytes + 2), xmm, 2); - break; - case 4: h->uni_vmovss(addr(start_bytes), xmm); break; - // h->uni_vpextrd(addr(start_bytes), xmm, 0); break; - case 5: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); - break; - case 6: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - break; - case 7: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); - break; - case 8: break; - case 9: h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; - case 10: h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; - case 11: - h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); - h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); - break; - case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; - case 13: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); - break; - case 14: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - break; - case 15: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); - break; - case 16: break; - default: - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_bytes."; - } - }; + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + if (store_size < 0 || store_size > 64) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes."; + if (is_ymm && store_size > 32) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_bytes."; + if (is_xmm && store_size > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_bytes."; + + auto xmm = Xbyak::Xmm(vmm.getIdx()); + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto zmm = Xbyak::Zmm(vmm.getIdx()); - switch (store_size) { - case 64: - h->uni_vmovdqu(addr(0), zmm); + const auto addr = [&](int bytes_offset) { + return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; + }; + + auto store_byte_base = [&]() { + int start_bytes = 0; + int bytes_to_store = store_size; + + if (store_size > 32) { + h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm + start_bytes += 32; + bytes_to_store -= 32; + h->vextractf64x4(ymm, zmm, 1); // load upper bits from zmm into ymm + } + + if (bytes_to_store > 16) { + h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm + start_bytes += 16; + bytes_to_store -= 16; + h->vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm + } + + if (bytes_to_store >= 8 && bytes_to_store < 16) + h->uni_vmovq(addr(start_bytes), xmm); + // h->pextrq(addr(start_bytes), xmm, 0); + else if (bytes_to_store == 16) + h->uni_vmovdqu(addr(start_bytes), xmm); + + // 64/32/16/8 with one go + // tail 7 bytes for lower or upper xmm + switch (bytes_to_store) { + case 0: break; + case 1: h->uni_vpextrb(addr(start_bytes), xmm, 0); break; + case 2: h->uni_vpextrw(addr(start_bytes), xmm, 0); break; + case 3: + h->uni_vpextrw(addr(start_bytes), xmm, 0); + h->uni_vpextrb(addr(start_bytes + 2), xmm, 2); break; - case 32: - h->uni_vmovdqu(addr(0), ymm); + case 4: h->uni_vmovss(addr(start_bytes), xmm); break; + // h->uni_vpextrd(addr(start_bytes), xmm, 0); break; + case 5: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); break; - case 16: - h->uni_vmovdqu(addr(0), xmm); + case 6: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); break; - default: - if (mayiuse(cpu::x64::avx512_core)) { - uint64_t mask = 1; - mask = (mask << store_size) - mask; - h->mov(Reg64(aux_gpr_idxs[0]), mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vmovdqu8(addr(0), zmm | k_mask); - } else { - store_byte_base(); - } + case 7: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); + break; + case 8: break; + case 9: h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; + case 10: h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; + case 11: + h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); + h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); + break; + case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; + case 13: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); + break; + case 14: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + break; + case 15: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); break; + case 16: break; + default: + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_bytes."; } + }; + + switch (store_size) { + case 64: + h->uni_vmovdqu(addr(0), zmm); + break; + case 32: + h->uni_vmovdqu(addr(0), ymm); + break; + case 16: + h->uni_vmovdqu(addr(0), xmm); + break; + default: + if (mayiuse(cpu::x64::avx512_core)) { + uint64_t mask = 1; + mask = (mask << store_size) - mask; + h->mov(Reg64(aux_gpr_idxs[0]), mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vmovdqu8(addr(0), zmm | k_mask); + } else { + store_byte_base(); + } + break; } +} /** * store_dword_to_byte_extension is the utility function to @@ -772,231 +808,235 @@ template * 2. store the packed byte into the memory referenced by ptr[reg + offset] address. */ template - void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int store_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - // At most 8 dwords can fit inside the Ymm register - // At most 4 dwords can fit inside the Xmm register - if (store_num < 0 || store_num > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_byte_extension."; - if (is_ymm && store_num > 8) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_byte_extension."; - if (is_xmm && store_num > 4) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_byte_extension."; - - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto xmm = Xbyak::Xmm(vmm.getIdx()); - - const auto addr = [&](int bytes_offset) { - return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; - }; - - auto store_dword_to_byte_base = [&]() { - // db only available on avx512, need dw+wb to emulate - if (is_signed) - h->uni_vpackssdw(vmm, vmm, vmm); - else - h->uni_vpackusdw(vmm, vmm, vmm); - // gather 2(cross lane) 64 bits into lower vmm to store - // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] - if (is_ymm) { - h->vpermq(ymm, ymm, 0x08); // 00001000 +void jit_store_emitter::store_dword_to_byte_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int store_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; + + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); + + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + // At most 8 dwords can fit inside the Ymm register + // At most 4 dwords can fit inside the Xmm register + if (store_num < 0 || store_num > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_byte_extension."; + if (is_ymm && store_num > 8) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_byte_extension."; + if (is_xmm && store_num > 4) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_byte_extension."; + + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto xmm = Xbyak::Xmm(vmm.getIdx()); + + const auto addr = [&](int bytes_offset) { + return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; + }; + + auto store_dword_to_byte_base = [&]() { + // db only available on avx512, need dw+wb to emulate + if (is_signed) + h->uni_vpackssdw(vmm, vmm, vmm); + else + h->uni_vpackusdw(vmm, vmm, vmm); + // gather 2(cross lane) 64 bits into lower vmm to store + // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] + if (is_ymm) { + h->vpermq(ymm, ymm, 0x08); // 00001000 + } + + if (is_signed) + h->uni_vpacksswb(vmm, vmm, vmm); + else + h->uni_vpackuswb(vmm, vmm, vmm); + + store_bytes(vmm, reg, offset, store_num); + }; + + switch (store_num) { + case 16: + // must support avx512F + if (is_signed) { + h->vpmovsdb(addr(0), vmm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(vmm, vmm, zero); + h->vpmovusdb(addr(0), vmm); + } + break; + case 8: + if (mayiuse(cpu::x64::avx512_core)) { // ymm block on avx512F + VL + if (is_signed) { + h->vpmovsdb(addr(0), ymm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(ymm, ymm, zero); + h->vpmovusdb(addr(0), ymm); + } + } else { + store_dword_to_byte_base(); + } + break; + case 4: + if (mayiuse(cpu::x64::avx512_core)) { // xmm block on avx512F + VL + if (is_signed) { + h->vpmovsdb(addr(0), xmm); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(xmm, xmm, zero); + h->vpmovusdb(addr(0), xmm); + } + } else { + store_dword_to_byte_base(); } + break; + default: + if (is_zmm) { // avx512F + unsigned int mask = 1; + mask = (mask << store_num) - mask; + h->mov(Reg32(aux_gpr_idxs[0]), mask); + h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); + if (is_signed) { + h->vpmovsdb(addr(0), vmm | k_mask); + } else { + Vmm zero(aux_vec_idxs[0]); + h->uni_vpxor(zero, zero, zero); + h->uni_vpmaxsd(vmm, vmm, zero); + h->vpmovusdb(addr(0), vmm | k_mask); + } + } else { + store_dword_to_byte_base(); + } + break; + } +} - if (is_signed) - h->uni_vpacksswb(vmm, vmm, vmm); - else - h->uni_vpackuswb(vmm, vmm, vmm); +/** +* store_dword_to_word_extension is the utility function to +* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged saturation. +* 2. store the packed words into the memory referenced by ptr[reg + offset] address. +*/ +template +void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 ®, + int offset, bool is_bf16, bool is_signed, int store_num) const { + constexpr bool is_xmm = std::is_same::value; + constexpr bool is_ymm = std::is_same::value; + constexpr bool is_zmm = std::is_same::value; - store_bytes(vmm, reg, offset, store_num); - }; + MAYBE_UNUSED(is_xmm); + MAYBE_UNUSED(is_ymm); + MAYBE_UNUSED(is_zmm); + + // Ensure data fits completely inside the Xmm/Ymm/Zmm register + // At most 4 dwords can fit inside the Xmm register + // At most 8 dwords can fit inside the Ymm register + if (store_num < 0 || store_num > 16) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store in store_dword_to_word_extension."; + if (is_ymm && store_num > 8) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to ymm in store_dword_to_word_extension."; + if (is_xmm && store_num > 4) + IE_THROW() << "Store emitter in " << name_ << " has unexpected number of values to store to xmm in store_dword_to_word_extension."; + auto xmm = Xbyak::Xmm(vmm.getIdx()); + auto ymm = Xbyak::Ymm(vmm.getIdx()); + auto zmm = Xbyak::Zmm(vmm.getIdx()); + + auto store_dword_to_word_base = [&]() { + // direct mov_dw available only on avx512, emulate with pack_dw + permute + pure store + if (is_signed) + h->uni_vpackssdw(vmm, vmm, vmm); + else + h->uni_vpackusdw(vmm, vmm, vmm); + // gather 2/4(cross lane) 64 bits into lower vmm to store + // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] + // [ 128 | 128 ] |--> [ 128 | 128 ] + if (is_ymm) { + h->vpermq(ymm, ymm, 0x08); // 00001000 + } + + store_bytes(vmm, reg, offset, store_num * 2); + }; + + if (is_bf16) { + // to avoid src vmm pollution + if (src_prc_ == Precision::FP32) { + ymm = Ymm(aux_vec_idxs[0]); + } + if (mayiuse(cpu::x64::avx512_core_bf16)) { + h->vcvtneps2bf16(ymm, zmm); + } else { + emu_vcvtneps2bf16_->emit_code({static_cast(vmm.getIdx())}, {static_cast(ymm.getIdx())}); + } + if (store_num == 16) { + h->vmovdqu16(ptr[reg + offset], ymm); + } else { + store_bytes(ymm, reg, offset, store_num * 2); + } + } else { switch (store_num) { case 16: - // must support avx512F if (is_signed) { - h->vpmovsdb(addr(0), vmm); + h->vpmovsdw(ptr[reg + offset], vmm); // singed int32 saturate to signed int16. } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, vmm, zero); - h->vpmovusdb(addr(0), vmm); + h->uni_vpmaxsd(vmm, zero, vmm); // if singed bit is 1, set value as 0. + h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. } break; case 8: - if (mayiuse(cpu::x64::avx512_core)) { // ymm block on avx512F + VL + if (mayiuse(cpu::x64::avx512_core)) { if (is_signed) { - h->vpmovsdb(addr(0), ymm); + h->vpmovsdw(ptr[reg + offset], ymm); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(ymm, ymm, zero); - h->vpmovusdb(addr(0), ymm); + h->uni_vpmaxsd(ymm, zero, ymm); + h->vpmovusdw(ptr[reg + offset], ymm); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; case 4: - if (mayiuse(cpu::x64::avx512_core)) { // xmm block on avx512F + VL + if (mayiuse(cpu::x64::avx512_core)) { if (is_signed) { - h->vpmovsdb(addr(0), xmm); + h->vpmovsdw(ptr[reg + offset], xmm); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(xmm, xmm, zero); - h->vpmovusdb(addr(0), xmm); + h->uni_vpmaxsd(xmm, zero, xmm); + h->vpmovusdw(ptr[reg + offset], xmm); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; default: - if (is_zmm) { // avx512F + if (is_zmm) { unsigned int mask = 1; mask = (mask << store_num) - mask; h->mov(Reg32(aux_gpr_idxs[0]), mask); h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); if (is_signed) { - h->vpmovsdb(addr(0), vmm | k_mask); + h->vpmovsdw(ptr[reg + offset], vmm | k_mask); } else { Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, vmm, zero); - h->vpmovusdb(addr(0), vmm | k_mask); + h->uni_vpmaxsd(vmm, zero, vmm); + h->vpmovusdw(ptr[reg + offset], vmm | k_mask); } } else { - store_dword_to_byte_base(); + store_dword_to_word_base(); } break; } } - -/** -* store_dword_to_word_extension is the utility function to -* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged saturation. -* 2. store the packed words into the memory referenced by ptr[reg + offset] address. -*/ -template - void jit_store_emitter::store_dword_to_word_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, - bool is_bf16, bool is_signed, int store_num) const { - constexpr bool is_xmm = std::is_same::value; - constexpr bool is_ymm = std::is_same::value; - constexpr bool is_zmm = std::is_same::value; - - MAYBE_UNUSED(is_xmm); - MAYBE_UNUSED(is_ymm); - MAYBE_UNUSED(is_zmm); - - // Ensure data fits completely inside the Xmm/Ymm/Zmm register - // At most 4 dwords can fit inside the Xmm register - // At most 8 dwords can fit inside the Ymm register - if (store_num < 0 || store_num > 16) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store in store_dword_to_word_extension."; - if (is_ymm && store_num > 8) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to ymm in store_dword_to_word_extension."; - if (is_xmm && store_num > 4) - IE_THROW() << "Store emitter in " << name << " has unexpected number of values to store to xmm in store_dword_to_word_extension."; - - auto xmm = Xbyak::Xmm(vmm.getIdx()); - auto ymm = Xbyak::Ymm(vmm.getIdx()); - auto zmm = Xbyak::Zmm(vmm.getIdx()); - - auto store_dword_to_word_base = [&]() { - // direct mov_dw available only on avx512, emulate with pack_dw + permute + pure store - if (is_signed) - h->uni_vpackssdw(vmm, vmm, vmm); - else - h->uni_vpackusdw(vmm, vmm, vmm); - // gather 2/4(cross lane) 64 bits into lower vmm to store - // [y_3 y_2 y_1 y_0] |--> [y_0 y_0 y_2 y_0] - // [ 128 | 128 ] |--> [ 128 | 128 ] - if (is_ymm) { - h->vpermq(ymm, ymm, 0x08); // 00001000 - } - - store_bytes(vmm, reg, offset, store_num * 2); - }; - - if (is_bf16) { - if (mayiuse(cpu::x64::avx512_core_bf16)) { - h->vcvtneps2bf16(ymm, zmm); - } else { - emu_vcvtneps2bf16->emit_code({static_cast(vmm.getIdx())}, {static_cast(ymm.getIdx())}); - } - if (store_num == 16) { - h->vmovdqu16(ptr[reg + offset], ymm); - } else { - store_bytes(ymm, reg, offset, store_num * 2); - } - } else { - switch (store_num) { - case 16: - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], vmm); // singed int32 saturate to signed int16. - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, zero, vmm); // if singed bit is 1, set value as 0. - h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. - } - break; - case 8: - if (mayiuse(cpu::x64::avx512_core)) { - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], ymm); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(ymm, zero, ymm); - h->vpmovusdw(ptr[reg + offset], ymm); - } - } else { - store_dword_to_word_base(); - } - break; - case 4: - if (mayiuse(cpu::x64::avx512_core)) { - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], xmm); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(xmm, zero, xmm); - h->vpmovusdw(ptr[reg + offset], xmm); - } - } else { - store_dword_to_word_base(); - } - break; - default: - if (is_zmm) { - unsigned int mask = 1; - mask = (mask << store_num) - mask; - h->mov(Reg32(aux_gpr_idxs[0]), mask); - h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); - if (is_signed) { - h->vpmovsdw(ptr[reg + offset], vmm | k_mask); - } else { - Vmm zero(aux_vec_idxs[0]); - h->uni_vpxor(zero, zero, zero); - h->uni_vpmaxsd(vmm, zero, vmm); - h->vpmovusdw(ptr[reg + offset], vmm | k_mask); - } - } else { - store_dword_to_word_base(); - } - break; - } - } - } +} } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp index 2abcf8e0ca6bd4..3784a343d3fbe2 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_load_store_emitters.hpp @@ -15,40 +15,37 @@ using namespace InferenceEngine; namespace ov { namespace intel_cpu { -struct load_emitter_context : public emitter_context { - load_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), load_num_(8), - offset_byte_(0), is_fill_(false), fill_value_("zero") {} +struct load_emitter_params : public emitter_params { + load_emitter_params(Precision src_prc, Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero"): + src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value) {} - load_emitter_context(Precision src_prc, Precision dst_prc, int load_num, int offset_byte = 0, bool is_fill = false, std::string fill_value = "zero"): - src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), offset_byte_(offset_byte), is_fill_(is_fill), fill_value_(fill_value) {} + size_t hash() const override; - int offset_byte_; - int load_num_; Precision src_prc_; Precision dst_prc_; + int load_num_; bool is_fill_; std::string fill_value_; }; -struct store_emitter_context : public emitter_context { - store_emitter_context() : src_prc_(Precision::FP32), dst_prc_(Precision::FP32), - store_num_(8), offset_byte_(0) {} +struct store_emitter_params : public emitter_params { + store_emitter_params(Precision src_prc, Precision dst_prc, int store_num): + src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num) {} - store_emitter_context(Precision src_prc, Precision dst_prc, int store_num, int offset_byte = 0) - : src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num), offset_byte_(offset_byte) {} + size_t hash() const override; - int offset_byte_; - int store_num_; Precision src_prc_; Precision dst_prc_; + int store_num_; }; class jit_load_emitter : public jit_emitter { public: - jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); + jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int load_num, + Precision exec_prc = Precision::FP32, bool is_fill = false, std::string fill_value = "zero", + emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); /** - * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc. + * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1] * is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values. * fill_value: when load_num can not fully fit in vector register, what values should be filled as default values. * currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max". @@ -66,27 +63,23 @@ class jit_load_emitter : public jit_emitter { * dst_prc */ void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const override; + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const override; size_t get_inputs_num() const override; private: template - void emit_isa(const Xbyak::Reg64 ®_src, int offset_byte, InferenceEngine::Precision src_prc, - const int out_vec_idx, InferenceEngine::Precision dst_prc, int load_num, bool is_fill = false, std::string fill_value = "zero") const; + void emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const; template - void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const; template - void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const; template - void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size, - bool is_fill = false, std::string fill_value = "zero") const; + void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_bf16, bool is_signed, int load_size) const; template void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const; @@ -95,17 +88,23 @@ class jit_load_emitter : public jit_emitter { size_t aux_gprs_count() const override; - std::string name; - int v_len_elt; // 4/8/16 + std::string name_; + int v_len_elt_; // 4/8/16 + int load_num_; + int load_size_; + Precision src_prc_; + Precision dst_prc_; + bool is_fill_; + std::string fill_value_; }; class jit_store_emitter : public jit_emitter { public: - jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); + jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, Precision src_prc, Precision dst_prc, int store_num, + Precision exec_prc = Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); /** - * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data. + * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, where offset_byte is in_idxs[1] * supported src_prc and dst_prc pairs are as below(x indicate for support): * FP32 I32 I16 U16 I8 U8 BF16 --> src_prc * FP32 x x @@ -120,21 +119,20 @@ class jit_store_emitter : public jit_emitter { * note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform */ void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, - const emitter_context *emit_context) const override; + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const override; size_t get_inputs_num() const override; void emit_data() const override; std::shared_ptr get_emu_vcvtneps2bf16() const { - return emu_vcvtneps2bf16; + return emu_vcvtneps2bf16_; } private: template - void emit_isa(const int in_vec_idx, InferenceEngine::Precision src_prc, - const Xbyak::Reg64 ®_dst, int offset_byte, InferenceEngine::Precision dst_prc, int store_num) const; + void emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const; template void store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int store_size) const; @@ -148,9 +146,13 @@ class jit_store_emitter : public jit_emitter { size_t aux_gprs_count() const override; size_t aux_vecs_count() const override; - std::string name; - int v_len_elt; // 4/8/16 - std::shared_ptr emu_vcvtneps2bf16; + std::string name_; + int v_len_elt_; // 4/8/16 + int store_num_; + int store_size_; + Precision src_prc_; + Precision dst_prc_; + std::shared_ptr emu_vcvtneps2bf16_; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.cpp b/src/plugins/intel_cpu/src/nodes/interpolate.cpp index 0c0df6ec67b4a8..7e562b1161470b 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/interpolate.cpp @@ -58,9 +58,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); - // dummy second reg_tmp_64 as no fill needed load_pool_gpr_idxs = {static_cast(reg_tmp_64.getIdx()), static_cast(reg_tmp_64.getIdx())}; store_pool_gpr_idxs = {static_cast(reg_tmp_64.getIdx())}; @@ -162,8 +159,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); for (auto& inj : eltwise_injectors) inj->prepare_table(); if ((jcp_.mode == InterpolateMode::cubic) && (jcp_.layout == InterpolateLayoutType::planar)) { @@ -176,6 +172,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Ymm, Xbyak::Zmm>::type; const int vlen = cpu_isa_traits::vlen; + const int vector_step = vlen / sizeof(float); + const int tail_step = jcp_.C % vector_step; + const int scalar_step = 1; Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_src_aux = r15; @@ -246,8 +245,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Label l_table_constant; Opmask k_mask = Xbyak::Opmask(1); - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unordered_map> emitters; + std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; std::vector load_pool_gpr_idxs; @@ -256,20 +255,44 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi std::vector>> depthwise_injectors; std::vector>> quantization_injectors; - inline void load(const Xbyak::Reg64& reg_src, Vmm& vmm, const int& elt_num, const int& offset = 0) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm.getIdx())}, - std::make_shared(jcp_.src_prc, Precision::FP32, elt_num, offset), - {}, {load_pool_gpr_idxs}); + void emit_emitters_data() { + for (const auto& emitter : emitters) { + if (emitter.second) + emitter.second->emit_data(); + } + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, jcp_.src_prc, Precision::FP32, elt_num, offset); + } + + inline void load_weights(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset); } - inline void store(const Vmm& vmm, const Xbyak::Reg64& reg_dst, const int& elt_num, const int& offset = 0) { - store_emitter->emit_code({static_cast(vmm.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.dst_prc, elt_num, offset), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + + inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); } - inline void load_weights(const Xbyak::Reg64& reg_weights, Vmm& vmm, const int& elt_num, const int& offset = 0) { - load_emitter->emit_code({static_cast(reg_weights.getIdx())}, {static_cast(vmm.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num, offset), - {}, {load_pool_gpr_idxs}); + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(Precision::FP32, jcp_.dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, elt_num)); + } + + // for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm + std::vector local_store_pool_vec_idxs = { static_cast(vmm_dst.getIdx()) }; + local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end()); + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, + {local_store_pool_vec_idxs}, {store_pool_gpr_idxs}); } void nn_planar() { @@ -303,7 +326,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // reset index_w, index_w * dataSize done when built to avoid redundent compute mov(reg_index, reg_index_w); - int step = vlen / sizeof(float); Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; @@ -312,7 +334,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(nn_loop_label); // inner loop { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(nn_loop_end_label, T_NEAR); uni_vmovdqu(vmm_index, ptr[reg_index]); @@ -320,17 +342,16 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi vgatherdps(vmm_val, ptr[reg_src_h + vmm_index], vmm_mask); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 1); - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, vector_step); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, vector_step * jcp_.dst_data_size); + add(reg_index, vector_step * jcp_.indices_size); + sub(reg_work_amount, vector_step); jmp(nn_loop_label, T_NEAR); } L(nn_loop_end_label); - step = 1; L(nn_tail_loop_label); { cmp(reg_work_amount, 1); @@ -340,14 +361,14 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_index_offset, dword[reg_index]); add(reg_src_aux, reg_index_offset); - load(reg_src_aux, vmm_val, step); + load(reg_src_aux, vmm_val, scalar_step); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 1); - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar_step); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, scalar_step * jcp_.dst_data_size); + add(reg_index, scalar_step * jcp_.indices_size); + sub(reg_work_amount, scalar_step); jmp(nn_tail_loop_label, T_NEAR); } @@ -363,8 +384,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } void nn_blk() { - int step = vlen / sizeof(float); - Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; L(nn_loop_label); @@ -376,22 +395,22 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_index_offset, dword[reg_index]); add(reg_src_aux, reg_index_offset); - load(reg_src_aux, vmm_val, step); + load(reg_src_aux, vmm_val, vector_step); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 0); - store(vmm_val, reg_dst, step); - add(reg_dst, step * jcp_.dst_data_size); + store(vmm_val, reg_dst, vector_step); + add(reg_dst, vector_step * jcp_.dst_data_size); if (isa == cpu::x64::sse41) { - add(reg_src_aux, step * jcp_.src_data_size); - load(reg_src_aux, vmm_val, step); + add(reg_src_aux, vector_step * jcp_.src_data_size); + load(reg_src_aux, vmm_val, vector_step); if (attr_.post_ops_.len() != 0) { - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, vector_step * sizeof(float)); apply_post_ops(jcp_.dst_prc, 0); - sub(reg_oc_off, step * sizeof(float)); + sub(reg_oc_off, vector_step * sizeof(float)); } - store(vmm_val, reg_dst, step); - add(reg_dst, step * jcp_.dst_data_size); + store(vmm_val, reg_dst, vector_step); + add(reg_dst, vector_step * jcp_.dst_data_size); } add(reg_index, jcp_.indices_size); @@ -421,8 +440,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cmp(reg_work_amount_out, 1); jl(out_loop_end, T_NEAR); - int step = vlen / sizeof(float); - //inner loop for C Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; @@ -444,35 +461,34 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(nn_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(nn_loop_end_label, T_NEAR); - load(reg_src_aux, vmm_val, step); + load(reg_src_aux, vmm_val, vector_step); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 0); - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, vector_step); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_src_aux, step * jcp_.src_data_size); - add(reg_oc_off, step * sizeof(float)); - sub(reg_work_amount, step); + add(reg_dst, vector_step * jcp_.dst_data_size); + add(reg_src_aux, vector_step * jcp_.src_data_size); + add(reg_oc_off, vector_step * sizeof(float)); + sub(reg_work_amount, vector_step); jmp(nn_loop_label, T_NEAR); } L(nn_loop_end_label); - int tail_num = jcp_.C % step; - if (tail_num != 0) { - load(reg_src_aux, vmm_val, tail_num); + if (tail_step != 0) { + load(reg_src_aux, vmm_val, tail_step); if (attr_.post_ops_.len() != 0) apply_post_ops(jcp_.dst_prc, 0); - store(vmm_val, reg_dst, tail_num); + store(vmm_val, reg_dst, tail_step); // check to remove below - add(reg_dst, tail_num * jcp_.dst_data_size); - add(reg_src_aux, tail_num * jcp_.src_data_size); - add(reg_oc_off, tail_num * sizeof(float)); - sub(reg_work_amount, tail_num); + add(reg_dst, tail_step * jcp_.dst_data_size); + add(reg_src_aux, tail_step * jcp_.src_data_size); + add(reg_oc_off, tail_step * sizeof(float)); + sub(reg_work_amount, tail_step); } add(reg_index, jcp_.indices_size); sub(reg_work_amount_out, 1); @@ -519,11 +535,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); - int blk = (isa == cpu::x64::sse41) ? (2 * step) : step; - int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.dst_data_size) : + int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step; + int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.dst_data_size) : (blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size); - int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (step * jcp_.src_data_size) : + int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.src_data_size) : (blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size); Xbyak::Label main_loop_label; @@ -535,29 +550,29 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(main_loop_label); { if (jcp_.layout == InterpolateLayoutType::by_channel) { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(main_loop_end_label, T_NEAR); } else { cmp(reg_work_amount, 1); jl(main_loop_end_label, T_NEAR); } // progressive manner - load(reg_src, vmm_valTL, step); - load(reg_src_aux, vmm_valTR, step); + load(reg_src, vmm_valTL, vector_step); + load(reg_src_aux, vmm_valTR, vector_step); if (jcp_.spatial_dim_size == 1) { linear_onnx_worker_1d(); } if (jcp_.spatial_dim_size > 1) { - load(reg_src_aux1, vmm_valBL, step); - load(reg_src_aux2, vmm_valBR, step); + load(reg_src_aux1, vmm_valBL, vector_step); + load(reg_src_aux2, vmm_valBR, vector_step); linear_onnx_worker_2d(); } if (jcp_.spatial_dim_size > 2) { uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm - load(reg_src_aux4, vmm_valTL, step); - load(reg_src_aux5, vmm_valTR, step); - load(reg_src_aux6, vmm_valBL, step); - load(reg_src_aux7, vmm_valBR, step); + load(reg_src_aux4, vmm_valTL, vector_step); + load(reg_src_aux5, vmm_valTR, vector_step); + load(reg_src_aux6, vmm_valBL, vector_step); + load(reg_src_aux7, vmm_valBR, vector_step); // 2d for end depth linear_onnx_worker_2d(); @@ -568,28 +583,28 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, vector_step * sizeof(float)); } - store(vmm_valTR, reg_dst, step); + store(vmm_valTR, reg_dst, vector_step); if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) { - int offset_src = step * jcp_.src_data_size; - load(reg_src, vmm_valTL, step, offset_src); - load(reg_src_aux, vmm_valTR, step, offset_src); + int offset_src = vector_step * jcp_.src_data_size; + load(reg_src, vmm_valTL, vector_step, offset_src); + load(reg_src_aux, vmm_valTR, vector_step, offset_src); if (jcp_.spatial_dim_size == 1) { linear_onnx_worker_1d(); } if (jcp_.spatial_dim_size > 1) { - load(reg_src_aux1, vmm_valBL, step, offset_src); - load(reg_src_aux2, vmm_valBR, step, offset_src); + load(reg_src_aux1, vmm_valBL, vector_step, offset_src); + load(reg_src_aux2, vmm_valBR, vector_step, offset_src); linear_onnx_worker_2d(); } if (jcp_.spatial_dim_size > 2) { uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm - load(reg_src_aux4, vmm_valTL, step, offset_src); - load(reg_src_aux5, vmm_valTR, step, offset_src); - load(reg_src_aux6, vmm_valBL, step, offset_src); - load(reg_src_aux7, vmm_valBR, step, offset_src); + load(reg_src_aux4, vmm_valTL, vector_step, offset_src); + load(reg_src_aux5, vmm_valTR, vector_step, offset_src); + load(reg_src_aux6, vmm_valBL, vector_step, offset_src); + load(reg_src_aux7, vmm_valBR, vector_step, offset_src); // 2d for end depth linear_onnx_worker_2d(); // 3th dimension @@ -599,10 +614,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, vector_step * sizeof(float)); } - int offset_dst = step * jcp_.dst_data_size; - store(vmm_valTR, reg_dst, step, offset_dst); + int offset_dst = vector_step * jcp_.dst_data_size; + store(vmm_valTR, reg_dst, vector_step, offset_dst); } add(reg_dst, dst_stride); add(reg_src, src_stride); @@ -618,7 +633,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi add(reg_src_aux7, src_stride); } if (jcp_.layout == InterpolateLayoutType::by_channel) { - sub(reg_work_amount, step); // work_amount is c + sub(reg_work_amount, vector_step); // work_amount is c } else { sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails } @@ -627,25 +642,24 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } L(main_loop_end_label); - int tail_num = jcp_.C % step; - if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_num != 0)) { - load(reg_src, vmm_valTL, tail_num); - load(reg_src_aux, vmm_valTR, tail_num); + if ((jcp_.layout == InterpolateLayoutType::by_channel) && (tail_step != 0)) { + load(reg_src, vmm_valTL, tail_step); + load(reg_src_aux, vmm_valTR, tail_step); if (jcp_.spatial_dim_size == 1) { linear_onnx_worker_1d(); } if (jcp_.spatial_dim_size > 1) { - load(reg_src_aux1, vmm_valBL, tail_num); - load(reg_src_aux2, vmm_valBR, tail_num); + load(reg_src_aux1, vmm_valBL, tail_step); + load(reg_src_aux2, vmm_valBR, tail_step); linear_onnx_worker_2d(); } if (jcp_.spatial_dim_size > 2) { uni_vmovups(vmm_d_bias, vmm_valTR); // temporally save front result to temp_vmm - load(reg_src_aux4, vmm_valTL, tail_num); - load(reg_src_aux5, vmm_valTR, tail_num); - load(reg_src_aux6, vmm_valBL, tail_num); - load(reg_src_aux7, vmm_valBR, tail_num); + load(reg_src_aux4, vmm_valTL, tail_step); + load(reg_src_aux5, vmm_valTR, tail_step); + load(reg_src_aux6, vmm_valBL, tail_step); + load(reg_src_aux7, vmm_valBR, tail_step); // 2d for end depth linear_onnx_worker_2d(); // 3th dimension @@ -655,10 +669,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); // vmm_val is vmm_valTR - add(reg_oc_off, tail_num * sizeof(float)); + add(reg_oc_off, tail_step * sizeof(float)); } - store(vmm_valTR, reg_dst, tail_num); + store(vmm_valTR, reg_dst, tail_step); } } @@ -669,7 +683,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux, ptr[reg_params + GET_OFF(weight_ptr[0])]); mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); int index_stride = jcp_.OW * jcp_.OH * jcp_.OD * jcp_.indices_size; int weight_stride = jcp_.OW * jcp_.OH * jcp_.OD * sizeof(float); @@ -679,7 +692,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Label tail_loop_end_label; L(main_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(main_loop_end_label, T_NEAR); uni_vmovdqu(vmm_index, ptr[reg_index]); @@ -690,8 +703,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); vgatherdps(vmm_valTR, ptr[reg_src + vmm_index], vmm_mask); - load_weights(reg_src_aux, vmm_weightL, step); - load_weights(reg_src_aux, vmm_weightR, step, weight_stride); + load_weights(reg_src_aux, vmm_weightL, vector_step); + load_weights(reg_src_aux, vmm_weightR, vector_step, weight_stride); // progressive manner if (jcp_.spatial_dim_size == 1) { @@ -706,8 +719,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); vgatherdps(vmm_valBR, ptr[reg_src + vmm_index], vmm_mask); - load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride); - load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride); + load_weights(reg_src_aux, vmm_weightT, vector_step, 2 * weight_stride); + load_weights(reg_src_aux, vmm_weightB, vector_step, 3 * weight_stride); linear_onnx_worker_2d(); } @@ -733,8 +746,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi linear_onnx_worker_2d(); - load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride); - load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride); + load_weights(reg_src_aux, vmm_weightE, vector_step, 5 * weight_stride); + load_weights(reg_src_aux, vmm_weightF, vector_step, 4 * weight_stride); uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight @@ -743,18 +756,17 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // vmm_val is vmm_valTR, broadcase is true } - store(vmm_valTR, reg_dst, step); + store(vmm_valTR, reg_dst, vector_step); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_src_aux, step * sizeof(float)); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, vector_step * jcp_.dst_data_size); + add(reg_src_aux, vector_step * sizeof(float)); + add(reg_index, vector_step * jcp_.indices_size); + sub(reg_work_amount, vector_step); jmp(main_loop_label, T_NEAR); } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -763,15 +775,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTL, step); + load(reg_src_aux1, vmm_valTL, scalar_step); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTR, step); + load(reg_src_aux1, vmm_valTR, scalar_step); - load_weights(reg_src_aux, vmm_weightL, step, 0); - load_weights(reg_src_aux, vmm_weightR, step, weight_stride); + load_weights(reg_src_aux, vmm_weightL, scalar_step, 0); + load_weights(reg_src_aux, vmm_weightR, scalar_step, weight_stride); if (jcp_.spatial_dim_size == 1) { linear_onnx_worker_1d(); @@ -780,15 +792,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 2 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBL, step); + load(reg_src_aux1, vmm_valBL, scalar_step); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 3 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBR, step); + load(reg_src_aux1, vmm_valBR, scalar_step); - load_weights(reg_src_aux, vmm_weightT, step, 2 * weight_stride); - load_weights(reg_src_aux, vmm_weightB, step, 3 * weight_stride); + load_weights(reg_src_aux, vmm_weightT, scalar_step, 2 * weight_stride); + load_weights(reg_src_aux, vmm_weightB, scalar_step, 3 * weight_stride); linear_onnx_worker_2d(); } @@ -799,27 +811,27 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 4 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTL, step); + load(reg_src_aux1, vmm_valTL, scalar_step); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 5 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valTR, step); + load(reg_src_aux1, vmm_valTR, scalar_step); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 6 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBL, step); + load(reg_src_aux1, vmm_valBL, scalar_step); mov(reg_src_aux1, reg_src); mov(reg_index_offset, dword[reg_index + 7 * index_stride]); add(reg_src_aux1, reg_index_offset); - load(reg_src_aux1, vmm_valBR, step); + load(reg_src_aux1, vmm_valBR, scalar_step); linear_onnx_worker_2d(); - load_weights(reg_src_aux, vmm_weightE, step, 5 * weight_stride); - load_weights(reg_src_aux, vmm_weightF, step, 4 * weight_stride); + load_weights(reg_src_aux, vmm_weightE, scalar_step, 5 * weight_stride); + load_weights(reg_src_aux, vmm_weightF, scalar_step, 4 * weight_stride); uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight @@ -828,12 +840,12 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // process on vmm_val, vmm_val is vmm_valTR, and bc } - store(vmm_valTR, reg_dst, step); + store(vmm_valTR, reg_dst, scalar_step); - add(reg_dst, step * jcp_.dst_data_size); - add(reg_src_aux, step * sizeof(float)); - add(reg_index, step * jcp_.indices_size); - sub(reg_work_amount, step); + add(reg_dst, scalar_step * jcp_.dst_data_size); + add(reg_src_aux, scalar_step * sizeof(float)); + add(reg_index, scalar_step * jcp_.indices_size); + sub(reg_work_amount, scalar_step); jmp(tail_loop_label, T_NEAR); } @@ -876,8 +888,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vbroadcastss(vmm_weightY2, ptr[reg_src_aux1 + 2 * sizeof(float)]); uni_vbroadcastss(vmm_weightY3, ptr[reg_src_aux1 + 3 * sizeof(float)]); - int step = vlen / sizeof(float); - int blk = (isa == cpu::x64::sse41) ? (2 * step) : step; + int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step; Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; @@ -886,7 +897,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(main_loop_label); { if (jcp_.layout == InterpolateLayoutType::by_channel) { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(main_loop_end_label, T_NEAR); } else { cmp(reg_work_amount, 1); @@ -899,14 +910,14 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, vector_step * sizeof(float)); } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, vector_step); if ((isa == cpu::x64::sse41) && (jcp_.layout == InterpolateLayoutType::block)) { // vmm is xmm here - add(reg_src, step * jcp_.src_data_size); - add(reg_dst, step * jcp_.dst_data_size); + add(reg_src, vector_step * jcp_.src_data_size); + add(reg_dst, vector_step * jcp_.dst_data_size); uni_vpxor(vmm_val, vmm_val, vmm_val); @@ -914,19 +925,19 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); - add(reg_oc_off, step * sizeof(float)); // second step for one blk + add(reg_oc_off, vector_step * sizeof(float)); // second vector_step for one blk } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, vector_step); - sub(reg_src, step * jcp_.src_data_size); - sub(reg_dst, step * jcp_.dst_data_size); + sub(reg_src, vector_step * jcp_.src_data_size); + sub(reg_dst, vector_step * jcp_.dst_data_size); } if (jcp_.layout == InterpolateLayoutType::by_channel) { - int dst_stride = step * jcp_.dst_data_size; - int src_stride = step * jcp_.src_data_size; + int dst_stride = vector_step * jcp_.dst_data_size; + int src_stride = vector_step * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); - sub(reg_work_amount, step); // work_amount is c + sub(reg_work_amount, vector_step); // work_amount is c } else { int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size; int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size; @@ -940,7 +951,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi L(main_loop_end_label); // only for by_channel layout for tails. - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -953,15 +963,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value - add(reg_oc_off, step * sizeof(float)); + add(reg_oc_off, scalar_step * sizeof(float)); } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar_step); - int dst_stride = step * jcp_.dst_data_size; - int src_stride = step * jcp_.src_data_size; + int dst_stride = scalar_step * jcp_.dst_data_size; + int src_stride = scalar_step * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); - sub(reg_work_amount, step); // work_amount is c + sub(reg_work_amount, scalar_step); // work_amount is c jmp(tail_loop_label, T_NEAR); } @@ -1020,7 +1030,6 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_weight_y, ptr[reg_params + GET_OFF(weight_ptr[0]) + sizeof(size_t)]); mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - int step = vlen / sizeof(float); int grid_len = 4; // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 @@ -1035,7 +1044,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Label tail_loop_end_label; L(main_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(main_loop_end_label, T_NEAR); // vmm_tbl_y: (0 0 0 0 1 1 1 1 * index_size) --> (0 0 0 0 4 4 4 4) @@ -1111,19 +1120,18 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, vector_step); - add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence by dd() - add(reg_tbl_x, step * sizeof(int)); - add(reg_dst, step * jcp_.dst_data_size); + add(reg_tbl_y, vector_step * sizeof(int)); // sizeof(int): sequence by dd() + add(reg_tbl_x, vector_step * sizeof(int)); + add(reg_dst, vector_step * jcp_.dst_data_size); - sub(reg_work_amount, step); + sub(reg_work_amount, vector_step); jmp(main_loop_label, T_NEAR); } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_work_amount, 1); @@ -1182,13 +1190,13 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi if (attr_.post_ops_.len() != 0) { apply_post_ops(jcp_.dst_prc, true); // oc_off is broadcast and always the same value for this channel } - store(vmm_val, reg_dst, step); + store(vmm_val, reg_dst, scalar_step); - add(reg_tbl_y, step * sizeof(int)); // sizeof(int): sequence with dd() - add(reg_tbl_x, step * sizeof(int)); - add(reg_dst, step * jcp_.dst_data_size); + add(reg_tbl_y, scalar_step * sizeof(int)); // sizeof(int): sequence with dd() + add(reg_tbl_x, scalar_step * sizeof(int)); + add(reg_dst, scalar_step * jcp_.dst_data_size); - sub(reg_work_amount, step); + sub(reg_work_amount, scalar_step); jmp(tail_loop_label, T_NEAR); } @@ -1264,7 +1272,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi return ptr[reg_table + index * vlen]; } - // always gather to Vmm, compute with Vmm, store with Xmm if scalar + // always gather to Vmm, compute with Vmm, store with Xmm if scalar_step inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale, Precision src_prc, bool is_scalar) { Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale]; diff --git a/src/plugins/intel_cpu/src/nodes/mvn.cpp b/src/plugins/intel_cpu/src/nodes/mvn.cpp index 506dbea8547733..70d2008cb9137c 100644 --- a/src/plugins/intel_cpu/src/nodes/mvn.cpp +++ b/src/plugins/intel_cpu/src/nodes/mvn.cpp @@ -110,7 +110,13 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); + tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step : + jcp_.C - (jcp_.C / vector_step) * vector_step; + + Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; + load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, vector_step)); + load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step)); + load_tail_with_fill_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, dst_prc, tail_step, Precision::FP32, true)); this->preamble(); mov(reg_src, ptr[reg_params + GET_OFF(src)]); @@ -134,14 +140,11 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } } - tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : - jcp_.C - (jcp_.C / step) * step; - load_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx()), static_cast(reg_load_table.getIdx())}; if (jcp_.planar_layout) { worker_unroll(); - if (tail_num != 0) { + if (tail_step != 0) { worker_tail_planar(); } @@ -198,7 +201,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } Xbyak::Label label_empty_2half_sse42; - if (tail_num == 0) { + if (tail_step == 0) { cmp(reg_oc_off, static_cast(jcp_.C * sizeof(float))); jae(label_empty_2half_sse42, T_NEAR); @@ -210,7 +213,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k Xbyak::Label label_full_size; Xbyak::Label label_size_end; - cmp(reg_oc_off, static_cast((jcp_.C - step) * sizeof(float))); + cmp(reg_oc_off, static_cast((jcp_.C - vector_step) * sizeof(float))); jle(label_full_size, T_NEAR); // no need care and fill rest @@ -251,7 +254,9 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k this->postamble(); - load_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_tail_emitter->emit_data(); + load_tail_with_fill_emitter->emit_data(); } private: @@ -259,8 +264,8 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k Xbyak::Ymm, Xbyak::Zmm>::type; const int vlen = cpu_isa_traits::vlen; - const int step = vlen / sizeof(float); - int tail_num = 0; + const int vector_step = vlen / sizeof(float); + int tail_step = 0; Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_mean = r9; @@ -286,15 +291,15 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k Xbyak::Opmask k_mask = Xbyak::Opmask(7); - std::unique_ptr load_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_tail_emitter = nullptr; + std::unique_ptr load_tail_with_fill_emitter = nullptr; std::vector load_pool_gpr_idxs; inline void worker_full_size() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, step), - {}, {load_pool_gpr_idxs}); + load_vector_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, + {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { // all with float @@ -313,9 +318,7 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } inline void worker_tail_blk() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, tail_num), + load_tail_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { @@ -357,10 +360,8 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k } inline void worker_tail_planar() { - Precision dst_prc = isFloatCompatible(jcp_.src_prc) ? Precision::FP32 : Precision::I32; - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, dst_prc, tail_num, 0, true), - {}, {load_pool_gpr_idxs}); + load_tail_with_fill_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, + {}, {load_pool_gpr_idxs}); if (jcp_.normalize_variance) { if (!isFloatCompatible(jcp_.src_prc)) @@ -371,15 +372,15 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k uni_vpxor(vmm_zero, vmm_zero, vmm_zero); if (isa == cpu::x64::sse41) { uint8 imm = 1; - imm = ~((imm << tail_num) - imm); + imm = ~((imm << tail_step) - imm); blendps(vmm_val, vmm_zero, imm); } else if (isa == cpu::x64::avx2) { uint8 imm = 1; - imm = ~((imm << tail_num) - imm); + imm = ~((imm << tail_step) - imm); vblendps(vmm_val, vmm_val, vmm_zero, imm); } else if (isa == cpu::x64::avx512_core) { uint64_t tail_mask = 1; - tail_mask = ~((tail_mask << tail_num) - tail_mask); + tail_mask = ~((tail_mask << tail_step) - tail_mask); mov(reg_aux, tail_mask); kmovq(k_mask, reg_aux); vblendmps(vmm_val | k_mask, vmm_val, vmm_zero); @@ -435,8 +436,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator } } - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + tail_step = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / vector_step) * vector_step : + jcp_.C - (jcp_.C / vector_step) * vector_step; + + load_vector_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, vector_step)); + load_tail_emitter.reset(new jit_load_emitter(this, isa, jcp_.src_prc, Precision::FP32, tail_step)); + store_vector_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, vector_step)); + store_tail_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jcp_.dst_prc, tail_step)); this->preamble(); @@ -463,16 +469,13 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator uni_vpxor(vmm_zero, vmm_zero, vmm_zero); - tail_num = jcp_.planar_layout ? (jcp_.D * jcp_.H * jcp_.W) - ((jcp_.D * jcp_.H * jcp_.W) / step) * step : - jcp_.C - (jcp_.C / step) * step; - load_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx()), static_cast(reg_load_table.getIdx())}; store_pool_gpr_idxs = {static_cast(reg_load_store_mask.getIdx())}; - store_pool_vec_idxs = {static_cast(vmm_zero.getIdx())}; + store_pool_vec_idxs = {static_cast(vmm_zero.getIdx()), static_cast(vmm_val.getIdx())}; if (jcp_.planar_layout) { worker_mvn_unroll(); - if (tail_num != 0) { + if (tail_step != 0) { worker_mvn(true); } } else { @@ -501,7 +504,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator } Xbyak::Label label_empty_2half_sse42; - if (tail_num == 0) { + if (tail_step == 0) { cmp(reg_oc_off, static_cast(jcp_.C * sizeof(float))); jae(label_empty_2half_sse42, T_NEAR); worker_mvn_unroll(); @@ -512,7 +515,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Xbyak::Label label_full_size_block; Xbyak::Label label_size_end; - cmp(reg_oc_off, static_cast((jcp_.C - step) * sizeof(float))); + cmp(reg_oc_off, static_cast((jcp_.C - vector_step) * sizeof(float))); jle(label_full_size_block, T_NEAR); worker_mvn_unroll(true); @@ -530,8 +533,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_tail_emitter->emit_data(); + store_vector_emitter->emit_data(); + store_tail_emitter->emit_data(); for (auto& inj : eltwise_injectors) inj->prepare_table(); @@ -542,8 +547,8 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Xbyak::Ymm, Xbyak::Zmm>::type; const int vlen = cpu_isa_traits::vlen; - const int step = vlen / sizeof(float); - int tail_num = 0; + const int vector_step = vlen / sizeof(float); + int tail_step = 0; Xbyak::Reg64 reg_src = r8; Xbyak::Reg64 reg_mean = r9; @@ -570,8 +575,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Vmm vmm_d_weights = Vmm(5); Vmm vmm_d_bias = Vmm(6); - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_tail_emitter = nullptr; + std::unique_ptr store_vector_emitter = nullptr; + std::unique_ptr store_tail_emitter = nullptr; std::vector>> eltwise_injectors; std::vector>> depthwise_injectors; @@ -582,9 +589,10 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator std::vector load_pool_gpr_idxs; inline void worker_mvn(bool is_tail) { - int elt_num = is_tail ? tail_num : step; + const auto& load_emitter = is_tail ? load_tail_emitter : load_vector_emitter; + const auto& store_emitter = is_tail ? store_tail_emitter : store_vector_emitter; + load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val.getIdx())}, - std::make_shared(jcp_.src_prc, Precision::FP32, elt_num), {}, {load_pool_gpr_idxs}); uni_vsubps(vmm_val, vmm_val, vmm_mean); @@ -594,7 +602,6 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator apply_post_ops(jcp_.dst_prc, jcp_.planar_layout); store_emitter->emit_code({static_cast(vmm_val.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.dst_prc, elt_num), {store_pool_vec_idxs}, {store_pool_gpr_idxs}); } diff --git a/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp b/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp index 632fa5572a6d02..f11da0b2fdb79a 100644 --- a/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp +++ b/src/plugins/intel_cpu/src/nodes/non_max_suppression.cpp @@ -44,8 +44,9 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + load_vector_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, vector_step)); + load_scalar_emitter.reset(new jit_load_emitter(this, isa, Precision::FP32, Precision::FP32, scalar_step)); + exp_injector.reset(new jit_uni_eltwise_injector_f32(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f)); this->preamble(); @@ -137,8 +138,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + load_vector_emitter->emit_data(); + load_scalar_emitter->emit_data(); prepare_table(); exp_injector->prepare_table(); @@ -147,6 +148,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator private: using Vmm = typename conditional3::type; uint32_t vlen = cpu_isa_traits::vlen; + const int vector_step = vlen / sizeof(float); + const int scalar_step = 1; Xbyak::Reg64 reg_boxes_coord0 = r8; Xbyak::Reg64 reg_boxes_coord1 = r9; @@ -172,8 +175,8 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator Xbyak::Reg64 reg_params = abi_param1; - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unique_ptr load_vector_emitter = nullptr; + std::unique_ptr load_scalar_emitter = nullptr; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -205,25 +208,24 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator std::shared_ptr> exp_injector; inline void hard_nms() { - int step = vlen / sizeof(float); Xbyak::Label main_loop_label_hard; Xbyak::Label main_loop_end_label_hard; Xbyak::Label tail_loop_label_hard; Xbyak::Label terminate_label_hard; L(main_loop_label_hard); { - cmp(reg_boxes_num, step); + cmp(reg_boxes_num, vector_step); jl(main_loop_end_label_hard, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, vector_step * sizeof(float)); + sub(reg_boxes_coord1, vector_step * sizeof(float)); + sub(reg_boxes_coord2, vector_step * sizeof(float)); + sub(reg_boxes_coord3, vector_step * sizeof(float)); // iou result is in vmm_temp3 - iou(step); + iou(vector_step); - sub(reg_boxes_num, step); + sub(reg_boxes_num, vector_step); suppressed_by_iou(false); @@ -236,21 +238,20 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } L(main_loop_end_label_hard); - step = 1; L(tail_loop_label_hard); { cmp(reg_boxes_num, 1); jl(terminate_label_hard, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, scalar_step * sizeof(float)); + sub(reg_boxes_coord1, scalar_step * sizeof(float)); + sub(reg_boxes_coord2, scalar_step * sizeof(float)); + sub(reg_boxes_coord3, scalar_step * sizeof(float)); // iou result is in vmm_temp3 - iou(step); + iou(scalar_step); - sub(reg_boxes_num, step); + sub(reg_boxes_num, scalar_step); suppressed_by_iou(true); @@ -267,7 +268,6 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator inline void soft_nms() { uni_vbroadcastss(vmm_scale, ptr[reg_scale]); - int step = vlen / sizeof(float); Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; Xbyak::Label tail_loop_label; @@ -277,17 +277,17 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator Xbyak::Label tail_loop_label_soft; L(main_loop_label); { - cmp(reg_boxes_num, step); + cmp(reg_boxes_num, vector_step); jl(main_loop_end_label, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, vector_step * sizeof(float)); + sub(reg_boxes_coord1, vector_step * sizeof(float)); + sub(reg_boxes_coord2, vector_step * sizeof(float)); + sub(reg_boxes_coord3, vector_step * sizeof(float)); // result(iou and weight) is in vmm_temp3 - iou(step); - sub(reg_boxes_num, step); + iou(vector_step); + sub(reg_boxes_num, vector_step); // soft suppressed by iou_threshold if (jcp.is_soft_suppressed_by_iou) { @@ -327,19 +327,18 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator } L(main_loop_end_label); - step = 1; L(tail_loop_label); { cmp(reg_boxes_num, 1); jl(terminate_label, T_NEAR); - sub(reg_boxes_coord0, step * sizeof(float)); - sub(reg_boxes_coord1, step * sizeof(float)); - sub(reg_boxes_coord2, step * sizeof(float)); - sub(reg_boxes_coord3, step * sizeof(float)); + sub(reg_boxes_coord0, scalar_step * sizeof(float)); + sub(reg_boxes_coord1, scalar_step * sizeof(float)); + sub(reg_boxes_coord2, scalar_step * sizeof(float)); + sub(reg_boxes_coord3, scalar_step * sizeof(float)); - iou(step); - sub(reg_boxes_num, step); + iou(scalar_step); + sub(reg_boxes_num, scalar_step); // soft suppressed by iou_threshold if (jcp.is_soft_suppressed_by_iou) { @@ -427,8 +426,11 @@ struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator inline void iou(int ele_num) { auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) { + if (ele_num != scalar_step && ele_num != vector_step) + IE_THROW() << "NMS JIT implementation supports load emitter with only element count scalar_step or vector_step! Get: " << ele_num; + + const auto& load_emitter = ele_num == 1 ? load_scalar_emitter : load_vector_emitter; load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_dst.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, ele_num), {}, {load_pool_gpr_idxs}); }; load(reg_boxes_coord0, vmm_boxes_coord0); diff --git a/src/plugins/intel_cpu/src/nodes/roi_align.cpp b/src/plugins/intel_cpu/src/nodes/roi_align.cpp index 4f2bdc6fc694a2..b10abcde8dd287 100644 --- a/src/plugins/intel_cpu/src/nodes/roi_align.cpp +++ b/src/plugins/intel_cpu/src/nodes/roi_align.cpp @@ -46,9 +46,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji }; void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); - this->preamble(); uni_vpxor(vmm_zero, vmm_zero, vmm_zero); @@ -65,8 +62,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); } private: @@ -107,10 +103,9 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji // [1] for reg_dst Xmm xmm_args_pool = Xmm(15); - std::unique_ptr load_emitter = nullptr; - std::vector load_pool_gpr_idxs; + std::unordered_map> emitters; - std::unique_ptr store_emitter = nullptr; + std::vector load_pool_gpr_idxs; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -157,6 +152,57 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji reg64_t reg_params = abi_param1; + void emit_emitters_data() { + for (const auto& emitter : emitters) { + emitter.second->emit_data(); + } + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, jcp_.data_prc, Precision::FP32, elt_num, offset); + } + + inline void load_buffer(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, Precision::FP32, Precision::FP32, elt_num, offset); + } + + inline void load_idx(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset); + } + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.data_prc, elt_num, offset); + } + + inline void store_buffer(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::FP32, elt_num, offset); + } + + inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + // for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm + std::vector local_store_pool_vec_idxs = { static_cast(vmm_dst.getIdx()) }; + local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end()); + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, + {local_store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + void roi_align_cgather() { mov(reg_src_address, ptr[reg_params + GET_OFF(src)]); mov(reg_weights, ptr[reg_params + GET_OFF(weights)]); @@ -180,23 +226,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji imul(reg_src_stride, reg_src_stride, jcp_.data_size); } - auto store = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) { - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.data_prc, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - }; - - auto load_buf = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - auto store_buf = [&](Vmm vmm_dst, Xbyak::Reg64 reg_dst, int elt_num) { - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, Precision::FP32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - }; - // out loop for samples in bin Xbyak::Label out_loop_label; Xbyak::Label out_loop_end_label; @@ -228,13 +257,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji generate_samples(v_step); // now this sample value across channel reside in vmm_sample // compute with other samples in vmm_buf - load_buf(reg_buf, vmm_buf, v_step); + load_buffer(reg_buf, vmm_buf, v_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, v_step); + store_buffer(vmm_buf, reg_buf, v_step); if ((isa == cpu::x64::sse41) && (jcp_.layout == ROIAlignLayoutType::blk)) { add(reg_src0, x_step * jcp_.data_size); @@ -244,13 +273,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji add(reg_buf, x_step * sizeof(float)); generate_samples(x_step); - load_buf(reg_buf, vmm_buf, x_step); + load_buffer(reg_buf, vmm_buf, x_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, x_step); + store_buffer(vmm_buf, reg_buf, x_step); sub(reg_src0, x_step * jcp_.data_size); sub(reg_src1, x_step * jcp_.data_size); @@ -280,13 +309,13 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji jl(in_loop_tail_end_label, T_NEAR); generate_samples(tail_step); - load_buf(reg_buf, vmm_buf, tail_step); + load_buffer(reg_buf, vmm_buf, tail_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vaddps(vmm_buf, vmm_buf, vmm_sample); } else { uni_vmaxps(vmm_buf, vmm_buf, vmm_sample); } - store_buf(vmm_buf, reg_buf, tail_step); + store_buffer(vmm_buf, reg_buf, tail_step); int tail_src_stride = tail_step * jcp_.data_size; add(reg_src0, tail_src_stride); @@ -333,7 +362,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji cmp(reg_work_amount, v_step); jl(store_loop_main_end_label, T_NEAR); - load_buf(reg_buf, vmm_buf, v_step); + load_buffer(reg_buf, vmm_buf, v_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -343,7 +372,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji add(reg_buf, x_step * sizeof(float)); add(reg_dst, x_step * jcp_.data_size); - load_buf(reg_buf, vmm_buf, x_step); + load_buffer(reg_buf, vmm_buf, x_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -369,7 +398,7 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji cmp(reg_work_amount, tail_step); jl(store_loop_tail_end_label, T_NEAR); - load_buf(reg_buf, vmm_buf, tail_step); + load_buffer(reg_buf, vmm_buf, tail_step); if (jcp_.alg == Algorithm::ROIAlignAvg) { uni_vmulps(vmm_buf, vmm_buf, vmm_scale); } @@ -402,12 +431,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji } void generate_samples(int num) { - auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_src, int elt_num) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(jcp_.data_prc, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - uni_vpxor(vmm_sample, vmm_sample, vmm_sample); load(reg_src0, vmm_src0, num); uni_vfmadd231ps(vmm_sample, vmm_src0, vmm_weights0); @@ -432,12 +455,6 @@ struct jit_uni_roi_align_kernel_f32 : public jit_uni_roi_align_kernel, public ji uni_vbroadcastss(vmm_scale, ptr[reg_tmp_64]); } - auto load_idx = [&](Xbyak::Reg64 reg_idx, Vmm vmm_idx, int elt_num) { - load_emitter->emit_code({static_cast(reg_idx.getIdx())}, {static_cast(vmm_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num), - {}, {load_pool_gpr_idxs}); - }; - Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; Xbyak::Label tail_loop_label; diff --git a/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp b/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp index a4899c55949143..a93bfc5e8cacf2 100644 --- a/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp +++ b/src/plugins/intel_cpu/src/nodes/roi_pooling.cpp @@ -48,8 +48,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi }; void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); + load_emitter.reset(new jit_load_emitter(this, isa, jpp_.src_prc, Precision::FP32, step)); + store_emitter.reset(new jit_store_emitter(this, isa, Precision::FP32, jpp_.dst_prc, step)); + store_empty_roi_emitter.reset(new jit_store_emitter(this, isa, jpp_.src_prc, jpp_.dst_prc, step)); this->preamble(); @@ -93,6 +94,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi load_emitter->emit_data(); store_emitter->emit_data(); + store_empty_roi_emitter->emit_data(); } private: @@ -114,6 +116,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi std::vector load_pool_gpr_idxs; std::unique_ptr store_emitter = nullptr; + std::unique_ptr store_empty_roi_emitter = nullptr; std::vector store_pool_gpr_idxs; std::vector store_pool_vec_idxs; @@ -147,6 +150,12 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Xbyak::Reg64 reg_load_table = r15; Xbyak::Reg64 reg_load_store_mask = abi_param1; + std::vector get_local_store_pool_vec_idxs(Vmm vmm) const { + std::vector local_store_pool_vec_idxs = { static_cast(vmm.getIdx()) }; + local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end()); + return local_store_pool_vec_idxs; + } + void roi_pool_max(int c_blocks) { Label h_loop_label; Label w_loop_label; @@ -157,8 +166,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { Vmm vmm_max = get_acc_reg(i); - load_emitter->emit_code({static_cast(reg_input.getIdx())}, {static_cast(vmm_max.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), + load_emitter->emit_code({static_cast(reg_input.getIdx()), static_cast(i * src_c_off)}, {static_cast(vmm_max.getIdx())}, {}, load_pool_gpr_idxs); } @@ -171,9 +179,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Vmm vmm_max = get_acc_reg(i); Vmm vmm_src = get_src_reg(i); - load_emitter->emit_code({static_cast(aux_reg_input1.getIdx())}, {static_cast(vmm_src.getIdx())}, - std::make_shared(jpp_.src_prc, Precision::FP32, step, i * src_c_off), - {}, load_pool_gpr_idxs); + load_emitter->emit_code({static_cast(aux_reg_input1.getIdx()), static_cast(i * src_c_off)}, + {static_cast(vmm_src.getIdx())}, {}, load_pool_gpr_idxs); if (isa == cpu::x64::sse41) { movups(vmm_mask, vmm_max); @@ -206,9 +213,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { Vmm vmm_dst = get_acc_reg(i); - store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off), - store_pool_vec_idxs, store_pool_gpr_idxs); + store_emitter->emit_code({static_cast(vmm_dst.getIdx()), static_cast(i * dst_c_off)}, {static_cast(reg_output.getIdx())}, + get_local_store_pool_vec_idxs(vmm_dst), store_pool_gpr_idxs); } } @@ -225,27 +231,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi for (int i = 0; i < c_blocks; i++) { const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_prc.size(); - const auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, src_c_off); mov(aux_reg_input, reg_input); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src00.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src00.getIdx())}, {}, load_pool_gpr_idxs); add(aux_reg_input, reg_xoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src01.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src01.getIdx())}, {}, load_pool_gpr_idxs); add(aux_reg_input, reg_yoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src11.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src11.getIdx())}, {}, load_pool_gpr_idxs); sub(aux_reg_input, reg_xoff); - load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src10.getIdx())}, - load_context, + load_emitter->emit_code({static_cast(aux_reg_input.getIdx()), static_cast(src_c_off)}, {static_cast(vmm_src10.getIdx())}, {}, load_pool_gpr_idxs); uni_vsubps(vmm_src01, vmm_src01, vmm_src00); @@ -259,9 +260,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); - store_emitter->emit_code({static_cast(vmm_src11.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(Precision::FP32, jpp_.dst_prc, step, dst_c_off), - store_pool_vec_idxs, store_pool_gpr_idxs); + store_emitter->emit_code({static_cast(vmm_src11.getIdx()), static_cast(dst_c_off)}, {static_cast(reg_output.getIdx())}, + get_local_store_pool_vec_idxs(vmm_src11), store_pool_gpr_idxs); } } @@ -270,9 +270,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_prc.size(); for (int i = 0; i < c_blocks; i++) { - store_emitter->emit_code({static_cast(vmm_zero.getIdx())}, {static_cast(reg_output.getIdx())}, - std::make_shared(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off), - store_pool_vec_idxs, store_pool_gpr_idxs); + store_empty_roi_emitter->emit_code({static_cast(vmm_zero.getIdx()), static_cast(i * dst_c_off)}, + {static_cast(reg_output.getIdx())}, store_pool_vec_idxs, store_pool_gpr_idxs); } } diff --git a/src/plugins/intel_cpu/src/nodes/topk.cpp b/src/plugins/intel_cpu/src/nodes/topk.cpp index 48394afc831940..9c7161da34a9ea 100644 --- a/src/plugins/intel_cpu/src/nodes/topk.cpp +++ b/src/plugins/intel_cpu/src/nodes/topk.cpp @@ -82,9 +82,6 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } void generate() override { - load_emitter.reset(new jit_load_emitter(this, isa)); - store_emitter.reset(new jit_store_emitter(this, isa)); - this->preamble(); mov(reg_src, ptr[reg_params + GET_OFF(src)]); @@ -123,8 +120,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato this->postamble(); - load_emitter->emit_data(); - store_emitter->emit_data(); + emit_emitters_data(); if (!shape_agnostic_alg) prepare_idx_table(); @@ -207,9 +203,8 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Vmm vmm_zero = Vmm(0); // vmm_zero represents Vmm(0) when isa is avx512_core, otherwise vmm_mask represents Vmm(0) const Xbyak::Opmask k_mask = Xbyak::Opmask(1); - const int step = vlen / sizeof(float); - const int tail = jcp_.work_amount % step; - const int topk_tail = jcp_.top_k % step; + const int vector_step = vlen / sizeof(float); + const int tail_step = jcp_.work_amount % vector_step; int blk_stride = 0; // stride of channel blocks at the same space coordinate, only used in blocked layout with topk on channel unsigned char cmp_flg; @@ -217,13 +212,67 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label l_table; - std::unique_ptr load_emitter = nullptr; - std::unique_ptr store_emitter = nullptr; + std::unordered_map> emitters; std::vector store_pool_gpr_idxs; std::vector load_pool_gpr_idxs; std::vector store_pool_vec_idxs; + void emit_emitters_data() { + for (const auto& emitter : emitters) { + emitter.second->emit_data(); + } + } + + inline void load(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, jcp_.precision, Precision::FP32, elt_num, offset); + } + + inline void load_i32_f32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, Precision::I32, Precision::FP32, elt_num, offset); + } + + inline void load_i32(Xbyak::Reg64 reg_src, Vmm vmm_src, const int elt_num, const int offset = 0) { + emit_load(reg_src, vmm_src, Precision::I32, Precision::I32, elt_num, offset); + } + + inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + emit_store(vmm_dst, reg_dst, Precision::FP32, jcp_.precision, elt_num, offset); + } + + inline void store_f32_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + emit_store(vmm_dst, reg_dst, Precision::FP32, Precision::I32, elt_num, offset); + } + + inline void store_i32(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { + emit_store(vmm_dst, reg_dst, Precision::I32, Precision::I32, elt_num, offset); + } + + inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) { + const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, + {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + } + + inline void emit_store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, Precision src_prc, Precision dst_prc, const int elt_num, const int offset = 0) { + const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num)); + } + + // for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm + std::vector local_store_pool_vec_idxs = { static_cast(vmm_dst.getIdx()) }; + local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end()); + + emitters[seed]->emit_code({static_cast(vmm_dst.getIdx()), static_cast(offset)}, + {static_cast(reg_dst.getIdx())}, + {local_store_pool_vec_idxs}, {store_pool_gpr_idxs}); + } + inline void topk_loop() { if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) { if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) { @@ -253,27 +302,27 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label topk_main_loop_end_label; L(topk_main_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(topk_main_loop_end_label, T_NEAR); - topk_bitonic(step); + topk_bitonic(vector_step); - add(reg_src, step * jcp_.data_size); - add(reg_dst, step * jcp_.data_size); - add(reg_dst_idx, step * sizeof(int)); - sub(reg_work_amount, step); + add(reg_src, vector_step * jcp_.data_size); + add(reg_dst, vector_step * jcp_.data_size); + add(reg_dst_idx, vector_step * sizeof(int)); + sub(reg_work_amount, vector_step); jmp(topk_main_loop_label, T_NEAR); } L(topk_main_loop_end_label); // tail - if (tail) { + if (tail_step) { Xbyak::Label topk_tail_loop_end_label; - cmp(reg_work_amount, tail); + cmp(reg_work_amount, tail_step); jl(topk_tail_loop_end_label, T_NEAR); - topk_bitonic(tail); + topk_bitonic(tail_step); L(topk_tail_loop_end_label); } @@ -282,19 +331,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void topk_bitonic(int elt_num) { // src => prc for (int i = 0; i < jcp_.axis_dim; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_prc.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load(reg_src, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store(vmm_tmp, reg_prc, elt_num, i * jcp_.sort_stride * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_table.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * vlen), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_prc_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load_i32(reg_table, vmm_tmp, elt_num, i * vlen); + store_i32(vmm_tmp, reg_prc_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } // sort @@ -305,19 +346,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // prc => dst for (int i = 0; i < jcp_.top_k; i++) { - load_emitter->emit_code({static_cast(reg_prc.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load(reg_prc, vmm_tmp, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store(vmm_tmp, reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_prc_idx.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {}, {load_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_dst_idx.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + load_i32(reg_prc_idx, vmm_tmp, elt_num, i * jcp_.sort_stride * sizeof(int)); + store_i32(vmm_tmp, reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } } @@ -330,46 +363,46 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label topk_main_loop_end_label; L(topk_main_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(topk_main_loop_end_label, T_NEAR); // src => prc - bitonic_BLK_on_channel_load(step); + bitonic_BLK_on_channel_load(vector_step); // sort - bitonic_sort_vector(step); + bitonic_sort_vector(vector_step); if (jcp_.sort_index) { - bitonic_sort_vector(step, false); + bitonic_sort_vector(vector_step, false); } // prc => dst - bitonic_BLK_on_channel_store(step); + bitonic_BLK_on_channel_store(vector_step); - add(reg_src, step * jcp_.blk_size * jcp_.data_size); - add(reg_dst, step * jcp_.blk_size * jcp_.data_size); - add(reg_dst_idx, step * jcp_.blk_size * sizeof(int)); - sub(reg_work_amount, step); + add(reg_src, vector_step * jcp_.blk_size * jcp_.data_size); + add(reg_dst, vector_step * jcp_.blk_size * jcp_.data_size); + add(reg_dst_idx, vector_step * jcp_.blk_size * sizeof(int)); + sub(reg_work_amount, vector_step); jmp(topk_main_loop_label, T_NEAR); } L(topk_main_loop_end_label); // tail exists because working buffer has planar layout, though source buffer has blocked layout) - if (tail) { + if (tail_step) { Xbyak::Label topk_tail_loop_end_label; - cmp(reg_work_amount, tail); + cmp(reg_work_amount, tail_step); jl(topk_tail_loop_end_label, T_NEAR); // src => prc - bitonic_BLK_on_channel_load(tail); + bitonic_BLK_on_channel_load(tail_step); - bitonic_sort_vector(tail); + bitonic_sort_vector(tail_step); if (jcp_.sort_index) { - bitonic_sort_vector(tail, false); + bitonic_sort_vector(tail_step, false); } // prc => dst - bitonic_BLK_on_channel_store(tail); + bitonic_BLK_on_channel_store(tail_step); L(topk_tail_loop_end_label); } @@ -437,40 +470,30 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void bitonic_swap_vector(int elt_num, bool cmp_val = true) { bitonic_get_addr(reg_prc, jcp_.data_size, 0); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_val_l.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_aux_idx, vmm_val_l, elt_num); + bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int)); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_aux_idx, vmm_val_r, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), 0); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_idx_l.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_aux_idx, vmm_idx_l, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int)); - load_emitter->emit_code({static_cast(reg_aux_idx.getIdx())}, {static_cast(vmm_idx_r.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_aux_idx, vmm_idx_r, elt_num); swap_vector(vmm_val_l, vmm_idx_l, vmm_val_r, vmm_idx_r, cmp_val); bitonic_get_addr(reg_prc, jcp_.data_size, 0); - store_emitter->emit_code({static_cast(vmm_val_l.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_l, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc, jcp_.data_size, sizeof(int)); - store_emitter->emit_code({static_cast(vmm_val_r.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_r, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), 0); - store_emitter->emit_code({static_cast(vmm_idx_l.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_l, reg_aux_idx, elt_num); + bitonic_get_addr(reg_prc_idx, sizeof(int), sizeof(int)); - store_emitter->emit_code({static_cast(vmm_idx_r.getIdx())}, {static_cast(reg_aux_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_r, reg_aux_idx, elt_num); } inline void topk_heap_sorting() { @@ -480,9 +503,9 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // init dst mov(reg_i, 0); - sub(reg_heap_top_k, step); - topk_heap_load(reg_heap_k_sub_step, step); - add(reg_heap_top_k, step); + sub(reg_heap_top_k, vector_step); + topk_heap_load(reg_heap_k_sub_step, vector_step); + add(reg_heap_top_k, vector_step); topk_heap_load(reg_heap_top_k, 1); mov(reg_zero, 0); @@ -569,7 +592,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label topk_init_loop_end_label; L(topk_init_loop_label); { - if (s == step) { + if (s == vector_step) { cmp(reg_i, reg_end); jg(topk_init_loop_end_label, T_NEAR); } else { @@ -578,25 +601,18 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } get_addr_by_reg_idx(reg_heap_outer_aux, reg_src, reg_i, jcp_.data_size); - load_emitter->emit_code({static_cast(reg_heap_outer_aux.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, s), - {}, {load_pool_gpr_idxs}); + load(reg_heap_outer_aux, vmm_tmp, s); + get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst, reg_i, jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_heap_outer_aux.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, s), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - if (s == step) { + store(vmm_tmp, reg_heap_outer_aux, s); + if (s == vector_step) { table_to_vmm(vmm_tmp, reg_heap_seq_idx, reg_i, 0, sizeof(int)); } else { get_addr_by_reg_idx(reg_heap_outer_aux, reg_heap_seq_idx, reg_i, sizeof(int)); - load_emitter->emit_code({static_cast(reg_heap_outer_aux.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, 1), - {}, {load_pool_gpr_idxs}); + load_i32(reg_heap_outer_aux, vmm_tmp, 1); } get_addr_by_reg_idx(reg_heap_outer_aux, reg_dst_idx, reg_i, sizeof(int)); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_heap_outer_aux.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, s), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_i32(vmm_tmp, reg_heap_outer_aux, s); add(reg_i, s); jmp(topk_init_loop_label, T_NEAR); @@ -812,19 +828,19 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato Xbyak::Label topk_main_loop_end_label; L(topk_main_loop_label); { - cmp(reg_work_amount, step); + cmp(reg_work_amount, vector_step); jl(topk_main_loop_end_label, T_NEAR); if (jcp_.bubble_inplace) { - topk_bubble_inplace(step); + topk_bubble_inplace(vector_step); } else { - topk_bubble(step); + topk_bubble(vector_step); } - add(reg_src, step * jcp_.data_size); - add(reg_dst, step * jcp_.data_size); - add(reg_dst_idx, step * sizeof(int)); - sub(reg_work_amount, step); + add(reg_src, vector_step * jcp_.data_size); + add(reg_dst, vector_step * jcp_.data_size); + add(reg_dst_idx, vector_step * sizeof(int)); + sub(reg_work_amount, vector_step); jmp(topk_main_loop_label, T_NEAR); } @@ -832,12 +848,12 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // tail if (jcp_.bubble_inplace) { - if (tail) { + if (tail_step) { Xbyak::Label topk_tail_loop_end_label; - cmp(reg_work_amount, tail); + cmp(reg_work_amount, tail_step); jl(topk_tail_loop_end_label, T_NEAR); - topk_bubble_inplace(tail); + topk_bubble_inplace(tail_step); L(topk_tail_loop_end_label); } @@ -1015,19 +1031,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato je(topk_init_loop_end_label, T_NEAR); get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_tmp.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_tmp, elt_num); get_addr_by_reg_idx(reg_tmp, reg_dst, reg_block_sort_stride_byte, reg_i); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_tmp, reg_tmp, elt_num); table_to_vmm(vmm_tmp, reg_bubble_block_idx, reg_i, 0, vlen); get_addr_by_reg_idx(reg_tmp, reg_dst_idx, reg_block_sort_stride_byte, sizeof(int) / jcp_.data_size, reg_i); - store_emitter->emit_code({static_cast(vmm_tmp.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::I32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_i32(vmm_tmp, reg_tmp, elt_num); add(reg_i, 1); jmp(topk_init_loop_label, T_NEAR); @@ -1047,9 +1057,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato je(topk_update_loop_end_label, T_NEAR); get_addr_by_reg_idx(reg_tmp, reg_src, reg_block_sort_stride_byte, reg_i); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_r, elt_num); table_to_vmm(vmm_idx_r, reg_bubble_block_idx, reg_i, 0, vlen); uni_vcvtdq2ps(vmm_idx_r, vmm_idx_r); @@ -1142,9 +1150,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato inline void topk_bubble_inplace(int elt_num) { // load for (int i = 0; i < jcp_.top_k; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(i).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(i), elt_num, i * jcp_.sort_stride * jcp_.data_size); uni_vmovdqu(vmm_idx(i), table_val(i)); uni_vcvtdq2ps(vmm_idx(i), vmm_idx(i)); } @@ -1155,9 +1161,7 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } } for (int i = jcp_.top_k; i < jcp_.axis_dim; i++) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(jcp_.top_k).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(jcp_.top_k), elt_num, i * jcp_.sort_stride * jcp_.data_size); uni_vmovdqu(vmm_idx(jcp_.top_k), table_val(i)); uni_vcvtdq2ps(vmm_idx(jcp_.top_k), vmm_idx(jcp_.top_k)); for (int j = jcp_.top_k; j > 0; j--) { @@ -1173,12 +1177,8 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato } // store for (int i = 0; i < jcp_.top_k; i++) { - store_emitter->emit_code({static_cast(vmm_val(i).getIdx())}, {static_cast(reg_dst.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num, i * jcp_.sort_stride * jcp_.data_size), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); - store_emitter->emit_code({static_cast(vmm_idx(i).getIdx())}, {static_cast(reg_dst_idx.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num, i * jcp_.sort_stride * sizeof(int)), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val(i), reg_dst, elt_num, i * jcp_.sort_stride * jcp_.data_size); + store_f32_i32(vmm_idx(i), reg_dst_idx, elt_num, i * jcp_.sort_stride * sizeof(int)); } } @@ -1201,15 +1201,11 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato L(topk_load_sort_label); { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(0).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step, 0), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(0), vector_step, 0); uni_vmovdqu(vmm_idx(0), table_bubble_seq_idx(0)); uni_vcvtdq2ps(vmm_idx(0), vmm_idx(0)); if (isa == cpu::x64::sse41) { - load_emitter->emit_code({static_cast(reg_src.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step, 4 * jcp_.data_size), - {}, {load_pool_gpr_idxs}); + load(reg_src, vmm_val(1), vector_step, 4 * jcp_.data_size); uni_vmovdqu(vmm_idx(1), table_bubble_seq_idx(4)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); @@ -1225,17 +1221,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato jg(topk_iter_end_label, T_NEAR); get_addr_by_reg_idx(reg_aux, reg_src, reg_i, jcp_.data_size, reg_seq_sort_stride); - load_emitter->emit_code({static_cast(reg_aux.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step), - {}, {load_pool_gpr_idxs}); + load(reg_aux, vmm_val(1), vector_step); table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 0, sizeof(int)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); if (isa == cpu::x64::sse41) { add(reg_aux, 4 * jcp_.data_size); - load_emitter->emit_code({static_cast(reg_aux.getIdx())}, {static_cast(vmm_val(1).getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, step), - {}, {load_pool_gpr_idxs}); + load(reg_aux, vmm_val(1), vector_step); table_to_vmm(vmm_idx(1), reg_bubble_seq_idx, reg_i, 4, sizeof(int)); uni_vcvtdq2ps(vmm_idx(1), vmm_idx(1)); swap_vector(vmm_val(0), vmm_idx(0), vmm_val(1), vmm_idx(1)); @@ -1528,16 +1520,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // load l mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_l.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_l, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_idx_l.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_tmp, vmm_idx_l, elt_num); // load r Xbyak::Label topk_load_jmp_label; @@ -1547,16 +1536,14 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato add(reg_tmp_64, reg_block_sort_stride_byte); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_val_r.getIdx())}, - std::make_shared(jcp_.precision, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load(reg_tmp, vmm_val_r, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - load_emitter->emit_code({static_cast(reg_tmp.getIdx())}, {static_cast(vmm_idx_r.getIdx())}, - std::make_shared(Precision::I32, Precision::FP32, elt_num), - {}, {load_pool_gpr_idxs}); + load_i32_f32(reg_tmp, vmm_idx_r, elt_num); + sub(reg_tmp_64, reg_block_sort_stride_byte); } L(topk_load_jmp_label); @@ -1566,16 +1553,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato // store l mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - store_emitter->emit_code({static_cast(vmm_val_l.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_l, reg_tmp, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_idx_l.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_l, reg_tmp, elt_num); // store r Xbyak::Label topk_store_jmp_label; @@ -1585,16 +1569,13 @@ struct jit_uni_topk_kernel_f32 : public jit_uni_topk_kernel, public jit_generato add(reg_tmp_64, reg_block_sort_stride_byte); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst); - store_emitter->emit_code({static_cast(vmm_val_r.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, jcp_.precision, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store(vmm_val_r, reg_tmp, elt_num); + reg_shl(reg_tmp_64, sizeof(int) / jcp_.data_size); mov(reg_tmp, reg_tmp_64); add(reg_tmp, reg_dst_idx); reg_shr(reg_tmp_64, sizeof(int) / jcp_.data_size); - store_emitter->emit_code({static_cast(vmm_idx_r.getIdx())}, {static_cast(reg_tmp.getIdx())}, - std::make_shared(Precision::FP32, Precision::I32, elt_num), - {store_pool_vec_idxs}, {store_pool_gpr_idxs}); + store_f32_i32(vmm_idx_r, reg_tmp, elt_num); } L(topk_store_jmp_label); } diff --git a/src/plugins/intel_cpu/src/utils/jit_kernel.cpp b/src/plugins/intel_cpu/src/utils/jit_kernel.cpp index f1c99751c1a7d5..0fedd26ea8decb 100644 --- a/src/plugins/intel_cpu/src/utils/jit_kernel.cpp +++ b/src/plugins/intel_cpu/src/utils/jit_kernel.cpp @@ -133,6 +133,11 @@ InferenceEngine::Precision type2precision() { return InferenceEngine::Precision::U8; } +template<> +InferenceEngine::Precision type2precision() { + return InferenceEngine::Precision::I8; +} + cpu_isa_t get_current_isa() { if (mayiuse(cpu_isa_t::avx512_core)) return cpu_isa_t::avx512_core; @@ -212,9 +217,7 @@ const void * consts_table::store(const void *data, size_t size) { } // namespace internal jit_kernel::jit_kernel() - : jit_generator() - , _load_emitter(this, internal::get_current_isa()) - , _store_emitter(this, internal::get_current_isa()) { + : jit_generator() { _free_rmmregs.reserve(16); _free_rmmregs.reserve(16); @@ -297,10 +300,10 @@ void jit_kernel::free(const Zmm & reg) { void jit_kernel::postamble() { jit_generator::postamble(); - if (_is_load_emitter_used) - _load_emitter.emit_data(); - if (_is_store_emitter_used) - _store_emitter.emit_data(); + for (const auto& emitter : _emitters) { + if (emitter.second) + emitter.second->emit_data(); + } } const AddressFrame & jit_kernel::address_frame(size_t size) const { diff --git a/src/plugins/intel_cpu/src/utils/jit_kernel.hpp b/src/plugins/intel_cpu/src/utils/jit_kernel.hpp index ce531d7806c78a..ce86feb427d961 100644 --- a/src/plugins/intel_cpu/src/utils/jit_kernel.hpp +++ b/src/plugins/intel_cpu/src/utils/jit_kernel.hpp @@ -697,11 +697,8 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { private: reg_indices _free_x64regs; reg_indices _free_rmmregs; - bool _is_load_emitter_used = false; - bool _is_store_emitter_used = false; - jit_load_emitter _load_emitter; - jit_store_emitter _store_emitter; internal::consts_table _consts; + std::unordered_map> _emitters; }; template @@ -746,17 +743,18 @@ void jit_kernel::load(const variable & dst, const variable & src, const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); - _load_emitter.emit_code( + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + + const auto key = load_emitter_params(src_prc, dst_prc, length).hash(); + if (!_emitters[key]) { + _emitters[key].reset(new jit_load_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length)); + } + _emitters[key]->emit_code( { static_cast(static_cast(src).getIdx()) }, { static_cast(static_cast(dst).getIdx()) }, - std::make_shared( - internal::type2precision(), - internal::type2precision(), - static_cast(length)), pool_vec_idxs, pool_gpr_idxs); - - _is_load_emitter_used = true; } template @@ -788,17 +786,18 @@ void jit_kernel::store(const variable & dst, const variable & src const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); - _store_emitter.emit_code( + const auto src_prc = internal::type2precision(); + const auto dst_prc = internal::type2precision(); + + const auto key = store_emitter_params(src_prc, dst_prc, length).hash(); + if (!_emitters[key]) { + _emitters[key].reset(new jit_store_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length)); + } + _emitters[key]->emit_code( { static_cast(static_cast(src).getIdx()) }, { static_cast(static_cast(dst).getIdx()) }, - std::make_shared( - internal::type2precision(), - internal::type2precision(), - static_cast(length)), pool_vec_idxs, pool_gpr_idxs); - - _is_store_emitter_used = true; } template diff --git a/src/tests/unit/cpu/jit_kernel_test.cpp b/src/tests/unit/cpu/jit_kernel_test.cpp index 73ae2c4e5cc5f6..54fe0f6a4f48d3 100644 --- a/src/tests/unit/cpu/jit_kernel_test.cpp +++ b/src/tests/unit/cpu/jit_kernel_test.cpp @@ -318,15 +318,30 @@ struct jit_variable_load_store_test_kernel { }; TEST(JitKernel, variable_load_and_store) { - jit_variable_load_store_test_kernel kernel; - if (mayiuse(cpu_isa_t::avx512_core)) { - kernel.test<16>(); - } - if (mayiuse(cpu_isa_t::avx2)) { - kernel.test<8>(); + { + jit_variable_load_store_test_kernel kernel; + if (mayiuse(cpu_isa_t::avx512_core)) { + kernel.test<16>(); + } + if (mayiuse(cpu_isa_t::avx2)) { + kernel.test<8>(); + } + if (mayiuse(cpu_isa_t::sse41)) { + kernel.test<4>(); + } } - if (mayiuse(cpu_isa_t::sse41)) { - kernel.test<4>(); + + { + jit_variable_load_store_test_kernel kernel; + if (mayiuse(cpu_isa_t::avx512_core)) { + kernel.test<16>(); + } + if (mayiuse(cpu_isa_t::avx2)) { + kernel.test<8>(); + } + if (mayiuse(cpu_isa_t::sse41)) { + kernel.test<4>(); + } } }