Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQ1_S: attempt to fix SYCL #6014

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4701,9 +4701,7 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
template<typename dst_t>
static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid,
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const uint32_t *iq1s_grid) {
const int i = item_ct1.get_group(2);
const block_iq1_s * x = (const block_iq1_s *) vx;

Expand All @@ -4712,14 +4710,14 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
#else
assert(false);
Expand Down Expand Up @@ -7616,27 +7614,23 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
static __dpct_inline__ float
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
const uint32_t *iq1s_grid) {
#if QK_K == 256
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;

const int ib32 = iqs;
const uint8_t * qs = bq1->qs + 4*ib32;
const int8_t * q8 = bq8_1[ib32].qs;
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int * q8 = (const int *)bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
grid[0] ^ signs[0], signs[0], std::minus<>());
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs[1], signs[1], std::minus<>());
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
q8 += 8;
}
const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
return d * sumi;
const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
int grid0 = grid[0] & 0x0f0f0f0f;
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
sumi = dpct::dp4a(q8[2*l+1], grid1, dpct::dp4a(q8[2*l+0], grid0, sumi));
}
const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
const float d = d1q * (float)bq8_1[ib32].ds[0];
const float m = d1q * (float)bq8_1[ib32].ds[1];
return d * sumi + m * delta;
#else
assert(false);
return 0.f;
Expand Down Expand Up @@ -8456,7 +8450,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
const uint32_t *iq1s_grid_ptr) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);

Expand Down Expand Up @@ -8484,7 +8478,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int

tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr);
}

// sum up partial sums and write back result
Expand Down Expand Up @@ -10227,16 +10221,12 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,

stream->submit([&](sycl::handler &cgh) {
auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr();
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();

cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s(
vx, y, item_ct1, iq1s_grid_ptr_ct1,
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_ptr_ct1);
});
});
}
Expand Down Expand Up @@ -10967,19 +10957,17 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
iq1s_grid_gpu.init(*stream);
ksigns64.init(*stream);

stream->submit([&](sycl::handler &cgh) {
auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr();
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();

cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(32)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1,
iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
iq1s_grid_ptr_ct1);
});
});
}
Expand Down
Loading