From 68088e5587e3ddc05637d4af456a6ccf618150e0 Mon Sep 17 00:00:00 2001 From: Nikolai Shchegolev Date: Thu, 24 Oct 2024 23:48:52 +0400 Subject: [PATCH] [CPU][OMP] Safe usage of threads num with buffers --- src/plugins/intel_cpu/src/nodes/ctc_loss.cpp | 5 +- src/plugins/intel_cpu/src/nodes/eltwise.cpp | 17 +++-- src/plugins/intel_cpu/src/nodes/gather.cpp | 12 ++-- src/plugins/intel_cpu/src/nodes/gather.h | 1 + .../intel_cpu/src/nodes/grid_sample.cpp | 12 ++-- .../intel_cpu/src/nodes/grid_sample.hpp | 2 +- .../kernels/scaled_attn/mha_single_token.cpp | 6 +- src/plugins/intel_cpu/src/nodes/llm_mlp.cpp | 41 ++++++----- src/plugins/intel_cpu/src/nodes/mha.cpp | 24 ++++--- src/plugins/intel_cpu/src/nodes/mha.h | 2 + src/plugins/intel_cpu/src/nodes/mvn.cpp | 70 +++++++++++++++---- src/plugins/intel_cpu/src/nodes/qkv_proj.cpp | 11 +-- src/plugins/intel_cpu/src/nodes/reduce.cpp | 8 ++- src/plugins/intel_cpu/src/nodes/roi_align.cpp | 6 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 27 ++++--- .../intel_cpu/src/nodes/strided_slice.cpp | 7 +- .../intel_cpu/src/nodes/strided_slice.h | 1 + 17 files changed, 171 insertions(+), 81 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp b/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp index 78bb6fc0563e60..3161c9a0e87a84 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp +++ b/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp @@ -84,7 +84,8 @@ void CTCLoss::execute(dnnl::stream strm) { std::vector decodedTargetLenB(batchNum, 0); std::vector> targetDB(batchNum); std::vector>> logProbabilitiesB(batchNum); - std::vector errorMsgB(parallel_get_max_threads()); + const auto threads_num = parallel_get_max_threads(); + std::vector errorMsgB(threads_num); auto threadBody_1 = [&](const int ithr, const int nthr) { size_t start(0lu), end(0lu); @@ -153,7 +154,7 @@ void CTCLoss::execute(dnnl::stream strm) { } // for batch }; // threadBody_1 - parallel_nt(0, threadBody_1); + parallel_nt(threads_num, threadBody_1); if (returnCode != 0) { std::string resErr(""); for (auto& err : errorMsgB) { diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index 5c3a358dff9d38..c2d23bf9adc89e 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -1503,7 +1503,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { fullWorkAmount *= jep.dims[i]; } - size_t minimalConcurrency = parallel_get_max_threads(); + m_threads_num = static_cast(parallel_get_max_threads()); size_t minimalJitWorkAmount = 256; size_t currentJitWorkAmount = jep.dims[jep.dims.size() - 1]; int collapsedDims = 0; @@ -1516,6 +1516,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { for (size_t j = 1; j < inpDims.size(); j++) { if (inpDims[j].back() != inpDims[0].back()) { hasDifferentDims = true; + break; } } @@ -1538,7 +1539,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { } size_t nextJitWorkAmount = currentJitWorkAmount * jep.dims[jep.dims.size() - 2]; - if (fullWorkAmount / nextJitWorkAmount >= minimalConcurrency) { + if (fullWorkAmount / nextJitWorkAmount >= m_threads_num) { currentJitWorkAmount = nextJitWorkAmount; collapsedDims++; @@ -1622,8 +1623,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { if (_pKernel->jep_.input_size == optimalTensorRank) { // execute Optimized 6D - parallel_for5d(dims_out[0], dims_out[1], dims_out[2], dims_out[3], dims_out[4], - [&](size_t i0, size_t i1, size_t i2, size_t i3, size_t i4) { + auto d6_loop = [&](size_t i0, size_t i1, size_t i2, size_t i3, size_t i4) { auto args = jit_eltwise_call_args_indexes(); args.indexes[0] = i0; args.indexes[1] = i1; @@ -1632,7 +1632,11 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { args.indexes[4] = i4; (*_pKernel)(&args_ptrs, &args); - }); + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_5d(ithr, nthr, dims_out[0], dims_out[1], dims_out[2], dims_out[3], dims_out[4], d6_loop); + }); } else { // execute Optimized Generic if (_pKernel->jep_.use_runtime_ptrs) { @@ -1642,7 +1646,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { _schedulerWorkAmount *= dims_out[i]; } } - parallel_nt(0, [&](const int ithr, const int nthr) { + parallel_nt(m_threads_num, [&](const int ithr, const int nthr) { size_t start = 0, end = 0; splitter(_schedulerWorkAmount, nthr, ithr, start, end); @@ -1676,6 +1680,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { std::unique_ptr _pKernel; size_t _schedulerWorkAmount = 0; size_t _batchDimIdx = 0; + size_t m_threads_num = 0lu; public: static const int optimalTensorRank = 6; diff --git a/src/plugins/intel_cpu/src/nodes/gather.cpp b/src/plugins/intel_cpu/src/nodes/gather.cpp index 81f6f36b84dd89..d2629fe8fe6811 100644 --- a/src/plugins/intel_cpu/src/nodes/gather.cpp +++ b/src/plugins/intel_cpu/src/nodes/gather.cpp @@ -253,6 +253,7 @@ void Gather::createPrimitive() { if (isInPlace()) { return; } + m_threads_num = parallel_get_max_threads(); #if defined(OPENVINO_ARCH_X86_64) uint64_t idxElPerVec = 1; if (!isDynamicNode()) { @@ -294,11 +295,10 @@ void Gather::createPrimitive() { if (!isDynamicNode()) { const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); - const uint64_t nthr = parallel_get_max_threads(); - const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec; - execParamsPerThread.resize(nthr); + const uint64_t wpt = ((totalWork / dataElPerVec) / m_threads_num + 1) * dataElPerVec; + execParamsPerThread.resize(m_threads_num); - parallel_nt(nthr, [&](const int ithr, const int nthr) { + parallel_nt(m_threads_num, [&](const int ithr, const int nthr) { const uint64_t dstStart = std::min(wpt * ithr, totalWork); const uint64_t dstEnd = std::min(wpt * (ithr + 1), totalWork); @@ -469,7 +469,7 @@ void Gather::execute(dnnl::stream strm) { (*jitKernel)(&arg); }; - parallel_nt(0, threadBody); + parallel_nt(m_threads_num, threadBody); return; } @@ -543,7 +543,7 @@ void Gather::executeDynamicImpl(dnnl::stream strm) { (*jitKernel)(&arg); }; - parallel_nt(0, threadBody); + parallel_nt(m_threads_num, threadBody); return; } diff --git a/src/plugins/intel_cpu/src/nodes/gather.h b/src/plugins/intel_cpu/src/nodes/gather.h index 96dad228f65b59..6ee097e9a1fbab 100644 --- a/src/plugins/intel_cpu/src/nodes/gather.h +++ b/src/plugins/intel_cpu/src/nodes/gather.h @@ -110,6 +110,7 @@ class Gather : public Node { bool have_scalar_scale = false; size_t zp_group_size = 1u; size_t scale_group_size = 1u; + size_t m_threads_num = 0lu; std::shared_ptr jitKernel; }; diff --git a/src/plugins/intel_cpu/src/nodes/grid_sample.cpp b/src/plugins/intel_cpu/src/nodes/grid_sample.cpp index 618d6b39105689..c8eed21bb312f5 100644 --- a/src/plugins/intel_cpu/src/nodes/grid_sample.cpp +++ b/src/plugins/intel_cpu/src/nodes/grid_sample.cpp @@ -149,11 +149,11 @@ void GridSample::createPrimitive() { } jitKernel->create_ker(); - nthr = parallel_get_max_threads(); - execParamsPerThread.resize(nthr); + m_threads_num = parallel_get_max_threads(); + execParamsPerThread.resize(m_threads_num); if (!x64::mayiuse(x64::avx512_core)) { const auto dataElPerVec = jitKernel->getDataElPerVec(); - parallel_nt(nthr, [&](const int ithr, const int nthr) { + parallel_nt(m_threads_num, [&](const int ithr, const int nthr) { auto& p = execParamsPerThread[ithr]; p.srcHeightF.resize(dataElPerVec); @@ -197,9 +197,9 @@ void GridSample::prepareParams() { const auto& srcDataShape = dataMemPtr->getStaticDims(); const auto& dstShape = dstMemPtr->getStaticDims(); const uint64_t totalWork = dstShape[2] * dstShape[3]; - const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec; + const uint64_t wpt = ((totalWork / dataElPerVec) / m_threads_num + 1) * dataElPerVec; - parallel_nt(nthr, [&](const int ithr, const int nthr) { + parallel_nt(m_threads_num, [&](const int ithr, const int nthr) { const uint64_t dstStart = std::min(wpt * ithr, totalWork); const uint64_t dstEnd = std::min(wpt * (ithr + 1), totalWork); @@ -303,7 +303,7 @@ void GridSample::execute(dnnl::stream strm) { (*jitKernel)(&arg); }; - parallel_nt(nthr, threadBody); + parallel_nt(m_threads_num, threadBody); } void GridSample::executeDynamicImpl(dnnl::stream strm) { diff --git a/src/plugins/intel_cpu/src/nodes/grid_sample.hpp b/src/plugins/intel_cpu/src/nodes/grid_sample.hpp index 0d172bd5c3e055..b4468d58be9b52 100644 --- a/src/plugins/intel_cpu/src/nodes/grid_sample.hpp +++ b/src/plugins/intel_cpu/src/nodes/grid_sample.hpp @@ -62,7 +62,7 @@ class GridSample : public Node { ov::element::Type dataPrecision; ov::element::Type gridPrecision = ov::element::f32; - int nthr = 1; + size_t m_threads_num = 0lu; std::vector execParamsPerThread; static constexpr size_t IN_DATA = 0; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 1543c168403382..6b6df3c3181ee0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -1068,11 +1068,15 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, } }); - parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) { + auto bhl_loop = [&](size_t b, size_t h, size_t pq) { auto* temp = buf_attn_score.ptr(0, b, pq, h); size_t temp_stride = buf_attn_score.stride(0); auto* dst = has_out_transpose ? output_emb.ptr(b, pq, h * SV) : output_emb.ptr(b, h, pq); attn_reduce(dst, temp, nthr, SV, temp_stride); + }; + + parallel_nt_static(nthr, [&](const int ithr, const int nthr) { + for_3d(ithr, nthr, B, H, q_len, bhl_loop); }); } diff --git a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp index 13c46a7c976cfd..4a98b3158fc4d4 100644 --- a/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp +++ b/src/plugins/intel_cpu/src/nodes/llm_mlp.cpp @@ -53,19 +53,19 @@ class LinearKsplit2 { OPENVINO_ASSERT((N % REG_BLK_N_SIZE) == 0); OPENVINO_ASSERT((K % reg_blk_K_size) == 0); - auto nthr = parallel_get_max_threads(); + m_threads_num = parallel_get_max_threads(); auto num_blk_N = N / REG_BLK_N_SIZE; - works.resize(nthr); + works.resize(m_threads_num); - auto K_splits = 2; + size_t K_splits = 2lu; // split task on more cores is better on TBB - auto valid_nthr = nthr / 2; + auto valid_nthr = m_threads_num / 2; auto blkN_per_thread = (num_blk_N) / valid_nthr; auto blkN_leftover = num_blk_N - (blkN_per_thread * valid_nthr); auto start_blkN = 0; used_nthr = 0; - for (int ithr = 0; ithr < nthr; ithr += K_splits) { + for (size_t ithr = 0lu; ithr < m_threads_num; ithr += K_splits) { auto blkN = std::min(num_blk_N - start_blkN, blkN_per_thread); if (blkN_leftover > 0) { blkN_leftover--; @@ -78,7 +78,7 @@ class LinearKsplit2 { auto start_blkK = 0; auto num_blk_K = K / reg_blk_K_size; auto blkK_per_thread = (num_blk_K + 1) / 2; - for (int ik = 0; ik < K_splits; ik++) { + for (size_t ik = 0lu; ik < K_splits; ik++) { auto blk_K = std::min(num_blk_K - start_blkK, blkK_per_thread); auto& work = works[ithr + ik]; @@ -106,7 +106,7 @@ class LinearKsplit2 { wbuffer.alloc(works, weight_element_size); - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; if (work) { if (is_quantized) { @@ -125,7 +125,7 @@ class LinearKsplit2 { float * w_scale) { static ReduceAdd2bh jit_reduce2cvt(true, std::is_same::value); - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; auto& workC = work.m_C; if (work) { @@ -165,6 +165,9 @@ class LinearKsplit2 { } }); } + +private: + int m_threads_num = 0lu; }; template @@ -205,18 +208,18 @@ class LinearGateUp { // in unit of 32 OPENVINO_ASSERT((N % REG_BLK_N_SIZE) == 0); OPENVINO_ASSERT((K % reg_blk_K_size) == 0); - auto nthr = parallel_get_max_threads(); + m_threads_num = parallel_get_max_threads(); auto num_blk_N = N / REG_BLK_N_SIZE; - works.resize(nthr); + works.resize(m_threads_num); // split task on more cores is better on TBB - auto valid_nthr = nthr; + auto valid_nthr = m_threads_num; auto blkN_per_thread = (num_blk_N) / valid_nthr; auto blkN_leftover = num_blk_N - (blkN_per_thread * valid_nthr); auto start_blkN = 0; used_nthr = 0; - for (int ithr = 0; ithr < nthr; ithr ++) { + for (int ithr = 0; ithr < m_threads_num; ithr ++) { auto blkN = std::min(num_blk_N - start_blkN, blkN_per_thread); if (blkN_leftover > 0) { blkN_leftover--; @@ -243,7 +246,7 @@ class LinearGateUp { wbuffer.alloc(works, weight_element_size); DEBUG_LOG("Linear N,K=", N, ",", K, " used_nthr=", used_nthr); - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; if (work) { if (quantized_int8) @@ -267,7 +270,7 @@ class LinearGateUp { const LLMMLPNode::Config& config, MatrixDynQuantPerRow& src_dq, float * w_scale) { - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; if (work) { work.run(M, pA, strideA_in_bytes); @@ -303,6 +306,9 @@ class LinearGateUp { } }); } + +private: + int m_threads_num = 0lu; }; template @@ -384,8 +390,8 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase { reinterpret_cast(ptr)); }); - auto nthr = parallel_get_max_threads(); - for (int ithr = 0; ithr < nthr; ithr++) { + m_threads_num = parallel_get_max_threads(); + for (size_t ithr = 0lu; ithr < m_threads_num; ithr++) { auto C1_size = gate_up.works[ithr].set_C(M, reinterpret_cast(cur_scratch_base)); auto C2_size = down.works[ithr].set_C(M, reinterpret_cast(cur_scratch_base)); auto max_C_size = std::max(C1_size, C2_size); @@ -482,6 +488,9 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase { dstC += BM * strideC / sizeof(T); } } + +private: + size_t m_threads_num = 0lu; }; #else template diff --git a/src/plugins/intel_cpu/src/nodes/mha.cpp b/src/plugins/intel_cpu/src/nodes/mha.cpp index 7d082e99fa4f6a..9364058c5d19a2 100644 --- a/src/plugins/intel_cpu/src/nodes/mha.cpp +++ b/src/plugins/intel_cpu/src/nodes/mha.cpp @@ -934,7 +934,7 @@ void MHA::prepareParams() { bool isAMXSupported = mayiuse(avx512_core_amx); - size_t numThreads = parallel_get_max_threads(); + m_threads_num = parallel_get_max_threads(); size_t matmulOptimalM = 32; @@ -1072,21 +1072,21 @@ void MHA::prepareParams() { bufferCompensation1Size = rnd_up(N1, N1_blk); if (brgCopyAKernel0) { - bufferMatMul0In0.resize(numThreads * bufferMatMul0In0Size); + bufferMatMul0In0.resize(m_threads_num * bufferMatMul0In0Size); } - bufferMatMul0In1.resize(numThreads * bufferMatMul0In1Size); - bufferMatMul0Out.resize(numThreads * bufferMatMul0OutSize); - bufferMatMul1In1.resize(numThreads * bufferMatMul1In1Size); - bufferMatMul1Out.resize(numThreads * bufferMatMul1OutSize); + bufferMatMul0In1.resize(m_threads_num * bufferMatMul0In1Size); + bufferMatMul0Out.resize(m_threads_num * bufferMatMul0OutSize); + bufferMatMul1In1.resize(m_threads_num * bufferMatMul1In1Size); + bufferMatMul1Out.resize(m_threads_num * bufferMatMul1OutSize); if (brgemmCtx0.is_with_comp) { - bufferCompensation0.resize(numThreads * bufferCompensation0Size); + bufferCompensation0.resize(m_threads_num * bufferCompensation0Size); } if (brgemmCtx1.is_with_comp) { - bufferCompensation1.resize(numThreads * bufferCompensation1Size); + bufferCompensation1.resize(m_threads_num * bufferCompensation1Size); } if (brgemmCtx0.is_with_amx || brgemmCtx1.is_with_amx) { - wsp.resize(numThreads * wsp_size_per_thread); + wsp.resize(m_threads_num * wsp_size_per_thread); } { @@ -1224,7 +1224,7 @@ void MHA::mhaImpl() { auto outPrcSize = outputPrecision.size(); - parallel_for2d(dimsMatMul0Out[0], dimsMatMul0Out[1], [&](size_t i0, size_t i1) { + auto spatial_loop = [&](size_t i0, size_t i1) { size_t threadNum = parallel_get_thread_num(); auto pTranspose0In0_aux = pTranspose0In0 + (i0 * strTranspose0In0[0] + i1 * strTranspose0In0[2]) * inputPrecisions[0].size(); // order 0213 @@ -1417,6 +1417,10 @@ void MHA::mhaImpl() { (*convertReorderKernel)(&call_args); } } + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_2d(ithr, nthr, dimsMatMul0Out[0], dimsMatMul0Out[1], spatial_loop); }); } diff --git a/src/plugins/intel_cpu/src/nodes/mha.h b/src/plugins/intel_cpu/src/nodes/mha.h index cd272c086e2190..36afe20224299a 100644 --- a/src/plugins/intel_cpu/src/nodes/mha.h +++ b/src/plugins/intel_cpu/src/nodes/mha.h @@ -238,6 +238,8 @@ class MHA : public Node { std::unique_ptr mulAddSoftmaxKernel; std::unique_ptr convertReorderKernel; std::unique_ptr convertTransposeKernel; + + size_t m_threads_num = 0lu; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/mvn.cpp b/src/plugins/intel_cpu/src/nodes/mvn.cpp index 61aa4738b8f81f..6e517e6a794cf2 100644 --- a/src/plugins/intel_cpu/src/nodes/mvn.cpp +++ b/src/plugins/intel_cpu/src/nodes/mvn.cpp @@ -2180,6 +2180,11 @@ void MVN::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, co size_t C2 = C1 * D; size_t C3 = C2 * C; +#if (OV_THREAD == OV_THREAD_OMP) + const auto origin_nested_levels = get_max_nested_levels(); + set_max_nested_levels(origin_nested_levels + 1); +#endif // OV_THREAD == OV_THREAD_OMP + if (mvnAttrs.execAcrossChannels_) { parallel_for(N, [&](int b) { size_t cb = b * C3; @@ -2298,6 +2303,10 @@ void MVN::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, co } }); } + +#if (OV_THREAD == OV_THREAD_OMP) + set_max_nested_levels(origin_nested_levels); +#endif // OV_THREAD == OV_THREAD_OMP } void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, const VectorDims& shape5d) { @@ -2313,6 +2322,11 @@ void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, co size_t C2 = C1 * D; size_t C3 = C2 * C; +#if (OV_THREAD == OV_THREAD_OMP) + const auto origin_nested_levels = get_max_nested_levels(); + set_max_nested_levels(origin_nested_levels + 1); +#endif // OV_THREAD == OV_THREAD_OMP + parallel_for(N, [&](int b) { size_t cb = b * C3; if (mvnAttrs.execAcrossChannels_) { @@ -2399,6 +2413,10 @@ void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, co }); } }); + +#if (OV_THREAD == OV_THREAD_OMP) + set_max_nested_levels(origin_nested_levels); +#endif // OV_THREAD == OV_THREAD_OMP } void MVN::MVNJitExecutor::mvn_nspc(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_, const VectorDims& shape5d) { @@ -2417,9 +2435,14 @@ void MVN::MVNJitExecutor::mvn_nspc(const uint8_t* src_data, uint8_t* dst_data, c const size_t H = shape5d[3]; const size_t W = shape5d[4]; - size_t threads_num = parallel_get_max_threads(); +#if (OV_THREAD == OV_THREAD_OMP) + const auto origin_nested_levels = get_max_nested_levels(); + set_max_nested_levels(origin_nested_levels + 1); +#endif // OV_THREAD == OV_THREAD_OMP + + const size_t threads_num = parallel_get_max_threads(); size_t aux_buffer_size = mvnAttrs.execAcrossChannels_ ? 1 : rnd_up(C, blk_size) + blk_size; - parallel_for(N, [&](size_t b) { + auto b_loop = [&](size_t b) { std::vector mean_buffer(aux_buffer_size * threads_num, 0.f); std::vector variance_buffer; if (mvnAttrs.normalizeVariance_) { @@ -2429,7 +2452,7 @@ void MVN::MVNJitExecutor::mvn_nspc(const uint8_t* src_data, uint8_t* dst_data, c // kernel_type: 0 for mean, 1 for variance, 2 for normalization auto worker = [&](const bool across_channel, const int kernel_type) { - parallel_nt(0, [&](const int ithr, const int nthr) { + parallel_nt(threads_num, [&](const int ithr, const int nthr) { size_t start = 0, end = 0; splitter(D * H * W, nthr, ithr, start, end); @@ -2512,7 +2535,15 @@ void MVN::MVNJitExecutor::mvn_nspc(const uint8_t* src_data, uint8_t* dst_data, c } worker(false, 2); } + }; + + parallel_nt_static(threads_num, [&](const int ithr, const int nthr) { + for_1d(ithr, nthr, N, b_loop); }); + +#if (OV_THREAD == OV_THREAD_OMP) + set_max_nested_levels(origin_nested_levels); +#endif // OV_THREAD == OV_THREAD_OMP } void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_, const VectorDims& shape5d) { @@ -2529,15 +2560,15 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co const size_t H = shape5d[3]; const size_t W = shape5d[4]; - size_t CB = div_up(C, blk_size); + const size_t CB = div_up(C, blk_size); - size_t C0 = W * blk_size; - size_t C1 = C0 * H; - size_t C2 = C1 * D; - size_t C3 = C2 * CB; - size_t C5 = C * D * H * W; + const size_t C0 = W * blk_size; + const size_t C1 = C0 * H; + const size_t C2 = C1 * D; + const size_t C3 = C2 * CB; + const size_t C5 = C * D * H * W; - size_t threads_num = parallel_get_max_threads(); + const size_t threads_num = parallel_get_max_threads(); size_t aux_buffer_size = mvnAttrs.execAcrossChannels_ ? blk_size : rnd_up(C, blk_size); aux_buffer_size += blk_size; std::vector mean_buffer(aux_buffer_size * threads_num); @@ -2562,7 +2593,11 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co // // | // // \|/ ///////////////////////////////// - auto mean_buffer_ptr = &mean_buffer[aux_buffer_size * static_cast(parallel_get_thread_num())]; + auto thread_idx = static_cast(parallel_get_thread_num()); + if (thread_idx >= threads_num) { + return mean_internal; + } + auto mean_buffer_ptr = &mean_buffer[aux_buffer_size * thread_idx]; for (size_t i = 0; i < blk_size; i++) mean_buffer_ptr[i] = 0.f; @@ -2651,7 +2686,7 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co // one thread for one C*W size(the same H) to get C size result for the same H, added to last group result // keep the compute order the same as planar - parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) { + auto dh_loop = [&](size_t thr_idx, size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { size_t src_offset = b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx]; @@ -2665,6 +2700,10 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co arg.post_op_data = post_ops_data_; (*mvn_mean_kernel)(&arg); } + }; + + parallel_nt_static(threads_num, [&](const int ithr, const int nthr) { + for_2d(ithr, nthr, D, H, dh_loop); }); for (size_t i = 1; i < threads_num; i++) { @@ -2678,7 +2717,7 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co for (size_t i = 0; i < variance_buffer.size(); i++) variance_buffer[i] = 0.f; - parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) { + auto dh_loop = [&](size_t thr_idx, size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { size_t src_offset = b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb]; @@ -2694,7 +2733,12 @@ void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, co arg.post_op_data = post_ops_data_; (*mvn_variance_kernel)(&arg); } + }; + + parallel_nt_static(threads_num, [&](const int ithr, const int nthr) { + for_2d(ithr, nthr, D, H, dh_loop); }); + for (size_t i = 1; i < threads_num; i++) { for (size_t c = 0; c < C; c++) variance_buffer[c] += variance_buffer[c + aux_buffer_size * i]; diff --git a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp index 3260b12f1b5b4b..00c8b6f9b17c0b 100644 --- a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp @@ -60,6 +60,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { MemoryPtr m_scratchMem; uint8_t* m_scratch_base = nullptr; int m_M = 0; + size_t m_threads_num = 0lu; MatrixDynQuantPerRow m_quant_act; @@ -79,11 +80,11 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { auto K = w0.size(1); OPENVINO_ASSERT((K % cache_blk_k_size) == 0); - auto nthr = parallel_get_max_threads(); + m_threads_num = parallel_get_max_threads(); auto num_blk_K = K / cache_blk_k_size; int stride_in_bytes = K * weight_element_size; - works.resize(nthr); + works.resize(m_threads_num); int cur_work_id = 0; auto create_works = [&](void* pw, int output_id, int N, int valid_nthr) { @@ -119,7 +120,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { auto proj_size0 = m_node->m_config.proj_size0; auto proj_size1 = m_node->m_config.proj_size1; auto proj_size2 = m_node->m_config.proj_size2; - auto n_group_workers = allocate_workers({proj_size0, proj_size1, proj_size2}, nthr); + auto n_group_workers = allocate_workers({proj_size0, proj_size1, proj_size2}, m_threads_num); if (m_node->m_config.weights_combined) { auto* ptr_weights = reinterpret_cast(w0.ptr_v()); @@ -140,7 +141,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { wbuffer.alloc(works, weight_element_size); - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; if (work) { if (quantized_int8) @@ -237,7 +238,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { strideA = m_quant_act.K; } - ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) { + ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) { auto& work = works[ithr]; if (work) { work.run(BM, pA, strideA); diff --git a/src/plugins/intel_cpu/src/nodes/reduce.cpp b/src/plugins/intel_cpu/src/nodes/reduce.cpp index b40c50f957514f..6cfc94a02b9f3b 100644 --- a/src/plugins/intel_cpu/src/nodes/reduce.cpp +++ b/src/plugins/intel_cpu/src/nodes/reduce.cpp @@ -2742,12 +2742,12 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) { (*reduce_post_kernel)(&arg); }); } else if (layout == ReduceLayoutType::reduce_nspc) { - size_t num_threads = static_cast(parallel_get_max_threads()); + const size_t num_threads = static_cast(parallel_get_max_threads()); size_t OP = OB * OC >= num_threads ? OB * OC : OB * OC * OD; if (OP < num_threads && OW > blk_size) OP *= OH; size_t work_amount = OB * OC * OD * OH * OW / OP; - parallel_for(OP, [&](size_t op) { + auto op_loop = [&](size_t op) { const uint8_t *in_p = in_ptr + op * work_amount * intermediate_data_size; uint8_t *out_p = out_ptr + op * work_amount * dst_data_size; auto arg = jit_reduce_post_call_args(); @@ -2759,6 +2759,10 @@ inline void Reduce::reduce_kernel_post_process(uint8_t *out_ptr) { arg.divisor = &divisor; arg.post_op_data = static_cast(postOpsDataPtrs.data()); (*reduce_post_kernel)(&arg); + }; + + parallel_nt_static(num_threads, [&](const int ithr, const int nthr) { + for_1d(ithr, nthr, OP, op_loop); }); } else { size_t OCB = div_up(OC, blk_size); diff --git a/src/plugins/intel_cpu/src/nodes/roi_align.cpp b/src/plugins/intel_cpu/src/nodes/roi_align.cpp index eb1797279e1415..27f9426dca6af9 100644 --- a/src/plugins/intel_cpu/src/nodes/roi_align.cpp +++ b/src/plugins/intel_cpu/src/nodes/roi_align.cpp @@ -1076,7 +1076,7 @@ void ROIAlign::executeSpecified() { int bufSize = rnd_up(C, 16); size_t threadsNum = parallel_get_max_threads(); workingBuf.resize(bufSize * threadsNum, 0.f); - parallel_for3d(realRois, pooledH, pooledW, [&](int n, int yBinInd, int xBinInd) { + auto rhw_loop = [&](int n, int yBinInd, int xBinInd) { int numSamplesROI = numSamples[n]; // each sample have 4 values for srcAddressList and weight size_t binOffset = numSamplesROI * BLIParamsNum * pooledW * yBinInd + numSamplesROI * BLIParamsNum * xBinInd; @@ -1095,6 +1095,10 @@ void ROIAlign::executeSpecified() { arg.dst = static_cast(&dst[dstOffset]); arg.src_stride = lastBlockDim * W * H; // only valid for blk, nspc generate inside (*roi_align_kernel)(&arg); + }; + + parallel_nt_static(threadsNum, [&](const int ithr, const int nthr) { + for_3d(ithr, nthr, realRois, pooledH, pooledW, rhw_loop); }); } else { // one lane for one sample generation, then pooling all samples. diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index e229ff4bb72c57..f9f853230c4dd6 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -217,6 +217,7 @@ struct MHAKernel { size_t wsp_size_per_thread = 0; using tag = dnnl::memory::format_tag; using dt = dnnl::memory::data_type; + size_t m_threads_num = 0lu; struct brgemmKey { size_t M; size_t N; @@ -315,21 +316,21 @@ struct MHAKernel { wv_gemm_ptr = wv_result.first; - size_t nthr = static_cast(parallel_get_max_threads()); + m_threads_num = static_cast(parallel_get_max_threads()); // wsp is used to compute beta when K is blocked wsp_size_per_thread = wv_gemm_ptr->get_wsp_size(); - wsp.resize(nthr * wsp_size_per_thread); + wsp.resize(m_threads_num * wsp_size_per_thread); // allocate scratch a/b, notice get_scratch_a_size/get_scratch_b_size returns in bytes size_t data_size = sizeof(T); - qk_scratch_a.resize({nthr, qk_gemm_ptr->get_scratch_a_size() / data_size}); - wv_scratch_a.resize({nthr, wv_gemm_ptr->get_scratch_a_size() / data_size}); + qk_scratch_a.resize({m_threads_num, qk_gemm_ptr->get_scratch_a_size() / data_size}); + wv_scratch_a.resize({m_threads_num, wv_gemm_ptr->get_scratch_a_size() / data_size}); qk_scratch_b.resize({B, Hk, qk_gemm_ptr->get_scratch_b_size() / data_size}); wv_scratch_b.resize({B, Hk, wv_gemm_ptr->get_scratch_b_size() / data_size}); const size_t m_block_size = qk_gemm_ptr->get_mblk_size(); - weight_score.resize({static_cast(parallel_get_max_threads()), H, m_block_size, kv_len}); + weight_score.resize({m_threads_num, H, m_block_size, kv_len}); if (has_out_transpose) { fp32_out.resize({B, q_len, H, head_size_v}); } else { @@ -367,7 +368,7 @@ struct MHAKernel { }); // attention - parallel_for3d(B, H, m_blocks, [&](size_t ithr, size_t b, size_t h, size_t m_blk) { + auto bhb_loop = [&](size_t ithr, size_t b, size_t h, size_t m_blk) { auto m_start = m_blk * m_block_size; auto m_end = std::min(m_start + m_block_size, q_len); auto m_cnt = m_end - m_start; @@ -456,6 +457,10 @@ struct MHAKernel { 1); } } + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_3d(ithr, nthr, B, H, m_blocks, bhb_loop); }); } @@ -652,12 +657,14 @@ struct MHAKernel { size_t m_block_size; // buffer to hold qk temp std::vector qk_buffers; + size_t m_threads_num = 0lu; MHAKernel() = delete; explicit MHAKernel(GraphContext::CPtr ctx): context(ctx) { m_block_size = 4; select_nfltmax_at_0 = false; - qk_buffers.resize(parallel_get_max_threads()); + m_threads_num = parallel_get_max_threads(); + qk_buffers.resize(m_threads_num); } PlainTensor causal_mask; @@ -699,7 +706,7 @@ struct MHAKernel { auto m_blocks = (q_len + m_block_size - 1) / m_block_size; - parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) { + auto bhb_loop = [&](size_t b, size_t h, size_t m_blk) { auto thread_id = parallel_get_thread_num(); if (thread_id < 0) OPENVINO_THROW("The calling thread isn't initialized!"); @@ -801,6 +808,10 @@ struct MHAKernel { has_out_transpose ? &output_emb.at({b, m_start, h * head_size_v}) : &output_emb.at({b, h, m_start}), has_out_transpose ? output_emb.stride(1) : output_emb.stride(2), 1); + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_3d(ithr, nthr, B, H, m_blocks, bhb_loop); }); } }; diff --git a/src/plugins/intel_cpu/src/nodes/strided_slice.cpp b/src/plugins/intel_cpu/src/nodes/strided_slice.cpp index 4f974cfe5e9748..13671c22d102ae 100644 --- a/src/plugins/intel_cpu/src/nodes/strided_slice.cpp +++ b/src/plugins/intel_cpu/src/nodes/strided_slice.cpp @@ -348,6 +348,7 @@ StridedSlice::StridedSliceCommonExecutor::StridedSliceCommonExecutor(const Strid dimsNormalization(); dimsGluing(); indicesCalculation(); + m_threads_num = parallel_get_max_threads(); } void StridedSlice::StridedSliceCommonExecutor::orderParametersByLayouts(const BlockedMemoryDescCPtr& blockedMemoryDesc) { @@ -642,8 +643,7 @@ void StridedSlice::StridedSliceCommonExecutor::dimsGluing() { for (size_t idx = secondDim.first + 1; idx < secondDim.second; idx++) params.attrs.begin[1] /= dstBlockedDimsBefore[idx]; - const size_t maxThreads = parallel_get_max_threads(); - if (params.dstBlockedDims[0] < maxThreads) { + if (params.dstBlockedDims[0] < m_threads_num) { params.dstBlockedDims[1] /= realDstDim; params.srcBlockedDims[1] /= realSrcDim; params.dstBlockedDims.insert(params.dstBlockedDims.begin() + 1, realDstDim); @@ -682,8 +682,7 @@ void StridedSlice::StridedSliceCommonExecutor::indicesCalculation() { dstIndices.resize(workAmount, 0); // should choose more optimal thread count - const size_t nthr = parallel_get_max_threads(); - nThreads = nthr > workAmount ? workAmount : nthr; + nThreads = m_threads_num > workAmount ? workAmount : m_threads_num; if (params.isOptimized) { indicesCalculationForOptimized(); diff --git a/src/plugins/intel_cpu/src/nodes/strided_slice.h b/src/plugins/intel_cpu/src/nodes/strided_slice.h index 5c5950520bda7d..bf698643271d7a 100644 --- a/src/plugins/intel_cpu/src/nodes/strided_slice.h +++ b/src/plugins/intel_cpu/src/nodes/strided_slice.h @@ -122,6 +122,7 @@ class StridedSlice : public Node { size_t workAmount = 0lu; size_t lastDstDim = 0lu; size_t srcShift = 0lu; + size_t m_threads_num = 0lu; }; using executorPtr = std::shared_ptr; executorPtr execPtr = nullptr;