Skip to content

Commit

Permalink
[CPU] Fix MLP segment fault if a new larger scratch created (openvino…
Browse files Browse the repository at this point in the history
…toolkit#25930)

### Details:
 - *Fix MLP segment fault may be caused by*
   -  if a new larger scratch created, the cached one is invalid
- Silu injector in
[master](https://github.com/openvinotoolkit/oneDNN/blame/6b99866a4531e38a74d1de36d5b366c54c5e6cc3/src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp#L175-L188)
will use r15 but currently not protect. The injector behavior changes in
master, does not affect releases/2024/3.
 - *...*

### Tickets:
 - *[148743](https://jira.devtools.intel.com/browse/CVS-148743)*
  • Loading branch information
luo-cheng2021 authored and mory91 committed Aug 13, 2024
1 parent 63bc4ca commit 868b046
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,21 @@ void GateUpCombine::generate() {
const auto zmm_up = zmm0;
const auto ymm_dst = ymm5;

// when save_state is false, push/pop will not be generated.
auto injector = std::make_shared<jit_uni_eltwise_injector_f32<dnnl::impl::cpu::x64::avx512_core>>(
this,
m_act_alg,
1.f,
1.0f,
1.f,
false, // save_state, state will be saved in our function
true, // save_state, true due to additional r15 is used.
Xbyak::Reg64(Xbyak::Operand::R10), // p_table
Xbyak::Opmask(1), // k_mask
true, // is_fwd
false, // use_dst
false, // preserve_vmm
false); // preserve_p_table
false); // preserve_p_table, false due to it will be saved in the function

push(r10);
xor_(loop_i, loop_i);
injector->load_table_addr();

Expand All @@ -317,6 +317,7 @@ void GateUpCombine::generate() {
cmp(loop_i, BN);
jl(loop_begin, T_NEAR);

pop(r10);
ret();

injector->prepare_table();
Expand Down
13 changes: 9 additions & 4 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct LLMMLP::Impl {
const LLMMLPNode::Config m_config;
DnnlScratchPadPtr m_scrachPad;
MemoryPtr m_scratchMem;
uint8_t* m_scratch_base = nullptr;

Linear gate_up;
Linear down;
Expand Down Expand Up @@ -200,7 +201,11 @@ struct LLMMLP::Impl {
}

void setM(int M) {
if (m_M < M) {
uint8_t* cur_scratch_base = nullptr;
if (m_scratchMem)
cur_scratch_base = m_scratchMem->getDataAs<uint8_t>();
// new M larger than previous or the scratch pointer is changed after the following allocation
if (m_M < M || cur_scratch_base != m_scratch_base) {
size_t total_scratch_size = M * m_N * sizeof(ov::bfloat16);
std::vector<size_t> scratch_offsets;
std::vector<size_t> scratch_C_sizes;
Expand All @@ -214,11 +219,11 @@ struct LLMMLP::Impl {
auto newMemDesc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::u8, Shape{total_scratch_size});
m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc);

auto* scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_actUp.resize<ov::bfloat16>({static_cast<size_t>(M), static_cast<size_t>(m_N)}, reinterpret_cast<ov::bfloat16*>(scratch_base));
m_scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_actUp.resize<ov::bfloat16>({static_cast<size_t>(M), static_cast<size_t>(m_N)}, reinterpret_cast<ov::bfloat16*>(m_scratch_base));

for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) {
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(scratch_base + scratch_offsets[ithr]));
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(m_scratch_base + scratch_offsets[ithr]));
}
m_M = M;
}
Expand Down
11 changes: 8 additions & 3 deletions src/plugins/intel_cpu/src/nodes/qkv_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct QKVProjection::Impl {
QKVProjection * m_node;
DnnlScratchPadPtr m_scrachPad;
MemoryPtr m_scratchMem;
uint8_t* m_scratch_base = nullptr;
int m_M;

Impl(QKVProjection * pnode, DnnlScratchPadPtr scrachPad) : m_node(pnode), m_scrachPad(scrachPad) {
Expand Down Expand Up @@ -116,7 +117,11 @@ struct QKVProjection::Impl {
}

void setM(int M) {
if (m_M < M) {
uint8_t* cur_scratch_base = nullptr;
if (m_scratchMem)
cur_scratch_base = m_scratchMem->getDataAs<uint8_t>();
// new M larger than previous or the scratch pointer is changed after the following allocation
if (m_M < M || cur_scratch_base != m_scratch_base) {
size_t total_scratch_size = 0;
std::vector<size_t> scratch_offsets;
std::vector<size_t> scratch_C_sizes;
Expand All @@ -130,9 +135,9 @@ struct QKVProjection::Impl {
auto newMemDesc = std::make_shared<CpuBlockedMemoryDesc>(ov::element::u8, Shape{total_scratch_size});
m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc);

auto* scratch_base = m_scratchMem->getDataAs<uint8_t>();
m_scratch_base = m_scratchMem->getDataAs<uint8_t>();
for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) {
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(scratch_base + scratch_offsets[ithr]));
m_tempC[ithr].resize<float>({1, scratch_C_sizes[ithr]}, reinterpret_cast<float*>(m_scratch_base + scratch_offsets[ithr]));
}

m_M = M;
Expand Down

0 comments on commit 868b046

Please sign in to comment.