Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 30, 2023
1 parent 580fe20 commit fd70e7a
Showing 1 changed file with 69 additions and 5 deletions.
74 changes: 69 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_SOFT_MAX_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
Expand Down Expand Up @@ -4719,11 +4720,12 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void soft_max_f32_warp(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.y;
const int tid = threadIdx.y;

const int block_size = blockDim.x;

float max_val = -INFINITY;

Expand Down Expand Up @@ -4763,6 +4765,66 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
}
}

// use shared memory to reduce the number of global memory reads
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension

const int block_size = blockDim.x;

__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];

buf[tid] = -INFINITY;

for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
}

__syncthreads();

// find the max value in the block
for (int i = block_size/2; i > 0; i >>= 1) {
if (tid < i) {
buf[tid] = max(buf[tid], buf[tid + i]);
}
__syncthreads();
}

float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
tmp += val;
dst[ix] = val;
}

__syncthreads();

buf[tid] = tmp;

__syncthreads();

// sum up partial sums
for (int i = block_size/2; i > 0; i >>= 1) {
if (tid < i) {
buf[tid] += buf[tid + i];
}
__syncthreads();
}

const float inv_tmp = 1.f / buf[0];

for (int col = tid; col < ncols; col += block_size) {
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}

static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

Expand Down Expand Up @@ -5796,7 +5858,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
}

static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth , 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
Expand Down

0 comments on commit fd70e7a

Please sign in to comment.