Skip to content

Commit

Permalink
vulkan : handle ggml_scale for n%8 != 0
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Nov 23, 2023
1 parent 2a41ba7 commit 6474fc8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ if (LLAMA_KOMPUTE)
# Compile our shaders
compile_shader(SOURCES
kompute/op_scale.comp
kompute/op_scale_8.comp
kompute/op_add.comp
kompute/op_addrow.comp
kompute/op_mul.comp
Expand Down Expand Up @@ -508,6 +509,7 @@ if (LLAMA_KOMPUTE)
# Create a custom target for our generated shaders
add_custom_target(generated_shaders DEPENDS
shaderop_scale.h
shaderop_scale_8.h
shaderop_add.h
shaderop_addrow.h
shaderop_mul.h
Expand Down
29 changes: 20 additions & 9 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

// These are generated at build time by cmake custom command
#include "shaderop_scale.h"
#include "shaderop_scale_8.h"
#include "shaderop_add.h"
#include "shaderop_addrow.h"
#include "shaderop_mul.h"
Expand Down Expand Up @@ -724,8 +725,12 @@ void ggml_vk_scale(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inOff, uint32_t outOff,
uint32_t size, float scale) {
const static auto spirv = getSpirvShader(kp::shader_data::op_scale_comp_spv,
kp::shader_data::op_scale_comp_spv_len);
const static auto spirv_1 = getSpirvShader(
kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
);
const static auto spirv_8 = getSpirvShader(
kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
);

struct PushConstants {
uint32_t inOff, outOff;
Expand All @@ -735,11 +740,19 @@ void ggml_vk_scale(kp::Sequence& seq,
scale
};

const auto * spirv = &spirv_1;
std::string name(__func__);
if (size % 8 == 0) {
size /= 8;
name += "_8";
spirv = &spirv_8;
}

std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__))
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
else {
s_algo = komputeManager()->getAlgorithm(__func__);
if (!komputeManager()->hasAlgorithm(name)) {
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({in, out});
s_algo->setWorkgroup({size});
s_algo->setPushConstants<PushConstants>({pushConsts});
Expand Down Expand Up @@ -1416,9 +1429,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
case GGML_OP_SCALE:
{
const float scale = *(const float *) src1->data;
int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 8 == 0);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale);
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
} break;
case GGML_OP_UNARY:
{
Expand Down
10 changes: 3 additions & 7 deletions kompute/op_scale.comp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ layout(push_constant) uniform PushConstants {
} pcs;

void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;

for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}
const uint i = gl_WorkGroupID.x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
31 changes: 31 additions & 0 deletions kompute/op_scale_8.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
*
* This software is licensed under the terms of the Software for Open Models License (SOM),
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
*/

#version 450

#include "common.comp"

layout(local_size_x = 1) in;

layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };

layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;

void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;

for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}

0 comments on commit 6474fc8

Please sign in to comment.