Skip to content

Commit

Permalink
[GPU]optimize sdpa_opt kernel softmax.
Browse files Browse the repository at this point in the history
  • Loading branch information
ceciliapeng2011 committed Dec 11, 2024
1 parent 2d78f2a commit bf69a35
Showing 1 changed file with 46 additions and 80 deletions.
126 changes: 46 additions & 80 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -868,18 +868,17 @@ KERNEL(sdpa_opt)(
__local INPUT0_TYPE slm_query[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];

// SLM buffer for intermediate QK results
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
__local OUTPUT_TYPE slm_qk_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SEQ_LEN_PARTITION_SIZE];

// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_vals[SUBGROUPS_PER_WG * TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[TARGET_SEQ_LEN_BLOCK_SIZE][SUBGROUPS_PER_WG];

// SLM buffers for SoftMax recalculation for current iteration based on the previous results
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_cur[TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_cur[TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sum_prev[TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_prev[TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_update_factor[TARGET_SEQ_LEN_BLOCK_SIZE];

const uint query_len = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
{
// Load Q input to SLM and transpose it
#if IS_PAGED_ATTENTION
Expand Down Expand Up @@ -985,8 +984,6 @@ KERNEL(sdpa_opt)(

__attribute__((opencl_unroll_hint(1)))
for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) {
SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint seq_len = start_partition_idx + sgid * SUBGROUP_SIZE;
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);

Expand Down Expand Up @@ -1149,6 +1146,7 @@ KERNEL(sdpa_opt)(
}

{
SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
#if !APPLY_SCALES_TO_QUERY
#if HAS_SCALE_INPUT
Expand All @@ -1167,63 +1165,50 @@ KERNEL(sdpa_opt)(
qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX);

qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]));
slm_qk_vals[sglid][sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
}
}

{
slm_qk_max_vals[sgid * SUBGROUP_SIZE + sglid] = qk_max;
qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
slm_qk_max_vals[sglid][sgid] = qk_max;
}

barrier(CLK_LOCAL_MEM_FENCE);

{
// SoftMax calculation
SOFTMAX_ACCUMULATOR_TYPE qk_max_new = SOFTMAX_ACCUMULATOR_VAL_MIN;

for (uint i = 0; i < SUBGROUPS_PER_WG; i++) {
SOFTMAX_ACCUMULATOR_TYPE qk_max_val = slm_qk_max_vals[i * SUBGROUP_SIZE + sglid];
qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max_new, qk_max_val);
}

if (sgid == 0) {
slm_max_val_cur[sglid] = qk_max_new;
}

SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO;

for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
qk_acc[i] = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]) - qk_max_new);
exp_sum_new += qk_acc[i];
}

{
slm_exp_sum_vals[sgid * SUBGROUP_SIZE + sglid] = exp_sum_new;
}

exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO;

barrier(CLK_LOCAL_MEM_FENCE);

for (uint i = 0; i < SUBGROUPS_PER_WG; i++) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum = slm_exp_sum_vals[i * SUBGROUP_SIZE + sglid];
exp_sum_new += exp_sum;
}

for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
qk_acc[i] = qk_acc[i] / exp_sum_new;
}
// each sg will compute a whole row of query
const int key_len_in_kv_block = SEQ_LEN_PARTITION_SIZE;
for (uint m = sgid; m < query_len; m += SUBGROUPS_PER_WG) {
// rowmax
SOFTMAX_ACCUMULATOR_TYPE qk_max_new;
if (sglid < SUBGROUPS_PER_WG) { // TODO: if SUBGROUPS_PER_WG > 16?
qk_max_new = slm_qk_max_vals[m][sglid];
} else {
qk_max_new = SOFTMAX_ACCUMULATOR_VAL_MIN;
}
qk_max_new = sub_group_reduce_max(qk_max_new);
SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[m];
qk_max_new = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max_new, max_val_prev);

// softmax
SOFTMAX_ACCUMULATOR_TYPE exp_sum_new = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint k = sglid; k < key_len_in_kv_block; k += SUBGROUP_SIZE) { // FIXME key_len_in_kv_block
SOFTMAX_ACCUMULATOR_TYPE a = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[m][k]) - qk_max_new);
slm_qk_vals[m][k] = convert_half(a);
exp_sum_new += a;
}
exp_sum_new = sub_group_reduce_add(exp_sum_new);

if (sgid == 0) {
slm_exp_sum_cur[sglid] = exp_sum_new;
}
// update
float pre_exp_sum = slm_exp_sum_prev[m];
float correction_factor = native_exp(max_val_prev - qk_max_new);
float pre_exp_sum_fixed = pre_exp_sum * correction_factor;
exp_sum_new += pre_exp_sum_fixed;

for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
slm_qk_vals[sglid * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
slm_update_factor[m] = correction_factor;
slm_max_val_prev[m] = qk_max_new;
slm_exp_sum_prev[m] = exp_sum_new;
}

barrier(CLK_LOCAL_MEM_FENCE);
}
barrier(CLK_LOCAL_MEM_FENCE);

{
// QK*V calculation
Expand Down Expand Up @@ -1270,7 +1255,7 @@ KERNEL(sdpa_opt)(

MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
qk_val[seq_idx] = slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len + sglid];
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len + sglid];
}

#if IS_KV_COMPRESSED
Expand Down Expand Up @@ -1346,7 +1331,7 @@ KERNEL(sdpa_opt)(

MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
qk_val[seq_idx] = slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + seq_len * SUBGROUP_SIZE + sglid];
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len * SUBGROUP_SIZE + sglid];
}

unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) {
Expand Down Expand Up @@ -1376,11 +1361,9 @@ KERNEL(sdpa_opt)(
// QK*V leftovers processing
const uint seq_len_leftovers_start = ((seq_len_end / SUBGROUP_SIZE) * SUBGROUP_SIZE);
if (seq_len_leftovers_start != seq_len_end) {
uint qk_offset = min(seq_len_leftovers_start + sglid, seq_len_end - 1);
MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_val;
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
qk_val[seq_idx] = slm_qk_vals[qk_offset];
qk_offset += SEQ_LEN_PARTITION_SIZE;
qk_val[seq_idx] = slm_qk_vals[seq_idx][seq_len_leftovers_start+sglid];
}
#if IS_PAGED_ATTENTION
#ifdef BROADCAST_GROUP_SIZE
Expand Down Expand Up @@ -1445,13 +1428,6 @@ KERNEL(sdpa_opt)(

{
// Rescale acc_output_res values and save current iter results to global accumulator
SOFTMAX_ACCUMULATOR_TYPE exp_sum_prev = slm_exp_sum_prev[sglid];
SOFTMAX_ACCUMULATOR_TYPE exp_sum_cur = slm_exp_sum_cur[sglid];
SOFTMAX_ACCUMULATOR_TYPE max_val_prev = slm_max_val_prev[sglid];
SOFTMAX_ACCUMULATOR_TYPE max_val_cur = slm_max_val_cur[sglid];

barrier(CLK_LOCAL_MEM_FENCE);

#if IS_PAGED_ATTENTION
const uint block_start_pos_new = blocked_indexes_start[target_seq_dim];
const uint block_end_pos_new = blocked_indexes_end[target_seq_dim];
Expand All @@ -1461,23 +1437,11 @@ KERNEL(sdpa_opt)(
#endif

for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
SOFTMAX_ACCUMULATOR_TYPE total_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(sub_group_broadcast(max_val_prev, seq_idx), sub_group_broadcast(max_val_cur, seq_idx));
SOFTMAX_ACCUMULATOR_TYPE updated_exp_sum_prev = sub_group_broadcast(exp_sum_prev, seq_idx) * native_exp(sub_group_broadcast(max_val_prev, seq_idx) - total_max);
SOFTMAX_ACCUMULATOR_TYPE updated_exp_sum_cur = sub_group_broadcast(exp_sum_cur, seq_idx) * native_exp(sub_group_broadcast(max_val_cur, seq_idx) - total_max);
SOFTMAX_ACCUMULATOR_TYPE updated_total_exp_sum = updated_exp_sum_prev + updated_exp_sum_cur;

if (start_partition_idx > 0) {
OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * updated_exp_sum_prev / updated_total_exp_sum;;
acc_output_res[seq_idx] *= updated_exp_sum_cur / updated_total_exp_sum;
OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx];
acc_output_res[seq_idx] += updated_prev_res;
}

output_acc[seq_idx] = acc_output_res[seq_idx];

if (sgid == 0 && sglid == 0) {
slm_exp_sum_prev[seq_idx] = updated_total_exp_sum;
slm_max_val_prev[seq_idx] = total_max;
}
}
}
}
Expand All @@ -1487,7 +1451,7 @@ KERNEL(sdpa_opt)(

if (sgid >= (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) {
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + (uint)get_local_id(2)] = output_acc[seq_idx];
slm_qk_vals[seq_idx][(uint)get_local_id(2)] = output_acc[seq_idx];
}
}

Expand All @@ -1496,7 +1460,7 @@ KERNEL(sdpa_opt)(
if (sgid < (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) {
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
unroll_for (uint i = 1; i < SG_SCALE_FACTOR; i++) {
output_acc[seq_idx] += slm_qk_vals[seq_idx * SEQ_LEN_PARTITION_SIZE + (i * HEAD_SIZE) + head_size_idx];
output_acc[seq_idx] += slm_qk_vals[seq_idx][(i * HEAD_SIZE) + head_size_idx];
}
}

Expand All @@ -1519,11 +1483,13 @@ KERNEL(sdpa_opt)(
const uint seq_idx_end = min((uint)TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx];
OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);
output_offset += output_pitch;
}
} else {
unroll_for (uint seq_idx = 0; seq_idx < TARGET_SEQ_LEN_BLOCK_SIZE; seq_idx++) {
output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx];
OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);
output_offset += output_pitch;
}
Expand Down

0 comments on commit bf69a35

Please sign in to comment.