diff --git a/ggml-impl.h b/ggml-impl.h index 19df66bceee4a0..7a36833f5e4d4c 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #pragma once #include "ggml.h" @@ -207,6 +208,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif // __ARM_NEON +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() extern float ggml_table_f32_f16[1 << 16]; diff --git a/ggml-quants.c b/ggml-quants.c index 6336538f0e99ec..7c3a66a0133b24 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #include "ggml-quants.h" #include "ggml-impl.h" @@ -10961,3 +10962,926 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest quantize_iq4_nl(x, y, 1, k, NULL, NULL); } +// Routines to create the blocked formats +// Note input is array of pointers. +// The exact interleaving format needed is different for GEMM (using SMMLA) +// and GEMV (using SDOT) cases. For GEMM, we interleave 8 pairs of values +// at a time (with the two nibbles separated at runtime to give 2x2x8 +// matrices). For GEMV, we need to interleave 4 pairs of values instead. +block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i]->d; + } + + for (int i = 0; i < QK4_0 * 2; i++) { + // We are interleaving 4 rows in blocks of 8, making a total of 32 + // output bytes per block (2 MMLA input vectors). This repeats + // until we have processed the whole block. + // + // Per the comment above, for GEMV cases a similar process is used + // but with blocks of 4 instead, giving a single DOT input vector. + // + // In the case of q4, we add on 128 to convert the top nibble from + // "bias offset" form to pure sign form (this saves a subtract when + // we unpack it). + int src_offset = (i / (4 * block_len)) * block_len; + int src_id = (i % (4 * block_len)) / block_len; + src_offset += (i % block_len); + + out.qs[i] = in[src_id]->qs[src_offset] + 0x80; + } + + return out; +} + +// 8-block version - see comments in code above +block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i]->d; + } + + for (int i = 0; i < QK4_0 * 4; i++) { + int src_offset = (i / (8 * block_len)) * block_len; + int src_id = (i % (8 * block_len)) / block_len; + src_offset += (i % block_len); + + out.qs[i] = in[src_id]->qs[src_offset] + 0x80; + } + + return out; +} + +block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len) { + block_q8_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i]->d; + } + + for (int i = 0; i < QK8_0 * 4; i++) { + int src_offset = (i / (4 * block_len)) * block_len; + int src_id = (i % (4 * block_len)) / block_len; + src_offset += (i % block_len); + + out.qs[i] = in[src_id]->qs[src_offset]; + } + + return out; +} + +// 8-block version - see comments in code above +block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len) { + block_q8_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i]->d; + } + + for (int i = 0; i < QK8_0 * 8; i++) { + int src_offset = (i / (8 * block_len)) * block_len; + int src_id = (i % (8 * block_len)) / block_len; + src_offset += (i % block_len); + + out.qs[i] = in[src_id]->qs[src_offset]; + } + + return out; +} + +void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x2 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv[rows_interleaved][8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + float id[rows_interleaved]; + + for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#endif +} + +void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv[rows_interleaved][8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + float id[rows_interleaved]; + + for (int row_iter = 0; row_iter < rows_interleaved; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#endif +} + +inline int64_t roundup(const int64_t a, const int64_t b) { + int64_t rem = a % b; + + if (rem) { + return a + b - rem; + } else { + return a; + } +} + +void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_NEON) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); + + int64_t nb = n / QK4_0; + int64_t a_nb = n / QK8_0; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const block_q4_0x8 * b_ptr_start = vx; + const block_q8_0 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width; y++) { + for (int64_t x = x0 / 8; x < xend / 8; x++) { + // Pointers to LHS blocks + const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb); + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb); + // Master FP accumulator + float32x4_t acc_row[2]; + acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const uint8x16_t rhs_raw_vec_0_0 = vld1q_u8(b_ptr[b].qs); + const uint8x16_t rhs_raw_vec_1_0 = vld1q_u8(b_ptr[b].qs + 16); + const uint8x16_t rhs_raw_vec_0_1 = vld1q_u8(b_ptr[b].qs + 32); + const uint8x16_t rhs_raw_vec_1_1 = vld1q_u8(b_ptr[b].qs + 48); + const uint8x16_t rhs_raw_vec_0_2 = vld1q_u8(b_ptr[b].qs + 64); + const uint8x16_t rhs_raw_vec_1_2 = vld1q_u8(b_ptr[b].qs + 80); + const uint8x16_t rhs_raw_vec_0_3 = vld1q_u8(b_ptr[b].qs + 96); + const uint8x16_t rhs_raw_vec_1_3 = vld1q_u8(b_ptr[b].qs + 112); + + const int8x16_t rhs_vec_0_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_0, m4b)), s8b); + const int8x16_t rhs_vec_0_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_1, m4b)), s8b); + const int8x16_t rhs_vec_0_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_2, m4b)), s8b); + const int8x16_t rhs_vec_0_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_0_3, m4b)), s8b); + const int8x16_t rhs_vec_1_0_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_0, m4b)), s8b); + const int8x16_t rhs_vec_1_1_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_1, m4b)), s8b); + const int8x16_t rhs_vec_1_2_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_2, m4b)), s8b); + const int8x16_t rhs_vec_1_3_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_vec_1_3, m4b)), s8b); + + const int8x16_t rhs_vec_0_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_0), 4); + const int8x16_t rhs_vec_0_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_1), 4); + const int8x16_t rhs_vec_0_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_2), 4); + const int8x16_t rhs_vec_0_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_0_3), 4); + const int8x16_t rhs_vec_1_0_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_0), 4); + const int8x16_t rhs_vec_1_1_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_1), 4); + const int8x16_t rhs_vec_1_2_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_2), 4); + const int8x16_t rhs_vec_1_3_1 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_vec_1_3), 4); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d); + const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16)); + const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16)); + + const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d)); + const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); + + const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs); + const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16); + + int32x4_t iacc0 = vdupq_n_s32(0); + int32x4_t iacc1 = vdupq_n_s32(0); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3); + + acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32)); + acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32)); + } + + vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]); + vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]); + } + } +#endif +} + +void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_SVE) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); + + int64_t nb = n / QK4_0; + int64_t a_nb = n / QK8_0; + + const svuint8_t m4b = svdup_u8(0x0F); + const svint8_t s8b = svdup_s8(0x8); + + const svbool_t ptrue = svptrue_b8(); + + const block_q4_0x8 * b_ptr_start = vx; + const block_q8_0 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width; y++) { + for (int64_t x = x0 / 8; x < xend / 8; x++) { + // Pointers to LHS blocks + const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb); + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulator + svfloat32_t acc_row = svdup_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const svuint8_t rhs_raw_vec_0_0 = svld1_u8(ptrue, b_ptr[b].qs); + const svuint8_t rhs_raw_vec_0_1 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 1); + const svuint8_t rhs_raw_vec_0_2 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 2); + const svuint8_t rhs_raw_vec_0_3 = svld1_vnum_u8(ptrue, b_ptr[b].qs, 3); + + const svint8_t rhs_vec_0_0_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_0), 4); + const svint8_t rhs_vec_0_1_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_1), 4); + const svint8_t rhs_vec_0_2_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_2), 4); + const svint8_t rhs_vec_0_3_1 = svasr_n_s8_x(ptrue, svreinterpret_s8_u8(rhs_raw_vec_0_3), 4); + + const svint8_t rhs_vec_0_0_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_0, m4b)), s8b); + const svint8_t rhs_vec_0_1_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_1, m4b)), s8b); + const svint8_t rhs_vec_0_2_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_2, m4b)), s8b); + const svint8_t rhs_vec_0_3_0 = svsub_s8_x(ptrue, svreinterpret_s8_u8(svand_u8_x(ptrue, rhs_raw_vec_0_3, m4b)), s8b); + + // Scale values + const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d)); + const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16); + + const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d); + const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16); + + const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs); + const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16); + + svint32_t iacc = svdup_s32(0); + + iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0); + iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0); + + iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1); + iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1); + + iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2); + iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2); + + iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3); + iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3); + + acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32)); + } + + svst1(ptrue, s + (y * output_channels + x * 8), acc_row); + } + } +#endif +} + +void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4); + + int64_t nb = n / QK4_0; + int64_t a_nb = n / QK8_0; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const block_q4_0x4 * b_ptr_start = vx; + const block_q8_0x4 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width / 4; y += rows / 4) { + for (int64_t x = x0 / 4; x < xend / 4; x++) { + const block_q8_0x4 * a_ptrs[rows / 4]; + + a_ptrs[0] = a_ptr_start + (y * a_nb); + for (int i = 0; i < (rows / 4) - 1; i++) { + a_ptrs[i + 1] = a_ptrs[i] + a_nb; + } + + const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulators + float32x4_t acc_rows[rows]; + for (int i = 0; i < rows; i++) { + acc_rows[i] = vdupq_n_f32(0.0f); + } + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs); + const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16); + const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32); + const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48); + + // 4-bit -> 8-bit + const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b); + const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b); + const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b); + const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b); + const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4); + const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4); + const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4); + const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d); + const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); + + // Process LHS in pairs of rows + for (int rp = 0; rp < rows / 4; rp++) { + const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs); + const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16); + const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32); + const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48); + + const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64); + const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80); + const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96); + const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112); + + // Do the MMLAs into 2x2 matrices + const int32x4_t iacc_mat_00 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3); + const int32x4_t iacc_mat_01 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3); + const int32x4_t iacc_mat_10 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3); + const int32x4_t iacc_mat_11 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3); + + // Straighten out to make 4 row vectors + const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + + const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d); + const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); + + acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0)); + acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1)); + acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2)); + acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3)); + } + } + + for (int i = 0; i < rows; i++) { + vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]); + } + } + } +#endif +} + +void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + int rows = 2; + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4); + + int64_t nb = n / QK4_0; + int64_t a_nb = n / QK8_0; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const block_q4_0x4 * b_ptr_start = vx; + const block_q8_0x2 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width / 2; y += rows / 2) { + for (int64_t x = x0 / 4; x < xend / 4; x++) { + const block_q8_0x2 * a_ptrs[rows / 2]; + + a_ptrs[0] = a_ptr_start + (y * a_nb); + + const block_q4_0x4 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulators + float32x4_t acc_rows[rows]; + acc_rows[0] = vdupq_n_f32(0.0f); + acc_rows[1] = vdupq_n_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const uint8x16_t rhs_raw_mat_01_0 = vld1q_u8(b_ptr[b].qs); + const uint8x16_t rhs_raw_mat_23_0 = vld1q_u8(b_ptr[b].qs + 16); + const uint8x16_t rhs_raw_mat_01_1 = vld1q_u8(b_ptr[b].qs + 32); + const uint8x16_t rhs_raw_mat_23_1 = vld1q_u8(b_ptr[b].qs + 48); + + const int8x16_t rhs_mat_01_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_0, m4b)), s8b); + const int8x16_t rhs_mat_23_0 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_0, m4b)), s8b); + const int8x16_t rhs_mat_01_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_01_1, m4b)), s8b); + const int8x16_t rhs_mat_23_1 = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(rhs_raw_mat_23_1, m4b)), s8b); + + const int8x16_t rhs_mat_01_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_0), 4); + const int8x16_t rhs_mat_23_2 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_0), 4); + const int8x16_t rhs_mat_01_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_01_1), 4); + const int8x16_t rhs_mat_23_3 = vshrq_n_s8(vreinterpretq_s8_u8(rhs_raw_mat_23_1), 4); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d); + const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); + + // Process LHS in pairs of rows + int rp = 0; + const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs); + const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16); + + const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32); + const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48); + + // Do the MMLAs into 2x2 matrices + const int32x4_t iacc_mat_00 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3); + const int32x4_t iacc_mat_01 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3); + + // Straighten out to make 2 row vectors + const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + + const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0])); + const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0); + const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1])); + const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1); + + acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0)); + acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1)); + } + + vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]); + vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]); + } + } +#endif +} + +void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_NEON) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); + + int64_t nb = n / QK8_0; + int64_t a_nb = n / QK8_0; + + const block_q8_0x8 * b_ptr_start = vx; + const block_q8_0 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width; y++) { + for (int64_t x = x0 / 8; x < xend / 8; x++) { + // Pointers to LHS blocks + const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb); + // Pointers to RHS blocks + const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb); + // Master FP accumulator + float32x4_t acc_row[2]; + acc_row[0] = acc_row[1] = vdupq_n_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const int8x16_t rhs_vec_0_0_0 = vld1q_s8(b_ptr[b].qs); + const int8x16_t rhs_vec_1_0_0 = vld1q_s8(b_ptr[b].qs + 16); + const int8x16_t rhs_vec_0_1_0 = vld1q_s8(b_ptr[b].qs + 32); + const int8x16_t rhs_vec_1_1_0 = vld1q_s8(b_ptr[b].qs + 48); + const int8x16_t rhs_vec_0_2_0 = vld1q_s8(b_ptr[b].qs + 64); + const int8x16_t rhs_vec_1_2_0 = vld1q_s8(b_ptr[b].qs + 80); + const int8x16_t rhs_vec_0_3_0 = vld1q_s8(b_ptr[b].qs + 96); + const int8x16_t rhs_vec_1_3_0 = vld1q_s8(b_ptr[b].qs + 112); + const int8x16_t rhs_vec_0_0_1 = vld1q_s8(b_ptr[b].qs + 128); + const int8x16_t rhs_vec_1_0_1 = vld1q_s8(b_ptr[b].qs + 144); + const int8x16_t rhs_vec_0_1_1 = vld1q_s8(b_ptr[b].qs + 160); + const int8x16_t rhs_vec_1_1_1 = vld1q_s8(b_ptr[b].qs + 176); + const int8x16_t rhs_vec_0_2_1 = vld1q_s8(b_ptr[b].qs + 192); + const int8x16_t rhs_vec_1_2_1 = vld1q_s8(b_ptr[b].qs + 208); + const int8x16_t rhs_vec_0_3_1 = vld1q_s8(b_ptr[b].qs + 224); + const int8x16_t rhs_vec_1_3_1 = vld1q_s8(b_ptr[b].qs + 240); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x8_t col_scale_f16 = vld1q_f16(b_ptr[b].d); + const float32x4_t col_scale_f32_0 = vcvt_f32_f16(vget_low_f16(col_scale_f16)); + const float32x4_t col_scale_f32_1 = vcvt_f32_f16(vget_high_f16(col_scale_f16)); + + const float16x4_t row_scale_f16 = vld1_dup_f16(&(a_ptr[b].d)); + const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); + + const int8x16_t lhs_vec_0 = vld1q_s8(a_ptr[b].qs); + const int8x16_t lhs_vec_1 = vld1q_s8(a_ptr[b].qs + 16); + + int32x4_t iacc0 = vdupq_n_s32(0); + int32x4_t iacc1 = vdupq_n_s32(0); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_0, lhs_vec_0, 0); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_0_1, lhs_vec_1, 0); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_0, lhs_vec_0, 0); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_0_1, lhs_vec_1, 0); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_0, lhs_vec_0, 1); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_1_1, lhs_vec_1, 1); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_0, lhs_vec_0, 1); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_1_1, lhs_vec_1, 1); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_0, lhs_vec_0, 2); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_2_1, lhs_vec_1, 2); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_0, lhs_vec_0, 2); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_2_1, lhs_vec_1, 2); + + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_0, lhs_vec_0, 3); + iacc0 = vdotq_laneq_s32(iacc0, rhs_vec_0_3_1, lhs_vec_1, 3); + + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_0, lhs_vec_0, 3); + iacc1 = vdotq_laneq_s32(iacc1, rhs_vec_1_3_1, lhs_vec_1, 3); + + acc_row[0] = vfmaq_f32(acc_row[0], vcvtq_f32_s32(iacc0), vmulq_f32(col_scale_f32_0, row_scale_f32)); + acc_row[1] = vfmaq_f32(acc_row[1], vcvtq_f32_s32(iacc1), vmulq_f32(col_scale_f32_1, row_scale_f32)); + } + + vst1q_f32(s + (y * output_channels + x * 8), acc_row[0]); + vst1q_f32(s + (y * output_channels + x * 8 + 4), acc_row[1]); + } + } +#endif +} + +void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_SVE) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); + + int64_t nb = n / QK8_0; + int64_t a_nb = n / QK8_0; + + const svbool_t ptrue = svptrue_b8(); + + const block_q8_0x8 * b_ptr_start = vx; + const block_q8_0 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width; y++) { + for (int64_t x = x0 / 8; x < xend / 8; x++) { + // Pointers to LHS blocks + const block_q8_0 * a_ptr = a_ptr_start + (y * a_nb); + // Pointers to RHS blocks + const block_q8_0x8 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulator + svfloat32_t acc_row = svdup_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const svint8_t rhs_vec_0_0_0 = svld1_s8(ptrue, b_ptr[b].qs); + const svint8_t rhs_vec_0_1_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 1); + const svint8_t rhs_vec_0_2_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 2); + const svint8_t rhs_vec_0_3_0 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 3); + const svint8_t rhs_vec_0_0_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 4); + const svint8_t rhs_vec_0_1_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 5); + const svint8_t rhs_vec_0_2_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 6); + const svint8_t rhs_vec_0_3_1 = svld1_vnum_s8(ptrue, b_ptr[b].qs, 7); + + // Scale values + const svfloat16_t col_scale_f16 = svreinterpret_f16_u32(svld1uh_u32(ptrue, (const uint16_t *) b_ptr[b].d)); + const svfloat32_t col_scale_f32 = svcvt_f32_f16_x(ptrue, col_scale_f16); + + const svfloat16_t row_scale_f16 = svdup_f16(a_ptr[b].d); + const svfloat32_t row_scale_f32 = svcvt_f32_f16_x(ptrue, row_scale_f16); + + const svint8_t lhs_vec_0 = svld1rq_s8(ptrue, a_ptr[b].qs); + const svint8_t lhs_vec_1 = svld1rq_s8(ptrue, a_ptr[b].qs + 16); + + svint32_t iacc = svdup_s32(0); + + iacc = svdot_lane(iacc, rhs_vec_0_0_0, lhs_vec_0, 0); + iacc = svdot_lane(iacc, rhs_vec_0_0_1, lhs_vec_1, 0); + + iacc = svdot_lane(iacc, rhs_vec_0_1_0, lhs_vec_0, 1); + iacc = svdot_lane(iacc, rhs_vec_0_1_1, lhs_vec_1, 1); + + iacc = svdot_lane(iacc, rhs_vec_0_2_0, lhs_vec_0, 2); + iacc = svdot_lane(iacc, rhs_vec_0_2_1, lhs_vec_1, 2); + + iacc = svdot_lane(iacc, rhs_vec_0_3_0, lhs_vec_0, 3); + iacc = svdot_lane(iacc, rhs_vec_0_3_1, lhs_vec_1, 3); + + acc_row = svmla_x(ptrue, acc_row, svcvt_f32_s32_x(ptrue, iacc), svmul_x(ptrue, col_scale_f32, row_scale_f32)); + } + + svst1(ptrue, s + (y * output_channels + x * 8), acc_row); + } + } +#endif +} + +void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4); + + int64_t nb = n / QK8_0; + int64_t a_nb = n / QK8_0; + + const block_q8_0x4 * b_ptr_start = vx; + const block_q8_0x4 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width / 4; y += rows / 4) { + for (int64_t x = x0 / 4; x < xend / 4; x++) { + const block_q8_0x4 * a_ptrs[rows / 4]; + + a_ptrs[0] = a_ptr_start + (y * a_nb); + for (int i = 0; i < (rows / 4) - 1; i++) { + a_ptrs[i + 1] = a_ptrs[i] + a_nb; + } + + const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulators + float32x4_t acc_rows[rows]; + for (int i = 0; i < rows; i++) { + acc_rows[i] = vdupq_n_f32(0.0f); + } + + for (int64_t b = 0; b < nb; b++) { + // Set up RHS - we need rhs_mat_* and col_scale_f32 (9 registers) + const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs); + const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16); + const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32); + const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48); + const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64); + const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80); + const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96); + const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d); + const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); + + // Process LHS in pairs of rows + for (int rp = 0; rp < rows / 4; rp++) { + const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs); + const int8x16_t lhs_mat_23_0 = vld1q_s8(a_ptrs[rp][b].qs + 16); + const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 32); + const int8x16_t lhs_mat_23_1 = vld1q_s8(a_ptrs[rp][b].qs + 48); + + const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 64); + const int8x16_t lhs_mat_23_2 = vld1q_s8(a_ptrs[rp][b].qs + 80); + const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 96); + const int8x16_t lhs_mat_23_3 = vld1q_s8(a_ptrs[rp][b].qs + 112); + + // Do the MMLAs into 2x2 matrices + const int32x4_t iacc_mat_00 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3); + const int32x4_t iacc_mat_01 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3); + const int32x4_t iacc_mat_10 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_01_0), lhs_mat_23_1, rhs_mat_01_1), lhs_mat_23_2, rhs_mat_01_2), lhs_mat_23_3, rhs_mat_01_3); + const int32x4_t iacc_mat_11 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_23_0, rhs_mat_23_0), lhs_mat_23_1, rhs_mat_23_1), lhs_mat_23_2, rhs_mat_23_2), lhs_mat_23_3, rhs_mat_23_3); + + // Straighten out to make 4 row vectors + const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_2 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + const int32x4_t iacc_row_3 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_10), vreinterpretq_u64_s32(iacc_mat_11))); + + const float16x4_t row_scale_f16 = vld1_f16(a_ptrs[rp][b].d); + const float32x4_t row_scale_f32 = vcvt_f32_f16(row_scale_f16); + + acc_rows[rp * 4] = vfmaq_f32(acc_rows[rp * 4], vcvtq_f32_s32(iacc_row_0), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 0)); + acc_rows[rp * 4 + 1] = vfmaq_f32(acc_rows[rp * 4 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 1)); + acc_rows[rp * 4 + 2] = vfmaq_f32(acc_rows[rp * 4 + 2], vcvtq_f32_s32(iacc_row_2), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 2)); + acc_rows[rp * 4 + 3] = vfmaq_f32(acc_rows[rp * 4 + 3], vcvtq_f32_s32(iacc_row_3), vmulq_laneq_f32(col_scale_f32, row_scale_f32, 3)); + } + } + + for (int i = 0; i < rows; i++) { + vst1q_f32(s + ((y * 4 + i) * output_channels + x * 4), acc_rows[i]); + } + } + } +#endif +} + +void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { +#if defined(__ARM_FEATURE_MATMUL_INT8) + int rows = 2; + int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)4); + int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)4); + + int64_t nb = n / QK8_0; + int64_t a_nb = n / QK8_0; + + const block_q8_0x4 * b_ptr_start = vx; + const block_q8_0x2 * a_ptr_start = vy; + + for (int64_t y = 0; y < input_width / 2; y += rows / 2) { + for (int64_t x = x0 / 4; x < xend / 4; x++) { + const block_q8_0x2 * a_ptrs[rows / 2]; + + a_ptrs[0] = a_ptr_start + (y * a_nb); + + const block_q8_0x4 * b_ptr = b_ptr_start + (x * nb); + + // Master FP accumulators + float32x4_t acc_rows[rows]; + acc_rows[0] = vdupq_n_f32(0.0f); + acc_rows[1] = vdupq_n_f32(0.0f); + + for (int64_t b = 0; b < nb; b++) { + const int8x16_t rhs_mat_01_0 = vld1q_s8(b_ptr[b].qs); + const int8x16_t rhs_mat_23_0 = vld1q_s8(b_ptr[b].qs + 16); + const int8x16_t rhs_mat_01_1 = vld1q_s8(b_ptr[b].qs + 32); + const int8x16_t rhs_mat_23_1 = vld1q_s8(b_ptr[b].qs + 48); + const int8x16_t rhs_mat_01_2 = vld1q_s8(b_ptr[b].qs + 64); + const int8x16_t rhs_mat_23_2 = vld1q_s8(b_ptr[b].qs + 80); + const int8x16_t rhs_mat_01_3 = vld1q_s8(b_ptr[b].qs + 96); + const int8x16_t rhs_mat_23_3 = vld1q_s8(b_ptr[b].qs + 112); + + // Scale values - assemble the four row/column scales into a (64-bit) vector, then expand to FP32 + const float16x4_t col_scale_f16 = vld1_f16(b_ptr[b].d); + const float32x4_t col_scale_f32 = vcvt_f32_f16(col_scale_f16); + + // Process LHS in pairs of rows + int rp = 0; + const int8x16_t lhs_mat_01_0 = vld1q_s8(a_ptrs[rp][b].qs); + const int8x16_t lhs_mat_01_1 = vld1q_s8(a_ptrs[rp][b].qs + 16); + + const int8x16_t lhs_mat_01_2 = vld1q_s8(a_ptrs[rp][b].qs + 32); + const int8x16_t lhs_mat_01_3 = vld1q_s8(a_ptrs[rp][b].qs + 48); + + // Do the MMLAs into 2x2 matrices + const int32x4_t iacc_mat_00 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_01_0), lhs_mat_01_1, rhs_mat_01_1), lhs_mat_01_2, rhs_mat_01_2), lhs_mat_01_3, rhs_mat_01_3); + const int32x4_t iacc_mat_01 = + vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vmmlaq_s32(vdupq_n_s32(0), lhs_mat_01_0, rhs_mat_23_0), lhs_mat_01_1, rhs_mat_23_1), lhs_mat_01_2, rhs_mat_23_2), lhs_mat_01_3, rhs_mat_23_3); + + // Straighten out to make 2 row vectors + const int32x4_t iacc_row_0 = vreinterpretq_s32_u64(vtrn1q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + const int32x4_t iacc_row_1 = vreinterpretq_s32_u64(vtrn2q_u64(vreinterpretq_u64_s32(iacc_mat_00), vreinterpretq_u64_s32(iacc_mat_01))); + + const float16x4_t row_scale_f16_0 = vld1_dup_f16(&(a_ptrs[rp][b].d[0])); + const float32x4_t row_scale_f32_0 = vcvt_f32_f16(row_scale_f16_0); + const float16x4_t row_scale_f16_1 = vld1_dup_f16(&(a_ptrs[rp][b].d[1])); + const float32x4_t row_scale_f32_1 = vcvt_f32_f16(row_scale_f16_1); + + acc_rows[rp * 2] = vfmaq_f32(acc_rows[rp * 2], vcvtq_f32_s32(iacc_row_0), vmulq_f32(col_scale_f32, row_scale_f32_0)); + acc_rows[rp * 2 + 1] = vfmaq_f32(acc_rows[rp * 2 + 1], vcvtq_f32_s32(iacc_row_1), vmulq_f32(col_scale_f32, row_scale_f32_1)); + } + vst1q_f32(s + ((y * 2) * output_channels + x * 4), acc_rows[0]); + vst1q_f32(s + ((y * 2 + 1) * output_channels + x * 4), acc_rows[1]); + } + } +#endif +} diff --git a/ggml-quants.h b/ggml-quants.h index 113623b62938a4..056a557bd9bb45 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #pragma once #include "ggml-impl.h" @@ -54,6 +55,48 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); +typedef struct { + ggml_fp16_t d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_fp16_t) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + ggml_fp16_t d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_fp16_t) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + ggml_fp16_t d[16]; // deltas for 16 q4_0 blocks + uint8_t qs[QK4_0 * 8]; // nibbles / quants for 16 q4_0 blocks +} block_q4_0x16; +static_assert(sizeof(block_q4_0x16) == 16 * sizeof(ggml_fp16_t) + QK4_0 * 8, "wrong q4_0x16 block size/padding"); + +typedef struct { + ggml_fp16_t d[64]; // deltas for 64 q4_0 blocks + uint8_t qs[QK4_0 * 32];// nibbles / quants for 64 q4_0 blocks +} block_q4_0x64; +static_assert(sizeof(block_q4_0x64) == 64 * sizeof(ggml_fp16_t) + QK4_0 * 32, "wrong q4_0x64 block size/padding"); + +typedef struct { + ggml_fp16_t d[2]; // deltas for 2 q8_0 blocks + int8_t qs[QK8_0 * 2]; // quants for 2 q8_0 blocks +} block_q8_0x2; +static_assert(sizeof(block_q8_0x2) == 2 * sizeof(ggml_fp16_t) + QK8_0 * 2, "wrong q8_0x2 block size/padding"); + +typedef struct { + ggml_fp16_t d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_fp16_t) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + ggml_fp16_t d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_fp16_t) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + // // Super-block quantization structures // @@ -304,6 +347,25 @@ void iq2xs_free_impl(enum ggml_type type); void iq3xs_init_impl(int grid_size); void iq3xs_free_impl(int grid_size); +block_q4_0x4 make_block_q4_0x4(const block_q4_0 * const in[4], unsigned int block_len); +block_q4_0x8 make_block_q4_0x8(const block_q4_0 * const in[8], unsigned int block_len); +block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int block_len); +block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len); +void quantize_row_q8_0_and_make_block_q8_0x2(const float * restrict x, void * restrict vy, int k, int rows_interleaved); +void quantize_row_q8_0_and_make_block_q8_0x4(const float * restrict x, void * restrict vy, int k, int rows_interleaved); + +// GEMV +void ggml_gemv_q4_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemv_q8_0_q8_0_blocked8_neon(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemv_q8_0_q8_0_blocked8_sve(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); + +// GEMM +void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemm_q4_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemm_q8_0_q8_0(const int n, int rows, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); +void ggml_gemm_q8_0_q8_0_2x4blocked_mmla(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index 5b9fa741a64799..83121e2403827a 100644 --- a/ggml.c +++ b/ggml.c @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC @@ -417,6 +418,192 @@ int64_t ggml_cycles_per_ms(void) { #define ggml_perf_cycles_per_ms() 0 #endif +void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur) { + block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q4_0x8 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK4_0; + + for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) { + const block_q4_0 * in_ptrs[8]; + + in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb); + for (int i = 0; i < 7; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT + out_ptr_B++; + + for (int i = 0; i < 8; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start; +} + +void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur) { +#if defined(__ARM_FEATURE_SVE) + if (svcntw() != 8) { + printf("ggml_gemv_q4_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n"); + exit(1); + } + + block_q4_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q4_0x8 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK4_0; + + for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) { + const block_q4_0 * in_ptrs[8]; + + in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 8 * nb); + for (int i = 0; i < 7; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = make_block_q4_0x8(in_ptrs, 4); // block_len=4 for SDOT + out_ptr_B++; + + for (int i = 0; i < 8; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start; +#endif +} + +#if defined(__ARM_FEATURE_SVE) +static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_sve; +#elif defined(__ARM_NEON) +static void (*_rearrange_q4_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q4_0_weights_blocked8_neon; +#endif + +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) +void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q4_0_weights_for_gemv(cur); } +#endif + +void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur) { + block_q4_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q4_0x4 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK4_0; + + for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) { + const block_q4_0 * in_ptrs[4]; + + in_ptrs[0] = (block_q4_0 *) cur->data + (y_out * 4 * nb); + for (int i = 0; i < 3; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = + make_block_q4_0x4(in_ptrs, 8); // block_len=8 for SMMLA + out_ptr_B++; + + for (int i = 0; i < 4; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start; +} + +void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur) { + block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q8_0x8 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK8_0; + + for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) { + const block_q8_0 * in_ptrs[8]; + + in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb); + for (int i = 0; i < 7; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT + out_ptr_B++; + + for (int i = 0; i < 8; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start; +} + +void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur) { +#if defined(__ARM_FEATURE_SVE) + if (svcntw() != 8) { + printf("ggml_gemv_q8_0_q8_0_blocked8_sve: SVE VL != 256 - aborting. Use Arm Neon GEMV kernels\n"); + exit(1); + } + + block_q8_0x8 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q8_0x8 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK8_0; + + for (int y_out = 0; y_out < cur->ne[1] / 8; y_out++) { + const block_q8_0 * in_ptrs[8]; + + in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 8 * nb); + for (int i = 0; i < 7; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = make_block_q8_0x8(in_ptrs, 4); // block_len=4 for SDOT + out_ptr_B++; + + for (int i = 0; i < 8; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemv = (uint8_t *) out_ptr_B_start; +#endif +} + +#if defined(__ARM_FEATURE_SVE) +static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_sve; +#elif defined(__ARM_NEON) +static void (*_rearrange_q8_0_weights_for_gemv)(struct ggml_tensor *) = &rearrange_q8_0_weights_blocked8_neon; +#endif + +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) +void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur) { _rearrange_q8_0_weights_for_gemv(cur); } +#endif + +void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur) { + block_q8_0x4 * out_ptr_B = malloc(ggml_nbytes(cur)); // B_blocked->data; + block_q8_0x4 * out_ptr_B_start = out_ptr_B; + int64_t nb = cur->ne[0] / QK8_0; + + for (int y_out = 0; y_out < cur->ne[1] / 4; y_out++) { + const block_q8_0 * in_ptrs[4]; + + in_ptrs[0] = (block_q8_0 *) cur->data + (y_out * 4 * nb); + for (int i = 0; i < 3; i++) { + in_ptrs[i + 1] = in_ptrs[i] + nb; + } + + for (int64_t x = 0; x < nb; x++) { + *out_ptr_B = + make_block_q8_0x4(in_ptrs, 8); // block_len=8 for SMMLA + out_ptr_B++; + + for (int i = 0; i < 4; i++) { + in_ptrs[i]++; + } + } + } + cur->rearranged_weight_gemm = (uint8_t *) out_ptr_B_start; +} + // // cache line // @@ -1708,6 +1895,10 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } +static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); + +static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth); + // // data types // @@ -2734,6 +2925,9 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.name =*/ { 0 }, /*.extra =*/ NULL, /*.padding =*/ { 0 }, + /*.rearranged_weight_gemv =*/ NULL, + /*.rearranged_weight_gemm =*/ NULL, + /*.weight_rearranged =*/ false, }; // TODO: this should not be needed as long as we don't rely on aligned SIMD loads @@ -10397,14 +10591,32 @@ static void ggml_compute_forward_mul_mat( assert(params->wsize >= ne11*ne12*ne13*row_size); GGML_ASSERT(src1->type == GGML_TYPE_F32); - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if ((src0->weight_rearranged == true) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) { + for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) { + quantize_row_q8_0_and_make_block_q8_0x4((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4); + wdata += row_size * 4; + } + for (int64_t i11 = (ne11 / 4) * 4; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i11 * nb11), (void *) wdata, ne10); + wdata += row_size; + } + } +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + else { +#endif + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } } } +#if defined(__ARM_FEATURE_MATMUL_INT8) } +#endif } return; @@ -10468,47 +10680,158 @@ static void ggml_compute_forward_mul_mat( // 16 * 2, accounting for mmla kernels float tmp[32]; - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { - const int64_t i13 = (ir1/(ne12*ne1)); - const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; - const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); +#if defined(__ARM_FEATURE_MATMUL_INT8) && (defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE)) + if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (src0->weight_rearranged == true)) { + if (src0->type == GGML_TYPE_Q4_0) { + ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels + } else if (src0->type == GGML_TYPE_Q8_0) { + ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data), (const char *) src0->rearranged_weight_gemv, (const char *) wdata, ith, nth); // use Arm Neon/SVE GEMV kernels + } + } + else if ((ggml_n_dims(src0) == 2) && (ne11 >= 16) && (src0->weight_rearranged == true)) { + // use batch-sized 16, 8, and 4 GEMM kernels + if (src0->type == GGML_TYPE_Q4_0) { + for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth); + } + int rows_processed = (ne11 / 16) * 16; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth); + } + rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth); + } + rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4; + for (int row_iter = rows_processed; row_iter < ne11; row_iter++) { + ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } else if (src0->type == GGML_TYPE_Q8_0) { + for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 16, ne01, 16, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), ith, nth); + } + int rows_processed = (ne11 / 16) * 16; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), ith, nth); + } + rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth); + } + rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4; + for (int row_iter = rows_processed; row_iter < ne11; row_iter++) { + ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } + } else if ((ggml_n_dims(src0) == 2) && (ne11 >= 8) && (src0->weight_rearranged == true)) { + // use batch-sized 8, and 4 GEMM kernels + if (src0->type == GGML_TYPE_Q4_0) { + for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth); + } + int rows_processed = (ne11 / 8) * 8; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth); + } + for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) { + ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } else if (src0->type == GGML_TYPE_Q8_0) { + for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 8, ne01, 8, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), ith, nth); + } + int rows_processed = (ne11 / 8) * 8; + for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->rearranged_weight_gemm, + (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), ith, nth); + } + for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) { + ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } + } else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (src0->weight_rearranged == true)) { + // use batch-sized 4 GEMM kernel + if (src0->type == GGML_TYPE_Q4_0) { + for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) { + ggml_gemm_q4_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth); + } + for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) { + ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } else if (src0->type == GGML_TYPE_Q8_0) { + for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) { + ggml_gemm_q8_0_q8_0(ne00, 4, ne01, 4, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->rearranged_weight_gemm, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), ith, nth); + } + for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) { + ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } + } +#elif defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) + if ((ggml_n_dims(src0) == 2) && (src0->weight_rearranged == true)) { + if (src0->type == GGML_TYPE_Q4_0) { + for (int row_iter = 0; row_iter < ne11; row_iter++) { + ggml_gemv_q4_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } else if (src0->type == GGML_TYPE_Q8_0) { + for (int row_iter = 0; row_iter < ne11; row_iter++) { + ggml_gemv_q8_0_q8_0(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->rearranged_weight_gemv, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth); + } + } + } +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) + else { +#endif + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) { + const int64_t i13 = (ir1/(ne12*ne1)); + const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; + const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); - // broadcast src0 into src1 - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); + const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { - vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); - } + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) { + vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc); + } - for (int cn = 0; cn < nrc; ++cn) { - memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + for (int cn = 0; cn < nrc; ++cn) { + memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } } } } +#if defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) } +#endif } // ggml_compute_forward_mul_mat_id @@ -21129,4 +21452,26 @@ int ggml_cpu_has_matmul_int8(void) { #endif } +#if defined(__ARM_FEATURE_SVE) +static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_sve; +#elif defined(__ARM_NEON) +static void (*_ggml_gemv_q4_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q4_0_q8_0_blocked8_neon; +#endif + +#if defined(__ARM_FEATURE_SVE) +static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_sve; +#elif defined(__ARM_NEON) +static void (*_ggml_gemv_q8_0_q8_0)(const int, int, int, float * restrict, const void * restrict, const void * restrict, int, int) = &ggml_gemv_q8_0_q8_0_blocked8_neon; +#endif + +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) +static void ggml_gemv_q4_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { + _ggml_gemv_q4_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth); +} + +static void ggml_gemv_q8_0_q8_0(const int n, int output_channels, int input_width, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { + _ggml_gemv_q8_0_q8_0(n, output_channels, input_width, s, vx, vy, ith, nth); +} +#endif + //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml.h b/ggml.h index bed7a36a0ee6a3..ea131067550a32 100644 --- a/ggml.h +++ b/ggml.h @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #pragma once // @@ -572,7 +573,11 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + char padding[9]; + + void * rearranged_weight_gemv; + void * rearranged_weight_gemm; + bool weight_rearranged; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -2341,6 +2346,15 @@ extern "C" { GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + GGML_API void rearrange_q4_0_weights_blocked8_neon(struct ggml_tensor * cur); + GGML_API void rearrange_q4_0_weights_blocked8_sve(struct ggml_tensor * cur); + GGML_API void rearrange_q4_0_weights_for_gemv(struct ggml_tensor * cur); + GGML_API void rearrange_q4_0_weights_for_gemm(struct ggml_tensor * cur); + GGML_API void rearrange_q8_0_weights_blocked8_neon(struct ggml_tensor * cur); + GGML_API void rearrange_q8_0_weights_blocked8_sve(struct ggml_tensor * cur); + GGML_API void rearrange_q8_0_weights_for_gemv(struct ggml_tensor * cur); + GGML_API void rearrange_q8_0_weights_for_gemm(struct ggml_tensor * cur); + #ifdef __cplusplus } #endif diff --git a/llama.cpp b/llama.cpp index 259f2a3a3ea00d..9583544e7810eb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,3 +1,4 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. #define LLAMA_API_INTERNAL #include "llama.h" @@ -2823,6 +2824,32 @@ struct llama_model_loader { } } +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) + if ((cur->type == GGML_TYPE_Q4_0) && (cur->ne[1] % 4 == 0)) { + cur->weight_rearranged = true; +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) + rearrange_q4_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + rearrange_q4_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels +#endif + } + else if ((cur->type == GGML_TYPE_Q8_0) && (cur->ne[1] % 4 == 0)) { + cur->weight_rearranged = true; +#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) + rearrange_q8_0_weights_for_gemv(cur); // rearrange weights for Arm Neon/SVE GEMV kernels +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + rearrange_q8_0_weights_for_gemm(cur); // rearrange weights for GEMM MMLA kernels +#endif + } + else { + cur->weight_rearranged = false; + } +#else + cur->weight_rearranged = false; +#endif + size_done += ggml_nbytes(cur); }