From cfc8fefaaf2b2eb2f199adb75528389491c576d3 Mon Sep 17 00:00:00 2001 From: Lyamin-Roman Date: Fri, 31 May 2024 07:57:58 +0900 Subject: [PATCH] [GPU] Added RoPE support for Llama2 --- .../include/intel_gpu/primitives/rope.hpp | 12 ++++++++++-- .../intel_gpu/src/graph/impls/ocl/rope.cpp | 5 +++++ .../kernel_selector/cl_kernels/rope_ref.cl | 19 ++++++++++++++++++- .../kernels/rope/rope_kernel_base.cpp | 5 +++++ .../kernels/rope/rope_kernel_base.h | 1 + .../kernels/rope/rope_kernel_ref.cpp | 1 + src/plugins/intel_gpu/src/plugin/ops/rope.cpp | 8 +++++++- .../src/plugin/transformations_pipeline.cpp | 1 - .../subgraph_tests/rotary_pos_emb.cpp | 5 +++++ 9 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp index b9568c766412de..51bd0e3bfdfecd 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp @@ -19,14 +19,18 @@ struct rope : public primitive_base { /// @param id This primitive id /// @param inputs Inputs primitive ids /// @param config Specific RoPE config + /// @param gather_rank Required for correct processing of gather input (if applicable) rope(const primitive_id& id, const std::vector& inputs, const RoPE::Config& config, + size_t gather_rank = 0, const padding& output_padding = padding()) : primitive_base(id, inputs, {output_padding}), - config(config) {} + config(config), + gather_rank(gather_rank) {} RoPE::Config config; + size_t gather_rank; size_t hash() const override { size_t seed = primitive::hash(); @@ -40,6 +44,7 @@ struct rope : public primitive_base { seed = hash_combine(seed, config.rotary_ndims); seed = hash_combine(seed, config.slice_start); seed = hash_combine(seed, config.slice_stop); + seed = hash_combine(seed, gather_rank); return seed; } @@ -58,7 +63,8 @@ struct rope : public primitive_base { config.is_qwen == rhs_casted.config.is_qwen && config.rotary_ndims == rhs_casted.config.rotary_ndims && config.slice_start == rhs_casted.config.slice_start && - config.slice_stop == rhs_casted.config.slice_stop; + config.slice_stop == rhs_casted.config.slice_stop && + gather_rank == rhs_casted.gather_rank; } void save(BinaryOutputBuffer& ob) const override { @@ -73,6 +79,7 @@ struct rope : public primitive_base { ob << config.rotary_ndims; ob << config.slice_start; ob << config.slice_stop; + ob << gather_rank; } void load(BinaryInputBuffer& ib) override { @@ -87,6 +94,7 @@ struct rope : public primitive_base { ib >> config.rotary_ndims; ib >> config.slice_start; ib >> config.slice_stop; + ib >> gather_rank; } }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp index 2914bbfa24e865..a9620b1372d869 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp @@ -46,6 +46,11 @@ struct rope_impl : typed_primitive_impl_ocl { params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3; params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3; + if (primitive->gather_rank > 0) { + params.gather_rank = primitive->gather_rank; + params.num_of_inputs++; + } + params.is_qwen = primitive->config.is_qwen; params.is_chatglm = primitive->config.is_chatglm; params.transposed_input = primitive->config.input_trans0213; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl index 76bc154ce3e95d..3004886730d588 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl @@ -91,6 +91,9 @@ KERNEL(rope_ref)( const __global INPUT0_TYPE* input, const __global INPUT1_TYPE* cos, const __global INPUT1_TYPE* sin, +#ifdef ENABLE_GATHER + const __global INPUT3_TYPE* gather, +#endif __global OUTPUT_TYPE* output) { const uint b = get_global_id(0); @@ -108,9 +111,23 @@ KERNEL(rope_ref)( #else uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); #endif + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; - uint cos_sin_p = p < INPUT1_SIZE_Y ? p : 0; + uint cos_sin_p = p; +#ifdef ENABLE_GATHER + uint gather_b = b < INPUT3_BATCH_NUM ? b : 0; +#if GATHER_RANK == 4 + uint gather_h = h < INPUT3_FEATURE_NUM ? h : 0; + uint gather_p = p < INPUT3_SIZE_Y ? p : 0; + uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_h, gather_p, 0); +#else + uint gather_p = p < INPUT3_FEATURE_NUM ? p : 0; + uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_p, 0, 0); +#endif + cos_sin_p = gather[gather_idx]; +#endif + cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0; uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp index a46ccd5e0d96b0..b5c971e0c2712a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp @@ -22,6 +22,11 @@ JitConstants RoPEKernelBase::GetJitConstants(const rope_params& params, RoPEKern jit.AddConstant(MakeJitConstant("ENABLE_IO_COPY", true)); } + if (params.gather_rank > 0) { + jit.AddConstant(MakeJitConstant("ENABLE_GATHER", true)); + jit.AddConstant(MakeJitConstant("GATHER_RANK", params.gather_rank)); + } + if (params.slice_stop - params.slice_start > 0) { jit.AddConstant(MakeJitConstant("ENABLE_SLICE", true)); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h index f3d94c95c05009..dde691bfb439de 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h @@ -20,6 +20,7 @@ struct rope_params : public base_params { size_t slice_stop; size_t axis; size_t num_of_inputs; + size_t gather_rank; bool is_qwen; bool is_chatglm; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp index 27d7efeb525db5..5ec125ef6f083c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp @@ -9,6 +9,7 @@ namespace kernel_selector { ParamsKey RoPEKernelRef::GetSupportedKey() const { ParamsKey k; + k.EnableInputDataType(Datatype::INT32); k.EnableInputDataType(Datatype::F16); k.EnableInputDataType(Datatype::F32); k.EnableOutputDataType(Datatype::F16); diff --git a/src/plugins/intel_gpu/src/plugin/ops/rope.cpp b/src/plugins/intel_gpu/src/plugin/ops/rope.cpp index 2b299890fe92a9..321342b3395660 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/rope.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/rope.cpp @@ -24,9 +24,15 @@ static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptrget_config(); + size_t gather_rank = 0; + if (config.gather_position_arg_id > 0) { + gather_rank = op->get_input_partial_shape(config.gather_position_arg_id).size(); + } + auto rope = cldnn::rope(layer_type_name_ID(op), inputs, - config); + config, + gather_rank); p.add_primitive(*op, rope); } diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 60a83f8b8c9b35..faac47298f517f 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -817,7 +817,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(); pass_config->disable(); - pass_config->disable(); pass_config->disable(); pass_config->disable(); diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index bed957cc35fc0d..5bfcadb10c8205 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -18,5 +18,10 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestQwen7b::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2, + RoPETestLlama2, + ::testing::Values(ov::test::utils::DEVICE_GPU), + RoPETestLlama2::getTestCaseName); + } // namespace test } // namespace ov