From 18992ce7af0afff426daa94db8ee48ddac750229 Mon Sep 17 00:00:00 2001 From: Britt Lewis Date: Fri, 17 Nov 2023 16:16:15 -0500 Subject: [PATCH] Handle odd attention head scales (falcon) * https://github.com/ggerganov/llama.cpp/commit/469c9addef75893e6be12edda852d12e840bf064 * https://github.com/ggerganov/llama.cpp/issues/3754 --- Sources/llmfarm_core_cpp/ggml/ggml-metal.m | 17 +++++++++++++---- Sources/llmfarm_core_cpp/metal/ggml-metal.metal | 10 +++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/Sources/llmfarm_core_cpp/ggml/ggml-metal.m b/Sources/llmfarm_core_cpp/ggml/ggml-metal.m index e7dbbb0..6e52f0a 100644 --- a/Sources/llmfarm_core_cpp/ggml/ggml-metal.m +++ b/Sources/llmfarm_core_cpp/ggml/ggml-metal.m @@ -62,6 +62,7 @@ GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); + GGML_METAL_DECL_KERNEL(scale_4); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(gelu); @@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); + GGML_METAL_ADD_KERNEL(scale_4); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(gelu); @@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul); GGML_METAL_DEL_KERNEL(mul_row); GGML_METAL_DEL_KERNEL(scale); + GGML_METAL_DEL_KERNEL(scale_4); GGML_METAL_DEL_KERNEL(silu); GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(gelu); @@ -924,14 +927,20 @@ void ggml_metal_graph_compute( const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + [encoder setComputePipelineState:ctx->pipeline_scale_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_scale]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { diff --git a/Sources/llmfarm_core_cpp/metal/ggml-metal.metal b/Sources/llmfarm_core_cpp/metal/ggml-metal.metal index 69fc713..f4b4605 100644 --- a/Sources/llmfarm_core_cpp/metal/ggml-metal.metal +++ b/Sources/llmfarm_core_cpp/metal/ggml-metal.metal @@ -125,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; }