Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Fix MLP segment fault if a new larger scratch created #25930

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,21 @@ void GateUpCombine::generate() {
const auto zmm_up = zmm0;
const auto ymm_dst = ymm5;

// when save_state is false, push/pop will not be generated.
auto injector = std::make_shared<jit_uni_eltwise_injector_f32<dnnl::impl::cpu::x64::avx512_core>>(
this,
m_act_alg,
1.f,
1.0f,
1.f,
false, // save_state, state will be saved in our function
true, // save_state, true due to additional r15 is used.
Xbyak::Reg64(Xbyak::Operand::R10), // p_table
Xbyak::Opmask(1), // k_mask
true, // is_fwd
false, // use_dst
false, // preserve_vmm
false); // preserve_p_table
false); // preserve_p_table, false due to it will be saved in the function

push(r10);
xor_(loop_i, loop_i);
injector->load_table_addr();

Expand All @@ -317,6 +317,7 @@ void GateUpCombine::generate() {
cmp(loop_i, BN);
jl(loop_begin, T_NEAR);

pop(r10);
ret();

injector->prepare_table();
Expand Down
13 changes: 9 additions & 4 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct LLMMLP::Impl {
const LLMMLPNode::Config m_config;
DnnlScratchPadPtr m_scrachPad;
MemoryPtr m_scratchMem;
uint8_t* m_scratch_base = nullptr;

Linear gate_up;
Linear down;
Expand Down Expand Up @@ -200,7 +201,11 @@ struct LLMMLP::Impl {
}

void setM(int M) {
if (m_M < M) {
uint8_t* cur_scratch_base = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the memory pointer is ambiguous. The condition behind is that the scratch buffer isn't big enough. Could you check why the scratch buffer is not big enough ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is not big enough scratch here, it comes from situation:
1, each mlp layer will create a Memory object using a same scratch size such as 4M, then uses the scratch pointer to initialize the class member m_actUp, m_tempC, actually because the size is same, the pointer is same too.
2, some layers such as last Matmul lm_head may need a bigger scratch then scratch is re-created and the pointers used in m_actUp, m_tempC become invalid.

Here use pointer to detect the condition: changed scratch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does different MLP layers share same LLMMLP executor even if they have different M, K, N inside ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no executor, the jit kernel which is a global variable does not use M, N, K to generate the kernel code.

if (m_scratchMem)
cur_scratch_base = m_scratchMem->getDataAs<uint8_t>();
// new M larger than previous or the scratch pointer is changed after the following allocation
if (m_M < M || cur_scratch_base != m_scratch_base) {
size_t total_scratch_size = M * m_N * sizeof(ov::bfloat16);
std::vector<size_t> scratch_offsets;
std::vector<size_t> scratch_C_sizes;
Expand All @@ -214,11 +219,11 @@ struct LLMMLP::Impl {
auto newMemDesc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::u8, Shape{total_scratch_size});
m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc);

auto* scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_actUp.resize<ov::bfloat16>({static_cast<size_t>(M), static_cast<size_t>(m_N)}, reinterpret_cast<ov::bfloat16*>(scratch_base));
m_scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_actUp.resize<ov::bfloat16>({static_cast<size_t>(M), static_cast<size_t>(m_N)}, reinterpret_cast<ov::bfloat16*>(m_scratch_base));

for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) {
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(scratch_base + scratch_offsets[ithr]));
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(m_scratch_base + scratch_offsets[ithr]));
}
m_M = M;
}
Expand Down
11 changes: 8 additions & 3 deletions src/plugins/intel_cpu/src/nodes/qkv_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct QKVProjection::Impl {
QKVProjection * m_node;
DnnlScratchPadPtr m_scrachPad;
MemoryPtr m_scratchMem;
uint8_t* m_scratch_base = nullptr;
int m_M;

Impl(QKVProjection * pnode, DnnlScratchPadPtr scrachPad) : m_node(pnode), m_scrachPad(scrachPad) {
Expand Down Expand Up @@ -116,7 +117,11 @@ struct QKVProjection::Impl {
}

void setM(int M) {
if (m_M < M) {
uint8_t* cur_scratch_base = nullptr;
if (m_scratchMem)
zhangYiIntel marked this conversation as resolved.
Show resolved Hide resolved
cur_scratch_base = m_scratchMem->getDataAs<uint8_t>();
// new M larger than previous or the scratch pointer is changed after the following allocation
if (m_M < M || cur_scratch_base != m_scratch_base) {
size_t total_scratch_size = 0;
std::vector<size_t> scratch_offsets;
std::vector<size_t> scratch_C_sizes;
Expand All @@ -130,9 +135,9 @@ struct QKVProjection::Impl {
auto newMemDesc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::u8, Shape{total_scratch_size});
m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc);

auto* scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_scratch_base = m_scratchMem->getDataAs<uint8_t>();
for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) {
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(scratch_base + scratch_offsets[ithr]));
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(m_scratch_base + scratch_offsets[ithr]));
}

m_M = M;
Expand Down
Loading