Skip to content

Commit

Permalink
[CPU][OMP] Safe usage of threads num with buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego committed Oct 28, 2024
1 parent 6f001e9 commit 68088e5
Show file tree
Hide file tree
Showing 17 changed files with 171 additions and 81 deletions.
5 changes: 3 additions & 2 deletions src/plugins/intel_cpu/src/nodes/ctc_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ void CTCLoss::execute(dnnl::stream strm) {
std::vector<int> decodedTargetLenB(batchNum, 0);
std::vector<std::vector<int>> targetDB(batchNum);
std::vector<std::vector<std::vector<float>>> logProbabilitiesB(batchNum);
std::vector<std::string> errorMsgB(parallel_get_max_threads());
const auto threads_num = parallel_get_max_threads();
std::vector<std::string> errorMsgB(threads_num);

auto threadBody_1 = [&](const int ithr, const int nthr) {
size_t start(0lu), end(0lu);
Expand Down Expand Up @@ -153,7 +154,7 @@ void CTCLoss::execute(dnnl::stream strm) {
} // for batch
}; // threadBody_1

parallel_nt(0, threadBody_1);
parallel_nt(threads_num, threadBody_1);
if (returnCode != 0) {
std::string resErr("");
for (auto& err : errorMsgB) {
Expand Down
17 changes: 11 additions & 6 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
fullWorkAmount *= jep.dims[i];
}

size_t minimalConcurrency = parallel_get_max_threads();
m_threads_num = static_cast<size_t>(parallel_get_max_threads());
size_t minimalJitWorkAmount = 256;
size_t currentJitWorkAmount = jep.dims[jep.dims.size() - 1];
int collapsedDims = 0;
Expand All @@ -1516,6 +1516,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
for (size_t j = 1; j < inpDims.size(); j++) {
if (inpDims[j].back() != inpDims[0].back()) {
hasDifferentDims = true;
break;
}
}

Expand All @@ -1538,7 +1539,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
}

size_t nextJitWorkAmount = currentJitWorkAmount * jep.dims[jep.dims.size() - 2];
if (fullWorkAmount / nextJitWorkAmount >= minimalConcurrency) {
if (fullWorkAmount / nextJitWorkAmount >= m_threads_num) {
currentJitWorkAmount = nextJitWorkAmount;
collapsedDims++;

Expand Down Expand Up @@ -1622,8 +1623,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {

if (_pKernel->jep_.input_size == optimalTensorRank) {
// execute Optimized 6D
parallel_for5d(dims_out[0], dims_out[1], dims_out[2], dims_out[3], dims_out[4],
[&](size_t i0, size_t i1, size_t i2, size_t i3, size_t i4) {
auto d6_loop = [&](size_t i0, size_t i1, size_t i2, size_t i3, size_t i4) {
auto args = jit_eltwise_call_args_indexes();
args.indexes[0] = i0;
args.indexes[1] = i1;
Expand All @@ -1632,7 +1632,11 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
args.indexes[4] = i4;

(*_pKernel)(&args_ptrs, &args);
});
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_5d(ithr, nthr, dims_out[0], dims_out[1], dims_out[2], dims_out[3], dims_out[4], d6_loop);
});
} else {
// execute Optimized Generic
if (_pKernel->jep_.use_runtime_ptrs) {
Expand All @@ -1642,7 +1646,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
_schedulerWorkAmount *= dims_out[i];
}
}
parallel_nt(0, [&](const int ithr, const int nthr) {
parallel_nt(m_threads_num, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(_schedulerWorkAmount, nthr, ithr, start, end);

Expand Down Expand Up @@ -1676,6 +1680,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
std::unique_ptr<jit_uni_eltwise_kernel> _pKernel;
size_t _schedulerWorkAmount = 0;
size_t _batchDimIdx = 0;
size_t m_threads_num = 0lu;

public:
static const int optimalTensorRank = 6;
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_cpu/src/nodes/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ void Gather::createPrimitive() {
if (isInPlace()) {
return;
}
m_threads_num = parallel_get_max_threads();
#if defined(OPENVINO_ARCH_X86_64)
uint64_t idxElPerVec = 1;
if (!isDynamicNode()) {
Expand Down Expand Up @@ -294,11 +295,10 @@ void Gather::createPrimitive() {

if (!isDynamicNode()) {
const uint64_t dataElPerVec = jitKernel->getDataElPerVec();
const uint64_t nthr = parallel_get_max_threads();
const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec;
execParamsPerThread.resize(nthr);
const uint64_t wpt = ((totalWork / dataElPerVec) / m_threads_num + 1) * dataElPerVec;
execParamsPerThread.resize(m_threads_num);

parallel_nt(nthr, [&](const int ithr, const int nthr) {
parallel_nt(m_threads_num, [&](const int ithr, const int nthr) {
const uint64_t dstStart = std::min(wpt * ithr, totalWork);
const uint64_t dstEnd = std::min(wpt * (ithr + 1), totalWork);

Expand Down Expand Up @@ -469,7 +469,7 @@ void Gather::execute(dnnl::stream strm) {
(*jitKernel)(&arg);
};

parallel_nt(0, threadBody);
parallel_nt(m_threads_num, threadBody);

return;
}
Expand Down Expand Up @@ -543,7 +543,7 @@ void Gather::executeDynamicImpl(dnnl::stream strm) {
(*jitKernel)(&arg);
};

parallel_nt(0, threadBody);
parallel_nt(m_threads_num, threadBody);

return;
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Gather : public Node {
bool have_scalar_scale = false;
size_t zp_group_size = 1u;
size_t scale_group_size = 1u;
size_t m_threads_num = 0lu;

std::shared_ptr<jitGatherKernelBase> jitKernel;
};
Expand Down
12 changes: 6 additions & 6 deletions src/plugins/intel_cpu/src/nodes/grid_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ void GridSample::createPrimitive() {
}
jitKernel->create_ker();

nthr = parallel_get_max_threads();
execParamsPerThread.resize(nthr);
m_threads_num = parallel_get_max_threads();
execParamsPerThread.resize(m_threads_num);
if (!x64::mayiuse(x64::avx512_core)) {
const auto dataElPerVec = jitKernel->getDataElPerVec();
parallel_nt(nthr, [&](const int ithr, const int nthr) {
parallel_nt(m_threads_num, [&](const int ithr, const int nthr) {
auto& p = execParamsPerThread[ithr];

p.srcHeightF.resize(dataElPerVec);
Expand Down Expand Up @@ -197,9 +197,9 @@ void GridSample::prepareParams() {
const auto& srcDataShape = dataMemPtr->getStaticDims();
const auto& dstShape = dstMemPtr->getStaticDims();
const uint64_t totalWork = dstShape[2] * dstShape[3];
const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec;
const uint64_t wpt = ((totalWork / dataElPerVec) / m_threads_num + 1) * dataElPerVec;

parallel_nt(nthr, [&](const int ithr, const int nthr) {
parallel_nt(m_threads_num, [&](const int ithr, const int nthr) {
const uint64_t dstStart = std::min(wpt * ithr, totalWork);
const uint64_t dstEnd = std::min(wpt * (ithr + 1), totalWork);

Expand Down Expand Up @@ -303,7 +303,7 @@ void GridSample::execute(dnnl::stream strm) {
(*jitKernel)(&arg);
};

parallel_nt(nthr, threadBody);
parallel_nt(m_threads_num, threadBody);
}

void GridSample::executeDynamicImpl(dnnl::stream strm) {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/grid_sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GridSample : public Node {
ov::element::Type dataPrecision;
ov::element::Type gridPrecision = ov::element::f32;

int nthr = 1;
size_t m_threads_num = 0lu;
std::vector<threadExecParams> execParamsPerThread;

static constexpr size_t IN_DATA = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,11 +1068,15 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
}
});

parallel_for3d(B, H, q_len, [&](size_t b, size_t h, size_t pq) {
auto bhl_loop = [&](size_t b, size_t h, size_t pq) {
auto* temp = buf_attn_score.ptr<T3>(0, b, pq, h);
size_t temp_stride = buf_attn_score.stride(0);
auto* dst = has_out_transpose ? output_emb.ptr<T>(b, pq, h * SV) : output_emb.ptr<T>(b, h, pq);
attn_reduce(dst, temp, nthr, SV, temp_stride);
};

parallel_nt_static(nthr, [&](const int ithr, const int nthr) {
for_3d(ithr, nthr, B, H, q_len, bhl_loop);
});
}

Expand Down
41 changes: 25 additions & 16 deletions src/plugins/intel_cpu/src/nodes/llm_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ class LinearKsplit2 {

OPENVINO_ASSERT((N % REG_BLK_N_SIZE) == 0);
OPENVINO_ASSERT((K % reg_blk_K_size) == 0);
auto nthr = parallel_get_max_threads();
m_threads_num = parallel_get_max_threads();
auto num_blk_N = N / REG_BLK_N_SIZE;
works.resize(nthr);
works.resize(m_threads_num);

auto K_splits = 2;
size_t K_splits = 2lu;
// split task on more cores is better on TBB
auto valid_nthr = nthr / 2;
auto valid_nthr = m_threads_num / 2;
auto blkN_per_thread = (num_blk_N) / valid_nthr;
auto blkN_leftover = num_blk_N - (blkN_per_thread * valid_nthr);
auto start_blkN = 0;
used_nthr = 0;

for (int ithr = 0; ithr < nthr; ithr += K_splits) {
for (size_t ithr = 0lu; ithr < m_threads_num; ithr += K_splits) {
auto blkN = std::min(num_blk_N - start_blkN, blkN_per_thread);
if (blkN_leftover > 0) {
blkN_leftover--;
Expand All @@ -78,7 +78,7 @@ class LinearKsplit2 {
auto start_blkK = 0;
auto num_blk_K = K / reg_blk_K_size;
auto blkK_per_thread = (num_blk_K + 1) / 2;
for (int ik = 0; ik < K_splits; ik++) {
for (size_t ik = 0lu; ik < K_splits; ik++) {
auto blk_K = std::min(num_blk_K - start_blkK, blkK_per_thread);

auto& work = works[ithr + ik];
Expand Down Expand Up @@ -106,7 +106,7 @@ class LinearKsplit2 {

wbuffer.alloc(works, weight_element_size);

ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) {
ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
if (work) {
if (is_quantized) {
Expand All @@ -125,7 +125,7 @@ class LinearKsplit2 {
float * w_scale) {
static ReduceAdd2bh jit_reduce2cvt(true, std::is_same<T, ov::float16>::value);

ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) {
ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
auto& workC = work.m_C;
if (work) {
Expand Down Expand Up @@ -165,6 +165,9 @@ class LinearKsplit2 {
}
});
}

private:
int m_threads_num = 0lu;
};

template<typename T>
Expand Down Expand Up @@ -205,18 +208,18 @@ class LinearGateUp {
// in unit of 32
OPENVINO_ASSERT((N % REG_BLK_N_SIZE) == 0);
OPENVINO_ASSERT((K % reg_blk_K_size) == 0);
auto nthr = parallel_get_max_threads();
m_threads_num = parallel_get_max_threads();
auto num_blk_N = N / REG_BLK_N_SIZE;
works.resize(nthr);
works.resize(m_threads_num);

// split task on more cores is better on TBB
auto valid_nthr = nthr;
auto valid_nthr = m_threads_num;
auto blkN_per_thread = (num_blk_N) / valid_nthr;
auto blkN_leftover = num_blk_N - (blkN_per_thread * valid_nthr);
auto start_blkN = 0;
used_nthr = 0;

for (int ithr = 0; ithr < nthr; ithr ++) {
for (int ithr = 0; ithr < m_threads_num; ithr ++) {
auto blkN = std::min(num_blk_N - start_blkN, blkN_per_thread);
if (blkN_leftover > 0) {
blkN_leftover--;
Expand All @@ -243,7 +246,7 @@ class LinearGateUp {
wbuffer.alloc(works, weight_element_size);

DEBUG_LOG("Linear N,K=", N, ",", K, " used_nthr=", used_nthr);
ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) {
ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
if (work) {
if (quantized_int8)
Expand All @@ -267,7 +270,7 @@ class LinearGateUp {
const LLMMLPNode::Config& config,
MatrixDynQuantPerRow& src_dq,
float * w_scale) {
ov::parallel_nt_static(0, [&](const size_t ithr, const size_t nthr) {
ov::parallel_nt_static(m_threads_num, [&](const size_t ithr, const size_t nthr) {
auto& work = works[ithr];
if (work) {
work.run(M, pA, strideA_in_bytes);
Expand Down Expand Up @@ -303,6 +306,9 @@ class LinearGateUp {
}
});
}

private:
int m_threads_num = 0lu;
};

template<typename T>
Expand Down Expand Up @@ -384,8 +390,8 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
reinterpret_cast<T*>(ptr));
});

auto nthr = parallel_get_max_threads();
for (int ithr = 0; ithr < nthr; ithr++) {
m_threads_num = parallel_get_max_threads();
for (size_t ithr = 0lu; ithr < m_threads_num; ithr++) {
auto C1_size = gate_up.works[ithr].set_C(M, reinterpret_cast<float*>(cur_scratch_base));
auto C2_size = down.works[ithr].set_C(M, reinterpret_cast<float*>(cur_scratch_base));
auto max_C_size = std::max(C1_size, C2_size);
Expand Down Expand Up @@ -482,6 +488,9 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
dstC += BM * strideC / sizeof(T);
}
}

private:
size_t m_threads_num = 0lu;
};
#else
template<typename T>
Expand Down
24 changes: 14 additions & 10 deletions src/plugins/intel_cpu/src/nodes/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ void MHA::prepareParams() {

bool isAMXSupported = mayiuse(avx512_core_amx);

size_t numThreads = parallel_get_max_threads();
m_threads_num = parallel_get_max_threads();

size_t matmulOptimalM = 32;

Expand Down Expand Up @@ -1072,21 +1072,21 @@ void MHA::prepareParams() {
bufferCompensation1Size = rnd_up(N1, N1_blk);

if (brgCopyAKernel0) {
bufferMatMul0In0.resize(numThreads * bufferMatMul0In0Size);
bufferMatMul0In0.resize(m_threads_num * bufferMatMul0In0Size);
}
bufferMatMul0In1.resize(numThreads * bufferMatMul0In1Size);
bufferMatMul0Out.resize(numThreads * bufferMatMul0OutSize);
bufferMatMul1In1.resize(numThreads * bufferMatMul1In1Size);
bufferMatMul1Out.resize(numThreads * bufferMatMul1OutSize);
bufferMatMul0In1.resize(m_threads_num * bufferMatMul0In1Size);
bufferMatMul0Out.resize(m_threads_num * bufferMatMul0OutSize);
bufferMatMul1In1.resize(m_threads_num * bufferMatMul1In1Size);
bufferMatMul1Out.resize(m_threads_num * bufferMatMul1OutSize);
if (brgemmCtx0.is_with_comp) {
bufferCompensation0.resize(numThreads * bufferCompensation0Size);
bufferCompensation0.resize(m_threads_num * bufferCompensation0Size);
}
if (brgemmCtx1.is_with_comp) {
bufferCompensation1.resize(numThreads * bufferCompensation1Size);
bufferCompensation1.resize(m_threads_num * bufferCompensation1Size);
}

if (brgemmCtx0.is_with_amx || brgemmCtx1.is_with_amx) {
wsp.resize(numThreads * wsp_size_per_thread);
wsp.resize(m_threads_num * wsp_size_per_thread);
}

{
Expand Down Expand Up @@ -1224,7 +1224,7 @@ void MHA::mhaImpl() {

auto outPrcSize = outputPrecision.size();

parallel_for2d(dimsMatMul0Out[0], dimsMatMul0Out[1], [&](size_t i0, size_t i1) {
auto spatial_loop = [&](size_t i0, size_t i1) {
size_t threadNum = parallel_get_thread_num();

auto pTranspose0In0_aux = pTranspose0In0 + (i0 * strTranspose0In0[0] + i1 * strTranspose0In0[2]) * inputPrecisions[0].size(); // order 0213
Expand Down Expand Up @@ -1417,6 +1417,10 @@ void MHA::mhaImpl() {
(*convertReorderKernel)(&call_args);
}
}
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_2d(ithr, nthr, dimsMatMul0Out[0], dimsMatMul0Out[1], spatial_loop);
});
}

Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class MHA : public Node {
std::unique_ptr<jit_uni_mul_add_softmax_kernel> mulAddSoftmaxKernel;
std::unique_ptr<jit_uni_convert_reorder_kernel> convertReorderKernel;
std::unique_ptr<jit_uni_convert_transpose_kernel> convertTransposeKernel;

size_t m_threads_num = 0lu;
};

} // namespace node
Expand Down
Loading

0 comments on commit 68088e5

Please sign in to comment.