Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add ggml_soft_max_ext #4256

Merged
merged 14 commits into from
Dec 1, 2023
Merged
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
35 changes: 21 additions & 14 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.y;
const int tid = threadIdx.y;

float max_val = -INFINITY;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}

// find the max value in the block
Expand All @@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[i] = val;
dst[ix] = val;
}

// sum up partial sums
Expand All @@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
const float inv_tmp = 1.f / tmp;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
Expand Down Expand Up @@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}

static void im2col_f32_f16_cuda(const float * x, half * dst,
Expand Down Expand Up @@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
ggerganov marked this conversation as resolved.
Show resolved Hide resolved

soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));

soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_scale(
Expand Down
15 changes: 10 additions & 5 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1036,11 +1036,16 @@ void ggml_metal_graph_compute(
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];

const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
Expand Down
26 changes: 16 additions & 10 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ kernel void kernel_gelu(

kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -194,14 +196,15 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

float max = simd_max(lmax);
Expand All @@ -225,7 +228,7 @@ kernel void kernel_soft_max(
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
Expand Down Expand Up @@ -257,10 +260,12 @@ kernel void kernel_soft_max(

kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -271,14 +276,15 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
Expand All @@ -303,7 +309,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
Expand Down
76 changes: 60 additions & 16 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
}

bool is_node = false;

if (a->grad) {
Expand All @@ -4835,23 +4845,35 @@ static struct ggml_tensor * ggml_soft_max_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = mask;

return result;
}

struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, false);
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
}

struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, true);
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
}

struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale) {
return ggml_soft_max_impl(ctx, a, mask, scale, false);
}

// ggml_soft_max_back
Expand Down Expand Up @@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));

// TODO: handle transposed/permuted matrices

const int ith = params->ith;
const int nth = params->nth;

const int64_t ne11 = src1 ? src1->ne[1] : 1;

const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);

Expand All @@ -10575,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;

for (int i1 = ir0; i1 < ir1; i1++) {
float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);

// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;

float * wp = wdata;
for (int i = 0; i < nc; i++) {
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
}

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(sp[i]));
assert(!isnan(wp[i]));
}
#endif

float max = -INFINITY;
ggml_vec_max_f32(nc, &max, sp);
ggml_vec_max_f32(nc, &max, wp);

ggml_float sum = 0.0;

uint16_t scvt;
for (int i = 0; i < nc; i++) {
if (sp[i] == -INFINITY) {
if (wp[i] == -INFINITY) {
dp[i] = 0.0f;
} else {
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
memcpy(&scvt, &s, sizeof(scvt));
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
sum += (ggml_float)val;
Expand All @@ -10622,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_soft_max_f32(params, src0, dst);
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
} break;
default:
{
Expand Down Expand Up @@ -13863,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SOFT_MAX:
{
ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
Expand Down Expand Up @@ -15899,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
} break;
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));

cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
GGML_ASSERT(node->src[0]->ne[3] == 1);
Expand Down
8 changes: 8 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);

GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
Loading
Loading