From f50803d57442b4e9eda1acc12dbb6cc3c39f653f Mon Sep 17 00:00:00 2001 From: Nikolai Shchegolev Date: Thu, 18 Jul 2024 15:14:23 +0400 Subject: [PATCH] [CPU][CACHE_DIR] hash optimization --- .../openvino/reference/utils/combine_hash.hpp | 15 + .../reference/utils/jit_generator.hpp | 98 +++ .../reference/utils/registers_pool.hpp | 312 ++++++++ src/core/reference/src/op/convert.cpp | 6 +- .../reference/src/op/utils/combine_hash.cpp | 666 ++++++++++++++++++ .../src/op/{ => utils}/jit_generator.cpp | 6 +- src/core/src/pass/serialize.cpp | 20 +- 7 files changed, 1101 insertions(+), 22 deletions(-) create mode 100644 src/core/reference/include/openvino/reference/utils/combine_hash.hpp create mode 100644 src/core/reference/include/openvino/reference/utils/jit_generator.hpp create mode 100644 src/core/reference/include/openvino/reference/utils/registers_pool.hpp create mode 100644 src/core/reference/src/op/utils/combine_hash.cpp rename src/core/reference/src/op/{ => utils}/jit_generator.cpp (96%) diff --git a/src/core/reference/include/openvino/reference/utils/combine_hash.hpp b/src/core/reference/include/openvino/reference/utils/combine_hash.hpp new file mode 100644 index 00000000000000..9f1cfdea812494 --- /dev/null +++ b/src/core/reference/include/openvino/reference/utils/combine_hash.hpp @@ -0,0 +1,15 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace runtime { + +size_t combine_hash(const void* src, size_t size); + +} // namespace runtime +} // namespace ov diff --git a/src/core/reference/include/openvino/reference/utils/jit_generator.hpp b/src/core/reference/include/openvino/reference/utils/jit_generator.hpp new file mode 100644 index 00000000000000..49c5cb6e0e959e --- /dev/null +++ b/src/core/reference/include/openvino/reference/utils/jit_generator.hpp @@ -0,0 +1,98 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#if defined _WIN32 && !defined NOMINMAX +#define NOMINMAX +#endif + +#include +#include + +namespace ov { +namespace reference { +namespace jit { +#ifdef XBYAK64 + static const Xbyak::Operand::Code abi_save_gpr_regs[] = { + Xbyak::Operand::RBX, + Xbyak::Operand::RBP, + Xbyak::Operand::R12, + Xbyak::Operand::R13, + Xbyak::Operand::R14, + Xbyak::Operand::R15, +#ifdef _WIN32 + Xbyak::Operand::RDI, + Xbyak::Operand::RSI, +#endif + }; + +#ifdef _WIN32 +#define abi_param1 Xbyak::Reg64(Xbyak::Operand::RCX) // RCX +#else +#define abi_param1 Xbyak::Reg64(Xbyak::Operand::RDI) // RDI +#endif +#endif // XBYAK64 + + typedef enum { + isa_any, + sse42, + avx, + avx2, + avx512_common, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, + avx512_core_bf16, + avx512_vpopcnt, + fp16, + pclmulqdq, + vpclmulqdq + } cpu_isa_t; + + class Generator : public Xbyak::CodeGenerator + { +#ifdef _WIN32 + static constexpr size_t xmm_to_preserve_start = 6; + static constexpr size_t xmm_to_preserve = 10; +#else + static constexpr size_t xmm_to_preserve_start = 0; + static constexpr size_t xmm_to_preserve = 0; +#endif + + static const size_t num_abi_save_gpr_regs = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); + const size_t size_of_abi_save_regs; + + const Xbyak::Reg64 reg_EVEX_max_8b_offt; + static constexpr int EVEX_max_8b_offt = 0x200; + + public: + static constexpr size_t xmm_len = 16; + static constexpr size_t ymm_len = 32; + static constexpr size_t zmm_len = 64; + + const Xbyak::Reg64 param = abi_param1; + + static bool mayiuse(const cpu_isa_t cpu_isa); + static bool is_x64(); + + Generator(void* code_ptr = nullptr, size_t code_size = 16 * 1024); + void preamble(); + void postamble(); + + void foreach (const Xbyak::Reg64& idx, + size_t step, + const Xbyak::Reg64& end, + std::function && fn); + + template + void copy(const Xbyak::Reg64& dst, + const Xbyak::Reg64& src, + const Xbyak::Reg64& size); + }; + +} // namespace jit +} // namespace reference +} // namespace ov diff --git a/src/core/reference/include/openvino/reference/utils/registers_pool.hpp b/src/core/reference/include/openvino/reference/utils/registers_pool.hpp new file mode 100644 index 00000000000000..59ddd11596b980 --- /dev/null +++ b/src/core/reference/include/openvino/reference/utils/registers_pool.hpp @@ -0,0 +1,312 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "jit_generator.hpp" +#include "openvino/core/except.hpp" + +#include +#include +#include + +namespace ov { +namespace runtime { +namespace jit { + +class RegistersPool { +public: + using Ptr = std::shared_ptr; + using WeakPtr = std::weak_ptr; + static constexpr int anyIdx = -1; + + template + class Reg { + friend class RegistersPool; + public: + Reg() {} + Reg(const RegistersPool::Ptr& regPool) { initialize(regPool); } + Reg(const RegistersPool::Ptr& regPool, int requestedIdx) { initialize(regPool, requestedIdx); } + ~Reg() { release(); } + Reg& operator=(Reg&& other) noexcept { + release(); + reg = other.reg; + regPool = std::move(other.regPool); + return *this; + } + Reg(Reg&& other) noexcept : reg(other.reg), regPool(std::move(other.regPool)) {} + operator TReg&() { ensureValid(); return reg; } + operator const TReg&() const { ensureValid(); return reg; } + operator Xbyak::RegExp() const { ensureValid(); return reg; } + int getIdx() const { ensureValid(); return reg.getIdx(); } + friend Xbyak::RegExp operator+(const Reg& lhs, const Xbyak::RegExp& rhs) { + lhs.ensureValid(); + return lhs.operator Xbyak::RegExp() + rhs; + } + void release() { + if (auto pool = regPool.lock()) { + pool->returnToPool(reg); + regPool.reset(); + } + } + bool isInitialized() const { return !regPool.expired(); } + + private: + void ensureValid() const { + if (!isInitialized()) { + OPENVINO_THROW("RegistersPool::Reg is either not initialized or released"); + } + } + + void initialize(const RegistersPool::Ptr& pool, int requestedIdx = anyIdx) { + release(); + reg = TReg(pool->template getFree(requestedIdx)); + regPool = pool; + } + + private: + TReg reg; + RegistersPool::WeakPtr regPool; + }; + + virtual ~RegistersPool() { + checkUniqueAndUpdate(false); + } + + template + static Ptr create(std::initializer_list regsToExclude); + + static Ptr create(cpu_isa_t isa, std::initializer_list regsToExclude); + + template + size_t countFree() const { + if (std::is_base_of::value) { + return simdSet.countUnused(); + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + return generalSet.countUnused(); + } else if (std::is_same::value) { + return countUnusedOpmask(); + } + } + +protected: + class PhysicalSet { + public: + PhysicalSet(int size) : isFreeIndexVector(size, true) {} + + void setAsUsed(size_t regIdx) { + if (regIdx >= isFreeIndexVector.size()) { + OPENVINO_THROW("regIdx is out of bounds in RegistersPool::PhysicalSet::setAsUsed()"); + } + if (!isFreeIndexVector[regIdx]) { + OPENVINO_THROW("Inconsistency in RegistersPool::PhysicalSet::setAsUsed()"); + } + isFreeIndexVector[regIdx] = false; + } + + void setAsUnused(size_t regIdx) { + if (regIdx >= isFreeIndexVector.size()) { + OPENVINO_THROW("regIdx is out of bounds in RegistersPool::PhysicalSet::setAsUsed()"); + } + if (isFreeIndexVector[regIdx]) { + OPENVINO_THROW("Inconsistency in RegistersPool::PhysicalSet::setAsUnused()"); + } + isFreeIndexVector[regIdx] = true; + } + + size_t getUnused(size_t requestedIdx) { + if (requestedIdx == static_cast(anyIdx)) { + return getFirstFreeIndex(); + } else { + if (requestedIdx >= isFreeIndexVector.size()) { + OPENVINO_THROW("requestedIdx is out of bounds in RegistersPool::PhysicalSet::getUnused()"); + } + if (!isFreeIndexVector[requestedIdx]) { + OPENVINO_THROW("The register with index #", requestedIdx, " already used in the RegistersPool"); + } + return requestedIdx; + } + } + + void exclude(Xbyak::Reg reg) { + isFreeIndexVector.at(reg.getIdx()) = false; + } + + size_t countUnused() const { + size_t count = 0; + for (const auto& isFree : isFreeIndexVector) { + if (isFree) { + ++count; + } + } + return count; + } + + private: + size_t getFirstFreeIndex() { + for (size_t c = 0; c < isFreeIndexVector.size(); ++c) { + if (isFreeIndexVector[c]) { + return c; + } + } + OPENVINO_THROW("Not enough registers in the RegistersPool"); + } + + private: + std::vector isFreeIndexVector; + }; + + virtual int getFreeOpmask(int requestedIdx) { OPENVINO_THROW("getFreeOpmask: The Opmask is not supported in current instruction set"); } + virtual void returnOpmaskToPool(int idx) { OPENVINO_THROW("returnOpmaskToPool: The Opmask is not supported in current instruction set"); } + virtual size_t countUnusedOpmask() const { OPENVINO_THROW("countUnusedOpmask: The Opmask is not supported in current instruction set"); } + + RegistersPool(int simdRegistersNumber) + : simdSet(simdRegistersNumber) { + checkUniqueAndUpdate(); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RSP)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RAX)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RCX)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RDI)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RBP)); + } + + RegistersPool(std::initializer_list regsToExclude, int simdRegistersNumber) + : simdSet(simdRegistersNumber) { + checkUniqueAndUpdate(); + for (auto& reg : regsToExclude) { + if (reg.isXMM() || reg.isYMM() || reg.isZMM()) { + simdSet.exclude(reg); + } else if (reg.isREG()) { + generalSet.exclude(reg); + } + } + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RSP)); + } + +private: + template + int getFree(int requestedIdx) { + if (std::is_base_of::value) { + auto idx = simdSet.getUnused(requestedIdx); + simdSet.setAsUsed(idx); + return idx; + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + auto idx = generalSet.getUnused(requestedIdx); + generalSet.setAsUsed(idx); + return idx; + } else if (std::is_same::value) { + return getFreeOpmask(requestedIdx); + } + } + + template + void returnToPool(const TReg& reg) { + if (std::is_base_of::value) { + simdSet.setAsUnused(reg.getIdx()); + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + generalSet.setAsUnused(reg.getIdx()); + } else if (std::is_same::value) { + returnOpmaskToPool(reg.getIdx()); + } + } + + void checkUniqueAndUpdate(bool isCtor = true) { + static thread_local bool isCreated = false; + if (isCtor) { + if (isCreated) { + OPENVINO_THROW("There should be only one instance of RegistersPool per thread"); + } + isCreated = true; + } else { + isCreated = false; + } + } + + PhysicalSet generalSet {16}; + PhysicalSet simdSet; +}; + +template +class IsaRegistersPool : public RegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : RegistersPool(regsToExclude, 32) {} +}; + +template <> +class IsaRegistersPool : public RegistersPool { +public: + IsaRegistersPool() : RegistersPool(32) { + opmaskSet.exclude(Xbyak::Opmask(0)); // the Opmask(0) has special meaning for some instructions, like gather instruction + } + + IsaRegistersPool(std::initializer_list regsToExclude) + : RegistersPool(regsToExclude, 32) { + for (auto& reg : regsToExclude) { + if (reg.isOPMASK()) { + opmaskSet.exclude(reg); + } + } + } + + int getFreeOpmask(int requestedIdx) override { + auto idx = opmaskSet.getUnused(requestedIdx); + opmaskSet.setAsUsed(idx); + return idx; + } + + void returnOpmaskToPool(int idx) override { + opmaskSet.setAsUnused(idx); + } + + size_t countUnusedOpmask() const override { + return opmaskSet.countUnused(); + } + +protected: + PhysicalSet opmaskSet {8}; +}; + +template <> +class IsaRegistersPool : public IsaRegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : IsaRegistersPool(regsToExclude) {} + IsaRegistersPool() : IsaRegistersPool() {} +}; + +template <> +class IsaRegistersPool : public IsaRegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : IsaRegistersPool(regsToExclude) {} + IsaRegistersPool() : IsaRegistersPool() {} +}; + +template +RegistersPool::Ptr RegistersPool::create(std::initializer_list regsToExclude) { + return std::make_shared>(regsToExclude); +} + +inline +RegistersPool::Ptr RegistersPool::create(cpu_isa_t isa, std::initializer_list regsToExclude) { +#define ISA_SWITCH_CASE(isa) case isa: return std::make_shared>(regsToExclude); + switch (isa) { + ISA_SWITCH_CASE(sse42) + ISA_SWITCH_CASE(avx) + ISA_SWITCH_CASE(avx2) + ISA_SWITCH_CASE(avx512_core) + ISA_SWITCH_CASE(avx512_core_vnni) + ISA_SWITCH_CASE(avx512_core_bf16) + case avx512_vpopcnt: return std::make_shared>(regsToExclude); + default: + OPENVINO_THROW("Invalid isa argument in RegistersPool::create(): ", isa); + } + OPENVINO_THROW("Invalid isa argument in RegistersPool::create()"); +#undef ISA_SWITCH_CASE +} + +} // namespace jit +} // namespace runtime +} // namespace ov diff --git a/src/core/reference/src/op/convert.cpp b/src/core/reference/src/op/convert.cpp index 5054121b5615c0..034734afd8fd2a 100644 --- a/src/core/reference/src/op/convert.cpp +++ b/src/core/reference/src/op/convert.cpp @@ -7,7 +7,7 @@ #include "openvino/reference/utils/convert_util.hpp" #ifdef OV_CORE_USE_XBYAK_JIT -# include "jit_generator.hpp" +# include "openvino/reference/utils/jit_generator.hpp" #endif #ifdef OV_CORE_USE_INTRINSICS @@ -256,7 +256,7 @@ class jit_convert_array : public jit::Generator { template static fn_t get() { - if (is_x64() && mayiuse(avx) && mayiuse(avx2) && mayiuse(fp16)) { + if (is_x64() && mayiuse(jit::avx) && mayiuse(jit::avx2) && mayiuse(jit::fp16)) { static const jit_convert_array::context_t context{{sizeof(src_t), &jit::Generator::copy}, {sizeof(dst_t), &jit::Generator::copy}, jit_convert_vec, @@ -460,7 +460,7 @@ class jit_count_out_of_range : public jit::Generator { template static fn_t get() { - if (is_x64() && mayiuse(avx2)) { + if (is_x64() && mayiuse(jit::avx2)) { static const jit_count_out_of_range::context_t context{ {sizeof(data_t), &jit::Generator::copy}, jit_count_out_of_range_vec_prepare, diff --git a/src/core/reference/src/op/utils/combine_hash.cpp b/src/core/reference/src/op/utils/combine_hash.cpp new file mode 100644 index 00000000000000..1835155becf711 --- /dev/null +++ b/src/core/reference/src/op/utils/combine_hash.cpp @@ -0,0 +1,666 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +// The CRC computation is used for x86. +// The calculations were taken from the article +// "Fast CRC Computation for Generic Polynomials Using PCLMULQDQ Instruction - Intel (December, 2009)". + +#include "openvino/core/visibility.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/reference/utils/combine_hash.hpp" + +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) +# include "openvino/reference/utils/registers_pool.hpp" +#endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 + +#include + +namespace ov { +namespace runtime { + +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) +namespace jit { + +#define GET_OFF(field) offsetof(CombineHashCallArgs, field) +#define getReg64() RegistersPool::Reg(registersPool) +#define getVmm() RegistersPool::Reg(registersPool) +#define getXmm() RegistersPool::Reg(registersPool) + +struct CombineHashCompileParams { +}; + +struct CombineHashCallArgs { + const void* src_ptr; + void* dst_ptr; + uint64_t work_amount = 0lu; + uint64_t make_64_fold = 0lu; +}; + +typedef void (*fn_t)(const CombineHashCallArgs*); + +template +class CombineHash : public Generator { +public: + explicit CombineHash(const CombineHashCompileParams& jcp) : + m_jcp(jcp) { + if (isa == avx512_core) { + vlen = zmm_len; + } else if (isa == avx2) { + vlen = ymm_len; + } else { + OPENVINO_THROW("Unsupported isa: ", isa); + } + if (!mayiuse(cpu_isa_t::pclmulqdq)) { + OPENVINO_THROW("The current CPU does not support pclmulqdq instruction, which is required for the CRC algorithm."); + } + if (mayiuse(cpu_isa_t::vpclmulqdq)) { + is_vpclmulqdq = true; + } + + generate(); + } + + void generate() { + this->preamble(); + registersPool = RegistersPool::create(isa, {rax, rcx, rsp, rdi, k0}); + + r64_src = getReg64(); + r64_dst = getReg64(); + r64_work_amount = getReg64(); + r64_make_64_fold = getReg64(); + + mov(r64_src, ptr[r64_params + GET_OFF(src_ptr)]); + mov(r64_dst, ptr[r64_params + GET_OFF(dst_ptr)]); + mov(r64_work_amount, ptr[r64_params + GET_OFF(work_amount)]); + mov(r64_make_64_fold, ptr[r64_params + GET_OFF(make_64_fold)]); + + initVectors(); + bulkFold(v_dst); + restFold(v_dst); + tailFold(v_dst); + + registersPool.reset(); + this->postamble(); + } + + static fn_t get() { + static const CombineHashCompileParams params; + static CombineHash kernel(params); + + return (fn_t)kernel.getCode(); + } + + void fillRestWorkMask(const Xbyak::Opmask& k_dst_mask, + const Xbyak::Reg64& r64_work_rest) { + Xbyak::Label l_mv_mask; + auto rOnes = getReg64(); + + mov(rOnes, 0xFFFFFFFFFFFFFFFF); + cmp(r64_work_rest, 0x3f); + jg(l_mv_mask); + + shlx(rOnes, rOnes, r64_work_rest); + not_(rOnes); + + L(l_mv_mask); + kmovq(k_dst_mask, rOnes); + } + + void partialLoad(const Xbyak::Xmm& xmm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + Xbyak::Label l_partial, l_end; + + cmp(r64_load_num, xmm_len); + jl(l_partial, T_NEAR); + vmovdqu(xmm_dst, ptr[src_addr.getRegExp()]); + jmp(l_end, T_NEAR); + + L(l_partial); { + size_t offset = xmm_len; + + for (size_t j = 0lu; j < xmm_len - 1; j++) { + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + offset], j); + cmp(r64_load_num, ++offset); + jle(l_end, T_NEAR); + } + } + + L(l_end); + } + + void partialLoad(const Xbyak::Ymm& ymm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + Xbyak::Label l_xmm, l_partial, l_end; + auto xmm_dst = Xbyak::Xmm(ymm_dst.getIdx()); + + cmp(r64_load_num, ymm_len); + jl(l_xmm, T_NEAR); + vmovdqu(ymm_dst, ptr[src_addr.getRegExp()]); + jmp(l_end, T_NEAR); + + L(l_xmm); + vpxorq(ymm_dst, ymm_dst, ymm_dst); + cmp(r64_load_num, xmm_len); + jl(l_partial, T_NEAR); + vmovdqu(xmm_dst, ptr[src_addr.getRegExp()]); + je(l_end, T_NEAR); + + { + Xbyak::Label l_rest_loop, l_perm; + size_t offset = xmm_len; + + vperm2f128(ymm_dst, ymm_dst, ymm_dst, 0x1); + for (size_t j = 0lu; j < xmm_len - 1; j++) { + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + offset], j); + cmp(r64_load_num, ++offset); + jle(l_perm, T_NEAR); + } + L(l_perm); + vperm2f128(ymm_dst, ymm_dst, ymm_dst, 0x1); + } + jmp(l_end, T_NEAR); + + L(l_partial); { + size_t offset = xmm_len; + + for (size_t j = 0lu; j < xmm_len - 1; j++) { + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + offset], j); + cmp(r64_load_num, ++offset); + jle(l_end, T_NEAR); + } + } + + L(l_end); + } + +private: + static constexpr uint64_t CHUNK_SIZE = 32; + static const uint64_t CRC_VAL; + static const uint64_t CONST_K[12]; + static const uint8_t SHUF_MASK[16]; + + using Vmm = typename std::conditional::type; + size_t vlen = xmm_len; + bool is_vpclmulqdq = false; + + CombineHashCompileParams m_jcp; + RegistersPool::Ptr registersPool; + + RegistersPool::Reg r64_src; + RegistersPool::Reg r64_dst; + RegistersPool::Reg r64_work_amount; + RegistersPool::Reg r64_make_64_fold; + + const Xbyak::Reg64 r64_params = abi_param1; + + // Vector registers + RegistersPool::Reg v_dst; + RegistersPool::Reg v_k_1_2; + RegistersPool::Reg v_k_4_5; + RegistersPool::Reg v_k_8_9; + RegistersPool::Reg v_k_16_17; + RegistersPool::Reg v_shuf_mask; + + size_t getVlen() { + return vlen; + } + + void initVectors(); + + void bulkFold(const Vmm& v_dst); + + void restFold(const Vmm& v_dst) { + Xbyak::Label l_fold_loop, l_end; + cmp(r64_work_amount, xmm_len); + jl(l_end, T_NEAR); + + auto xmm_shuf_mask = Xbyak::Xmm(v_shuf_mask.getIdx()); + auto xmm_k_1_2 = Xbyak::Xmm(v_k_1_2.getIdx()); + auto xmm_src = getXmm(); + auto xmm_dst = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_aux = getXmm(); + + L(l_fold_loop); { + vmovdqu64(xmm_src, ptr[r64_src]); + vpshufb(xmm_src, xmm_src, xmm_shuf_mask); + + vpclmulqdq(xmm_aux, xmm_dst, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst, xmm_dst, xmm_k_1_2, 0b00010001); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + vpxorq(xmm_dst, xmm_dst, xmm_src); + + add(r64_src, xmm_len); + sub(r64_work_amount, xmm_len); + cmp(r64_work_amount, xmm_len); + jge(l_fold_loop, T_NEAR); + } + + L(l_end); + } + + void tailFold(const Vmm& v_dst); +}; + +template <> +void CombineHash::initVectors() { + auto r64_aux = getReg64(); + + v_k_1_2 = getVmm(); + mov(r64_aux, reinterpret_cast(CONST_K)); + vbroadcasti64x2(v_k_1_2, ptr[r64_aux]); + v_k_8_9 = getVmm(); + mov(r64_aux, reinterpret_cast(CONST_K + 6)); + vbroadcasti64x2(v_k_8_9, ptr[r64_aux]); + + v_shuf_mask = getVmm(); + mov(r64_aux, reinterpret_cast(SHUF_MASK)); + vbroadcasti64x2(v_shuf_mask, ptr[r64_aux]); + + v_dst = getVmm(); + auto xmm_dst = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_shuf_mask = Xbyak::Xmm(v_shuf_mask.getIdx()); + auto xmm_aux = getXmm(); + auto k_rest_mask = RegistersPool::Reg(registersPool); + // Initial CRC + mov(r64_aux, CRC_VAL); + vpxorq(v_dst, v_dst, v_dst); + vpinsrq(xmm_dst, xmm_dst, r64_work_amount, 0x0); + vpinsrq(xmm_dst, xmm_dst, r64_aux, 0x1); + // First xor with source + fillRestWorkMask(k_rest_mask, r64_work_amount); + vmovdqu8(Xbyak::Xmm(xmm_aux.getIdx()) | k_rest_mask | T_z, ptr[r64_src]); + vpshufb(xmm_aux, xmm_aux, xmm_shuf_mask); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + sub(r64_work_amount, xmm_len); + add(r64_src, xmm_len); +} + +template +void CombineHash::initVectors() { + auto r64_aux = getReg64(); + + v_k_1_2 = getVmm(); + mov(r64_aux, reinterpret_cast(CONST_K)); + vbroadcasti128(v_k_1_2, ptr[r64_aux]); + v_k_8_9 = getVmm(); + mov(r64_aux, reinterpret_cast(CONST_K + 6)); + vbroadcasti128(v_k_8_9, ptr[r64_aux]); + + v_shuf_mask = getVmm(); + mov(r64_aux, reinterpret_cast(SHUF_MASK)); + vbroadcasti128(v_shuf_mask, ptr[r64_aux]); + + v_dst = getVmm(); + auto xmm_dst = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_shuf_mask = Xbyak::Xmm(v_shuf_mask.getIdx()); + auto xmm_aux = getXmm(); + auto k_rest_mask = RegistersPool::Reg(registersPool); + // Initial CRC + mov(r64_aux, CRC_VAL); + vpxorq(v_dst, v_dst, v_dst); + vpinsrq(xmm_dst, xmm_dst, r64_aux, 0x1); + // First xor with source + partialLoad(xmm_aux, ptr[r64_src], r64_work_amount); + vpshufb(xmm_aux, xmm_aux, xmm_shuf_mask); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + sub(r64_work_amount, xmm_len); +} + +template <> +void CombineHash::bulkFold(const Vmm& v_dst) { + Xbyak::Label l_fold_loop, l_end; + cmp(r64_work_amount, zmm_len + 3 * xmm_len); + jl(l_end, T_NEAR); + + auto r64_aux = getReg64(); + + auto v_src_0 = getVmm(); + auto v_dst_0 = getVmm(); + auto v_dst_1 = getVmm(); + auto v_dst_2 = getVmm(); + auto& v_dst_3 = v_dst; + auto v_aux_0 = getVmm(); + + auto xmm_k_8_9 = Xbyak::Xmm(v_k_8_9.getIdx()); + auto xmm_k_1_2 = Xbyak::Xmm(v_k_1_2.getIdx()); + auto xmm_src_0 = Xbyak::Xmm(v_src_0.getIdx()); + auto xmm_src_1 = getXmm(); + auto xmm_dst_0 = Xbyak::Xmm(v_dst_0.getIdx()); + auto xmm_dst_1 = Xbyak::Xmm(v_dst_1.getIdx()); + auto xmm_dst_2 = Xbyak::Xmm(v_dst_2.getIdx()); + auto xmm_dst_3 = Xbyak::Xmm(v_dst_3.getIdx()); + auto xmm_aux_0 = Xbyak::Xmm(v_aux_0.getIdx()); + + vmovdqu64(v_dst_0, v_dst_3); + + if (!is_vpclmulqdq) { + prefetchnta(ptr[r64_src + 3 * xmm_len]); + vmovdqu64(xmm_dst_1, ptr[r64_src + 0 * xmm_len]); + vmovdqu64(xmm_dst_2, ptr[r64_src + 1 * xmm_len]); + vmovdqu64(xmm_dst_3, ptr[r64_src + 2 * xmm_len]); + } + + add(r64_src, 3 * xmm_len); + sub(r64_work_amount, zmm_len + 3 * xmm_len); + + L(l_fold_loop); { + vmovdqu64(v_src_0, ptr[r64_src]); + vpshufb(v_src_0, v_src_0, v_shuf_mask); + + if (is_vpclmulqdq) { + vpclmulqdq(v_aux_0, v_dst_0, v_k_8_9, 0b00000000); + vpclmulqdq(v_dst_0, v_dst_0, v_k_8_9, 0b00010001); + vpxorq(v_aux_0, v_aux_0, v_src_0); + vpxorq(v_dst_0, v_dst_0, v_aux_0); + } else { + // 0 + vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_8_9, 0b00000000); + vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_8_9, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_0); + vpxorq(xmm_dst_0, xmm_dst_0, xmm_aux_0); + // 1 + vextracti64x2(xmm_src_1, v_src_0, 0x1); + vpclmulqdq(xmm_aux_0, xmm_dst_1, xmm_k_8_9, 0b00000000); + vpclmulqdq(xmm_dst_1, xmm_dst_1, xmm_k_8_9, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_1); + vpxorq(xmm_dst_1, xmm_dst_1, xmm_aux_0); + // 2 + vextracti64x2(xmm_src_1, v_src_0, 0x2); + vpclmulqdq(xmm_aux_0, xmm_dst_2, xmm_k_8_9, 0b00000000); + vpclmulqdq(xmm_dst_2, xmm_dst_2, xmm_k_8_9, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_1); + vpxorq(xmm_dst_2, xmm_dst_2, xmm_aux_0); + // 3 + vextracti64x2(xmm_src_1, v_src_0, 0x3); + vpclmulqdq(xmm_aux_0, xmm_dst_3, xmm_k_8_9, 0b00000000); + vpclmulqdq(xmm_dst_3, xmm_dst_3, xmm_k_8_9, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_1); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + } + + add(r64_src, zmm_len); + sub(r64_work_amount, zmm_len); + jge(l_fold_loop, T_NEAR); + } + add(r64_work_amount, zmm_len); + + if (is_vpclmulqdq) { + auto ymm_dst_0 = Xbyak::Ymm(v_dst_0.getIdx()); + auto ymm_dst_1 = Xbyak::Ymm(v_dst_1.getIdx()); + auto ymm_aux_0 = Xbyak::Ymm(v_aux_0.getIdx()); + + vextracti64x4(ymm_dst_1, v_dst_0, 0x1); + mov(r64_aux, reinterpret_cast(CONST_K + 2)); + vpclmulqdq(ymm_aux_0, ymm_dst_0, ptr[r64_aux], 0b00000000); + vpclmulqdq(ymm_dst_0, ymm_dst_0, ptr[r64_aux], 0b00010001); + vpxorq(ymm_dst_1, ymm_dst_1, ymm_aux_0); + vpxorq(ymm_dst_0, ymm_dst_0, ymm_dst_1); + + vextracti64x2(xmm_dst_3, ymm_dst_0, 0x1); + vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_1_2, 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_0); + } else { + mov(r64_aux, reinterpret_cast(CONST_K + 4)); + vpclmulqdq(xmm_aux_0, xmm_dst_0, ptr[r64_aux], 0b00000000); + vpclmulqdq(xmm_dst_0, xmm_dst_0, ptr[r64_aux], 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_0); + + mov(r64_aux, reinterpret_cast(CONST_K + 2)); + vpclmulqdq(xmm_aux_0, xmm_dst_1, ptr[r64_aux], 0b00000000); + vpclmulqdq(xmm_dst_1, xmm_dst_1, ptr[r64_aux], 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_1); + + vpclmulqdq(xmm_aux_0, xmm_dst_2, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst_2, xmm_dst_2, xmm_k_1_2, 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_2); + } + + L(l_end); +} + +template <> +void CombineHash::bulkFold(const Vmm& v_dst) { + Xbyak::Label l_fold_loop, l_end; + cmp(r64_work_amount, 2 * vlen - xmm_len); + jl(l_end, T_NEAR); + + auto r64_aux = getReg64(); + + auto v_src_0 = getVmm(); + auto v_dst_0 = getVmm(); + auto v_dst_1 = getVmm(); + auto v_dst_2 = getVmm(); + auto& v_dst_3 = v_dst; + auto v_aux_0 = getVmm(); + + auto xmm_k_4_5 = Xbyak::Xmm(v_k_4_5.getIdx()); + auto xmm_k_1_2 = Xbyak::Xmm(v_k_1_2.getIdx()); + auto xmm_src_0 = Xbyak::Xmm(v_src_0.getIdx()); + auto xmm_src_1 = getXmm(); + auto xmm_dst_0 = Xbyak::Xmm(v_dst_0.getIdx()); + auto xmm_dst_1 = Xbyak::Xmm(v_dst_1.getIdx()); + auto xmm_dst_2 = Xbyak::Xmm(v_dst_2.getIdx()); + auto xmm_dst_3 = Xbyak::Xmm(v_dst_3.getIdx()); + auto xmm_aux_0 = Xbyak::Xmm(v_aux_0.getIdx()); + + if (!is_vpclmulqdq) { + vmovdqu64(xmm_dst_1, ptr[r64_src + 0 * xmm_len]); + } + + add(r64_src, vlen - xmm_len); + sub(r64_work_amount, 2 * vlen - xmm_len); + + L(l_fold_loop); { + vmovdqu64(v_src_0, ptr[r64_src]); + vpshufb(v_src_0, v_src_0, v_shuf_mask); + + if (is_vpclmulqdq) { + vpclmulqdq(v_aux_0, v_dst_0, v_k_4_5, 0b00000000); + vpclmulqdq(v_dst_0, v_dst_0, v_k_4_5, 0b00010001); + vpxorq(v_aux_0, v_aux_0, v_src_0); + vpxorq(v_dst_0, v_dst_0, v_aux_0); + } else { + // 0 + vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_4_5, 0b00000000); + vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_4_5, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_0); + vpxorq(xmm_dst_0, xmm_dst_0, xmm_aux_0); + // 1 + vextracti128(xmm_src_1, v_src_0, 0x1); + vpclmulqdq(xmm_aux_0, xmm_dst_1, xmm_k_4_5, 0b00000000); + vpclmulqdq(xmm_dst_1, xmm_dst_1, xmm_k_4_5, 0b00010001); + vpxorq(xmm_aux_0, xmm_aux_0, xmm_src_1); + vpxorq(xmm_dst_1, xmm_dst_1, xmm_aux_0); + } + + add(r64_src, vlen); + sub(r64_work_amount, vlen); + jge(l_fold_loop, T_NEAR); + } + add(r64_work_amount, vlen); + + if (is_vpclmulqdq) { + auto ymm_dst_0 = Xbyak::Ymm(v_dst_0.getIdx()); + auto ymm_dst_1 = Xbyak::Ymm(v_dst_1.getIdx()); + auto ymm_aux_0 = Xbyak::Ymm(v_aux_0.getIdx()); + + vextracti128(xmm_dst_3, ymm_dst_0, 0x1); + vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_1_2, 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_0); + } else { + vpclmulqdq(xmm_aux_0, xmm_dst_2, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst_2, xmm_dst_2, xmm_k_1_2, 0b00010001); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); + vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_2); + } + + L(l_end); +} + + +template <> +void CombineHash::tailFold(const Vmm& v_dst) { + Xbyak::Label l_fold_to_64, l_save_128, l_end; + cmp(r64_work_amount, 0); + jle(l_fold_to_64, T_NEAR); + + auto r64_aux = getReg64(); + auto xmm_shuf_mask = Xbyak::Xmm(v_shuf_mask.getIdx()); + auto xmm_k_1_2 = Xbyak::Xmm(v_k_1_2.getIdx()); + auto xmm_src = getXmm(); + auto xmm_dst = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_aux = getXmm(); + auto xmm_aux_1 = getXmm(); + auto xmm_aux_2 = getXmm(); + auto k_rest_mask = RegistersPool::Reg(registersPool); + + fillRestWorkMask(k_rest_mask, r64_work_amount); + + vpxorq(xmm_src, xmm_src, xmm_src); + vmovdqu8(Xbyak::Xmm(xmm_src.getIdx()) | k_rest_mask | T_z, ptr[r64_src]); + vpshufb(xmm_src, xmm_src, xmm_shuf_mask); + + vpclmulqdq(xmm_aux, xmm_dst, xmm_k_1_2, 0b00000000); + vpclmulqdq(xmm_dst, xmm_dst, xmm_k_1_2, 0b00010001); + vpxorq(xmm_aux, xmm_aux, xmm_src); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + + L(l_fold_to_64); + cmp(r64_make_64_fold, 0); + je(l_save_128, T_NEAR); + + mov(r64_aux, reinterpret_cast(CONST_K + 8)); + vpclmulqdq(xmm_aux, xmm_dst, ptr[r64_aux], 0b00000001); + vpslldq(xmm_dst, xmm_dst, 0x8); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + + mov(r64_aux, reinterpret_cast(CONST_K + 10)); + vmovdqu64(xmm_aux_2, ptr[r64_aux]); + vpclmulqdq(xmm_aux, xmm_dst, xmm_aux_2, 0b00000001); + mov(r64_aux, 0x0); + vpinsrq(xmm_aux_1, xmm_dst, r64_aux, 0x0); + vpxorq(xmm_aux, xmm_aux, xmm_aux_1); + vpinsrq(xmm_aux_1, xmm_aux, r64_aux, 0x0); + vpclmulqdq(xmm_aux, xmm_aux, xmm_aux_2, 0b00010001); + vpxorq(xmm_aux, xmm_aux, xmm_aux_1); + vpxorq(xmm_dst, xmm_dst, xmm_aux); + + vpextrq(ptr[r64_dst], xmm_dst, 0x0); + jmp(l_end, T_NEAR); + + + L(l_save_128); + vmovdqu64(ptr[r64_dst], xmm_dst); + + L(l_end); +} + +template <> +void CombineHash::tailFold(const Vmm& v_dst) { +} + +template +const uint64_t CombineHash::CRC_VAL = 0xffffffffffffffff; + +// P(x) = 0x42F0E1EBA9EA3693 +template +const uint64_t CombineHash::CONST_K[12] = { 0x05f5c3c7eb52fab6, 0x4eb938a7d257740e, // x^(64*1), x^(64*2) + 0x571bee0a227ef92b, 0x44bef2a201b5200c, // x^(64*3), x^(64*4) + 0x54819d8713758b2c, 0x4a6b90073eb0af5a, // x^(64*5), x^(64*6) + 0x5f6843ca540df020, 0xddf4b6981205b83f, // x^(64*7), x^(64*8) + 0x05f5c3c7eb52fab6, 0x0000000000000000, // x^(64*1), x^(64*1) mod P(x) + 0x578d29d06cc4f872, 0x42f0e1eba9ea3693 // floor(x^128/P(x)) - x^64, P(x) - x^64 + }; + +template +const uint8_t CombineHash::SHUF_MASK[] = { 0b00001111, 0b00001110, 0b00001101, 0b00001100, 0b00001011, 0b00001010, 0b00001001, 0b00001000, + 0b00000111, 0b00000110, 0b00000101, 0b00000100, 0b00000011, 0b00000010, 0b00000001, 0b00000000 }; + +} // namespace jit +#endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 + +size_t combine_hash(const void* src, size_t size) { +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) + jit::fn_t kernel; + + if (jit::Generator::mayiuse(jit::avx512_core)) { + kernel = jit::CombineHash::get(); + } else if (jit::Generator::mayiuse(jit::avx2)) { + kernel = jit::CombineHash::get(); + } + + if (kernel) { + size_t res = 0lu; + + static const size_t block_size = 2lu * jit::Generator::zmm_len; + // There is no sense to perform parallel execution if there are less than 2 blocks. + if (size >= 2lu * block_size) { + const auto nthr = parallel_get_max_threads() / 2; // TODO: WA for Hyper Threading + std::vector intermediate(nthr * 2); // xmm_len * nthr + const uint64_t blocks = size / block_size; + const uint64_t el_per_thread = block_size * ((blocks + nthr - 1) / nthr); + + parallel_nt(nthr, [&](const int ithr, const int nthr) { + uint64_t start = ithr * el_per_thread; + if (start >= size) { + return; + } + uint64_t work_amount = (el_per_thread + start > size) ? size - start : el_per_thread; + + size_t res = 0lu; + jit::CombineHashCallArgs args; + + args.src_ptr = reinterpret_cast(src) + start; + args.dst_ptr = &intermediate[ithr * 2]; + args.work_amount = work_amount; + args.make_64_fold = 0lu; + kernel(&args); + }); + + + jit::CombineHashCallArgs args; + args.src_ptr = intermediate.data(); + args.dst_ptr = &res; + args.work_amount = ((size + el_per_thread - 1) / el_per_thread) * jit::Generator::xmm_len; + args.make_64_fold = 1lu; + kernel(&args); + } else { + jit::CombineHashCallArgs args; + args.src_ptr = src; + args.dst_ptr = &res; + args.work_amount = size; + args.make_64_fold = 1lu; + kernel(&args); + } + return res; + } +#endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 + + constexpr auto cel_size = sizeof(size_t); + auto seed = static_cast(size); + const auto data = static_cast(src); + const auto d_end = std::next(data, size / cel_size); + // The constant value used as a magic number has been + // traditionally used e.g. in boost library's hash_combine. + // It happens to be derived from the golden ratio. + for (auto d = data; d != d_end; ++d) { + seed ^= *d + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + size_t last_bytes{0}; + std::memcpy(&last_bytes, d_end, size % cel_size); + seed ^= last_bytes + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; +} + +} // namespace runtime +} // namespace ov diff --git a/src/core/reference/src/op/jit_generator.cpp b/src/core/reference/src/op/utils/jit_generator.cpp similarity index 96% rename from src/core/reference/src/op/jit_generator.cpp rename to src/core/reference/src/op/utils/jit_generator.cpp index 7d7da06d5da8d5..174cbb9242acc4 100644 --- a/src/core/reference/src/op/jit_generator.cpp +++ b/src/core/reference/src/op/utils/jit_generator.cpp @@ -11,7 +11,7 @@ # endif # include -# include "jit_generator.hpp" +# include "openvino/reference/utils/jit_generator.hpp" # include "openvino/core/type/bfloat16.hpp" # include "openvino/core/type/float16.hpp" @@ -51,6 +51,10 @@ bool Generator::mayiuse(const cpu_isa_t cpu_isa) { return true && cpu.has(Cpu::tAVX512_VPOPCNTDQ); case fp16: return cpu.has(Cpu::tF16C); + case cpu_isa_t::pclmulqdq: + return cpu.has(Cpu::tPCLMULQDQ); + case cpu_isa_t::vpclmulqdq: + return cpu.has(Cpu::tVPCLMULQDQ); case isa_any: return true; } diff --git a/src/core/src/pass/serialize.cpp b/src/core/src/pass/serialize.cpp index 409dcad066d7a6..c36b681d9e034d 100644 --- a/src/core/src/pass/serialize.cpp +++ b/src/core/src/pass/serialize.cpp @@ -22,6 +22,7 @@ #include "openvino/opsets/opset1.hpp" #include "openvino/pass/constant_folding.hpp" #include "openvino/reference/convert.hpp" +#include "openvino/reference/utils/combine_hash.hpp" #include "openvino/runtime/aligned_buffer.hpp" #include "openvino/runtime/string_aligned_buffer.hpp" #include "openvino/util/file_util.hpp" @@ -69,23 +70,6 @@ std::string translate_type_name(const std::string& name) { return name; } -size_t hash_combine(const void* v, int64_t size) { - constexpr auto cel_size = sizeof(size_t); - auto seed = static_cast(size); - const auto data = static_cast(v); - const auto d_end = std::next(data, size / cel_size); - // The constant value used as a magic number has been - // traditionally used e.g. in boost library's hash_combine. - // It happens to be derived from the golden ratio. - for (auto d = data; d != d_end; ++d) { - seed ^= *d + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - size_t last_bytes{0}; - std::memcpy(&last_bytes, d_end, size % cel_size); - seed ^= last_bytes + 0x9e3779b9 + (seed << 6) + (seed >> 2); - return seed; -} - class ConstantWriter { public: using FilePosition = int64_t; @@ -132,7 +116,7 @@ class ConstantWriter { // the same hash for {2, 2} and {0, 128} arrays. // But even strong hashing algorithms sometimes give collisions. // Therefore we always have to compare values when finding a match in the hash multimap. - const HashValue hash = hash_combine(ptr_to_write, *new_size); + const HashValue hash = ov::runtime::combine_hash(ptr_to_write, *new_size); auto found = m_hash_to_file_positions.find(hash); // iterate over all matches of the key in the multimap while (found != m_hash_to_file_positions.end()) {