From 1f6700fa8c521f9cba1dcab3c21439ed9fe2fe91 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Thu, 24 Oct 2024 13:54:58 +0400 Subject: [PATCH] [GPU] Handle runtime scale value for PagedAttention (#27204) ### Details: - Add support for non-constant scale input, as the current Paged Attention specification does not require this value to be strictly constant --- .../src/graph/impls/ocl/paged_attention.cpp | 42 +++++++++++++++++-- .../kernel_selector/cl_kernels/pa_sdpa_opt.cl | 9 +++- .../kernel_selector/cl_kernels/sdpa_opt.cl | 14 +++++-- .../kernels/sdpa/pa_sdpa_kernel_opt.cpp | 12 +++++- .../kernels/sdpa/sdpa_kernel_base.h | 2 +- .../kernels/sdpa/sdpa_kernel_opt.cpp | 4 +- .../src/plugin/ops/paged_attention.cpp | 9 ++-- 7 files changed, 78 insertions(+), 14 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index cfc1e17c87ac6e..9cf1a252564934 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -122,6 +122,10 @@ struct paged_attention_impl : multi_stage_primitive { instance.value_memory_ptr(), instance.subsequence_begins_memory_ptr() }; + if (!desc->scale_val.has_value()) { + args.inputs.push_back(instance.input_memory_ptr(9)); + } + if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } @@ -144,6 +148,10 @@ struct paged_attention_impl : multi_stage_primitive { args.inputs.push_back(instance.subsequence_begins_memory_ptr()); } + if (!desc->scale_val.has_value()) { + args.inputs.push_back(instance.input_memory_ptr(9)); + } + if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } @@ -343,8 +351,10 @@ struct paged_attention_impl : multi_stage_primitive { config.paged_attention_block_size = static_cast(paged_attention::block_size); if (desc->scale_val.has_value()) { - config.has_scale_val = true; + config.has_const_scale_val = true; config.scale_val = desc->scale_val.value(); + } else { + config.has_const_scale_val = false; } if (desc->heads_num != desc->kv_heads_num) { @@ -409,16 +419,22 @@ struct paged_attention_impl : multi_stage_primitive { } static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); const auto& query_layout = impl_param.get_input_layout(0); const auto& key_layout = impl_param.get_input_layout(1); const auto& value_layout = impl_param.get_input_layout(2); const auto& subsequence_begins_layout = impl_param.get_input_layout(6); + const auto& scale_layout = impl_param.get_input_layout(9); const auto& alibi_layout = impl_param.get_input_layout(11); const auto has_alibi = alibi_layout.count() > 0; + const auto has_scale_input = !desc->scale_val.has_value(); auto inputs_number = 4; + if (has_scale_input) + inputs_number++; + if (has_alibi) inputs_number++; @@ -429,6 +445,9 @@ struct paged_attention_impl : multi_stage_primitive { params.inputs[input_idx++] = convert_data_tensor(value_layout); params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); + if (has_scale_input) + params.inputs[input_idx++] = convert_data_tensor(scale_layout); + if (has_alibi) params.inputs[input_idx++] = convert_data_tensor(alibi_layout); @@ -446,8 +465,12 @@ struct paged_attention_impl : multi_stage_primitive { {0, out_offsets_map.at(0)}, }; + input_idx = 4; + if (has_scale_input) + in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(9)}); + if (has_alibi) - in_tensor_to_offset_map.insert({4, in_offsets_map.at(11)}); + in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)}); if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic) params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage); @@ -458,6 +481,7 @@ struct paged_attention_impl : multi_stage_primitive { } static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); const auto& query_layout = impl_param.get_input_layout(0); @@ -467,10 +491,15 @@ struct paged_attention_impl : multi_stage_primitive { const auto& block_indices_layout = impl_param.get_input_layout(7); const auto& block_indices_begins_layout = impl_param.get_input_layout(8); const auto& subsequence_begins_layout = impl_param.get_input_layout(6); + const auto& scale_layout = impl_param.get_input_layout(9); const auto& alibi_layout = impl_param.get_input_layout(11); const auto has_alibi = alibi_layout.count() > 0; + const auto has_scale_input = !desc->scale_val.has_value(); auto inputs_number = 7; + if (has_scale_input) + inputs_number++; + if (has_alibi) inputs_number++; @@ -485,6 +514,9 @@ struct paged_attention_impl : multi_stage_primitive { params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); params.conf = get_sdpa_configuration(impl_param); + if (has_scale_input) + params.inputs[input_idx++] = convert_data_tensor(scale_layout); + if (has_alibi) params.inputs[input_idx++] = convert_data_tensor(alibi_layout); @@ -513,8 +545,12 @@ struct paged_attention_impl : multi_stage_primitive { {0, out_offsets_map.at(0)}, }; + input_idx = 7; + if (has_scale_input) + in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(9)}); + if (has_alibi) - in_tensor_to_offset_map.insert({7, in_offsets_map.at(11)}); + in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)}); params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 22b561e3d78661..a3bdd7e12dcd49 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -37,8 +37,11 @@ KERNEL(pa_sdpa_opt)( #if MULTI_TOKENS_PROCESSING const __global INPUT6_TYPE* subsequence_begins, #endif +#if HAS_SCALE_INPUT + const __global SCALE_INPUT_TYPE* scale, +#endif #if HAS_ALIBI - const __global INPUT7_TYPE* alibi_slopes, + const __global ALIBI_INPUT_TYPE* alibi_slopes, #endif __global OUTPUT_TYPE* output, __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, @@ -117,6 +120,8 @@ KERNEL(pa_sdpa_opt)( // Apply scale value directly to the query input to improve accuracy in case of a high range of input data #ifdef SCALE_VAL q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val; +#else + q_val = *scale * q_val; #endif slm_query[query_idx_local] = q_val; @@ -133,6 +138,8 @@ KERNEL(pa_sdpa_opt)( // Apply scale value directly to the query input to improve accuracy in case of a high range of input data #ifdef SCALE_VAL q_val[i] = TO_INPUT0_TYPE(SCALE_VAL) * q_val[i]; +#else + q_val[i] = *scale * q_val[i]; #endif } #endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 948bd3c0f1a305..748f79115262e0 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -656,6 +656,14 @@ inline MASK_VECTOR_TYPE FUNC(load_attn_mask)(OPTIONAL_SHAPE_INFO_ARG return mask_vec; } +#if IS_PAGED_ATTENTION && HAS_ALIBI +#if HAS_SCALE_INPUT +#define ALIBI_TYPE INPUT5_TYPE +#else +#define ALIBI_TYPE INPUT4_TYPE +#endif +#endif + REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) KERNEL(sdpa_opt)( OPTIONAL_SHAPE_INFO_ARG @@ -664,15 +672,15 @@ KERNEL(sdpa_opt)( const __global INPUT2_TYPE* value_input, #if IS_PAGED_ATTENTION const __global INPUT3_TYPE* subsequence_begins, -#if HAS_ALIBI - const __global INPUT4_TYPE* alibi_slopes, -#endif #endif #if HAS_ATTN_MASK_INPUT const __global INPUT3_TYPE* attn_mask, #endif #if HAS_SCALE_INPUT const __global INPUT4_TYPE* scale, +#endif +#if IS_PAGED_ATTENTION && HAS_ALIBI + const __global ALIBI_TYPE* alibi_slopes, #endif __global OUTPUT_TYPE* output, #ifdef BEAM_TABLE_TYPE diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 161c37ab3d3bf7..63c5e74160f652 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -176,11 +176,19 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0; jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); - if (config.has_scale_val) + if (config.has_const_scale_val) { jit.AddConstant(MakeJitConstant("SCALE_VAL", config.scale_val)); + } else { + const size_t scale_input_idx = 7; + jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1)); + jit.Merge(MakeTypeJitConstants(params.inputs[scale_input_idx].GetDType(), "SCALE_INPUT")); + } - if (params.conf.has_alibi_input) + if (params.conf.has_alibi_input) { + const size_t alibi_input_idx = config.has_const_scale_val ? 7 : 8; jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1)); + jit.Merge(MakeTypeJitConstants(params.inputs[alibi_input_idx].GetDType(), "ALIBI_INPUT")); + } if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 6ea8d85527d19d..492e86ebcce5cc 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -93,7 +93,7 @@ struct sdpa_configuration { bool is_paged_attention = false; int64_t paged_attention_aligned_seq_len = -1; int64_t paged_attention_block_size = 0; - bool has_scale_val = false; + bool has_const_scale_val = false; float scale_val = 0.f; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index 2f0174d0a45912..6942e5f8ea4357 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -180,9 +180,11 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1)); } - if (params.conf.has_scale_val) { + if (params.conf.has_const_scale_val) { jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", 1.0f / params.conf.scale_val)); jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE", params.conf.scale_val)); + } else { + jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1)); } } else if (params.inputs.size() <= 4) { jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", std::sqrt(static_cast(params.conf.head_size)))); diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index e4e7dcb77e03fb..7425b096b6d324 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -50,9 +50,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared const size_t alibi_idx = 11; std::shared_ptr scale_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(scale_idx)); - OPENVINO_ASSERT(scale_const != nullptr); - OPENVINO_ASSERT(ov::shape_size(scale_const->get_output_shape(0)) == 1); - prim.scale_val = scale_const->cast_vector()[0]; + if (scale_const) { + OPENVINO_ASSERT(ov::shape_size(scale_const->get_output_shape(0)) == 1); + prim.scale_val = scale_const->cast_vector()[0]; + } else { + prim.scale_val = cldnn::optional_value(); + } std::shared_ptr alibi_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(alibi_idx)); OPENVINO_ASSERT(alibi_const != nullptr);