Skip to content

Commit

Permalink
Revert "15"
Browse files Browse the repository at this point in the history
This reverts commit be7ce83.
  • Loading branch information
nshchego committed Oct 3, 2024
1 parent be7ce83 commit 937bba3
Showing 1 changed file with 33 additions and 39 deletions.
72 changes: 33 additions & 39 deletions src/core/reference/src/op/utils/combine_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ void CombineHash<avx512_core>::bulkFold(const Vmm& v_dst) {
}
add(r64_work_amount, vlen);

vmovdqu64(ptr[r64_tmp], xmm_dst_2);
vmovdqu64(ptr[r64_tmp + xmm_len], xmm_dst_3);
vmovdqu64(ptr[r64_tmp], xmm_dst_3);

if (is_vpclmulqdq) {
auto ymm_dst_0 = Xbyak::Ymm(v_dst_0.getIdx());
Expand Down Expand Up @@ -437,8 +436,7 @@ vmovdqu64(ptr[r64_tmp + xmm_len], xmm_dst_3);
template <>
void CombineHash<avx2>::bulkFold(const Vmm& v_dst) {
Xbyak::Label l_fold_loop, l_end;
// cmp(r64_work_amount, 2 * vlen - xmm_len);
cmp(r64_work_amount, 4 * vlen - xmm_len);
cmp(r64_work_amount, 2 * vlen - xmm_len);
jl(l_end, T_NEAR);

auto r64_aux = getReg64();
Expand All @@ -461,8 +459,7 @@ void CombineHash<avx2>::bulkFold(const Vmm& v_dst) {
// auto xmm_dst_3 = Xbyak::Xmm(v_dst_3.getIdx());
auto xmm_aux_0 = Xbyak::Xmm(v_aux_0.getIdx());

// mov(r64_aux, reinterpret_cast<uintptr_t>(CONST_K + 2));
mov(r64_aux, reinterpret_cast<uintptr_t>(CONST_K + 6));
mov(r64_aux, reinterpret_cast<uintptr_t>(CONST_K + 2));
vbroadcasti128(v_k_loop, ptr[r64_aux]);

vmovdqu(v_dst_0, v_dst_1);
Expand All @@ -471,47 +468,45 @@ void CombineHash<avx2>::bulkFold(const Vmm& v_dst) {
vmovdqu(xmm_aux_0, ptr[r64_src]);
vpshufb(xmm_aux_0, xmm_aux_0, Xbyak::Xmm(v_shuf_mask.getIdx()));
vinserti128(v_dst_0, v_dst_0, xmm_aux_0, 0x1);
vmovdqu(v_dst_1, ptr[r64_src + xmm_len]);
vpshufb(v_dst_1, v_dst_1, v_shuf_mask);
add(r64_src, vlen + xmm_len);
// vmovdqu(v_dst_1, ptr[r64_src + xmm_len]);
// vpshufb(v_dst_1, v_dst_1, v_shuf_mask);
// add(r64_src, vlen + xmm_len);
} else {
vmovdqu(xmm_dst_1, ptr[r64_src]);
vpshufb(xmm_dst_1, xmm_dst_1, Xbyak::Xmm(v_shuf_mask.getIdx()));
add(r64_src, xmm_len);
}

// add(r64_src, xmm_len);
// sub(r64_work_amount, 2 * vlen - xmm_len); // Check
sub(r64_work_amount, 3 * vlen + xmm_len); // Check
add(r64_src, xmm_len);
sub(r64_work_amount, 2 * vlen - xmm_len); // Check

L(l_fold_loop); {
vmovdqu(v_src_0, ptr[r64_src]);
vpshufb(v_src_0, v_src_0, v_shuf_mask);
add(r64_src, vlen);

if (is_vpclmulqdq) {
// vpclmulqdq(v_aux_0, v_dst_0, v_k_loop, 0b00000000);
// vpclmulqdq(v_dst_0, v_dst_0, v_k_loop, 0b00010001);
// vpxor(v_aux_0, v_aux_0, v_src_0);
// vpxor(v_dst_0, v_dst_0, v_aux_0);

// 0
vmovdqu(v_src_0, ptr[r64_src]);
vpshufb(v_src_0, v_src_0, v_shuf_mask);
add(r64_src, vlen);

vpclmulqdq(v_aux_0, v_dst_0, v_k_loop, 0b00000000);
vpclmulqdq(v_dst_0, v_dst_0, v_k_loop, 0b00010001);
vpxor(v_aux_0, v_aux_0, v_src_0);
vpxor(v_dst_0, v_dst_0, v_aux_0);

// 1
vmovdqu(v_src_0, ptr[r64_src]);
vpshufb(v_src_0, v_src_0, v_shuf_mask);
add(r64_src, vlen);
// // 1
// vmovdqu(v_src_0, ptr[r64_src]);
// vpshufb(v_src_0, v_src_0, v_shuf_mask);
// add(r64_src, vlen);

vpclmulqdq(v_aux_0, v_dst_1, v_k_loop, 0b00000000);
vpclmulqdq(v_dst_1, v_dst_1, v_k_loop, 0b00010001);
vpxor(v_aux_0, v_aux_0, v_src_0);
vpxor(v_dst_1, v_dst_1, v_aux_0);
// vpclmulqdq(v_aux_0, v_dst_1, v_k_loop, 0b00000000);
// vpclmulqdq(v_dst_1, v_dst_1, v_k_loop, 0b00010001);
// vpxor(v_aux_0, v_aux_0, v_src_1);
// vpxor(v_dst_1, v_dst_1, v_aux_0);

sub(r64_work_amount, vlen * 2lu);
// sub(r64_work_amount, vlen * 2lu);
} else {
// 0
vpclmulqdq(xmm_aux_0, xmm_dst_0, xmm_k_loop, 0b00000000);
Expand All @@ -525,17 +520,16 @@ void CombineHash<avx2>::bulkFold(const Vmm& v_dst) {
vpxor(xmm_aux_0, xmm_aux_0, xmm_src_1);
vpxor(xmm_dst_1, xmm_dst_1, xmm_aux_0);

add(r64_src, vlen);
sub(r64_work_amount, vlen);
// add(r64_src, vlen);
// sub(r64_work_amount, vlen);
}

// sub(r64_work_amount, vlen);
sub(r64_work_amount, vlen);
jge(l_fold_loop, T_NEAR);
}
// add(r64_work_amount, vlen);
add(r64_work_amount, vlen * 2lu);
add(r64_work_amount, vlen);

vmovdqu(ptr[r64_tmp], v_dst_1);
vmovdqu(ptr[r64_tmp], xmm_dst_0);

if (is_vpclmulqdq) {
vextracti128(xmm_dst_1, v_dst_0, 0x1);
Expand Down Expand Up @@ -1010,8 +1004,8 @@ size_t combine_hash(const void* src, size_t size) {
const uint64_t blocks = size / block_size;
const uint64_t el_per_thread = block_size * ((blocks + nthr - 1) / nthr);

std::vector<uint64_t> tmp_vec(nthr * 4);
std::vector<uint64_t> tmp_vec_2(nthr * 4);
std::vector<uint64_t> tmp_vec(nthr * 2);
std::vector<uint64_t> tmp_vec_2(nthr * 2);

// if (!(counter == 104)) {
// if (!(counter == 88 || counter == 92 || counter == 96 || counter == 100 || counter == 104 || counter == 108)) {
Expand All @@ -1028,10 +1022,10 @@ std::vector<uint64_t> tmp_vec_2(nthr * 4);
args.dst_ptr = &intermediate[ithr * 2];
args.work_amount = work_amount;
args.make_64_fold = 0lu;
args.tmp_ptr = &(tmp_vec[ithr * 4]);
args.tmp_ptr = &(tmp_vec[ithr * 2]);
// if (counter == 8)
// printf(" [%d] start: %lu, work_amount: %lu\n", ithr, start, work_amount);
printf("size >= 200000: %lu -> parallel_nt tmp_vec {%lu; %lu; %lu; %lu}\n", size, tmp_vec[0], tmp_vec[1], tmp_vec[2], tmp_vec[3]);
printf("size >= 200000: %lu -> parallel_nt tmp_vec {%lu;%lu}\n", size, tmp_vec[0], tmp_vec[1]);
kernel(&args);
});
// } else {
Expand Down Expand Up @@ -1079,10 +1073,10 @@ printf("size >= 200000: %lu -> parallel_nt tmp_vec {%lu; %lu; %lu; %lu}\n", size
args.tmp_ptr = tmp_vec_2.data();
// if (size == 2359296)
// printf(" [single] work_amount: %lu\n", args.work_amount);
printf("size >= 200000: %lu -> fold tmp_vec {%lu; %lu; %lu; %lu}\n", size, tmp_vec_2[0], tmp_vec_2[1], tmp_vec_2[2], tmp_vec_2[3]);
printf("size >= 200000: %lu -> fold tmp_vec {%lu;%lu}\n", size, tmp_vec_2[0], tmp_vec_2[1]);
kernel(&args);
} else {
std::vector<uint64_t> tmp_vec(4);
std::vector<uint64_t> tmp_vec(2);

jit::CombineHashCallArgs args;
args.src_ptr = src;
Expand All @@ -1095,7 +1089,7 @@ if (size > 16) {
// std::cout << "combine_hash size: " << size << "; tmp_vec: {" << tmp_vec[0] << "; " << tmp_vec[1] << "}" << std::endl;
// if (size == 4) {
std::cout << "combine_hash size: " << size << "; src: {" << reinterpret_cast<const int*>(src)[0]
<< "} tmp_vec: {" << tmp_vec[0] << "; " << tmp_vec[1] << "; " << tmp_vec[2] << "; " << tmp_vec[3] << "}" << std::endl;
<< "} tmp_vec: {" << tmp_vec[0] << "; " << tmp_vec[1] << "}" << std::endl;
// }
}
}
Expand Down

0 comments on commit 937bba3

Please sign in to comment.