diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp index 3b1d24fcbf8a0b..2564c25ade1731 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp @@ -281,21 +281,21 @@ void GateUpCombine::generate() { const auto zmm_up = zmm0; const auto ymm_dst = ymm5; - // when save_state is false, push/pop will not be generated. auto injector = std::make_shared>( this, m_act_alg, 1.f, 1.0f, 1.f, - false, // save_state, state will be saved in our function + true, // save_state, true due to additional r15 is used. Xbyak::Reg64(Xbyak::Operand::R10), // p_table Xbyak::Opmask(1), // k_mask true, // is_fwd false, // use_dst false, // preserve_vmm - false); // preserve_p_table + false); // preserve_p_table, false due to it will be saved in the function + push(r10); xor_(loop_i, loop_i); injector->load_table_addr(); @@ -317,6 +317,7 @@ void GateUpCombine::generate() { cmp(loop_i, BN); jl(loop_begin, T_NEAR); + pop(r10); ret(); injector->prepare_table(); diff --git a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp index dc63bdf740e3ae..f6ba63357e60b4 100644 --- a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp +++ b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp @@ -166,6 +166,7 @@ struct LLMMLP::Impl { const LLMMLPNode::Config m_config; DnnlScratchPadPtr m_scrachPad; MemoryPtr m_scratchMem; + uint8_t* m_scratch_base = nullptr; Linear gate_up; Linear down; @@ -200,7 +201,11 @@ struct LLMMLP::Impl { } void setM(int M) { - if (m_M < M) { + uint8_t* cur_scratch_base = nullptr; + if (m_scratchMem) + cur_scratch_base = m_scratchMem->getDataAs(); + // new M larger than previous or the scratch pointer is changed after the following allocation + if (m_M < M || cur_scratch_base != m_scratch_base) { size_t total_scratch_size = M * m_N * sizeof(ov::bfloat16); std::vector scratch_offsets; std::vector scratch_C_sizes; @@ -214,11 +219,11 @@ struct LLMMLP::Impl { auto newMemDesc = std::make_shared(ov::element::u8, Shape{total_scratch_size}); m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc); - auto* scratch_base = m_scratchMem->getDataAs(); - m_actUp.resize({static_cast(M), static_cast(m_N)}, reinterpret_cast(scratch_base)); + m_scratch_base = m_scratchMem->getDataAs(); + m_actUp.resize({static_cast(M), static_cast(m_N)}, reinterpret_cast(m_scratch_base)); for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) { - m_tempC[ithr].resize({1, scratch_C_sizes[ithr]}, reinterpret_cast(scratch_base + scratch_offsets[ithr])); + m_tempC[ithr].resize({1, scratch_C_sizes[ithr]}, reinterpret_cast(m_scratch_base + scratch_offsets[ithr])); } m_M = M; } diff --git a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp index 74cfecf7aee83c..bc3c20d860ca94 100644 --- a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp @@ -48,6 +48,7 @@ struct QKVProjection::Impl { QKVProjection * m_node; DnnlScratchPadPtr m_scrachPad; MemoryPtr m_scratchMem; + uint8_t* m_scratch_base = nullptr; int m_M; Impl(QKVProjection * pnode, DnnlScratchPadPtr scrachPad) : m_node(pnode), m_scrachPad(scrachPad) { @@ -116,7 +117,11 @@ struct QKVProjection::Impl { } void setM(int M) { - if (m_M < M) { + uint8_t* cur_scratch_base = nullptr; + if (m_scratchMem) + cur_scratch_base = m_scratchMem->getDataAs(); + // new M larger than previous or the scratch pointer is changed after the following allocation + if (m_M < M || cur_scratch_base != m_scratch_base) { size_t total_scratch_size = 0; std::vector scratch_offsets; std::vector scratch_C_sizes; @@ -130,9 +135,9 @@ struct QKVProjection::Impl { auto newMemDesc = std::make_shared(ov::element::u8, Shape{total_scratch_size}); m_scratchMem = m_scrachPad->createScratchPadMem(newMemDesc); - auto* scratch_base = m_scratchMem->getDataAs(); + m_scratch_base = m_scratchMem->getDataAs(); for (size_t ithr = 0; ithr < m_tempC.size(); ithr++) { - m_tempC[ithr].resize({1, scratch_C_sizes[ithr]}, reinterpret_cast(scratch_base + scratch_offsets[ithr])); + m_tempC[ithr].resize({1, scratch_C_sizes[ithr]}, reinterpret_cast(m_scratch_base + scratch_offsets[ithr])); } m_M = M;