Skip to content

Commit

Permalink
[SYCL] Fix DMMV dequantization (ggerganov#9279)
Browse files Browse the repository at this point in the history
Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X
  • Loading branch information
OuadiElfarouki authored and arthw committed Nov 15, 2024
1 parent d21f426 commit c6e616b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ggml/src/ggml-sycl/dmmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
}

// sum up partial sums and write back result
#pragma unroll
for (int mask = warp_size / 2; mask > 0; mask >>= 1) {
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
for (int mask = mask_start; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}
Expand Down

0 comments on commit c6e616b

Please sign in to comment.