Skip to content

Commit

Permalink
sdpa_opt kernel: seq_idx_end is frequently used in multiple stages...…
Browse files Browse the repository at this point in the history
… move it ahead and visible to all stages.
  • Loading branch information
ceciliapeng2011 committed Dec 11, 2024
1 parent bf69a35 commit e1af9ed
Showing 1 changed file with 22 additions and 36 deletions.
58 changes: 22 additions & 36 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -878,18 +878,20 @@ KERNEL(sdpa_opt)(
__local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_prev[TARGET_SEQ_LEN_BLOCK_SIZE];
__local SOFTMAX_ACCUMULATOR_TYPE slm_update_factor[TARGET_SEQ_LEN_BLOCK_SIZE];

const uint query_len = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#if IS_PAGED_ATTENTION
const uint block_start_pos = blocked_indexes_start[target_seq_dim];
const uint block_end_pos = blocked_indexes_end[target_seq_dim];
const uint seq_idx_end = block_end_pos - block_start_pos;
#else
const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif
{
// Load Q input to SLM and transpose it
#if IS_PAGED_ATTENTION
const uint block_start_pos = blocked_indexes_start[target_seq_dim];
const uint block_end_pos = blocked_indexes_end[target_seq_dim];
uint query_offset = INPUT0_OFFSET +
block_start_pos * (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) +
num_heads_dim * HEAD_SIZE + head_size_idx;
const uint query_pitch = (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM);

const uint cur_target_seq_len_size = block_end_pos - block_start_pos;
#else
#ifdef INPUT0_DIMS_ORDER
uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (head_size_idx));
Expand All @@ -899,7 +901,6 @@ KERNEL(sdpa_opt)(
uint query_offset = INPUT0_GET_INDEX(b0_idx, b1_idx, target_seq_idx, (head_size_idx));
const uint query_pitch = HEAD_SIZE;
#endif
const uint cur_target_seq_len_size = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif
uint query_local_offset = head_size_idx * TARGET_SEQ_LEN_BLOCK_SIZE;

Expand All @@ -913,9 +914,9 @@ KERNEL(sdpa_opt)(
const INPUT0_TYPE scale_val = INPUT0_VAL_ONE;
#endif

if (cur_target_seq_len_size != TARGET_SEQ_LEN_BLOCK_SIZE) {
if (seq_idx_end != TARGET_SEQ_LEN_BLOCK_SIZE) {
if (sgid * SUBGROUP_SIZE < HEAD_SIZE) {
for (uint seq_idx = 0; seq_idx < cur_target_seq_len_size; seq_idx++) {
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, 1, query_input, query_offset);

slm_query[query_local_offset] = val * scale_val;
Expand Down Expand Up @@ -1023,7 +1024,7 @@ KERNEL(sdpa_opt)(
b0_idx,
b1_idx,
#if IS_PAGED_ATTENTION
blocked_indexes_start[target_seq_dim] - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid,
block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid,
#else
target_seq_idx + sglid,
#endif
Expand Down Expand Up @@ -1176,7 +1177,7 @@ KERNEL(sdpa_opt)(
// SoftMax calculation
// each sg will compute a whole row of query
const int key_len_in_kv_block = SEQ_LEN_PARTITION_SIZE;
for (uint m = sgid; m < query_len; m += SUBGROUPS_PER_WG) {
for (uint m = sgid; m < seq_idx_end; m += SUBGROUPS_PER_WG) {
// rowmax
SOFTMAX_ACCUMULATOR_TYPE qk_max_new;
if (sglid < SUBGROUPS_PER_WG) { // TODO: if SUBGROUPS_PER_WG > 16?
Expand All @@ -1198,14 +1199,16 @@ KERNEL(sdpa_opt)(
exp_sum_new = sub_group_reduce_add(exp_sum_new);

// update
float pre_exp_sum = slm_exp_sum_prev[m];
float correction_factor = native_exp(max_val_prev - qk_max_new);
float pre_exp_sum_fixed = pre_exp_sum * correction_factor;
exp_sum_new += pre_exp_sum_fixed;

slm_update_factor[m] = correction_factor;
slm_max_val_prev[m] = qk_max_new;
slm_exp_sum_prev[m] = exp_sum_new;
if (sglid == 0) {
float pre_exp_sum = slm_exp_sum_prev[m];
float correction_factor = native_exp(max_val_prev - qk_max_new);
float pre_exp_sum_fixed = pre_exp_sum * correction_factor;
exp_sum_new += pre_exp_sum_fixed;

slm_update_factor[m] = correction_factor;
slm_max_val_prev[m] = qk_max_new;
slm_exp_sum_prev[m] = exp_sum_new;
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
Expand Down Expand Up @@ -1428,14 +1431,6 @@ KERNEL(sdpa_opt)(

{
// Rescale acc_output_res values and save current iter results to global accumulator
#if IS_PAGED_ATTENTION
const uint block_start_pos_new = blocked_indexes_start[target_seq_dim];
const uint block_end_pos_new = blocked_indexes_end[target_seq_dim];
const uint seq_idx_end = block_end_pos_new - block_start_pos_new;
#else
const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif

for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
if (start_partition_idx > 0) {
OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx];
Expand Down Expand Up @@ -1465,23 +1460,14 @@ KERNEL(sdpa_opt)(
}

#if IS_PAGED_ATTENTION
const uint block_start_pos_new = blocked_indexes_start[target_seq_dim];
const uint block_end_pos_new = blocked_indexes_end[target_seq_dim];

uint output_offset = block_start_pos_new * HEAD_SIZE * NUM_HEADS + num_heads_dim * HEAD_SIZE + sgid * SUBGROUP_SIZE;
const uint output_pitch = HEAD_SIZE * NUM_HEADS;
#else
uint output_offset = OUTPUT_GET_INDEX(b0_idx, b1_idx, target_seq_idx, sgid * SUBGROUP_SIZE);
const uint output_pitch = HEAD_SIZE;
#endif

#if IS_PAGED_ATTENTION
if (block_start_pos_new + TARGET_SEQ_LEN_BLOCK_SIZE != block_end_pos_new) {
const uint seq_idx_end = block_end_pos_new - block_start_pos_new;
#else
if (get_global_id(1) == get_global_size(1) - 1) {
const uint seq_idx_end = min((uint)TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif
if (TARGET_SEQ_LEN_BLOCK_SIZE > seq_idx_end) {
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx];
OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);
Expand Down

0 comments on commit e1af9ed

Please sign in to comment.