From 1540832d6cc528703dd63e5d69c8e33bc866ec80 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 17 Oct 2024 15:43:10 +0800 Subject: [PATCH] [CPU]support glm4 rope (#27094) ### Details: - *Support Rope kernel of GLM4* - *the input data order has changed from (**[seq_length, batch, 4608]**) in **ChatGLM3** to (**[batch, seq_length, 4608]**) in **ChatGLM4**. Within RoPE process, the data order changes from (**[seq_length, batch, head_count, head_size]**) to (**[batch, head_count, seq_length, head_size]**) by permute operation added in **ChatGLM4**.* - *the RoPE cache data order has changed from (**[seq_length, batch, head_count, 2]**) in ChatGLM3 to (**[batch, head_count, seq_length, 2]**) in **ChatGLM4**.* - *Consequently, the output of RoPE has also changed from (**[seq_length, batch, head_count, head_size]**) in **ChatGLM3** to (**[batch, head_count, seq_length, head_size]**) in **ChatGLM4*** - *Due to these changes, the RoPE pattern matching needs to create something new, something different from what already existed ChatGLM pattern matching. Additionally, new kernels need to be added to accommodate these changes* ### Tickets: - *ticket-id* --- src/plugins/intel_cpu/src/nodes/rope.cpp | 83 +++++++++++++------ .../transformation_pipeline.cpp | 4 +- .../subgraph_tests/rotary_pos_emb.cpp | 6 ++ 3 files changed, 66 insertions(+), 27 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp index ac95b0f31213de..f089b67a122beb 100644 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -244,34 +244,67 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor { if (m_config.slice_stop - m_config.slice_start > 0) { t_src = t_src.slice(2, m_config.slice_start, m_config.slice_stop); } - auto seq_len = t_src.size(0); - auto batch_size = t_src.size(1); - - auto head_cnt = m_config.head_cnt; - auto head_size = m_config.head_size; - - auto rotary_dims = m_config.rotary_ndims; - - parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { - auto* src = t_src.ptr(p, b, h * head_size); - // [length, batch_size, ndims//2, 2] - auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); - auto* dst = t_dst.ptr(p, b, h, 0); + if (m_config.support_2d_rope) { + // src [batch, length, H x S] + auto seq_len = t_src.size(1); + auto batch_size = t_src.size(0); + + auto head_cnt = m_config.head_cnt; + auto head_size = m_config.head_size; + + auto rotary_dims = m_config.rotary_ndims; + + parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) { + // src [batch, length, H x S] + auto* src = t_src.ptr(b, p, h * head_size); + // [batch_size, length, ndims//2, 2] + auto* cos_sin = &t_cos_sin.at({b, p, 0, 0}, true); + auto* dst = t_dst.ptr(b, h, p, 0); + + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); + } else { + size_t i = 0; + for (; i < rotary_dims; i += 2) { + auto cosv = cos_sin[i]; + auto sinv = cos_sin[i + 1]; + dst[i] = cosv * src[i] - sinv * src[i + 1]; + dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + } + } - if (m_rotaryKernel) { - execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); - } else { - size_t i = 0; - for (; i < rotary_dims; i += 2) { - auto cosv = cos_sin[i]; - auto sinv = cos_sin[i + 1]; - dst[i] = cosv * src[i] - sinv * src[i + 1]; - dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); + }); + } else { + auto seq_len = t_src.size(0); + auto batch_size = t_src.size(1); + + auto head_cnt = m_config.head_cnt; + auto head_size = m_config.head_size; + + auto rotary_dims = m_config.rotary_ndims; + + parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { + auto* src = t_src.ptr(p, b, h * head_size); + // [length, batch_size, ndims//2, 2] + auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); + auto* dst = t_dst.ptr(p, b, h, 0); + + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); + } else { + size_t i = 0; + for (; i < rotary_dims; i += 2) { + auto cosv = cos_sin[i]; + auto sinv = cos_sin[i + 1]; + dst[i] = cosv * src[i] - sinv * src[i + 1]; + dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + } } - } - memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); - }); + memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); + }); + } } }; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 0e683482a97934..04808baaebec54 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -835,8 +835,8 @@ void Transformations::PostLpt() { // Execute before snippets. Otherwise FQ will be converted to Subgraph CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn); - CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion); - CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true); + CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true); CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion); // MLP & QKV fusion optimizations is focused on throughput, only enabled on AMX-bf16 & LLM serving use cases. diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index 7fd916e4300768..8cd8707e047878 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -50,5 +50,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJSlice, ::testing::Combine(::testing::Values(true, false), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestGPTJSlice::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, + RoPETestChatGLM2DRoPEStridedSlice, + ::testing::Values(ov::test::utils::DEVICE_CPU), + RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName); + } // namespace test } // namespace ov