Skip to content

Commit

Permalink
[CPU]support glm4 rope (#27094)
Browse files Browse the repository at this point in the history
### Details:
 - *Support Rope kernel of GLM4*
- *the input data order has changed from (**[seq_length, batch, 4608]**)
in **ChatGLM3** to (**[batch, seq_length, 4608]**) in **ChatGLM4**.
Within RoPE process, the data order changes from (**[seq_length, batch,
head_count, head_size]**) to (**[batch, head_count, seq_length,
head_size]**) by permute operation added in **ChatGLM4**.*
- *the RoPE cache data order has changed from (**[seq_length, batch,
head_count, 2]**) in ChatGLM3 to (**[batch, head_count, seq_length,
2]**) in **ChatGLM4**.*
- *Consequently, the output of RoPE has also changed from
(**[seq_length, batch, head_count, head_size]**) in **ChatGLM3** to
(**[batch, head_count, seq_length, head_size]**) in **ChatGLM4***
- *Due to these changes, the RoPE pattern matching needs to create
something new, something different from what already existed ChatGLM
pattern matching. Additionally, new kernels need to be added to
accommodate these changes*

### Tickets:
 - *ticket-id*
  • Loading branch information
zhangYiIntel authored and pull[bot] committed Nov 27, 2024
1 parent 6381b01 commit 1540832
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 27 deletions.
83 changes: 58 additions & 25 deletions src/plugins/intel_cpu/src/nodes/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,34 +244,67 @@ struct RoPE::RoPEExecutorChatGLM : public RoPE::Executor {
if (m_config.slice_stop - m_config.slice_start > 0) {
t_src = t_src.slice(2, m_config.slice_start, m_config.slice_stop);
}
auto seq_len = t_src.size(0);
auto batch_size = t_src.size(1);

auto head_cnt = m_config.head_cnt;
auto head_size = m_config.head_size;

auto rotary_dims = m_config.rotary_ndims;

parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) {
auto* src = t_src.ptr<T>(p, b, h * head_size);
// [length, batch_size, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at<float>({p, b, 0, 0}, true);
auto* dst = t_dst.ptr<T>(p, b, h, 0);
if (m_config.support_2d_rope) {
// src [batch, length, H x S]
auto seq_len = t_src.size(1);
auto batch_size = t_src.size(0);

auto head_cnt = m_config.head_cnt;
auto head_size = m_config.head_size;

auto rotary_dims = m_config.rotary_ndims;

parallel_for3d(batch_size, head_cnt, seq_len, [&](size_t b, size_t h, size_t p) {
// src [batch, length, H x S]
auto* src = t_src.ptr<T>(b, p, h * head_size);
// [batch_size, length, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at<float>({b, p, 0, 0}, true);
auto* dst = t_dst.ptr<T>(b, h, p, 0);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
} else {
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
}
}

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
} else {
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
} else {
auto seq_len = t_src.size(0);
auto batch_size = t_src.size(1);

auto head_cnt = m_config.head_cnt;
auto head_size = m_config.head_size;

auto rotary_dims = m_config.rotary_ndims;

parallel_for3d(seq_len, batch_size, head_cnt, [&](size_t p, size_t b, size_t h) {
auto* src = t_src.ptr<T>(p, b, h * head_size);
// [length, batch_size, ndims//2, 2]
auto* cos_sin = &t_cos_sin.at<float>({p, b, 0, 0}, true);
auto* dst = t_dst.ptr<T>(p, b, h, 0);

if (m_rotaryKernel) {
execJitKernel(m_rotaryKernel, src, dst, cos_sin, nullptr);
} else {
size_t i = 0;
for (; i < rotary_dims; i += 2) {
auto cosv = cos_sin[i];
auto sinv = cos_sin[i + 1];
dst[i] = cosv * src[i] - sinv * src[i + 1];
dst[i + 1] = sinv * src[i] + cosv * src[i + 1];
}
}
}

memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
memcpy(dst + rotary_dims, src + rotary_dims, (head_size - rotary_dims) * sizeof(T));
});
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,8 @@ void Transformations::PostLpt() {
// Execute before snippets. Otherwise FQ will be converted to Subgraph
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);

CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion);
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true);
CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion);

// MLP & QKV fusion optimizations is focused on throughput, only enabled on AMX-bf16 & LLM serving use cases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJSlice,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestGPTJSlice::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
RoPETestChatGLM2DRoPEStridedSlice,
::testing::Values(ov::test::utils::DEVICE_CPU),
RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName);

} // namespace test
} // namespace ov

0 comments on commit 1540832

Please sign in to comment.