From d89e0e09d78abfe477fbf544e00400560695cb33 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Wed, 24 Jan 2024 18:57:09 +0800 Subject: [PATCH 01/17] New Feature: 1. Sum_Rows: fix cuda kernel overflow fix block shape error when nrows too big 2. Im2Col: Support Batch in cuda Support f32 to f32 both in cpu && cuda 3. DepthWiseConv: Support by Im2Col && MulMat 4. Pool_2d: Supoort avg pooling in cuda 5. HardSigmoid: Imp in cuda 6. HardSwish: Imp in cuda --- examples/llava/MobileVLM-README.md | 58 ++++++++++- ggml-cuda.cu | 156 ++++++++++++++++++++++++++--- ggml.c | 130 +++++++++++++++++++++--- ggml.h | 3 +- tests/test-backend-ops.cpp | 2 +- 5 files changed, 317 insertions(+), 32 deletions(-) diff --git a/examples/llava/MobileVLM-README.md b/examples/llava/MobileVLM-README.md index c6258eba69a53..9eba791dadfef 100644 --- a/examples/llava/MobileVLM-README.md +++ b/examples/llava/MobileVLM-README.md @@ -111,17 +111,71 @@ llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 m llama_print_timings: total time = 34570.79 ms ``` +## Orin compile and run +### compile +```sh +make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_87 LLAMA_CUDA_F16=1 -j 32 +``` + +### run on Orin +### case 1 +**input** +```sh +./llava-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + --image /data/local/tmp/demo.jpeg \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \ + --n-gpu-layers 999 +``` +**output** +```sh + +encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch) + + Susan Wise Bauer + +llama_print_timings: load time = 1067.64 ms +llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second) +llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second) +llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second) +llama_print_timings: total time = 1352.63 ms / 252 tokens +``` + +### case 2 +**input** +```sh +./llava-cli \ + -m /data/local/tmp/ggml-model-q4_k.gguf \ + --mmproj /data/local/tmp/mmproj-model-f16.gguf \ + -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \ + --n-gpu-layers 999 + +``` +**output** +```sh +encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch) + + The image features a cat lying in the grass. + +llama_print_timings: load time = 1057.07 ms +llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second) +llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second) +llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second) +llama_print_timings: total time = 1365.47 ms / 243 tokens +``` + ## Minor shortcomings The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost. ## TODO -- [ ] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid` +- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid` - [ ] Optimize LDP projector performance - Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`; - Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc. -- [ ] run MobileVLM on `Jetson Orin` +- [x] run MobileVLM on `Jetson Orin` - [ ] Support more model variants, such as `MobileVLM-3B`. diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7f460449eaa05..b211b1a8a1cd8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -512,6 +512,8 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_HARDSIGMOID_BLOCK_SIZE 256 +#define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 @@ -811,6 +813,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } +static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +static __global__ void hardswish_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); +} + static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -5656,12 +5676,14 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, } static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.y; + const int row = blockIdx.x; const int col = threadIdx.x; float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { + int i = col; + while(i < ncols) { sum += x[row * ncols + i]; + i += blockDim.x; } sum = warp_reduce_sum(sum); @@ -5978,9 +6000,43 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } +static __global__ void im2col_f32_f32( + const float * x, float * dst, int batch_offset, + int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= pelements) { + return; + } + + const int ksize = OW * (KH > 1 ? KW : 1); + const int kx = i / ksize; + const int kd = kx * ksize; + const int ky = (i - kd) / OW; + const int ix = i % OW; + + const int oh = blockIdx.y; + const int batch = blockIdx.z / IC; + const int ic = blockIdx.z % IC; + + const int64_t iiw = ix * s0 + kx * d0 - p0; + const int64_t iih = oh * s1 + ky * d1 - p1; + + const int64_t offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = (0.0f); + } else { + const int64_t offset_src = ic * offset_delta + batch * batch_offset; + dst[offset_dst] = (x[offset_src + iih * IW + iiw]); + } +} + static __global__ void im2col_f32_f16( - const float * x, half * dst, - int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW, + const float * x, half * dst, int batch_offset, + int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= pelements) { @@ -5993,17 +6049,21 @@ static __global__ void im2col_f32_f16( const int ky = (i - kd) / OW; const int ix = i % OW; + const int oh = blockIdx.y; + const int batch = blockIdx.z / IC; + const int ic = blockIdx.z % IC; + const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = blockIdx.y * s1 + ky * d1 - p1; + const int64_t iih = oh * s1 + ky * d1 - p1; const int64_t offset_dst = - (blockIdx.y * OW + ix) * CHW + - (blockIdx.z * (KW * KH) + ky * KW + kx); + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst[offset_dst] = __float2half(0.0f); } else { - const int64_t offset_src = blockIdx.z * offset_delta; + const int64_t offset_src = ic * offset_delta + batch * batch_offset; dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); } } @@ -6221,6 +6281,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ relu_f32<<>>(x, dst, k); } +static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; + hardsigmoid_f32<<>>(x, dst, k); +} + +static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE; + hardswish_f32<<>>(x, dst, k); +} + static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; leaky_relu_f32<<>>(x, dst, k, negative_slope); @@ -7276,7 +7346,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums(1, nrows, 1); + const dim3 block_nums(nrows, 1, 1); k_sum_rows_f32<<>>(x, dst, ncols); } @@ -7388,14 +7458,24 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con } } +static void im2col_f32_f32_cuda(const float* x, float* dst, + int IW, int IH, int OW, int OH, int KW, int KH, int IC, + int batch, int batch_offset, int offset_delta, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + const int parallel_elements = OW * KW * KH; + const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OH, batch * IC); + im2col_f32_f32<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +} + static void im2col_f32_f16_cuda(const float* x, half* dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, - int offset_delta, + int batch, int batch_offset, int offset_delta, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, IC); - im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + dim3 block_nums(num_blocks, OH, batch * IC); + im2col_f32_f16<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } // buffer pool for cuda @@ -7980,6 +8060,34 @@ static void ggml_cuda_op_relu( (void) src1_dd; } +static void ggml_cuda_op_hardsigmoid( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +static void ggml_cuda_op_hardswish( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_leaky_relu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { @@ -8612,7 +8720,7 @@ static void ggml_cuda_op_im2col( GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; @@ -8634,8 +8742,13 @@ static void ggml_cuda_op_im2col( const int64_t OW = dst->ne[1]; const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[3]; + const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 - im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + if(dst->type == GGML_TYPE_F16) + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + else + im2col_f32_f32_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); (void) src0; (void) src0_dd; @@ -9231,6 +9344,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); } +static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid); +} + +static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish); +} static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu); } @@ -10109,6 +10229,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_UNARY_OP_RELU: func = ggml_cuda_relu; break; + case GGML_UNARY_OP_HARDSIGMOID: + func = ggml_cuda_hardsigmoid; + break; + case GGML_UNARY_OP_HARDSWISH: + func = ggml_cuda_hardswish; + break; default: return false; } @@ -10917,6 +11043,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: return true; diff --git a/ggml.c b/ggml.c index ca98fde8ab239..1c74d80e3f524 100644 --- a/ggml.c +++ b/ggml.c @@ -5296,7 +5296,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d( int s0, int p0, int d0) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -5374,16 +5374,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d( int p1, int d0, int d1) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW] - - struct ggml_tensor * result = - ggml_mul_mat(ctx, - ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] return result; @@ -5404,7 +5403,8 @@ struct ggml_tensor * ggml_im2col( int p1, int d0, int d1, - bool is_2D) { + bool is_2D, + enum ggml_type dst_type) { if(is_2D) { GGML_ASSERT(a->ne[2] == b->ne[2]); @@ -5428,7 +5428,7 @@ struct ggml_tensor * ggml_im2col( is_2D ? b->ne[3] : 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); @@ -5453,7 +5453,7 @@ struct ggml_tensor * ggml_conv_2d( int p1, int d0, int d1) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = ggml_mul_mat(ctx, @@ -5579,12 +5579,28 @@ struct ggml_tensor * ggml_pool_2d( is_node = true; } + struct ggml_tensor * result; +#if defined(GGML_USE_CUBLAS) + if(!(op == GGML_OP_POOL_AVG)) { + GGML_ASSERT(false); + } + + const int64_t ne[4] = {k0, k1, 1, a->ne[2]}; + struct ggml_tensor * b = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct ggml_tensor * im2col = ggml_im2col(ctx, b, new_a, + s0, s1, p0, p1, 1, 1, true, GGML_TYPE_F32); // [N * IC, OH, OW, KH * KW] + + result = ggml_sum_rows(ctx, im2col); + result = ggml_scale(ctx, result, 1. / (k0 * k1)); + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[2], a->ne[3]); +#else const int64_t ne[3] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), a->ne[2], }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); @@ -5592,7 +5608,7 @@ struct ggml_tensor * ggml_pool_2d( result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - +#endif return result; } @@ -12416,6 +12432,92 @@ static void ggml_compute_forward_conv_transpose_1d( } } +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + + // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] @@ -12506,14 +12608,14 @@ static void ggml_compute_forward_im2col( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - switch (src0->type) { + switch (dst->type) { case GGML_TYPE_F16: { ggml_compute_forward_im2col_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - GGML_ASSERT(false); + ggml_compute_forward_im2col_f32(params, src0, src1, dst); } break; default: { diff --git a/ggml.h b/ggml.h index 1c49762716774..0f541b2e7be8a 100644 --- a/ggml.h +++ b/ggml.h @@ -1493,7 +1493,8 @@ extern "C" { int p1, int d0, int d1, - bool is_2D); + bool is_2D, + enum ggml_type dst_type); GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( struct ggml_context * ctx, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55ce14e0d902c..ac1ae8ad2ef94 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1174,7 +1174,7 @@ struct test_im2col : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D); + ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F16); return out; } }; From b08c6b1ad82881227309390fbdcec62839116af9 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Fri, 26 Jan 2024 10:45:34 +0800 Subject: [PATCH 02/17] fix tabs instead of spaces --- ggml.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.h b/ggml.h index 0f541b2e7be8a..cf568b3e58c0d 100644 --- a/ggml.h +++ b/ggml.h @@ -1494,7 +1494,7 @@ extern "C" { int d0, int d1, bool is_2D, - enum ggml_type dst_type); + enum ggml_type dst_type); GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( struct ggml_context * ctx, From c29a855453dd29f5a1cb8bc3edeeda2ae50fe460 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Fri, 26 Jan 2024 15:38:37 +0800 Subject: [PATCH 03/17] code clean --- ggml-cuda.cu | 66 +++++++++------------------------------------------- 1 file changed, 11 insertions(+), 55 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b211b1a8a1cd8..db136ae9cf4d6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5680,10 +5680,8 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc const int col = threadIdx.x; float sum = 0.0f; - int i = col; - while(i < ncols) { + for (int i = col; i < ncols; i += blockDim.x) { sum += x[row * ncols + i]; - i += blockDim.x; } sum = warp_reduce_sum(sum); @@ -6000,8 +5998,9 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -static __global__ void im2col_f32_f32( - const float * x, float * dst, int batch_offset, +template +static __global__ void im2col_kernel( + const float * x, T * dst, int batch_offset, int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -6027,44 +6026,10 @@ static __global__ void im2col_f32_f32( (ic * (KW * KH) + ky * KW + kx); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = (0.0f); - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = (x[offset_src + iih * IW + iiw]); - } -} - -static __global__ void im2col_f32_f16( - const float * x, half * dst, int batch_offset, - int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, - int s0, int s1, int p0, int p1, int d0, int d1) { - const int i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= pelements) { - return; - } - - const int ksize = OW * (KH > 1 ? KW : 1); - const int kx = i / ksize; - const int kd = kx * ksize; - const int ky = (i - kd) / OW; - const int ix = i % OW; - - const int oh = blockIdx.y; - const int batch = blockIdx.z / IC; - const int ic = blockIdx.z % IC; - - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; - - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = __float2half(0.0f); + dst[offset_dst] = 0.0f; } else { const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + dst[offset_dst] = x[offset_src + iih * IW + iiw]; } } @@ -7458,24 +7423,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con } } -static void im2col_f32_f32_cuda(const float* x, float* dst, - int IW, int IH, int OW, int OH, int KW, int KH, int IC, - int batch, int batch_offset, int offset_delta, - int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, batch * IC); - im2col_f32_f32<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); -} - -static void im2col_f32_f16_cuda(const float* x, half* dst, +template +static void im2col_cuda(const float* x, T* dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, int batch, int batch_offset, int offset_delta, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; dim3 block_nums(num_blocks, OH, batch * IC); - im2col_f32_f16<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } // buffer pool for cuda @@ -8746,9 +8702,9 @@ static void ggml_cuda_op_im2col( const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) - im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); else - im2col_f32_f32_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); (void) src0; (void) src0_dd; From ba5592c65336159e7de379f1e39a1354ce287b74 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Mon, 29 Jan 2024 10:31:41 +0800 Subject: [PATCH 04/17] CUDA POOL2D --- ggml-cuda.cu | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++ ggml.c | 16 ---------- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index db136ae9cf4d6..4fb5ce78460da 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -530,6 +530,7 @@ static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16 #define CUDA_PAD_BLOCK_SIZE 256 #define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256 +#define CUDA_POOL2D_BLOCK_SIZE 256 #define CUDA_Q8_0_NE_ALIGN 2048 @@ -6033,6 +6034,48 @@ static __global__ void im2col_kernel( } } +template +static __global__ void pool2d_nchw_kernel( + const int ih, const int iw, const int oh, const int ow, + const int kh, const int kw, const int sh, const int sw, + const int ph, const int pw, + const Ti* src, To* dst, const enum ggml_op_pool op) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + const int I_HW = ih * iw; + const int O_HW = oh * ow; + const int nc = idx / (oh * ow); + const int cur_oh = idx % (oh * ow) / ow; + const int cur_ow = idx % (oh * ow) % ow; + const Ti* i_ptr = src + nc * I_HW; + To* o_ptr = dst + nc * O_HW; + const int start_h = cur_oh * sh - ph; + const int bh = max(0, start_h); + const int eh = min(ih, start_h + kh); + const int start_w = ow * sw - pw; + const int bw = max(0, start_w); + const int ew = min(iw, start_w + kw); + const To scale = 1. / ((eh - bh) * (ew - bw)); + To res = 0; + switch(op){ + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; + } + for(int i = bh; i < eh; i += 1){ + for(int j = bw; j < ew; j += 1){ + #if __CUDA_ARCH__ >= 350 + Ti cur = __ldg(i_ptr + i * iw + j); + #else + Ti cur = i_ptr[i * iw + j]; + #endif + switch(op){ + case GGML_OP_POOL_AVG: res += cur * scale; break; + case GGML_OP_POOL_MAX: res = max(res, (To)cur); break; + } + } + } + o_ptr[cur_oh * ow + cur_ow] = res; +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -8670,6 +8713,38 @@ static void ggml_cuda_op_alibi( (void) src1_dd; } +static void ggml_cuda_op_pool2d( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = static_cast(opts[0]); + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + const int64_t IC = src0->ne[2]; + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int parallel_elements = N * OC * OH * OW; + const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; + dim3 block_nums(num_blocks); + pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op); + + (void) src0; + (void) src0_dd; +} + + static void ggml_cuda_op_im2col( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { @@ -10084,6 +10159,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } +static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d); +} + static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } @@ -10265,6 +10344,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_IM2COL: func = ggml_cuda_im2col; break; + case GGML_OP_POOL_2D: + func = ggml_cuda_pool2d; + break; case GGML_OP_SUM_ROWS: func = ggml_cuda_sum_rows; break; diff --git a/ggml.c b/ggml.c index 1c74d80e3f524..ccb1bfb5e2645 100644 --- a/ggml.c +++ b/ggml.c @@ -5580,21 +5580,6 @@ struct ggml_tensor * ggml_pool_2d( } struct ggml_tensor * result; -#if defined(GGML_USE_CUBLAS) - if(!(op == GGML_OP_POOL_AVG)) { - GGML_ASSERT(false); - } - - const int64_t ne[4] = {k0, k1, 1, a->ne[2]}; - struct ggml_tensor * b = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct ggml_tensor * im2col = ggml_im2col(ctx, b, new_a, - s0, s1, p0, p1, 1, 1, true, GGML_TYPE_F32); // [N * IC, OH, OW, KH * KW] - - result = ggml_sum_rows(ctx, im2col); - result = ggml_scale(ctx, result, 1. / (k0 * k1)); - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[2], a->ne[3]); -#else const int64_t ne[3] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), @@ -5608,7 +5593,6 @@ struct ggml_tensor * ggml_pool_2d( result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; -#endif return result; } From 1a82788028337d114e7949ca3262ccb3869a5dd6 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Mon, 29 Jan 2024 11:21:03 +0800 Subject: [PATCH 05/17] ADD POOL2D test case in test-backend-ops.cpp --- ggml-cuda.cu | 1 + tests/test-backend-ops.cpp | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4fb5ce78460da..3387d54c9f900 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -11178,6 +11178,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ROPE: case GGML_OP_ALIBI: case GGML_OP_IM2COL: + case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ac1ae8ad2ef94..45f6fc36a396d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1140,6 +1140,40 @@ struct test_alibi : public test_case { } }; +// GGML_OP_POOL2D +struct test_pool2d : public test_case { + enum ggml_op_pool pool_type; + const ggml_type type_input; + const std::array ne_input; + // kernel size + const int k0; + const int k1; + // stride + const int s0; + const int s1; + // padding + const int p0; + const int p1; + + std::string vars() override { + return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1); + } + + test_pool2d(enum ggml_op_pool pool_type = GGML_OP_POOL_AVG, + ggml_type type_input = GGML_TYPE_F32, + std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] + int k0 = 3, int k1 = 3, + int s0 = 1, int s1 = 1, + int p0 = 1, int p1 = 1) + : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1); + return out; + } +}; + // GGML_OP_IM2COL struct test_im2col : public test_case { const ggml_type type_input; @@ -1502,6 +1536,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_AVG)); + test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_MAX)); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); From 41a34cb3de0ab225be89340d716de478828c4fbb Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 09:53:07 +0800 Subject: [PATCH 06/17] code clean --- ggml-cuda.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3387d54c9f900..a8f9265f14731 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -8726,7 +8726,6 @@ static void ggml_cuda_op_pool2d( const int p0 = opts[5]; const int p1 = opts[6]; - const int64_t IC = src0->ne[2]; const int64_t IH = src0->ne[1]; const int64_t IW = src0->ne[0]; From 1556d4ca17718417a6dad9bf73939625c2b2e7a0 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 10:28:18 +0800 Subject: [PATCH 07/17] fix pool2d_kernel nits --- ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a8f9265f14731..3c9863daed9a6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6038,9 +6038,11 @@ template static __global__ void pool2d_nchw_kernel( const int ih, const int iw, const int oh, const int ow, const int kh, const int kw, const int sh, const int sw, - const int ph, const int pw, + const int ph, const int pw, const int parallel_elements, const Ti* src, To* dst, const enum ggml_op_pool op) { int idx = threadIdx.x + blockIdx.x * blockDim.x; + if(idx >= parallel_elements) + return; const int I_HW = ih * iw; const int O_HW = oh * ow; const int nc = idx / (oh * ow); @@ -8737,7 +8739,7 @@ static void ggml_cuda_op_pool2d( const int parallel_elements = N * OC * OH * OW; const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; dim3 block_nums(num_blocks); - pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, src0_dd, dst_dd, op); + pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op); (void) src0; (void) src0_dd; From 379f89fbbe4658522795b9ae8aff970fb0696b36 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 11:06:37 +0800 Subject: [PATCH 08/17] fix bug in pool2d kernel --- ggml-cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3c9863daed9a6..20e3b5efa7e3a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6045,15 +6045,15 @@ static __global__ void pool2d_nchw_kernel( return; const int I_HW = ih * iw; const int O_HW = oh * ow; - const int nc = idx / (oh * ow); - const int cur_oh = idx % (oh * ow) / ow; - const int cur_ow = idx % (oh * ow) % ow; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / ow; + const int cur_ow = idx % O_HW % ow; const Ti* i_ptr = src + nc * I_HW; To* o_ptr = dst + nc * O_HW; const int start_h = cur_oh * sh - ph; const int bh = max(0, start_h); const int eh = min(ih, start_h + kh); - const int start_w = ow * sw - pw; + const int start_w = cur_ow * sw - pw; const int bw = max(0, start_w); const int ew = min(iw, start_w + kw); const To scale = 1. / ((eh - bh) * (ew - bw)); From 49f09aa72c0338817f3c50271b8e420a4bf66251 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 20:36:08 +0800 Subject: [PATCH 09/17] fix avg pooling, count_include_pad nits --- ggml-cuda.cu | 2 +- ggml.c | 4 ++-- ggml.h | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 20e3b5efa7e3a..5980e29bdd851 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6056,7 +6056,7 @@ static __global__ void pool2d_nchw_kernel( const int start_w = cur_ow * sw - pw; const int bw = max(0, start_w); const int ew = min(iw, start_w + kw); - const To scale = 1. / ((eh - bh) * (ew - bw)); + const To scale = 1. / (kh * kw); To res = 0; switch(op){ case GGML_OP_POOL_AVG: res = 0; break; diff --git a/ggml.c b/ggml.c index ccb1bfb5e2645..ad0a503fd7420 100644 --- a/ggml.c +++ b/ggml.c @@ -5569,8 +5569,8 @@ struct ggml_tensor * ggml_pool_2d( int k1, int s0, int s1, - float p0, - float p1) { + int p0, + int p1) { bool is_node = false; diff --git a/ggml.h b/ggml.h index cf568b3e58c0d..5a56f73b77d79 100644 --- a/ggml.h +++ b/ggml.h @@ -1600,8 +1600,8 @@ extern "C" { int k1, int s0, int s1, - float p0, - float p1); + int p0, + int p1); // nearest interpolate // used in stable-diffusion From 04f10a2287d9e099fe61911e16360b55fc56c48e Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 30 Jan 2024 14:03:49 +0100 Subject: [PATCH 10/17] test-backend-ops : add more pool_2d tests --- tests/test-backend-ops.cpp | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 45f6fc36a396d..7cfe74b386342 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -228,6 +228,14 @@ static std::string var_to_str(ggml_type type) { return ggml_type_name(type); } +static std::string var_to_str(ggml_op_pool pool) { + switch (pool) { + case GGML_OP_POOL_AVG: return "avg"; + case GGML_OP_POOL_MAX: return "max"; + default: return std::to_string(pool); + } +} + #define VARS_TO_STR1(a) VAR_TO_STR(a) #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b) #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c) @@ -1159,7 +1167,7 @@ struct test_pool2d : public test_case { return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1); } - test_pool2d(enum ggml_op_pool pool_type = GGML_OP_POOL_AVG, + test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG, ggml_type type_input = GGML_TYPE_F32, std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] int k0 = 3, int k1 = 3, @@ -1536,8 +1544,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_AVG)); - test_cases.emplace_back(new test_pool2d(GGML_OP_POOL_MAX)); + for (ggml_type type_input : {GGML_TYPE_F16, GGML_TYPE_F32}) { + for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { + for (int k0 : {1, 3}) { + for (int k1 : {1, 3}) { + for (int s0 : {1, 2}) { + for (int s1 : {1, 2}) { + for (int p0 : {0, 1}) { + for (int p1 : {0, 1}) { + test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1)); + } + } + } + } + } + } + } + } + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); From caf2fc829447696107a5b3625638ad57a56d878a Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 30 Jan 2024 14:05:35 +0100 Subject: [PATCH 11/17] cuda : fix warnings and formatting --- ggml-cuda.cu | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5980e29bdd851..f9899f613f880 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6041,8 +6041,10 @@ static __global__ void pool2d_nchw_kernel( const int ph, const int pw, const int parallel_elements, const Ti* src, To* dst, const enum ggml_op_pool op) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if(idx >= parallel_elements) + if (idx >= parallel_elements) { return; + } + const int I_HW = ih * iw; const int O_HW = oh * ow; const int nc = idx / O_HW; @@ -6058,12 +6060,14 @@ static __global__ void pool2d_nchw_kernel( const int ew = min(iw, start_w + kw); const To scale = 1. / (kh * kw); To res = 0; - switch(op){ + + switch (op) { case GGML_OP_POOL_AVG: res = 0; break; case GGML_OP_POOL_MAX: res = -FLT_MAX; break; } - for(int i = bh; i < eh; i += 1){ - for(int j = bw; j < ew; j += 1){ + + for(int i = bh; i < eh; i += 1) { + for(int j = bw; j < ew; j += 1) { #if __CUDA_ARCH__ >= 350 Ti cur = __ldg(i_ptr + i * iw + j); #else @@ -8741,11 +8745,10 @@ static void ggml_cuda_op_pool2d( dim3 block_nums(num_blocks); pool2d_nchw_kernel<<>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op); - (void) src0; - (void) src0_dd; + (void) src1; + (void) src1_dd; } - static void ggml_cuda_op_im2col( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { From bdf3b8ad70cbfd39015b13e72bb33470f49d1f1d Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 30 Jan 2024 14:49:55 +0100 Subject: [PATCH 12/17] ggml : check types in release builds too in pool_2d --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index ad0a503fd7420..f913c213e67ed 100644 --- a/ggml.c +++ b/ggml.c @@ -12790,8 +12790,8 @@ static void ggml_compute_forward_pool_2d( const struct ggml_compute_params * params, const struct ggml_tensor * src, struct ggml_tensor * dst) { - assert(src->type == GGML_TYPE_F32); - assert(params->ith == 0); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(params->ith == 0); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; From 8824e42786c0eb17a3b786486abf9089db92708a Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 30 Jan 2024 14:50:26 +0100 Subject: [PATCH 13/17] test-backend-ops : remove f16 pool_2d tests --- tests/test-backend-ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7cfe74b386342..0d029a76589a0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1544,7 +1544,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } - for (ggml_type type_input : {GGML_TYPE_F16, GGML_TYPE_F32}) { + for (ggml_type type_input : {GGML_TYPE_F32}) { for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { for (int k0 : {1, 3}) { for (int k1 : {1, 3}) { From 0d94da7cbb881e3b46c2f9b5539fef48a7bea700 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 30 Jan 2024 14:52:54 +0100 Subject: [PATCH 14/17] cuda : more style fixes --- ggml-cuda.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f9899f613f880..ed79f6abf6218 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6066,14 +6066,14 @@ static __global__ void pool2d_nchw_kernel( case GGML_OP_POOL_MAX: res = -FLT_MAX; break; } - for(int i = bh; i < eh; i += 1) { - for(int j = bw; j < ew; j += 1) { + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { #if __CUDA_ARCH__ >= 350 Ti cur = __ldg(i_ptr + i * iw + j); #else Ti cur = i_ptr[i * iw + j]; #endif - switch(op){ + switch (op) { case GGML_OP_POOL_AVG: res += cur * scale; break; case GGML_OP_POOL_MAX: res = max(res, (To)cur); break; } @@ -8780,10 +8780,11 @@ static void ggml_cuda_op_im2col( const int64_t batch = src1->ne[3]; const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 - if(dst->type == GGML_TYPE_F16) + if(dst->type == GGML_TYPE_F16) { im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); - else + } else { im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } (void) src0; (void) src0_dd; From ca4ec6d867db7ddc9f5f01f4eec0a9c64812c858 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Tue, 30 Jan 2024 22:14:50 +0800 Subject: [PATCH 15/17] Add assert in ggml_cuda_op_pool2d --- ggml-cuda.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ed79f6abf6218..00f64a59e19ed 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -8723,6 +8723,9 @@ static void ggml_cuda_op_pool2d( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + const int32_t * opts = (const int32_t *)dst->op_params; enum ggml_op_pool op = static_cast(opts[0]); const int k0 = opts[1]; From 66dd123b0f54bd1a6baaaa1c25c322c0132d2312 Mon Sep 17 00:00:00 2001 From: zhangjidong <1119708529@qq.com> Date: Wed, 31 Jan 2024 10:22:09 +0800 Subject: [PATCH 16/17] pool2d float padding fallback --- ggml.c | 4 ++-- ggml.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml.c b/ggml.c index f913c213e67ed..be17e6dcd3892 100644 --- a/ggml.c +++ b/ggml.c @@ -5569,8 +5569,8 @@ struct ggml_tensor * ggml_pool_2d( int k1, int s0, int s1, - int p0, - int p1) { + float p0, + float p1) { bool is_node = false; diff --git a/ggml.h b/ggml.h index 5a56f73b77d79..cf568b3e58c0d 100644 --- a/ggml.h +++ b/ggml.h @@ -1600,8 +1600,8 @@ extern "C" { int k1, int s0, int s1, - int p0, - int p1); + float p0, + float p1); // nearest interpolate // used in stable-diffusion From 18fd0b0ccc8816d3b73738bc123e49d928fa0ff7 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 31 Jan 2024 14:06:55 +0100 Subject: [PATCH 17/17] test-backend-ops : add dst_type to im2col --- tests/test-backend-ops.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0d029a76589a0..e0ac78a5f380d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -247,6 +247,7 @@ static std::string var_to_str(ggml_op_pool pool) { #define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i) #define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j) #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k) +#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l) // accept FLT_MAX as infinity @@ -1186,6 +1187,7 @@ struct test_pool2d : public test_case { struct test_im2col : public test_case { const ggml_type type_input; const ggml_type type_kernel; + const ggml_type dst_type; const std::array ne_input; const std::array ne_kernel; // stride @@ -1201,22 +1203,22 @@ struct test_im2col : public test_case { const bool is_2D; std::string vars() override { - return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D); + return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D); } - test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, + test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32, std::array ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1] std::array ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1] int s0 = 1, int s1 = 1, int p0 = 1, int p1 = 1, int d0 = 1, int d1 = 1, bool is_2D = true) - : type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {} + : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F16); + ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type); return out; } }; @@ -1562,6 +1564,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); @@ -1694,7 +1699,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_alibi()); - test_cases.emplace_back(new test_im2col()); test_cases.emplace_back(new test_concat(GGML_TYPE_F32)); test_cases.emplace_back(new test_concat(GGML_TYPE_I32));