From e1af9ed88312ab35f6ac707fd2fbaafc4aba7825 Mon Sep 17 00:00:00 2001 From: ceciliapeng2011 Date: Wed, 11 Dec 2024 15:45:23 +0800 Subject: [PATCH] sdpa_opt kernel: seq_idx_end is frequently used in multiple stages... move it ahead and visible to all stages. --- .../kernel_selector/cl_kernels/sdpa_opt.cl | 58 +++++++------------ 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 177af024e8a722..e78c2f04dc5ead 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -878,18 +878,20 @@ KERNEL(sdpa_opt)( __local SOFTMAX_ACCUMULATOR_TYPE slm_max_val_prev[TARGET_SEQ_LEN_BLOCK_SIZE]; __local SOFTMAX_ACCUMULATOR_TYPE slm_update_factor[TARGET_SEQ_LEN_BLOCK_SIZE]; - const uint query_len = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#if IS_PAGED_ATTENTION + const uint block_start_pos = blocked_indexes_start[target_seq_dim]; + const uint block_end_pos = blocked_indexes_end[target_seq_dim]; + const uint seq_idx_end = block_end_pos - block_start_pos; +#else + const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); +#endif { // Load Q input to SLM and transpose it #if IS_PAGED_ATTENTION - const uint block_start_pos = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos = blocked_indexes_end[target_seq_dim]; uint query_offset = INPUT0_OFFSET + block_start_pos * (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + num_heads_dim * HEAD_SIZE + head_size_idx; const uint query_pitch = (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM); - - const uint cur_target_seq_len_size = block_end_pos - block_start_pos; #else #ifdef INPUT0_DIMS_ORDER uint query_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, target_seq_idx, (head_size_idx)); @@ -899,7 +901,6 @@ KERNEL(sdpa_opt)( uint query_offset = INPUT0_GET_INDEX(b0_idx, b1_idx, target_seq_idx, (head_size_idx)); const uint query_pitch = HEAD_SIZE; #endif - const uint cur_target_seq_len_size = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); #endif uint query_local_offset = head_size_idx * TARGET_SEQ_LEN_BLOCK_SIZE; @@ -913,9 +914,9 @@ KERNEL(sdpa_opt)( const INPUT0_TYPE scale_val = INPUT0_VAL_ONE; #endif - if (cur_target_seq_len_size != TARGET_SEQ_LEN_BLOCK_SIZE) { + if (seq_idx_end != TARGET_SEQ_LEN_BLOCK_SIZE) { if (sgid * SUBGROUP_SIZE < HEAD_SIZE) { - for (uint seq_idx = 0; seq_idx < cur_target_seq_len_size; seq_idx++) { + for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { INPUT0_TYPE val = BLOCK_READN(INPUT0_TYPE, 1, query_input, query_offset); slm_query[query_local_offset] = val * scale_val; @@ -1023,7 +1024,7 @@ KERNEL(sdpa_opt)( b0_idx, b1_idx, #if IS_PAGED_ATTENTION - blocked_indexes_start[target_seq_dim] - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid, + block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid, #else target_seq_idx + sglid, #endif @@ -1176,7 +1177,7 @@ KERNEL(sdpa_opt)( // SoftMax calculation // each sg will compute a whole row of query const int key_len_in_kv_block = SEQ_LEN_PARTITION_SIZE; - for (uint m = sgid; m < query_len; m += SUBGROUPS_PER_WG) { + for (uint m = sgid; m < seq_idx_end; m += SUBGROUPS_PER_WG) { // rowmax SOFTMAX_ACCUMULATOR_TYPE qk_max_new; if (sglid < SUBGROUPS_PER_WG) { // TODO: if SUBGROUPS_PER_WG > 16? @@ -1198,14 +1199,16 @@ KERNEL(sdpa_opt)( exp_sum_new = sub_group_reduce_add(exp_sum_new); // update - float pre_exp_sum = slm_exp_sum_prev[m]; - float correction_factor = native_exp(max_val_prev - qk_max_new); - float pre_exp_sum_fixed = pre_exp_sum * correction_factor; - exp_sum_new += pre_exp_sum_fixed; - - slm_update_factor[m] = correction_factor; - slm_max_val_prev[m] = qk_max_new; - slm_exp_sum_prev[m] = exp_sum_new; + if (sglid == 0) { + float pre_exp_sum = slm_exp_sum_prev[m]; + float correction_factor = native_exp(max_val_prev - qk_max_new); + float pre_exp_sum_fixed = pre_exp_sum * correction_factor; + exp_sum_new += pre_exp_sum_fixed; + + slm_update_factor[m] = correction_factor; + slm_max_val_prev[m] = qk_max_new; + slm_exp_sum_prev[m] = exp_sum_new; + } } } barrier(CLK_LOCAL_MEM_FENCE); @@ -1428,14 +1431,6 @@ KERNEL(sdpa_opt)( { // Rescale acc_output_res values and save current iter results to global accumulator -#if IS_PAGED_ATTENTION - const uint block_start_pos_new = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos_new = blocked_indexes_end[target_seq_dim]; - const uint seq_idx_end = block_end_pos_new - block_start_pos_new; -#else - const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#endif - for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { if (start_partition_idx > 0) { OUTPUT_TYPE updated_prev_res = TO_SOFTMAX_ACCUMULATOR_TYPE(output_acc[seq_idx]) * slm_update_factor[seq_idx]; @@ -1465,9 +1460,6 @@ KERNEL(sdpa_opt)( } #if IS_PAGED_ATTENTION - const uint block_start_pos_new = blocked_indexes_start[target_seq_dim]; - const uint block_end_pos_new = blocked_indexes_end[target_seq_dim]; - uint output_offset = block_start_pos_new * HEAD_SIZE * NUM_HEADS + num_heads_dim * HEAD_SIZE + sgid * SUBGROUP_SIZE; const uint output_pitch = HEAD_SIZE * NUM_HEADS; #else @@ -1475,13 +1467,7 @@ KERNEL(sdpa_opt)( const uint output_pitch = HEAD_SIZE; #endif -#if IS_PAGED_ATTENTION - if (block_start_pos_new + TARGET_SEQ_LEN_BLOCK_SIZE != block_end_pos_new) { - const uint seq_idx_end = block_end_pos_new - block_start_pos_new; -#else - if (get_global_id(1) == get_global_size(1) - 1) { - const uint seq_idx_end = min((uint)TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); -#endif + if (TARGET_SEQ_LEN_BLOCK_SIZE > seq_idx_end) { for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { output_acc[seq_idx] /= slm_exp_sum_prev[seq_idx]; OUTPUT_BLOCK_WRITE(output, output_offset, output_acc[seq_idx]);