From 96517682ef8b0bb4315ab548850a30df867e33fa Mon Sep 17 00:00:00 2001 From: River Li Date: Mon, 16 Dec 2024 13:36:30 +0800 Subject: [PATCH] [GPU] rope optimization (#27907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Details: - Optimize rope opencl kernel to improve its performance - Test result shows it can improve RoPE performance about 50% in average.
batch=128, seq_length = 7 | base latency(ns) | optimized latency(ns) | latency decreased |   |   -- | -- | -- | -- | -- | -- rope_ref_5266667119713786613_0_0__sa, | 921352 | 872395 | 5.31% | RoPETestQwen7b | f32 rope_ref_2672092794364911740_0_0__sa, | 1724374 | 514790 | 70.15% | RoPETestChatGLM | f32 rope_ref_8061762790816124098_0_0__sa, | 633019 | 127186 | 79.91% | RoPETestQwen7b | f16 rope_ref_4392014836945391706_0_0__sa, | 629791 | 518749 | 17.63% | RoPETestLlama2 | f32 rope_ref_13829176589243505378_0_0__sa, | 870312 | 259583 | 70.17% | RoPETestChatGLM | f32 rope_ref_6813544162411765619_0_0__sa, | 749895 | 421875 | 43.74% | RoPETestChatGLM | f16 rope_ref_15054358246334082928_0_0__sa, | 637708 | 45208 | 92.91% | RoPETestFlux | f32 rope_ref_3898891400599565440_0_0__sa, | 378333 | 335937 | 11.21% | RoPETestRotateHalfWithoutTranspose | f32 rope_ref_18119704851383556529_0_0__sa, | 371250 | 208645 | 43.80% | RoPETestChatGLM | f16 rope_ref_17460680473512025171_0_0__sa, | 299166 | 98958 | 66.92% | RoPETestFlux | f16
![image](https://github.com/user-attachments/assets/4328b1a7-18ec-485f-abd0-b0fe16785854) ### Tickets: - *CVS-157438* --- .../subgraph_tests/rotary_pos_emb.cpp | 24 +- .../intel_gpu/src/graph/impls/ocl/rope.cpp | 2 +- .../kernel_selector/cl_kernels/rope_opt.cl | 445 ++++++++++++++++++ .../kernel_selector/cl_kernels/rope_ref.cl | 201 -------- .../kernels/rope/rope_kernel_base.cpp | 3 +- .../kernels/rope/rope_kernel_opt.cpp | 107 +++++ .../kernels/rope/rope_kernel_opt.h | 25 + .../kernels/rope/rope_kernel_ref.cpp | 35 -- .../kernels/rope/rope_kernel_ref.h | 20 - .../kernels/rope/rope_kernel_selector.cpp | 4 +- .../subgraph_tests/rotary_pos_emb.cpp | 30 +- .../subgraph/rotary_pos_emb.hpp | 50 +- .../src/subgraph/rotary_pos_emb.cpp | 193 ++++---- 13 files changed, 760 insertions(+), 379 deletions(-) create mode 100644 src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_opt.cl delete mode 100644 src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.cpp create mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.h delete mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp delete mode 100644 src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index 8cd8707e047878..20e77bb02897e5 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -9,51 +9,65 @@ namespace test { INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2StridedSlice, RoPETestLlama2StridedSlice, - ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestLlama2StridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLMStridedSlice, RoPETestChatGLMStridedSlice, - ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestChatGLMStridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7bStridedSlice, RoPETestQwen7bStridedSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestQwen7bStridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJStridedSlice, RoPETestGPTJStridedSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestGPTJStridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2Slice, RoPETestLlama2Slice, - ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestLlama2Slice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLMSlice, RoPETestChatGLMSlice, - ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestChatGLMSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7bSlice, RoPETestQwen7bSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestQwen7bSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJSlice, RoPETestGPTJSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestGPTJSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, RoPETestChatGLM2DRoPEStridedSlice, - ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName); } // namespace test diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp index d06e643c71ad18..1d2051461fd9c0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp @@ -6,7 +6,7 @@ #include "rope_inst.h" #include "rope/rope_kernel_selector.h" -#include "rope/rope_kernel_ref.h" +#include "rope/rope_kernel_opt.h" namespace cldnn { namespace ocl { diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_opt.cl new file mode 100644 index 00000000000000..a5c71c84211c7e --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_opt.cl @@ -0,0 +1,445 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/fetch_utils.cl" + +#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE) +#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_SIZE) + +#define UNPACK_FLOAT_VEC_1(outputv, input1, input2) \ + outputv.s0 = convert_float(input1.s0); \ + outputv.s1 = convert_float(input1.s2); \ + outputv.s2 = convert_float(input1.s4); \ + outputv.s3 = convert_float(input1.s6); \ + outputv.s4 = convert_float(input2.s0); \ + outputv.s5 = convert_float(input2.s2); \ + outputv.s6 = convert_float(input2.s4); \ + outputv.s7 = convert_float(input2.s6); + +#define UNPACK_FLOAT_VEC_2(outputv, input1, input2) \ + outputv.s0 = convert_float(input1.s1); \ + outputv.s1 = convert_float(input1.s3); \ + outputv.s2 = convert_float(input1.s5); \ + outputv.s3 = convert_float(input1.s7); \ + outputv.s4 = convert_float(input2.s1); \ + outputv.s5 = convert_float(input2.s3); \ + outputv.s6 = convert_float(input2.s5); \ + outputv.s7 = convert_float(input2.s7); + +#define UNPACK_HALF_VEC_1(outputv, input1) \ + outputv.s0 = convert_float(input1[0]); \ + outputv.s1 = convert_float(input1[2]); \ + outputv.s2 = convert_float(input1[4]); \ + outputv.s3 = convert_float(input1[6]); \ + outputv.s4 = convert_float(input1[8]); \ + outputv.s5 = convert_float(input1[10]); \ + outputv.s6 = convert_float(input1[12]); \ + outputv.s7 = convert_float(input1[14]); + +#define UNPACK_HALF_VEC_2(outputv, input1) \ + outputv.s0 = convert_float(input1[1]); \ + outputv.s1 = convert_float(input1[3]); \ + outputv.s2 = convert_float(input1[5]); \ + outputv.s3 = convert_float(input1[7]); \ + outputv.s4 = convert_float(input1[9]); \ + outputv.s5 = convert_float(input1[11]); \ + outputv.s6 = convert_float(input1[13]); \ + outputv.s7 = convert_float(input1[15]); + +#define UNPACK_HALF16_VEC_1(outputv, input1, input2) \ + outputv = (half16)(input1[0], \ + input1[2], \ + input1[4], \ + input1[6], \ + input1[8], \ + input1[10], \ + input1[12], \ + input1[14], \ + input2[0], \ + input2[2], \ + input2[4], \ + input2[6], \ + input2[8], \ + input2[10], \ + input2[12], \ + input2[14]); + +#define UNPACK_HALF16_VEC_2(outputv, input1, input2) \ + outputv = (half16)(input1[1], \ + input1[3], \ + input1[5], \ + input1[7], \ + input1[9], \ + input1[11], \ + input1[13], \ + input1[15], \ + input2[1], \ + input2[3], \ + input2[5], \ + input2[7], \ + input2[9], \ + input2[11], \ + input2[13], \ + input2[15]); + +#define PACK_HALF16_VEC_1(outputv, input1, input2) \ + outputv = (half16)(input1[0], \ + input2[0], \ + input1[1], \ + input2[1], \ + input1[2], \ + input2[2], \ + input1[3], \ + input2[3], \ + input1[4], \ + input2[4], \ + input1[5], \ + input2[5], \ + input1[6], \ + input2[6], \ + input1[7], \ + input2[7]); + +#define PACK_HALF16_VEC_2(outputv, input1, input2) \ + outputv = (half16)(input1[8], \ + input2[8], \ + input1[9], \ + input2[9], \ + input1[10], \ + input2[10], \ + input1[11], \ + input2[11], \ + input1[12], \ + input2[12], \ + input1[13], \ + input2[13], \ + input1[14], \ + input2[14], \ + input1[15], \ + input2[15]); + +#ifdef CHATGLM +KERNEL(rope_opt)( + OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos_sin, + __global OUTPUT_TYPE* output) { +#if VEC_SIZE != 1 && VEC_SIZE != 8 && VEC_SIZE != 16 +# error "rope_opt.cl - VEC_SIZE must be one of {1, 8, 16}" +#endif + +#ifdef SUPPORT_2D_ROPE + const uint p = get_global_id(0) / HEAD_COUNT; + const uint h = get_global_id(0) % HEAD_COUNT; + const uint b = get_global_id(1); // sequence length + const uint rf = get_global_id(2); // max(HALF_ROTARY_NDIMS, HEAD_SIZE - ROTARY_NDIMS) + uint output_idx = OUTPUT_GET_INDEX(p, h, b, 0); +#else + const uint p = get_global_id(0); + const uint b = get_global_id(1); + const uint h = (uint)get_global_id(2) % HEAD_COUNT; + const uint rf = (uint)get_global_id(2) / HEAD_COUNT; + uint output_idx = OUTPUT_GET_INDEX(p, b, h, 0); +#endif + + uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 * VEC_SIZE : 0; + uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 * VEC_SIZE : 0; + + uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0); +#ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; +#endif + + uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0; + uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0; + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0); + +#if VEC_SIZE == 1 + float cosv = convert_float(cos_sin[cos_sin_idx + r]); + float sinv = convert_float(cos_sin[cos_sin_idx + r + 1]); + + float in1 = convert_float(input[input_idx + r]); + float in2 = convert_float(input[input_idx + r + 1]); + + output[output_idx + r] = TO_OUTPUT_TYPE(cosv * in1 - sinv * in2); + output[output_idx + r + 1] = TO_OUTPUT_TYPE(sinv * in1 + cosv * in2); + + #ifdef ENABLE_IO_COPY + output[output_idx + ROTARY_NDIMS + f] = input[input_idx + ROTARY_NDIMS + f]; + output[output_idx + ROTARY_NDIMS + f + 1] = input[input_idx + ROTARY_NDIMS + f + 1]; + #endif +#elif VEC_SIZE == 8 + INPUT_VEC_TYPE inv1 = *(INPUT_VEC_TYPE*)(input + input_idx + r); + INPUT_VEC_TYPE inv2 = *(INPUT_VEC_TYPE*)(input + input_idx + r + VEC_SIZE); + INPUT_VEC_TYPE cossinv1 = *(INPUT_VEC_TYPE*)(cos_sin + cos_sin_idx + r); + INPUT_VEC_TYPE cossinv2 = *(INPUT_VEC_TYPE*)(cos_sin + cos_sin_idx + r + VEC_SIZE); + + float8 in1, in2, cosv, sinv; + UNPACK_FLOAT_VEC_1(in1, inv1, inv2); + UNPACK_FLOAT_VEC_2(in2, inv1, inv2); + UNPACK_FLOAT_VEC_1(cosv, cossinv1, cossinv2); + UNPACK_FLOAT_VEC_2(sinv, cossinv1, cossinv2); + float8 out1 = cosv * in1 - sinv * in2; + float8 out2 = sinv * in1 + cosv * in2; + + *(float8*)(output + output_idx + r) = + (float8)(out1.s0, out2.s0, out1.s1, out2.s1, out1.s2, out2.s2, out1.s3, out2.s3); + *(float8*)(output + output_idx + r + VEC_SIZE) = + (float8)(out1.s4, out2.s4, out1.s5, out2.s5, out1.s6, out2.s6, out1.s7, out2.s7); + + #ifdef ENABLE_IO_COPY + *(float8*)(output + output_idx + ROTARY_NDIMS + f) = *(float8*)(input + input_idx + ROTARY_NDIMS + f); + *(float8*)(output + output_idx + ROTARY_NDIMS + f + VEC_SIZE) = + *(float8*)(input + input_idx + ROTARY_NDIMS + f + VEC_SIZE); + #endif +#elif VEC_SIZE == 16 + unroll_for(int i = 0; i < 2; i += 1) { + INPUT_VEC_TYPE inv = *(INPUT_VEC_TYPE*)(input + input_idx + r + i * VEC_SIZE); + INPUT_VEC_TYPE cossinv = *(INPUT_VEC_TYPE*)(cos_sin + cos_sin_idx + r + i * VEC_SIZE); + float8 in1, in2, cosv, sinv; + UNPACK_HALF_VEC_1(in1, inv); + UNPACK_HALF_VEC_2(in2, inv); + UNPACK_HALF_VEC_1(cosv, cossinv); + UNPACK_HALF_VEC_2(sinv, cossinv); + float8 out1 = cosv * in1 - sinv * in2; + float8 out2 = sinv * in1 + cosv * in2; + + unroll_for(int j = 0; j < 8; j += 1) { + output[output_idx + r + i * VEC_SIZE + 2 * j] = TO_OUTPUT_TYPE(out1[j]); + output[output_idx + r + i * VEC_SIZE + 2 * j + 1] = TO_OUTPUT_TYPE(out2[j]); + } + } + #ifdef ENABLE_IO_COPY + *(float8*)(output + output_idx + ROTARY_NDIMS + f) = *(float8*)(input + input_idx + ROTARY_NDIMS + f); + *(float8*)(output + output_idx + ROTARY_NDIMS + f + VEC_SIZE) = + *(float8*)(input + input_idx + ROTARY_NDIMS + f + VEC_SIZE); + #endif +#endif +} +#endif + +#ifdef QWEN +KERNEL(rope_opt)( + OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos, + const __global INPUT2_TYPE* sin, + __global OUTPUT_TYPE* output) { + const uint b = get_global_id(0); + const uint p = get_global_id(1); + const uint h = (uint)get_global_id(2) * VEC_SIZE / HALF_ROTARY_NDIMS; + const uint r = ((uint)get_global_id(2) * VEC_SIZE) % HALF_ROTARY_NDIMS; + + uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0); +#ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; +#endif + + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; + uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM + ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM + : 0; + uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0; + +#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); + + uint cos_idx = cos_sin_idx; + uint sin_idx = cos_sin_idx; +#else + uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); + uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); +#endif + + uint output_idx = OUTPUT_GET_INDEX(b, p, h, 0); + +#if VEC_SIZE == 1 + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r]; + + output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; + + output[output_idx + HALF_ROTARY_NDIMS + r] = + cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 + sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1; +#else + INPUT_VEC_TYPE in1 = *(INPUT_VEC_TYPE*)(input + input_idx + r); + INPUT_VEC_TYPE in2 = *(INPUT_VEC_TYPE*)(input + input_idx + HALF_ROTARY_NDIMS + r); + INPUT_VEC_TYPE cos1 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r); + INPUT_VEC_TYPE cos2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + HALF_ROTARY_NDIMS + r); + INPUT_VEC_TYPE sin1 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r); + INPUT_VEC_TYPE sin2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + HALF_ROTARY_NDIMS + r); + + OUTPUT_VEC_TYPE out1 = cos1 * in1 - sin1 * in2; + OUTPUT_VEC_TYPE out2 = cos2 * in2 + sin2 * in1; + + *(OUTPUT_VEC_TYPE*)(output + output_idx + r) = out1; + *(OUTPUT_VEC_TYPE*)(output + output_idx + HALF_ROTARY_NDIMS + r) = out2; +#endif +} +#endif + +#ifdef RotateHalf +KERNEL(rope_opt) +(OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos, + const __global INPUT2_TYPE* sin, +#ifdef ENABLE_GATHER + const __global INPUT3_TYPE* gather, +#endif + __global OUTPUT_TYPE* output) { + const uint b = get_global_id(0); + const uint h = get_global_id(1); + const uint p = ((uint)get_global_id(2) * VEC_SIZE) / HALF_ROTARY_NDIMS; + const uint r = ((uint)get_global_id(2) * VEC_SIZE) % HALF_ROTARY_NDIMS; + +#if ENABLE_TRANSPOSE + uint input_idx = INPUT0_GET_INDEX(b, p, h, 0); +#else + uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); + #ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; + #endif +#endif + + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; + uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; + uint cos_sin_p = p; +#ifdef ENABLE_GATHER + uint gather_b = b < INPUT3_BATCH_NUM ? b : 0; + #if GATHER_RANK == 4 + uint gather_h = h < INPUT3_FEATURE_NUM ? h : 0; + uint gather_p = p < INPUT3_SIZE_Y ? p : 0; + uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_h, gather_p, 0); + #else + uint gather_p = p < INPUT3_FEATURE_NUM ? p : 0; + uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_p, 0, 0); + #endif + cos_sin_p = gather[gather_idx]; +#endif + cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0; + +#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + + uint cos_idx = cos_sin_idx; + uint sin_idx = cos_sin_idx; +#else + uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); +#endif + + uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); + +#if VEC_SIZE == 1 + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r]; + + output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; + + output[output_idx + HALF_ROTARY_NDIMS + r] = + cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 + sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1; +#else + INPUT_VEC_TYPE in1 = *(INPUT_VEC_TYPE*)(input + input_idx + r); + INPUT_VEC_TYPE in2 = *(INPUT_VEC_TYPE*)(input + input_idx + HALF_ROTARY_NDIMS + r); + INPUT_VEC_TYPE cos1 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r); + INPUT_VEC_TYPE cos2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + HALF_ROTARY_NDIMS + r); + INPUT_VEC_TYPE sin1 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r); + INPUT_VEC_TYPE sin2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + HALF_ROTARY_NDIMS + r); + + OUTPUT_VEC_TYPE out1 = cos1 * in1 - sin1 * in2; + OUTPUT_VEC_TYPE out2 = cos2 * in2 + sin2 * in1; + + *(OUTPUT_VEC_TYPE*)(output + output_idx + r) = out1; + *(OUTPUT_VEC_TYPE*)(output + output_idx + HALF_ROTARY_NDIMS + r) = out2; +#endif +} +#endif + +#ifdef RotateInterleaved +KERNEL(rope_opt)( + OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos, + const __global INPUT2_TYPE* sin, + __global OUTPUT_TYPE* output) { +#if VEC_SIZE != 1 && VEC_SIZE != 8 && VEC_SIZE != 16 +# error "rope_opt.cl - VEC_SIZE must be one of {1, 8, 16}" +#endif + const uint b = get_global_id(0); + const uint h = get_global_id(1); + const uint p = ((uint)get_global_id(2) * VEC_SIZE) / HALF_ROTARY_NDIMS; + const uint r = 2 * (((uint)get_global_id(2) * VEC_SIZE) % HALF_ROTARY_NDIMS); + + uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); + + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; + uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; + uint cos_sin_p = p < INPUT1_SIZE_Y ? p : 0; + +#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + + uint cos_idx = cos_sin_idx; + uint sin_idx = cos_sin_idx; +#else + uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); +#endif + + uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); + +#if VEC_SIZE == 1 + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + r + 1]; + + output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; + output[output_idx + r + 1] = cos[cos_idx + r + 1] * in2 + sin[sin_idx + r + 1] * in1; +#elif VEC_SIZE == 8 + INPUT_VEC_TYPE inv1 = *(INPUT_VEC_TYPE*)(input + input_idx + r); + INPUT_VEC_TYPE inv2 = *(INPUT_VEC_TYPE*)(input + input_idx + r + VEC_SIZE); + INPUT_VEC_TYPE cosv1 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r); + INPUT_VEC_TYPE sinv1 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r); + INPUT_VEC_TYPE cosv2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r + VEC_SIZE); + INPUT_VEC_TYPE sinv2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r + VEC_SIZE); + + float8 in1, in2, cos1, sin1, cos2, sin2; + UNPACK_FLOAT_VEC_1(in1, inv1, inv2); + UNPACK_FLOAT_VEC_2(in2, inv1, inv2); + UNPACK_FLOAT_VEC_1(cos1, cosv1, cosv2); + UNPACK_FLOAT_VEC_2(cos2, cosv1, cosv2); + UNPACK_FLOAT_VEC_1(sin1, sinv1, sinv2); + UNPACK_FLOAT_VEC_2(sin2, sinv1, sinv2); + + float8 out1 = cos1 * in1 - sin1 * in2; + float8 out2 = sin2 * in1 + cos2 * in2; + + *(float8*)(output + output_idx + r) = + (float8)(out1.s0, out2.s0, out1.s1, out2.s1, out1.s2, out2.s2, out1.s3, out2.s3); + *(float8*)(output + output_idx + r + VEC_SIZE) = + (float8)(out1.s4, out2.s4, out1.s5, out2.s5, out1.s6, out2.s6, out1.s7, out2.s7); +#elif VEC_SIZE == 16 + INPUT_VEC_TYPE inv1 = *(INPUT_VEC_TYPE*)(input + input_idx + r); + INPUT_VEC_TYPE inv2 = *(INPUT_VEC_TYPE*)(input + input_idx + r + VEC_SIZE); + INPUT_VEC_TYPE cosv1 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r); + INPUT_VEC_TYPE sinv1 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r); + INPUT_VEC_TYPE cosv2 = *(INPUT_VEC_TYPE*)(cos + cos_idx + r + VEC_SIZE); + INPUT_VEC_TYPE sinv2 = *(INPUT_VEC_TYPE*)(sin + sin_idx + r + VEC_SIZE); + + INPUT_VEC_TYPE in1, in2, cos1, sin1, cos2, sin2; + UNPACK_HALF16_VEC_1(in1, inv1, inv2); + UNPACK_HALF16_VEC_2(in2, inv1, inv2); + UNPACK_HALF16_VEC_1(cos1, cosv1, cosv2); + UNPACK_HALF16_VEC_2(cos2, cosv1, cosv2); + UNPACK_HALF16_VEC_1(sin1, sinv1, sinv2); + UNPACK_HALF16_VEC_2(sin2, sinv1, sinv2); + + half16 out1 = cos1 * in1 - sin1 * in2; + half16 out2 = sin2 * in1 + cos2 * in2; + + half16 outputv1, outputv2; + PACK_HALF16_VEC_1(outputv1, out1, out2); + PACK_HALF16_VEC_2(outputv2, out1, out2); + + *(half16*)(output + output_idx + r) = outputv1; + *(half16*)(output + output_idx + r + VEC_SIZE) = outputv2; +#endif +} +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl deleted file mode 100644 index d429916b46d69a..00000000000000 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "include/fetch_utils.cl" - -#ifdef CHATGLM -KERNEL(rope_ref)( - OPTIONAL_SHAPE_INFO_ARG - const __global INPUT0_TYPE* input, - const __global INPUT1_TYPE* cos_sin, - __global OUTPUT_TYPE* output) -{ -#ifdef SUPPORT_2D_ROPE - const uint p = get_global_id(0) / HEAD_COUNT; - const uint h = get_global_id(0) % HEAD_COUNT; - const uint b = get_global_id(1);//sequence length - const uint rf = get_global_id(2);//max(HALF_ROTARY_NDIMS, HEAD_SIZE - ROTARY_NDIMS) - uint output_idx = OUTPUT_GET_INDEX(p, h, b, 0); -#else - const uint p = get_global_id(0); - const uint b = get_global_id(1); - const uint h = (uint)get_global_id(2) % HEAD_COUNT; - const uint rf = (uint)get_global_id(2) / HEAD_COUNT; - uint output_idx = OUTPUT_GET_INDEX(p, b, h, 0); -#endif - - uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0; - uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 : 0; - - uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0); -#ifdef ENABLE_SLICE - input_idx += SLICED_FROM_START; -#endif - - uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0; - uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0; - uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0); - - float cosv = convert_float(cos_sin[cos_sin_idx + r]); - float sinv = convert_float(cos_sin[cos_sin_idx + r + 1]); - - float in1 = convert_float(input[input_idx + r]); - float in2 = convert_float(input[input_idx + r + 1]); - - output[output_idx + r] = TO_OUTPUT_TYPE(cosv * in1 - sinv * in2); - output[output_idx + r + 1] = TO_OUTPUT_TYPE(sinv * in1 + cosv * in2); - -#ifdef ENABLE_IO_COPY - output[output_idx + ROTARY_NDIMS + f] = input[input_idx + ROTARY_NDIMS + f]; - output[output_idx + ROTARY_NDIMS + f + 1] = input[input_idx + ROTARY_NDIMS + f + 1]; -#endif -} -#endif - -#ifdef QWEN -KERNEL(rope_ref)( - OPTIONAL_SHAPE_INFO_ARG - const __global INPUT0_TYPE* input, - const __global INPUT1_TYPE* cos, - const __global INPUT2_TYPE* sin, - __global OUTPUT_TYPE* output) -{ - const uint b = get_global_id(0); - const uint p = get_global_id(1); - const uint h = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; - const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS; - - uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0); -#ifdef ENABLE_SLICE - input_idx += SLICED_FROM_START; -#endif - - uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; - uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0; - uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0; - -#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS - uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); - - uint cos_idx = cos_sin_idx; - uint sin_idx = cos_sin_idx; -#else - uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); - uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0); -#endif - - uint output_idx = OUTPUT_GET_INDEX(b, p, h, 0); - - INPUT0_TYPE in1 = input[input_idx + r]; - INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r]; - - output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; - - output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 + - sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1; -} -#endif - -#ifdef RotateHalf -KERNEL(rope_ref)( - OPTIONAL_SHAPE_INFO_ARG - const __global INPUT0_TYPE* input, - const __global INPUT1_TYPE* cos, - const __global INPUT2_TYPE* sin, -#ifdef ENABLE_GATHER - const __global INPUT3_TYPE* gather, -#endif - __global OUTPUT_TYPE* output) -{ - const uint b = get_global_id(0); - const uint h = get_global_id(1); - const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; - const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS; - -#if ENABLE_TRANSPOSE - uint input_idx = INPUT0_GET_INDEX(b, p, h, 0); -#else - uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); -#ifdef ENABLE_SLICE - input_idx += SLICED_FROM_START; -#endif -#endif - - uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; - uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; - uint cos_sin_p = p; -#ifdef ENABLE_GATHER - uint gather_b = b < INPUT3_BATCH_NUM ? b : 0; -#if GATHER_RANK == 4 - uint gather_h = h < INPUT3_FEATURE_NUM ? h : 0; - uint gather_p = p < INPUT3_SIZE_Y ? p : 0; - uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_h, gather_p, 0); -#else - uint gather_p = p < INPUT3_FEATURE_NUM ? p : 0; - uint gather_idx = INPUT3_GET_INDEX(gather_b, gather_p, 0, 0); -#endif - cos_sin_p = gather[gather_idx]; -#endif - cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0; - -#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS - uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); - - uint cos_idx = cos_sin_idx; - uint sin_idx = cos_sin_idx; -#else - uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); - uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); -#endif - - uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); - - INPUT0_TYPE in1 = input[input_idx + r]; - INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r]; - - output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; - - output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 + - sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1; -} -#endif - -#ifdef RotateInterleaved -KERNEL(rope_ref)( - OPTIONAL_SHAPE_INFO_ARG - const __global INPUT0_TYPE* input, - const __global INPUT1_TYPE* cos, - const __global INPUT2_TYPE* sin, - __global OUTPUT_TYPE* output) -{ - const uint b = get_global_id(0); - const uint h = get_global_id(1); - const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; - const uint r = 2 * ((uint)get_global_id(2) % HALF_ROTARY_NDIMS); - - uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); - - uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; - uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; - uint cos_sin_p = p < INPUT1_SIZE_Y ? p : 0; - -#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS - uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); - - uint cos_idx = cos_sin_idx; - uint sin_idx = cos_sin_idx; -#else - uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); - uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); -#endif - - uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); - - INPUT0_TYPE in1 = input[input_idx + r]; - INPUT0_TYPE in2 = input[input_idx + r + 1]; - - output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; - output[output_idx + r + 1] = cos[cos_idx + r + 1] * in2 + sin[sin_idx + r + 1] * in1; -} -#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp index 98212254be9e3c..13df539c0c5508 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp @@ -70,7 +70,8 @@ RoPEKernelBase::DispatchData RoPEKernelBase::SetDefault(const rope_params& param if (params.is_qwen) { dispatchData.gws = {input.Batch().v, input.Feature().v, - params.head_cnt * std::max(params.rotary_ndims / 2ul, params.head_size - params.rotary_ndims)}; + params.head_cnt * + std::max(params.rotary_ndims / 2ul, params.head_size - params.rotary_ndims)}; } else if (params.is_chatglm) { if (params.support_2d_rope) { // input [batch_size, seq_length] diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.cpp new file mode 100644 index 00000000000000..40cf9094c458c3 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rope_kernel_opt.h" +#include "kernel_selector_utils.h" +#include + +namespace kernel_selector { +ParamsKey RoPEKernelOpt::GetSupportedKey() const { + ParamsKey k; + k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + k.EnableDynamicShapesSupport(); + return k; +} + +RoPEKernelBase::DispatchData RoPEKernelOpt::SetDefault(const rope_params& params) const { + DispatchData dispatchData; + const auto& input = params.inputs[0]; + const auto& output = params.outputs[0]; + + std::vector> dims_by_gws = { + {Tensor::DataChannelName::BATCH}, + {Tensor::DataChannelName::FEATURE}, + {Tensor::DataChannelName::Y, Tensor::DataChannelName::X}}; + + size_t vec_size = GetVecSize(params); + if (params.is_qwen) { + auto count = params.head_cnt * std::max(params.rotary_ndims / 2ul, params.head_size - params.rotary_ndims); + dispatchData.gws = {input.Batch().v, input.Feature().v, count / vec_size}; + } else if (params.is_chatglm) { + if (params.support_2d_rope) { + // input [batch_size, seq_length] + // output [batch_size, head_count, seq_length, half_rotary_ndims] + dispatchData.gws = {input.Batch().v * params.head_cnt, + input.Feature().v, + params.rotary_ndims / 2ul / vec_size}; + } else { + dispatchData.gws = {input.Batch().v, + input.Feature().v, + params.head_cnt * (params.rotary_ndims / 2ul / vec_size)}; + } + } else { + dispatchData.gws = {output.Batch().v, output.Feature().v, output.Y().v * params.rotary_ndims / 2ul / vec_size}; + } + + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, + params.engineInfo, + input.GetLayout(), + output.GetLayout(), + dims_by_gws); + + return dispatchData; +} + +JitConstants RoPEKernelOpt::GetJitConstants(const rope_params& params, RoPEKernelBase::DispatchData dispatchData) const { + JitConstants jit = RoPEKernelBase::GetJitConstants(params, dispatchData); + + jit.AddConstant(MakeJitConstant("VEC_SIZE", GetVecSize(params))); + return jit; +} + +size_t RoPEKernelOpt::GetVecSize(const rope_params& params) const { + const auto& input = params.inputs[0]; + size_t vec_size = 1; + switch (input.GetDType()) { + case Datatype::F16: + vec_size = 16; + break; + case Datatype::F32: + vec_size = 8; + break; + default: + vec_size = 1; + break; + } + if (params.rotary_ndims % (2 * vec_size) != 0) + vec_size = 1; + + if (params.is_qwen) { + auto count = params.head_cnt * std::max(params.rotary_ndims / 2ul, params.head_size - params.rotary_ndims); + if (count % vec_size != 0) + vec_size = 1; + } + + return vec_size; +} + +KernelsData RoPEKernelOpt::GetKernelsData(const Params& params) const { + return GetCommonKernelsData(params); +} + +KernelsPriority RoPEKernelOpt::GetKernelsPriority(const Params& /*params*/) const { + return FORCE_PRIORITY_8; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.h new file mode 100644 index 00000000000000..91010628595d64 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_opt.h @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "rope_kernel_base.h" + +namespace kernel_selector { +class RoPEKernelOpt : public RoPEKernelBase { +public: + using Parent = RoPEKernelBase; + RoPEKernelOpt() : RoPEKernelBase("rope_opt") {} + virtual ~RoPEKernelOpt() {} + + KernelsData GetKernelsData(const Params& params) const override; + KernelsPriority GetKernelsPriority(const Params& params) const override; + ParamsKey GetSupportedKey() const override; +protected: + JitConstants GetJitConstants(const rope_params& params, DispatchData dispatchData) const override; + DispatchData SetDefault(const rope_params& params) const override; +private: + size_t GetVecSize(const rope_params& params) const; +}; +} // namespace kernel_selector \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp deleted file mode 100644 index 5ec125ef6f083c..00000000000000 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "rope_kernel_ref.h" -#include "kernel_selector_utils.h" -#include - -namespace kernel_selector { -ParamsKey RoPEKernelRef::GetSupportedKey() const { - ParamsKey k; - k.EnableInputDataType(Datatype::INT32); - k.EnableInputDataType(Datatype::F16); - k.EnableInputDataType(Datatype::F32); - k.EnableOutputDataType(Datatype::F16); - k.EnableOutputDataType(Datatype::F32); - k.EnableInputLayout(DataLayout::bfyx); - k.EnableOutputLayout(DataLayout::bfyx); - - k.EnableTensorOffset(); - k.EnableTensorPitches(); - k.EnableBatching(); - k.EnableDifferentTypes(); - k.EnableDynamicShapesSupport(); - return k; -} - -KernelsData RoPEKernelRef::GetKernelsData(const Params& params) const { - return GetCommonKernelsData(params); -} - -KernelsPriority RoPEKernelRef::GetKernelsPriority(const Params& /*params*/) const { - return FORCE_PRIORITY_9; -} -} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h deleted file mode 100644 index ceea2a17720ad1..00000000000000 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_ref.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "rope_kernel_base.h" - -namespace kernel_selector { -class RoPEKernelRef : public RoPEKernelBase { -public: - using Parent = RoPEKernelBase; - RoPEKernelRef() : RoPEKernelBase("rope_ref") {} - virtual ~RoPEKernelRef() {} - - KernelsData GetKernelsData(const Params& params) const override; - KernelsPriority GetKernelsPriority(const Params& params) const override; - ParamsKey GetSupportedKey() const override; -}; -} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp index e5436971b90c09..63063ad2af646b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_selector.cpp @@ -3,11 +3,11 @@ // #include "rope_kernel_selector.h" -#include "rope_kernel_ref.h" +#include "rope_kernel_opt.h" namespace kernel_selector { rope_kernel_selector::rope_kernel_selector() { - Attach(); + Attach(); } KernelsData rope_kernel_selector::GetBestKernels(const Params& params) const { diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index 3f7fe91da86d93..98cb257776c388 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -9,49 +9,65 @@ namespace test { INSTANTIATE_TEST_SUITE_P(smoke_RoPETestFlux, RoPETestFlux, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f16, ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestFlux::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, RoPETestChatGLMStridedSlice, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f16, ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestChatGLMStridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, RoPETestQwen7bStridedSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f16, ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestQwen7bStridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2, RoPETestLlama2StridedSlice, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestLlama2StridedSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestRotateHalfWithoutTranspose, RoPETestRotateHalfWithoutTranspose, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestRotateHalfWithoutTranspose::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, RoPETestChatGLMSlice, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f16, ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestChatGLMSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b, RoPETestQwen7bSlice, ::testing::Combine(::testing::Values(true, false), + ::testing::Values(ov::element::f16, ov::element::f32), ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestQwen7bSlice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2, RoPETestLlama2Slice, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestLlama2Slice::getTestCaseName); INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, RoPETestChatGLM2DRoPEStridedSlice, - ::testing::Values(ov::test::utils::DEVICE_GPU), + ::testing::Combine( + ::testing::Values(ov::element::f16, ov::element::f32), + ::testing::Values(ov::test::utils::DEVICE_GPU)), RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName); diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp index 39cdb871710e64..4fd59740cac549 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp @@ -9,21 +9,25 @@ namespace ov { namespace test { -class RoPETestFlux : public SubgraphBaseTest, public testing::WithParamInterface { +using rope_params = std::tuple; +using rope_params_2 = std::tuple; + +class RoPETestFlux : public SubgraphBaseTest, public testing::WithParamInterface { private: std::shared_ptr build_rope_flux(int batch, int seq_length, int num_head, - int ndims); + int ndims, + ov::element::Type element_type); protected: void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; -class RoPETestLlama2StridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { +class RoPETestLlama2StridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: std::shared_ptr buildROPE_Llama2(int batch, int seq_length, @@ -37,47 +41,48 @@ class RoPETestLlama2StridedSlice : public SubgraphBaseTest, public testing::With void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; -class RoPETestChatGLMStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { +class RoPETestChatGLMStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: - std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims); + std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type); protected: ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1); void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; -class RoPETestQwen7bStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface> { +class RoPETestQwen7bStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: - std::shared_ptr buildROPE_QWen7b(bool specialReshape); + std::shared_ptr buildROPE_QWen7b(bool specialReshape, ov::element::Type element_type); protected: void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo>& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; -class RoPETestGPTJStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface> { +class RoPETestGPTJStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: std::shared_ptr buildROPE_GPTJ(int num_head, int hidden_dims, int rotary_dims, - bool hasShapeOf); + bool hasShapeOf, + ov::element::Type element_type); protected: void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo>& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; -class RoPETestRotateHalfWithoutTranspose : public SubgraphBaseTest, public testing::WithParamInterface { +class RoPETestRotateHalfWithoutTranspose : public SubgraphBaseTest, public testing::WithParamInterface { private: ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1); ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims); @@ -91,7 +96,7 @@ class RoPETestRotateHalfWithoutTranspose : public SubgraphBaseTest, public testi void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; class RoPETestLlama2Slice : public RoPETestLlama2StridedSlice { @@ -107,14 +112,14 @@ class RoPETestLlama2Slice : public RoPETestLlama2StridedSlice { class RoPETestChatGLMSlice : public RoPETestChatGLMStridedSlice { private: - std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims); + std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type); protected: void SetUp() override; }; class RoPETestQwen7bSlice : public RoPETestQwen7bStridedSlice { private: - std::shared_ptr buildROPE_Qwen7b(bool specialReshape); + std::shared_ptr buildROPE_Qwen7b(bool specialReshape, ov::element::Type element_type); protected: void SetUp() override; }; @@ -124,21 +129,22 @@ class RoPETestGPTJSlice : public RoPETestGPTJStridedSlice { std::shared_ptr buildROPE_GPTJ(int num_head, int hidden_dims, int rotary_dims, - bool hasShapeOf); + bool hasShapeOf, + ov::element::Type element_type); protected: void SetUp() override; }; -class RoPETestChatGLM2DRoPEStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { +class RoPETestChatGLM2DRoPEStridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: - std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims); + std::shared_ptr buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type); protected: ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1); void generate_inputs(const std::vector& targetInputStaticShapes) override; void SetUp() override; public: - static std::string getTestCaseName(const testing::TestParamInfo& obj); + static std::string getTestCaseName(const testing::TestParamInfo& obj); }; } // namespace test diff --git a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp index 1a078d9b49ebb7..ca8ecf39fb8276 100644 --- a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp +++ b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp @@ -17,10 +17,11 @@ namespace test { std::shared_ptr RoPETestFlux::build_rope_flux(int batch, int seq_length, int num_head, - int ndims) { - auto x = std::make_shared(ov::element::f32, PartialShape{batch, num_head, seq_length, ndims}); - auto t_cos = std::make_shared(ov::element::f32, PartialShape{1, 1, seq_length, ndims}); - auto t_sin = std::make_shared(ov::element::f32, PartialShape{1, 1, seq_length, ndims}); + int ndims, + ov::element::Type element_type) { + auto x = std::make_shared(element_type, PartialShape{batch, num_head, seq_length, ndims}); + auto t_cos = std::make_shared(element_type, PartialShape{1, 1, seq_length, ndims}); + auto t_sin = std::make_shared(element_type, PartialShape{1, 1, seq_length, ndims}); auto x1_shape = makeConst(element::i64, ov::Shape({5}), {0, num_head, 0, -1, 2}); auto x1 = std::make_shared(x, x1_shape, true); @@ -28,7 +29,7 @@ std::shared_ptr RoPETestFlux::build_rope_flux(int batch, auto split_axis = makeConst(element::i64, ov::Shape(), {-1}); auto split = std::make_shared(x1, split_axis, 2); - auto minus_one = makeConst(element::f32, ov::Shape({}), {-1.0f}); + auto minus_one = makeConst(element_type, ov::Shape({}), {-1.0f}); auto x1_1_neg = std::make_shared(split->output(1), minus_one); auto x2 = std::make_shared(OutputVector{x1_1_neg->output(0), split->output(0)}, -1); @@ -68,9 +69,10 @@ void RoPETestFlux::generate_inputs(const std::vector& targetInputStat } void RoPETestFlux::SetUp() { - targetDevice = this->GetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 1; + const int batch = 128; const int seq_length = 7; const size_t max_position_embeddings = 2048; const size_t ndims = 128; @@ -82,17 +84,18 @@ void RoPETestFlux::SetUp() { {{1, 1, seq_length, ndims}, {{1, 1, seq_length, ndims}}} }; init_input_shapes(input_shapes); - function = build_rope_flux(batch, -1, num_head, ndims); + function = build_rope_flux(batch, -1, num_head, ndims, element_type); } -std::string RoPETestFlux::getTestCaseName(const testing::TestParamInfo& obj) { - std::string targetDevice = obj.param; +std::string RoPETestFlux::getTestCaseName(const testing::TestParamInfo& obj) { + std::string targetDevice; + ov::element::Type element_type; + std::tie(element_type, targetDevice) = obj.param; std::ostringstream result; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice << ",element_type=" << element_type.to_string(); return result.str(); } - ov::OutputVector RoPETestLlama2StridedSlice::makeCosSinCache(int max_position_embeddings, int rotary_ndims) { std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f); @@ -229,9 +232,10 @@ void RoPETestLlama2StridedSlice::generate_inputs(const std::vector& t } void RoPETestLlama2StridedSlice::SetUp() { - targetDevice = this->GetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const size_t max_position_embeddings = 2048; const size_t ndims = 128; @@ -242,16 +246,18 @@ void RoPETestLlama2StridedSlice::SetUp() { function = buildROPE_Llama2(batch, seq_length, max_position_embeddings, num_head, ndims); } -std::string RoPETestLlama2StridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { - std::string targetDevice = obj.param; +std::string RoPETestLlama2StridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type element_type; + std::string targetDevice; + std::tie(element_type, targetDevice) = obj.param; std::ostringstream result; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice << ",element_type=" << element_type.to_string(); return result.str(); } -std::shared_ptr RoPETestChatGLMStridedSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { - auto input = std::make_shared(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256}); - auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); +std::shared_ptr RoPETestChatGLMStridedSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type) { + auto input = std::make_shared(element_type, PartialShape{-1, batch, 4096 + 256 + 256}); + auto cos_sin_cache = std::make_shared(element_type, PartialShape{32768, 32, 2}); auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); auto __module_transformer_index_67_Gather = @@ -371,29 +377,32 @@ void RoPETestChatGLMStridedSlice::generate_inputs(const std::vector& } void RoPETestChatGLMStridedSlice::SetUp() { - targetDevice = this->GetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const int num_head = 32; const int rotary_dims = 64; InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}}; init_input_shapes({inpShape}); - function = buildROPE_ChatGLM(batch, num_head, rotary_dims); + function = buildROPE_ChatGLM(batch, num_head, rotary_dims, element_type); } -std::string RoPETestChatGLMStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { - std::string targetDevice = obj.param; +std::string RoPETestChatGLMStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { + std::string targetDevice; + ov::element::Type element_type; + std::tie(element_type, targetDevice) = obj.param; std::ostringstream result; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice << ",element_type=" << element_type.to_string(); return result.str(); } -std::shared_ptr RoPETestQwen7bStridedSlice::buildROPE_QWen7b(bool specialReshape) { - auto input = std::make_shared(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096}); - auto cos_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); - auto sin_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); +std::shared_ptr RoPETestQwen7bStridedSlice::buildROPE_QWen7b(bool specialReshape, ov::element::Type element_type) { + auto input = std::make_shared(element_type, PartialShape{-1, -1, 4096 + 4096 + 4096}); + auto cos_cache = std::make_shared(element_type, PartialShape{1, -1, 1, 128}); + auto sin_cache = std::make_shared(element_type, PartialShape{1, -1, 1, 128}); auto ListUnpack_389_VariadicSplit = makeOP({input, 2, {4096, 4096, -1}}); auto view_Reshape = @@ -459,7 +468,7 @@ std::shared_ptr RoPETestQwen7bStridedSlice::buildROPE_QWen7b(bool spe 1, }), {-1}); - auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", "f32"}}); + auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", element_type.to_string()}}); auto neg_Multiply_499 = makeOP({ListUnpack_496_Squeeze_0, Constant_296840}, {{"auto_broadcast", "numpy"}}); auto ListUnpack_496_Squeeze = makeOP({ListUnpack_496_Split->output(0), -2}); @@ -507,32 +516,34 @@ void RoPETestQwen7bStridedSlice::generate_inputs(const std::vector& t void RoPETestQwen7bStridedSlice::SetUp() { bool specialReshape; - std::tie(specialReshape, targetDevice) = this->GetParam(); - const int batch = 2; + ov::element::Type element_type; + std::tie(specialReshape, element_type, targetDevice) = this->GetParam(); + const int batch = 128; const int seq_length = 7; InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}}; init_input_shapes({inpShape}); - function = buildROPE_QWen7b(specialReshape); + function = buildROPE_QWen7b(specialReshape, element_type); } -std::string RoPETestQwen7bStridedSlice::getTestCaseName( - const testing::TestParamInfo>& obj) { +std::string RoPETestQwen7bStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { bool specialReshape; + ov::element::Type element_type; std::string targetDevice; - std::tie(specialReshape, targetDevice) = obj.param; + std::tie(specialReshape, element_type, targetDevice) = obj.param; std::ostringstream result; result << "specialReshape=" << specialReshape << "_" - << "targetDevice=" << targetDevice; + << "targetDevice=" << targetDevice << "_element_type=" << element_type.to_string(); return result.str(); } std::shared_ptr RoPETestGPTJStridedSlice::buildROPE_GPTJ(int num_head, int hidden_dims, int rotary_dims, - bool hasShapeOf) { + bool hasShapeOf, + ov::element::Type element_type) { auto int32_max = std::numeric_limits::max(); - auto input = std::make_shared(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims}); - auto sincos = std::make_shared(ov::element::f32, PartialShape{-1, -1, rotary_dims}); + auto input = std::make_shared(element_type, PartialShape{-1, -1, num_head, hidden_dims}); + auto sincos = std::make_shared(element_type, PartialShape{-1, -1, rotary_dims}); auto slice_Slice_965 = makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}}, {{"begin_mask", {1, 1, 1, 0}}, @@ -555,7 +566,7 @@ std::shared_ptr RoPETestGPTJStridedSlice::buildROPE_GPTJ(int num_head } auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast(rotary_dims)}), gather_idx); - auto constant_155588 = makeConst(element::f32, + auto constant_155588 = makeConst(element_type, ov::Shape({ 1, 1, @@ -630,11 +641,11 @@ void RoPETestGPTJStridedSlice::generate_inputs(const std::vector& tar inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); } -std::string RoPETestGPTJStridedSlice::getTestCaseName( - const testing::TestParamInfo>& obj) { +std::string RoPETestGPTJStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { bool hasShapeOf; + ov::element::Type element_type; std::string targetDevice; - std::tie(hasShapeOf, targetDevice) = obj.param; + std::tie(hasShapeOf, element_type, targetDevice) = obj.param; std::ostringstream result; result << "hasShapeOf=" << hasShapeOf << "_" << "targetDevice=" << targetDevice; @@ -643,9 +654,10 @@ std::string RoPETestGPTJStridedSlice::getTestCaseName( void RoPETestGPTJStridedSlice::SetUp() { bool hasShapeOf; - std::tie(hasShapeOf, targetDevice) = this->GetParam(); + ov::element::Type element_type; + std::tie(hasShapeOf, element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const int num_head = 16; const int hidden_dims = 256; @@ -654,7 +666,7 @@ void RoPETestGPTJStridedSlice::SetUp() { InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}}; InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}}; init_input_shapes({input, sincos}); - function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf); + function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf, element_type); } ov::OutputVector RoPETestRotateHalfWithoutTranspose::makeCosSinCache(int max_position_embeddings, int rotary_ndims) { @@ -791,9 +803,10 @@ void RoPETestRotateHalfWithoutTranspose::generate_inputs(const std::vectorGetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const size_t max_position_embeddings = 2048; const size_t ndims = 128; @@ -804,17 +817,20 @@ void RoPETestRotateHalfWithoutTranspose::SetUp() { function = buildROPE_RotateHalfWithoutTranspose(batch, seq_length, max_position_embeddings, num_head, ndims); } -std::string RoPETestRotateHalfWithoutTranspose::getTestCaseName(const testing::TestParamInfo& obj) { - std::string targetDevice = obj.param; +std::string RoPETestRotateHalfWithoutTranspose::getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type element_type; + std::string targetDevice; + std::tie(element_type, targetDevice) = obj.param; std::ostringstream result; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice <<"_element_type=" << element_type.to_string(); return result.str(); } void RoPETestLlama2Slice::SetUp() { - targetDevice = this->GetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const size_t max_position_embeddings = 2048; const size_t ndims = 128; @@ -879,21 +895,22 @@ std::shared_ptr RoPETestLlama2Slice::buildROPE_Llama2(int batch, } void RoPETestChatGLMSlice::SetUp() { - targetDevice = this->GetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const int num_head = 32; const int rotary_dims = 64; InputShape inpShape = {{-1, batch, 4096 + 256 + 256}, {{seq_length, batch, 4096 + 256 + 256}}}; init_input_shapes({inpShape}); - function = RoPETestChatGLMSlice::buildROPE_ChatGLM(batch, num_head, rotary_dims); + function = RoPETestChatGLMSlice::buildROPE_ChatGLM(batch, num_head, rotary_dims, element_type); } -std::shared_ptr RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { - auto input = std::make_shared(ov::element::f32, PartialShape{-1, batch, 4096 + 256 + 256}); - auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); +std::shared_ptr RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type) { + auto input = std::make_shared(element_type, PartialShape{-1, batch, 4096 + 256 + 256}); + auto cos_sin_cache = std::make_shared(element_type, PartialShape{32768, 32, 2}); auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); auto __module_transformer_index_67_Gather = @@ -962,18 +979,19 @@ std::shared_ptr RoPETestChatGLMSlice::buildROPE_ChatGLM(int batch, in void RoPETestQwen7bSlice::SetUp() { bool specialReshape; - std::tie(specialReshape, targetDevice) = this->GetParam(); - const int batch = 2; + ov::element::Type element_type; + std::tie(specialReshape, element_type, targetDevice) = this->GetParam(); + const int batch = 128; const int seq_length = 7; InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}}; init_input_shapes({inpShape}); - function = RoPETestQwen7bSlice::buildROPE_Qwen7b(specialReshape); + function = RoPETestQwen7bSlice::buildROPE_Qwen7b(specialReshape, element_type); } -std::shared_ptr RoPETestQwen7bSlice::buildROPE_Qwen7b(bool specialReshape) { - auto input = std::make_shared(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096}); - auto cos_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); - auto sin_cache = std::make_shared(ov::element::f32, PartialShape{1, -1, 1, 128}); +std::shared_ptr RoPETestQwen7bSlice::buildROPE_Qwen7b(bool specialReshape, ov::element::Type element_type) { + auto input = std::make_shared(element_type, PartialShape{-1, -1, 4096 + 4096 + 4096}); + auto cos_cache = std::make_shared(element_type, PartialShape{1, -1, 1, 128}); + auto sin_cache = std::make_shared(element_type, PartialShape{1, -1, 1, 128}); auto ListUnpack_389_VariadicSplit = makeOP({input, 2, {4096, 4096, -1}}); auto view_Reshape = @@ -1013,7 +1031,7 @@ std::shared_ptr RoPETestQwen7bSlice::buildROPE_Qwen7b(bool specialRes 1, }), {-1}); - auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", "f32"}}); + auto Constant_296840 = makeOP({Constant_296840_compressed}, {{"destination_type", element_type.to_string()}}); auto neg_Multiply_499 = makeOP({ListUnpack_496_Squeeze_0, Constant_296840}, {{"auto_broadcast", "numpy"}}); auto ListUnpack_496_Squeeze = makeOP({ListUnpack_496_Split->output(0), -2}); @@ -1026,9 +1044,10 @@ std::shared_ptr RoPETestQwen7bSlice::buildROPE_Qwen7b(bool specialRes void RoPETestGPTJSlice::SetUp() { bool hasShapeOf; - std::tie(hasShapeOf, targetDevice) = this->GetParam(); + ov::element::Type element_type; + std::tie(hasShapeOf, element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const int num_head = 16; const int hidden_dims = 256; @@ -1037,16 +1056,17 @@ void RoPETestGPTJSlice::SetUp() { InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}}; InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}}; init_input_shapes({input, sincos}); - function = RoPETestGPTJSlice::buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf); + function = RoPETestGPTJSlice::buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf, element_type); } std::shared_ptr RoPETestGPTJSlice::buildROPE_GPTJ(int num_head, int hidden_dims, int rotary_dims, - bool hasShapeOf) { + bool hasShapeOf, + ov::element::Type element_type) { auto int32_max = std::numeric_limits::max(); - auto input = std::make_shared(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims}); - auto sincos = std::make_shared(ov::element::f32, PartialShape{-1, -1, rotary_dims}); + auto input = std::make_shared(element_type, PartialShape{-1, -1, num_head, hidden_dims}); + auto sincos = std::make_shared(element_type, PartialShape{-1, -1, rotary_dims}); auto slice_Slice_965 = makeOP({input, {0}, {rotary_dims}, {1}, {3}}); slice_Slice_965->set_friendly_name("slice_Slice_965"); @@ -1064,7 +1084,7 @@ std::shared_ptr RoPETestGPTJSlice::buildROPE_GPTJ(int num_head, } auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast(rotary_dims)}), gather_idx); - auto constant_155588 = makeConst(element::f32, + auto constant_155588 = makeConst(element_type, ov::Shape({ 1, 1, @@ -1107,9 +1127,9 @@ std::shared_ptr RoPETestGPTJSlice::buildROPE_GPTJ(int num_head, return std::make_shared(model_output, ov::ParameterVector{input, sincos}); } -std::shared_ptr RoPETestChatGLM2DRoPEStridedSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims) { - auto input = std::make_shared(ov::element::f32, PartialShape{batch, -1, 4096 + 256 + 256}); - auto cos_sin_cache = std::make_shared(ov::element::f32, PartialShape{32768, 32, 2}); +std::shared_ptr RoPETestChatGLM2DRoPEStridedSlice::buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims, ov::element::Type element_type) { + auto input = std::make_shared(element_type, PartialShape{batch, -1, 4096 + 256 + 256}); + auto cos_sin_cache = std::make_shared(element_type, PartialShape{32768, 32, 2}); auto position_ids = std::make_shared(ov::element::i32, PartialShape{-1, -1}); auto __module_transformer_index_67_Gather = @@ -1212,22 +1232,25 @@ void RoPETestChatGLM2DRoPEStridedSlice::generate_inputs(const std::vectorGetParam(); + ov::element::Type element_type; + std::tie(element_type, targetDevice) = this->GetParam(); - const int batch = 2; + const int batch = 128; const int seq_length = 7; const int num_head = 32; const int rotary_dims = 64; InputShape inpShape = {{batch, -1, 4096 + 256 + 256}, {{batch, seq_length, 4096 + 256 + 256}}}; init_input_shapes({inpShape}); - function = buildROPE_ChatGLM(-1, num_head, rotary_dims); + function = buildROPE_ChatGLM(-1, num_head, rotary_dims, element_type); } -std::string RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { - std::string targetDevice = obj.param; +std::string RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type element_type; + std::string targetDevice; + std::tie(element_type, targetDevice) = obj.param; std::ostringstream result; - result << "targetDevice=" << targetDevice; + result << "targetDevice=" << targetDevice << "_element_type=" << element_type.to_string(); return result.str(); }