Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] iq2_s #6052

Closed
wants to merge 29 commits into from
Closed
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
08d3b40
iq2_s
abhilash1910 Mar 14, 2024
9b030b9
iq2_s
abhilash1910 Mar 14, 2024
81b6139
bug fix
abhilash1910 Mar 14, 2024
0af3ed7
bug fix
abhilash1910 Mar 14, 2024
87e5c86
allow iq quant
abhilash1910 Mar 15, 2024
f3a3ea1
Merge branch 'ggerganov:master' into iq2_s
abhilash1910 Mar 17, 2024
1641c52
Merge branch 'ggerganov:master' into iq2_s
abhilash1910 Mar 18, 2024
15617b8
format
abhilash1910 Mar 18, 2024
32589a6
supress assert
abhilash1910 Mar 18, 2024
a553def
refactor logic
abhilash1910 Mar 19, 2024
9fa92aa
fix build
abhilash1910 Mar 19, 2024
7f70fbe
Merge pull request #6 from ggerganov/iq2_s
abhilash1910 Mar 19, 2024
4b7aaae
Merge branch 'ggerganov:master' into iq2_s
abhilash1910 Mar 19, 2024
7466e4e
add quants
abhilash1910 Mar 19, 2024
f5fed74
add quant types from cuda
abhilash1910 Mar 19, 2024
36c7f02
Merge branch 'ggerganov:master' into iq2_s
abhilash1910 Mar 25, 2024
7ea2e15
fix format
abhilash1910 Mar 25, 2024
551f5a0
fix format
abhilash1910 Mar 25, 2024
ada101e
explicit add conditions fp32
abhilash1910 Mar 26, 2024
d4b182c
refine condition
abhilash1910 Mar 26, 2024
e9377ba
add conditions
abhilash1910 Mar 26, 2024
69aaa3d
revert logic
abhilash1910 Mar 27, 2024
19772fa
add condition for iq2s
abhilash1910 Mar 27, 2024
871a135
iq2s and other quant logic add
abhilash1910 Mar 27, 2024
ff4ace5
disable to check perf
abhilash1910 Mar 27, 2024
8c07b8f
Merge branch 'ggerganov:master' into iq2_s
abhilash1910 Mar 27, 2024
4e6df37
enable with rebase
abhilash1910 Mar 27, 2024
619ce80
Update ggml-sycl.cpp
abhilash1910 Mar 28, 2024
935eabd
add condition
abhilash1910 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 197 additions & 3 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4839,6 +4839,36 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr

}

template<typename dst_t>
static void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint64_t *iq2s_grid,
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2);
const block_iq2_s * x = (const block_iq2_s *) vx;

const int tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq2s_grid + qs[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq2s_grid + qs[2*il+1]);
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif

}


/*
DPCT1110:4: The total declared local variable size in device function
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
Expand Down Expand Up @@ -7755,6 +7785,57 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
#endif
}

static __dpct_inline__ float
vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint64_t *iq2s_grid, const uint64_t *ksigns64) {
#if QK_K == 256
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;

const int ib32 = iqs;
const int8_t * q8 = bq8_1[ib32].qs;
const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
const uint8_t ls2 = bq2->scales[ib32] >> 4;
int sumi1 = 0;
for (int l = 0; l < 2; ++l) {
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
grid[0] ^ signs0, signs0, std::minus<>());
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs1, signs1, std::minus<>());
sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
q8 += 8;
}
int sumi2 = 0;
for (int l = 2; l < 4; ++l) {
const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
grid[0] ^ signs0, signs0, std::minus<>());
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
grid[1] ^ signs1, signs1, std::minus<>());
sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
q8 += 8;
}
const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
#else
(void) ksigns64;
assert(false);
return 0.f;
#endif
}

template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
vec_dot_q_mul_mat_sycl_t vec_dot>
Expand Down Expand Up @@ -8611,6 +8692,53 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
}
}


template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq2_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1,
const uint64_t *iq2s_grid_ptr, const uint64_t *ksigns64_ptr ) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);

if (row >= nrows) {
return;
}

const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;

// partial sum for each thread
float tmp = 0.0f;

const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;

for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index

const int iby = i * (qk/QK8_1); // y block index that aligns with ibx

const int iqs =
vdr *
(item_ct1.get_local_id(2) %
(qi / vdr)); // x block quant index when casting the quants to int

tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs, iq2s_grid_ptr, ksigns64_ptr);
}

// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}

if (item_ct1.get_local_id(2) == 0) {
dst[row] = tmp;
}
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
Expand Down Expand Up @@ -10354,6 +10482,36 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
}
}

template <typename dst_t>
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
{
iq2s_grid.init(*stream);
ksigns_iq2xs.init(*stream);
kmask_iq2xs.init(*stream);

dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

stream->submit([&](sycl::handler &cgh) {
auto iq2s_grid_ptr_ct1 = iq2s_grid.get_ptr();
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();

cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(
vx, y, item_ct1, iq2s_grid_ptr_ct1,
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
});
});
}
}


template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k,
Expand Down Expand Up @@ -10408,6 +10566,8 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
default:
Expand Down Expand Up @@ -10452,6 +10612,8 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
return dequantize_row_iq3_s_sycl;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
default:
Expand Down Expand Up @@ -11097,6 +11259,35 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
}
}

static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
iq2s_grid.init(*stream);
ksigns64.init(*stream);

stream->submit([&](sycl::handler &cgh) {
auto iq2s_grid_ptr_ct1 = iq2s_grid.get_ptr();
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();

cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(32)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1,
iq2s_grid_ptr_ct1, ksigns64_ptr_ct1);
});
});
}
}


static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols_x,
const int nrows_x, const int ncols_y,
Expand Down Expand Up @@ -13870,6 +14061,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
case GGML_TYPE_IQ3_S:
Expand Down Expand Up @@ -13940,6 +14132,9 @@ inline void ggml_sycl_op_mul_mat_vec_q(
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
Expand Down Expand Up @@ -15435,7 +15630,7 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
#ifdef GGML_SYCL_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) & ggml_nrows(src1) == 1;

This is a breaking change for both Intel and Nvidia targets. ggml_sycl_op_mul_mat_vec_q asserts that GGML_ASSERT(ggml_nrows(src1) == 1);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes actually the assertion should not be mandatory and ggml_nrows=1 wont allow the iq stages to run.

Copy link
Contributor

@AidanBeltonS AidanBeltonS Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iv tested with the suppression of the assert. It fails on both Intel and Nvidia targets. In both cases it is due to the function lacking the case for your new quantization type

Intel (Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100 1.3 [1.3.28454]):
There is a failure in ggml_sycl_op_dequantize_mul_mat_vec of
MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1]): GGML_ASSERT: /home/aidanbelton/source/temp/llama.cpp/ggml-sycl.cpp:14217: false

Nvidia (NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.2]):
There is a failure in ggml_sycl_op_mul_mat_q of MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1]): GGML_ASSERT: /home/aidanbelton/source/temp/llama.cpp/ggml-sycl.cpp:14019: false. Issue is missing case for quantization type

Copy link
Collaborator Author

@abhilash1910 abhilash1910 Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually for both the cases, it should go to neither of the methods. It is strange that it is falling back to either dequantize_mul_mat_vec or op_mul_mat_q path. For this type only vectorised mul_mat_q-mmvq should be called.

#endif // GGML_SYCL_FORCE_DMMV

if (use_mul_mat_vec_q) {
Expand Down Expand Up @@ -17287,8 +17482,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S ||
a_type == GGML_TYPE_IQ4_XS) {
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS) {
abhilash1910 marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
return true;
Expand Down
Loading