Skip to content
This repository has been archived by the owner on Feb 6, 2024. It is now read-only.

Commit

Permalink
Handle odd attention head scales (falcon)
Browse files Browse the repository at this point in the history
  • Loading branch information
brittlewis12 committed Nov 17, 2023
1 parent c6bdce2 commit 18992ce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
17 changes: 13 additions & 4 deletions Sources/llmfarm_core_cpp/ggml/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale);
GGML_METAL_DECL_KERNEL(scale_4);
GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
Expand Down Expand Up @@ -249,6 +250,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale);
GGML_METAL_ADD_KERNEL(scale_4);
GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
Expand Down Expand Up @@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul);
GGML_METAL_DEL_KERNEL(mul_row);
GGML_METAL_DEL_KERNEL(scale);
GGML_METAL_DEL_KERNEL(scale_4);
GGML_METAL_DEL_KERNEL(silu);
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
Expand Down Expand Up @@ -924,14 +927,20 @@ void ggml_metal_graph_compute(
const float scale = *(const float *) src1->data;

[encoder setComputePipelineState:ctx->pipeline_scale];
int64_t n = ggml_nelements(dst);

if (n % 4 == 0) {
n /= 4;
[encoder setComputePipelineState:ctx->pipeline_scale_4];
} else {
[encoder setComputePipelineState:ctx->pipeline_scale];
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);

[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
Expand Down
10 changes: 9 additions & 1 deletion Sources/llmfarm_core_cpp/metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,17 @@ kernel void kernel_mul_row(
}

kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}

kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
constant float & scale,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
Expand Down

0 comments on commit 18992ce

Please sign in to comment.