Skip to content

Commit

Permalink
Vulkan: Fix float16 use on devices without float16 support + fix subg…
Browse files Browse the repository at this point in the history
…roup_size_control validation error (ggerganov#11161)

* Vulkan: Remove float16 use in shaders

* Fix validation error about subgroup_size_control extension
  • Loading branch information
0cc4m authored Jan 10, 2025
1 parent ee7136c commit c3f9d25
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 51 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2277,6 +2277,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
device_extensions.push_back("VK_EXT_subgroup_size_control");
}

device->subgroup_size_control = device->subgroup_size_control &&
Expand All @@ -2285,7 +2286,6 @@ static vk_device ggml_vk_get_device(size_t idx) {

if (device->subgroup_size_control) {
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
device_extensions.push_back("VK_EXT_subgroup_size_control");
}

#if defined(VK_KHR_cooperative_matrix)
Expand Down
9 changes: 3 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#version 450

#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand All @@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const

#if K_PER_ITER == 8
#if QUANT_R == 2
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
Expand Down
24 changes: 12 additions & 12 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand Down Expand Up @@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
const FLOAT_TYPE dall = d.x;
const FLOAT_TYPE dmin = d.y;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
Expand All @@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec2 qs16 = uvec2(unpack8(qs16_u16));

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);

FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
Expand Down
18 changes: 9 additions & 9 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand Down Expand Up @@ -60,14 +60,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {

B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);

FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand Down Expand Up @@ -45,7 +45,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

Expand Down Expand Up @@ -96,10 +96,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_15 = qs64_hi4.w;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);

const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
Expand Down
20 changes: 10 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand Down Expand Up @@ -42,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

Expand Down Expand Up @@ -105,14 +105,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint32_t q4_15 = qs64_80_hi4.w;

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);

const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.comp"

Expand Down Expand Up @@ -77,10 +77,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uvec4 q3 = uvec4(unpack8(q3_u32));

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);

FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) {
Expand Down
1 change: 0 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#version 450

#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_control_flow_attributes : enable

layout (push_constant) uniform parameter
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP

#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_16bit_storage : require

#if defined(DATA_A_F32)
#define QUANT_K 1
Expand Down

0 comments on commit c3f9d25

Please sign in to comment.