diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp index ac95b0f31213de..f089b67a122beb 100644 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -244,34 +244,67 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor { if (m_config.slice_stop - m_config.slice_start > 0) { t_src = t_src.slice(2, m_config.slice_start, m_config.slice_stop); } - auto seq_len = t_src.size(0); - auto batch_size = t_src.size(1); - - auto head_cnt = m_config.head_cnt; - auto head_size = m_config.head_size; - - auto rotary_dims = m_config.rotary_ndims; - - parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { - auto* src = t_src.ptr(p, b, h * head_size); - // [length, batch_size, ndims//2, 2] - auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); - auto* dst = t_dst.ptr(p, b, h, 0); + if (m_config.support_2d_rope) { + // src [batch, length, H x S] + auto seq_len = t_src.size(1); + auto batch_size = t_src.size(0); + + auto head_cnt = m_config.head_cnt; + auto head_size = m_config.head_size; + + auto rotary_dims = m_config.rotary_ndims; + + parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) { + // src [batch, length, H x S] + auto* src = t_src.ptr(b, p, h * head_size); + // [batch_size, length, ndims//2, 2] + auto* cos_sin = &t_cos_sin.at({b, p, 0, 0}, true); + auto* dst = t_dst.ptr(b, h, p, 0); + + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); + } else { + size_t i = 0; + for (; i < rotary_dims; i += 2) { + auto cosv = cos_sin[i]; + auto sinv = cos_sin[i + 1]; + dst[i] = cosv * src[i] - sinv * src[i + 1]; + dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + } + } - if (m_rotaryKernel) { - execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); - } else { - size_t i = 0; - for (; i < rotary_dims; i += 2) { - auto cosv = cos_sin[i]; - auto sinv = cos_sin[i + 1]; - dst[i] = cosv * src[i] - sinv * src[i + 1]; - dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); + }); + } else { + auto seq_len = t_src.size(0); + auto batch_size = t_src.size(1); + + auto head_cnt = m_config.head_cnt; + auto head_size = m_config.head_size; + + auto rotary_dims = m_config.rotary_ndims; + + parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) { + auto* src = t_src.ptr(p, b, h * head_size); + // [length, batch_size, ndims//2, 2] + auto* cos_sin = &t_cos_sin.at({p, b, 0, 0}, true); + auto* dst = t_dst.ptr(p, b, h, 0); + + if (m_rotaryKernel) { + execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr); + } else { + size_t i = 0; + for (; i < rotary_dims; i += 2) { + auto cosv = cos_sin[i]; + auto sinv = cos_sin[i + 1]; + dst[i] = cosv * src[i] - sinv * src[i + 1]; + dst[i + 1] = sinv * src[i] + cosv * src[i + 1]; + } } - } - memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); - }); + memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T)); + }); + } } }; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 0e683482a97934..04808baaebec54 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -835,8 +835,8 @@ void Transformations::PostLpt() { // Execute before snippets. Otherwise FQ will be converted to Subgraph CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn); - CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion); - CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true); + CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true); CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion); // MLP & QKV fusion optimizations is focused on throughput, only enabled on AMX-bf16 & LLM serving use cases. diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index 7fd916e4300768..8cd8707e047878 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -50,5 +50,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJSlice, ::testing::Combine(::testing::Values(true, false), ::testing::Values(ov::test::utils::DEVICE_CPU)), RoPETestGPTJSlice::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, + RoPETestChatGLM2DRoPEStridedSlice, + ::testing::Values(ov::test::utils::DEVICE_CPU), + RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName); + } // namespace test } // namespace ov