From e2e08f330c15271bd547b58ad9081ac128a0e872 Mon Sep 17 00:00:00 2001 From: Nikolai Shchegolev Date: Thu, 17 Oct 2024 11:40:33 +0400 Subject: [PATCH] 26. refactoring --- src/core/src/runtime/compute_hash.cpp | 547 +++++++++----------------- 1 file changed, 179 insertions(+), 368 deletions(-) diff --git a/src/core/src/runtime/compute_hash.cpp b/src/core/src/runtime/compute_hash.cpp index e44edb08a8eeb5..6e10d51564bfdb 100644 --- a/src/core/src/runtime/compute_hash.cpp +++ b/src/core/src/runtime/compute_hash.cpp @@ -89,17 +89,6 @@ constexpr uint64_t K_12_13_OFF = 7lu * 2lu * sizeof(uint64_t); constexpr uint64_t K_14_15_OFF = 8lu * 2lu * sizeof(uint64_t); constexpr uint64_t K_16_17_OFF = 9lu * 2lu * sizeof(uint64_t); -// static const uint64_t K_2_3[] = { K_2, 0x4eb938a7d257740e }; // x^(64*2), x^(64*3) -// static const uint64_t K_3_4[] = { 0x571bee0a227ef92b, 0x44bef2a201b5200c }; // x^(64*4), x^(64*5) -// static const uint64_t K_5_6[] = { 0x54819d8713758b2c, 0x4a6b90073eb0af5a }; // x^(64*6), x^(64*7) -// static const uint64_t K_7_8[] = { 0x5f6843ca540df020, 0xddf4b6981205b83f }; // x^(64*8), x^(64*9) -// static const uint64_t K_9_10[] = { 0x097c516e98bd2e73, 0x0b76477b31e22e7b }; // x^(64*10), x^(64*11) -// static const uint64_t K_11_12[] = { 0x9af04e1eff82d0dd, 0x6e82e609297f8fe8 }; // x^(64*12), x^(64*13) -// static const uint64_t K_13_14[] = { 0xe464f4df5fb60ac1, 0xb649c5b35a759cf2 }; // x^(64*14), x^(64*14) -// static const uint64_t K_15_16[] = { 0x05cf79dea9ac37d6, 0x001067e571d7d5c2 }; // x^(64*16), x^(64*17) -// static const uint64_t K_1_0[] = { K_2, 0x0000000000000000 }; // x^(64*1), x^(64*1) mod P(x) -// static const uint64_t K_P_P[] = { P_1, P_2 }; // floor(x^128/P(x)) - x^64, P(x) - x^64 - class HashBase : public Generator { protected: void (*ker_fn)(const ComputeHashCallArgs*); @@ -182,7 +171,6 @@ r64_tmp = getReg64(); RegistersPool::Reg r64_tmp; // Vector registers - // RegistersPool::Reg v_dst; RegistersPool::Reg v_k_2_3; RegistersPool::Reg v_shuf_mask; @@ -204,86 +192,10 @@ RegistersPool::Reg r64_tmp; void uni_vbroadcasti64x2(const Xbyak::Ymm& v_dst, const Xbyak::Address& v_src_0); - void fill_rest_work_mask(const Xbyak::Opmask& k_dst_mask, - const Xbyak::Reg64& r64_work_rest) { - Xbyak::Label l_mv_mask; - auto rOnes = getReg64(); - - mov(rOnes, 0xFFFFFFFFFFFFFFFF); - cmp(r64_work_rest, 0x3f); - jg(l_mv_mask); - - shlx(rOnes, rOnes, r64_work_rest); - not_(rOnes); - - L(l_mv_mask); - kmovq(k_dst_mask, rOnes); - } - - void partial_load(const Xbyak::Xmm& xmm_dst, - const Xbyak::Address& src_addr, - const Xbyak::Reg64& r64_load_num) { - Xbyak::Label l_partial, l_end; - - cmp(r64_load_num, xmm_len); - jl(l_partial, T_NEAR); - uni_vmovdqu64(xmm_dst, ptr[src_addr.getRegExp()]); - jmp(l_end, T_NEAR); - - L(l_partial); { - uni_vpxorq(xmm_dst, xmm_dst, xmm_dst); - for (size_t j = 0lu; j < xmm_len - 1; j++) { - cmp(r64_load_num, j); - jle(l_end, T_NEAR); - pinsrb(xmm_dst, ptr[src_addr.getRegExp() + j], j); - } - } - - L(l_end); - } + void partial_load(const Xbyak::Xmm& xmm_dst, const Xbyak::Address& src_addr, const Xbyak::Reg64& r64_load_num); - void partial_load(const Xbyak::Ymm& ymm_dst, - const Xbyak::Address& src_addr, - const Xbyak::Reg64& r64_load_num) { - Xbyak::Label l_xmm, l_partial, l_end; - auto xmm_dst = Xbyak::Xmm(ymm_dst.getIdx()); - - cmp(r64_load_num, ymm_len); - jl(l_xmm, T_NEAR); - uni_vmovdqu64(ymm_dst, ptr[src_addr.getRegExp()]); - jmp(l_end, T_NEAR); - - L(l_xmm); - uni_vpxorq(ymm_dst, ymm_dst, ymm_dst); - cmp(r64_load_num, xmm_len); - jl(l_partial, T_NEAR); - uni_vmovdqu64(xmm_dst, ptr[src_addr.getRegExp()]); - je(l_end, T_NEAR); - - { - Xbyak::Label l_rest_loop, l_perm; - - vperm2i128(ymm_dst, ymm_dst, ymm_dst, 0x1); - for (size_t j = 0lu; j < xmm_len - 1; j++) { - cmp(r64_load_num, xmm_len + j); - jle(l_perm, T_NEAR); - pinsrb(xmm_dst, ptr[src_addr.getRegExp() + xmm_len + j], j); - } - L(l_perm); - vperm2i128(ymm_dst, ymm_dst, ymm_dst, 0x1); - } - jmp(l_end, T_NEAR); - - L(l_partial); { - for (size_t j = 0lu; j < xmm_len - 1; j++) { - cmp(r64_load_num, j); - jle(l_end, T_NEAR); - pinsrb(xmm_dst, ptr[src_addr.getRegExp() + j], j); - } - } - - L(l_end); - } + void partial_load(const Xbyak::Ymm& ymm_dst, const Xbyak::Address& src_addr, const Xbyak::Reg64& r64_load_num); + }; class ComputeHash4 : public Generator { @@ -397,6 +309,97 @@ template void ComputeHash::uni_vbroadcasti64x2(const Xbyak::Ymm& v_dst, const Xbyak::Address& v_src_0) { vbroadcasti128(v_dst, v_src_0); } +template <> +void ComputeHash::partial_load(const Xbyak::Xmm& xmm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + Xbyak::Label l_mv_mask; + auto rOnes = getReg64(); + auto k_load_mask = RegistersPool::Reg(m_registers_pool); + + mov(rOnes, 0xFFFFFFFFFFFFFFFF); + cmp(r64_load_num, 0x3f); + jg(l_mv_mask); + + shlx(rOnes, rOnes, r64_load_num); + not_(rOnes); + + L(l_mv_mask); + kmovq(k_load_mask, rOnes); + + vmovdqu8(Vmm(xmm_dst.getIdx()) | k_load_mask | T_z, ptr[r64_src_ptr]); +} +template +void ComputeHash::partial_load(const Xbyak::Xmm& xmm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + Xbyak::Label l_partial, l_end; + + cmp(r64_load_num, xmm_len); + jl(l_partial, T_NEAR); + uni_vmovdqu64(xmm_dst, ptr[src_addr.getRegExp()]); + jmp(l_end, T_NEAR); + + L(l_partial); { + uni_vpxorq(xmm_dst, xmm_dst, xmm_dst); + for (size_t j = 0lu; j < xmm_len - 1; j++) { + cmp(r64_load_num, j); + jle(l_end, T_NEAR); + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + j], j); + } + } + + L(l_end); +} +template <> +void ComputeHash::partial_load(const Xbyak::Ymm& xmm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + partial_load(Xbyak::Xmm(xmm_dst.getIdx()), src_addr, r64_load_num); +} +template +void ComputeHash::partial_load(const Xbyak::Ymm& ymm_dst, + const Xbyak::Address& src_addr, + const Xbyak::Reg64& r64_load_num) { + Xbyak::Label l_xmm, l_partial, l_end; + auto xmm_dst = Xbyak::Xmm(ymm_dst.getIdx()); + + cmp(r64_load_num, ymm_len); + jl(l_xmm, T_NEAR); + uni_vmovdqu64(ymm_dst, ptr[src_addr.getRegExp()]); + jmp(l_end, T_NEAR); + + L(l_xmm); + uni_vpxorq(ymm_dst, ymm_dst, ymm_dst); + cmp(r64_load_num, xmm_len); + jl(l_partial, T_NEAR); + uni_vmovdqu64(xmm_dst, ptr[src_addr.getRegExp()]); + je(l_end, T_NEAR); + + { + Xbyak::Label l_rest_loop, l_perm; + + vperm2i128(ymm_dst, ymm_dst, ymm_dst, 0x1); + for (size_t j = 0lu; j < xmm_len - 1; j++) { + cmp(r64_load_num, xmm_len + j); + jle(l_perm, T_NEAR); + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + xmm_len + j], j); + } + L(l_perm); + vperm2i128(ymm_dst, ymm_dst, ymm_dst, 0x1); + } + jmp(l_end, T_NEAR); + + L(l_partial); { + for (size_t j = 0lu; j < xmm_len - 1; j++) { + cmp(r64_load_num, j); + jle(l_end, T_NEAR); + pinsrb(xmm_dst, ptr[src_addr.getRegExp() + j], j); + } + } + + L(l_end); +} template void ComputeHash::initialize(const Vmm& v_dst) { @@ -422,19 +425,12 @@ mov(r64_tmp, ptr[r64_params + GET_OFF(tmp_ptr)]); vpinsrq(xmm_aux, xmm_aux, r64_aux, 0x1); // First xor with source. - if (isa == avx512_core) { - auto k_rest_mask = RegistersPool::Reg(m_registers_pool); - fill_rest_work_mask(k_rest_mask, r64_work_amount); - vmovdqu8(Vmm(v_dst.getIdx()) | k_rest_mask | T_z, ptr[r64_src_ptr]); - } else { - partial_load(v_dst, ptr[r64_src_ptr], r64_work_amount); - } + partial_load(v_dst, ptr[r64_src_ptr], r64_work_amount); vpshufb(v_dst, v_dst, v_shuf_mask); pxor(xmm_dst, xmm_aux); // The SSE version is used to avoid zeroing out the rest of the Vmm. if (m_jcp.type == SINGLE_THREAD) { add(r64_src_ptr, xmm_len); } -// uni_vmovdqu64(ptr[r64_tmp], xmm_dst); } else if (m_jcp.type == N_THREAD) { uni_vmovdqu64(v_dst, ptr[r64_src_ptr]); vpshufb(v_dst, v_dst, v_shuf_mask); @@ -692,8 +688,6 @@ void ComputeHash::bulk_fold(const Vmm& v_dst) { } add(r64_work_amount, get_vlen()); -// uni_vmovdqu64(ptr[r64_tmp], xmm_dst_0); - if (m_jcp.type == SINGLE_THREAD) { if (is_vpclmulqdq) { vextracti128(xmm_dst_1, v_dst_0, 0x1); @@ -719,76 +713,58 @@ void ComputeHash::join(const Vmm& v_dst) { if (m_jcp.type != FINAL_FOLD) { return; } - // if (is_vpclmulqdq) { - // auto ymm_dst_0 = Xbyak::Ymm(v_dst_0.getIdx()); - // auto ymm_dst_1 = Xbyak::Ymm(v_dst_1.getIdx()); - // auto ymm_aux_0 = Xbyak::Ymm(v_aux_0.getIdx()); - - // vextracti64x4(ymm_dst_1, v_dst_0, 0x1); - // mov(r64_aux, reinterpret_cast(K_3_4)); - // vpclmulqdq(ymm_aux_0, ymm_dst_0, ptr[r64_aux], 0b00000000); - // vpclmulqdq(ymm_dst_0, ymm_dst_0, ptr[r64_aux], 0b00010001); - // uni_vpxorq(ymm_dst_1, ymm_dst_1, ymm_aux_0); - // uni_vpxorq(ymm_dst_0, ymm_dst_0, ymm_dst_1); - - // vextracti64x2(xmm_dst_3, ymm_dst_0, 0x1); - // vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_2_3, 0b00000000); - // vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_2_3, 0b00010001); - // uni_vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); - // uni_vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_0); - // } else { - mov(r64_aux, ptr[r64_params + GET_OFF(intermediate_ptr)]); - prefetcht0(ptr[r64_aux + 1024]); - - auto xmm_src_0 = getXmm(); - auto xmm_src_last = Xbyak::Xmm(v_dst.getIdx()); - auto xmm_aux_0 = getXmm(); - auto xmm_k_2_3 = Xbyak::Xmm(v_k_2_3.getIdx()); - - uni_vmovdqu64(xmm_src_last, ptr[r64_aux + xmm_len * 7]); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_14_15_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_14_15_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_12_13_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_12_13_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 2lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_10_11_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_10_11_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 3lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_8_9_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_8_9_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 4lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 5lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 6lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, xmm_k_2_3, 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, xmm_k_2_3, 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - // } + + mov(r64_aux, ptr[r64_params + GET_OFF(intermediate_ptr)]); + prefetcht0(ptr[r64_aux + 1024]); + + auto xmm_src_0 = getXmm(); + auto xmm_src_last = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_aux_0 = getXmm(); + auto xmm_k_2_3 = Xbyak::Xmm(v_k_2_3.getIdx()); + + uni_vmovdqu64(xmm_src_last, ptr[r64_aux + xmm_len * 7]); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_14_15_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_14_15_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_12_13_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_12_13_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 2lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_10_11_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_10_11_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 3lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_8_9_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_8_9_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 4lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 5lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 6lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, xmm_k_2_3, 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, xmm_k_2_3, 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); } template @@ -796,52 +772,34 @@ void ComputeHash::join(const Vmm& v_dst) { if (m_jcp.type != FINAL_FOLD) { return; } - // if (is_vpclmulqdq) { - // auto ymm_dst_0 = Xbyak::Ymm(v_dst_0.getIdx()); - // auto ymm_dst_1 = Xbyak::Ymm(v_dst_1.getIdx()); - // auto ymm_aux_0 = Xbyak::Ymm(v_aux_0.getIdx()); - - // vextracti64x4(ymm_dst_1, v_dst_0, 0x1); - // mov(r64_aux, reinterpret_cast(K_3_4)); - // vpclmulqdq(ymm_aux_0, ymm_dst_0, ptr[r64_aux], 0b00000000); - // vpclmulqdq(ymm_dst_0, ymm_dst_0, ptr[r64_aux], 0b00010001); - // uni_vpxorq(ymm_dst_1, ymm_dst_1, ymm_aux_0); - // uni_vpxorq(ymm_dst_0, ymm_dst_0, ymm_dst_1); - - // vextracti64x2(xmm_dst_3, ymm_dst_0, 0x1); - // vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_2_3, 0b00000000); - // vpclmulqdq(xmm_dst_0, xmm_dst_0, xmm_k_2_3, 0b00010001); - // uni_vpxorq(xmm_dst_3, xmm_dst_3, xmm_aux_0); - // uni_vpxorq(xmm_dst_3, xmm_dst_3, xmm_dst_0); - // } else { - mov(r64_aux, ptr[r64_params + GET_OFF(intermediate_ptr)]); - prefetcht0(ptr[r64_aux + 1024]); - - auto xmm_src_0 = getXmm(); - auto xmm_src_last = Xbyak::Xmm(v_dst.getIdx()); - auto xmm_aux_0 = getXmm(); - auto xmm_k_2_3 = Xbyak::Xmm(v_k_2_3.getIdx()); - - uni_vmovdqu64(xmm_src_last, ptr[r64_aux + xmm_len * 3]); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 0lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 1lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - - uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 2lu]); - vpclmulqdq(xmm_aux_0, xmm_src_0, xmm_k_2_3, 0b00000000); - vpclmulqdq(xmm_src_0, xmm_src_0, xmm_k_2_3, 0b00010001); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); - uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); - // } + + mov(r64_aux, ptr[r64_params + GET_OFF(intermediate_ptr)]); + prefetcht0(ptr[r64_aux + 1024]); + + auto xmm_src_0 = getXmm(); + auto xmm_src_last = Xbyak::Xmm(v_dst.getIdx()); + auto xmm_aux_0 = getXmm(); + auto xmm_k_2_3 = Xbyak::Xmm(v_k_2_3.getIdx()); + + uni_vmovdqu64(xmm_src_last, ptr[r64_aux + xmm_len * 3]); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 0lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_6_7_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 1lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, ptr[r64_k_ptr + K_4_5_OFF], 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); + + uni_vmovdqu64(xmm_src_0, ptr[r64_aux + xmm_len * 2lu]); + vpclmulqdq(xmm_aux_0, xmm_src_0, xmm_k_2_3, 0b00000000); + vpclmulqdq(xmm_src_0, xmm_src_0, xmm_k_2_3, 0b00010001); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_aux_0); + uni_vpxorq(xmm_src_last, xmm_src_last, xmm_src_0); } template @@ -894,13 +852,7 @@ void ComputeHash::fold_to_64(const Vmm& v_dst) { auto xmm_aux_1 = getXmm(); auto xmm_aux_2 = getXmm(); - if (isa == avx512_core) { - auto k_rest_mask = RegistersPool::Reg(m_registers_pool); - fill_rest_work_mask(k_rest_mask, r64_work_amount); - vmovdqu8(Xbyak::Xmm(xmm_src.getIdx()) | k_rest_mask | T_z, ptr[r64_src_ptr]); - } else { - partial_load(xmm_src, ptr[r64_src_ptr], r64_work_amount); - } + partial_load(xmm_src, ptr[r64_src_ptr], r64_work_amount); vpshufb(xmm_src, xmm_src, xmm_shuf_mask); vpclmulqdq(xmm_aux, xmm_dst, xmm_k_2_3, 0b00000000); @@ -933,12 +885,6 @@ void ComputeHash::fold_to_64(const Vmm& v_dst) { vpextrq(ptr[r64_dst_ptr], xmm_dst, 0x0); } -// template -// const uint64_t ComputeHash::K12 = 0x7B4BC8789D65B2A5; - -// template -// const uint64_t ComputeHash::CRC_VAL = 0xffffffffffffffff; - // Auxiliary fn to obtain K constant multipliers. // uint32_t gen_k_value(int t, uint32_t poly = 0x04C11DB7) { // uint32_t gen_k_value(int t, uint32_t poly = 0xD663B05D) { @@ -1183,7 +1129,7 @@ uint64_t barrett_calc(uint64_t poly = 0x42F0E1EBA9EA3693, int bits = 64) { size_t compute_hash(const void* src, size_t size) { // if (size > 131072 * 2) - printf("combine_hash size: %lu\n", size); + // printf("combine_hash size: %lu\n", size); // static uint64_t counter = 0lu; // static uint64_t sum = 0lu; // // counter++; @@ -1375,143 +1321,10 @@ args.tmp_ptr = &(tmp_vec[0]); // std::cout << "[" << counter << "] compute_hash time: " << ms_int.count() << "; sum: " << sum << "; size: " << size << "; avg_time: " << sum / counter << " nanosec" << std::endl; // } // if (size >= 131072 * 2) - printf(" res: %lu\n", result); + // printf(" res: %lu\n", result); return result; } -// if (kernel) { -// size_t res = 0lu; - -// static const size_t block_size = 2lu * jit::Generator::zmm_len; // TODO: vlen -// // There is no sense to perform parallel execution if there are less than 2 blocks. -// // if (size >= 2lu * block_size) { -// if (size >= 20000000lu) { -// // static const auto nthr = parallel_get_max_threads() / 2; // TODO: WA for Hyper Threading -// static const auto nthr = 1lu; -// static std::vector intermediate(nthr * 2); // xmm_len * nthr -// const uint64_t blocks = size / block_size; -// const uint64_t el_per_thread = block_size * ((blocks + nthr - 1) / nthr); - -// std::vector tmp_vec(nthr * 4); -// std::vector tmp_vec_2(nthr * 4); - -// // if (!(counter == 104)) { -// // if (!(counter == 88 || counter == 92 || counter == 96 || counter == 100 || counter == 104 || counter == 108)) { -// parallel_nt(nthr, [&](const int ithr, const int nthr) { -// uint64_t start = ithr * el_per_thread; -// if (start >= size) { -// return; -// } -// uint64_t work_amount = (el_per_thread + start > size) ? size - start : el_per_thread; - -// jit::ComputeHashCallArgs args; - -// args.src_ptr = reinterpret_cast(src) + start; -// args.dst_ptr = &intermediate[ithr * 2]; -// args.work_amount = work_amount; -// args.make_64_fold = 0lu; -// args.tmp_ptr = &(tmp_vec[ithr * 4]); - -// kernel(&args); - -// // if (counter == 8) -// // printf(" [%d] start: %lu, work_amount: %lu\n", ithr, start, work_amount); -// // printf(" Parallel fold: %lu; tmp_vec {%lu; %lu; %lu; %lu}\n", -// // size, tmp_vec[ithr * 4], tmp_vec[ithr * 4 + 1], tmp_vec[ithr * 4 + 2], tmp_vec[ithr * 4 + 3]); -// }); -// // } else { -// // for (int ithr = 0; ithr < nthr; ithr++) { -// // uint64_t start = ithr * el_per_thread; -// // if (start >= size) { -// // continue; -// // } -// // uint64_t work_amount = (el_per_thread + start > size) ? size - start : el_per_thread; - -// // size_t res = 0lu; -// // jit::ComputeHashCallArgs args; - -// // args.src_ptr = reinterpret_cast(src) + start; -// // args.dst_ptr = &(intermediate[ithr * 2]); -// // args.work_amount = work_amount; -// // args.make_64_fold = 0lu; -// // args.tmp_ptr = &(tmp_vec[ithr * 2]); -// // kernel(&args); -// // } -// // } - -// // if (counter == 88 || counter == 92 || counter == 96 || counter == 100 || counter == 104 || counter == 108) { -// // std::cout << "Combine hash " << counter << " Hash: " ; -// // for (int i = 0; i < intermediate.size(); i++) { -// // std::cout << intermediate[i] << "; "; -// // } -// // std::cout << std::endl << " tmp vals: "; -// // for (int i = 0; i < tmp_vec.size(); i++) { -// // std::cout << tmp_vec[i] << "; "; -// // } -// // std::cout << std::endl; - -// // // auto data = reinterpret_cast(src);// + 131072; -// // // for (int i = 0; i < 131072; i++) { -// // // std::cout << static_cast(data[i]) << std::endl; -// // // } -// // } - -// jit::ComputeHashCallArgs args; -// args.src_ptr = intermediate.data(); -// args.dst_ptr = &res; -// args.work_amount = ((size + el_per_thread - 1) / el_per_thread) * jit::Generator::xmm_len; -// args.make_64_fold = 1lu; -// args.tmp_ptr = tmp_vec_2.data(); - -// kernel(&args); - -// // if (size == 2359296) -// // printf(" [single] work_amount: %lu\n", args.work_amount); -// // printf(" Final fold: %lu; tmp_vec {%lu; %lu; %lu; %lu}\n", size, tmp_vec_2[0], tmp_vec_2[1], tmp_vec_2[2], tmp_vec_2[3]); -// } else { -// std::vector tmp_vec(4, 0lu); -// jit::ComputeHashCallArgs args; -// args.src_ptr = src; -// args.dst_ptr = &res; -// args.work_amount = size; -// args.make_64_fold = 1lu; -// args.tmp_ptr = &(tmp_vec[0]); - -// kernel(&args); - -// // if (size > 200000lu) { -// // std::cout << "compute_hash size: " << size << "; tmp_vec: {" << tmp_vec[0] << "; " << tmp_vec[1] << "}" << std::endl; -// // if (size == 4) { -// // std::cout << " Seq size: " << size << "; src: {" << reinterpret_cast(src)[0] -// // << "} tmp_vec: {" << tmp_vec[0] << "; " << tmp_vec[1] << "; " << tmp_vec[2] << "; " << tmp_vec[3] << "}" << std::endl; -// // } -// // } -// } -// // static uint64_t counter = 0lu; -// // counter++; -// // // // if (counter < 200) { -// // if (size == 4) { -// // std::cout << "compute_hash(" << counter << ") kernel res: " << res << "; size: " << size << std::endl; -// // // if (res == 0 || size == 8) { -// // auto src_u8 = reinterpret_cast(src); -// // for (int i = 0; i < size; i++) { -// // std::cout << int(src_u8[i]) << "; "; -// // } -// // std::cout << std::endl; -// // // } -// // } - -// auto t2 = std::chrono::high_resolution_clock::now(); -// auto ms_int = std::chrono::duration_cast(t2 - t1); -// sum += ms_int.count(); -// // if (counter == 1 || counter == 8 || counter == 557 || counter == 564) -// // // if (size >= 100000 && size <= 200000) -// if (size > 200000) -// std::cout << "[" << counter << "] compute_hash time: " << ms_int.count() << "; sum: " << sum << "; size: " << size << "; avg_time: " << sum / counter << " nanosec" << std::endl; -// // std::cout << ms_int.count() << std::endl; -// printf(" res: %lu\n", res); -// return res; -// } #endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 constexpr auto cel_size = sizeof(size_t); @@ -1527,9 +1340,7 @@ args.tmp_ptr = &(tmp_vec[0]); size_t last_bytes{0}; std::memcpy(&last_bytes, d_end, size % cel_size); seed ^= last_bytes + 0x9e3779b9 + (seed << 6) + (seed >> 2); -// static uint64_t counter = 0lu; -// if (counter++ < 100) -// std::cout << "compute_hash ref res: " << seed << "; size: " << size << std::endl; + return seed; }