Skip to content

Commit

Permalink
[CPU][OMP] Fix for Interpolate node (openvinotoolkit#27184)
Browse files Browse the repository at this point in the history
### Details:
 - *Fix threads number*
 - *...*

### Tickets:
 - *152606*
  • Loading branch information
nshchego authored and CuriousPanCake committed Nov 6, 2024
1 parent 7934965 commit 00a1f74
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
24 changes: 18 additions & 6 deletions src/core/include/openvino/core/parallel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,10 @@ void parallel_for(const T0& D0, const F& func) {
for_1d(ithr, nthr, D0, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_1d(parallel_get_thread_num(), parallel_get_num_threads(), D0, func);
{ for_1d(parallel_get_thread_num(), parallel_get_num_threads(), D0, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_1d(0, 1, D0, func);
#endif
Expand Down Expand Up @@ -509,8 +511,10 @@ void parallel_for2d(const T0& D0, const T1& D1, const F& func) {
for_2d(ithr, nthr, D0, D1, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_2d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, func);
{ for_2d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_2d(0, 1, D0, D1, func);
#endif
Expand Down Expand Up @@ -575,8 +579,10 @@ void parallel_for3d(const T0& D0, const T1& D1, const T2& D2, const F& func) {
for_3d(ithr, nthr, D0, D1, D2, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_3d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, func);
{ for_3d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_3d(0, 1, D0, D1, D2, func);
#endif
Expand Down Expand Up @@ -645,8 +651,10 @@ void parallel_for4d(const T0& D0, const T1& D1, const T2& D2, const T3& D3, cons
for_4d(ithr, nthr, D0, D1, D2, D3, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_4d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, func);
{ for_4d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_4d(0, 1, D0, D1, D2, D3, func);
#endif
Expand Down Expand Up @@ -703,8 +711,10 @@ void parallel_for5d(const T0& D0, const T1& D1, const T2& D2, const T3& D3, cons
for_5d(ithr, nthr, D0, D1, D2, D3, D4, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_5d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, D4, func);
{ for_5d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, D4, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_5d(0, 1, D0, D1, D2, D3, D4, func);
#endif
Expand Down Expand Up @@ -763,8 +773,10 @@ void parallel_for6d(const T0& D0, const T1& D1, const T2& D2, const T3& D3, cons
for_6d(ithr, nthr, D0, D1, D2, D3, D4, D5, func);
});
#elif OV_THREAD == OV_THREAD_OMP
// Please note that this function does not guarantee execution on the same number of threads from call to call.
// Use the parallel_nt* functions if the procedure depends on a certain number of threads.
# pragma omp parallel
for_6d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, D4, D5, func);
{ for_6d(parallel_get_thread_num(), parallel_get_num_threads(), D0, D1, D2, D3, D4, D5, func); }
#elif OV_THREAD == OV_THREAD_SEQ
for_6d(0, 1, D0, D1, D2, D3, D4, D5, func);
#endif
Expand Down
36 changes: 21 additions & 15 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2828,24 +2828,27 @@ void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t *in_ptr_
bool xPass = IW != OW;
bool yPass = IH != OH;

parallel_for(B, [&](size_t b) {
auto b_loop = [&](size_t b) {
auto arg = jit_interpolate_call_args();
arg.src_ptr[0] = in_ptr_ + (IW * IH * C * b) * srcDataSize;
if (xPass && yPass) {
size_t threadsNum = parallel_get_num_threads();
size_t parallelNum = B;
size_t parallel_num = B;
// IH * OW * C buf needed
size_t buffer_size = static_cast<size_t>(OW * IH * C);
if (parallelNum < threadsNum) {
if (parallel_num < m_threads_num) {
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[b * buffer_size * srcDataSize]);
} else {
size_t threadsIdx = parallel_get_thread_num();
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]);
size_t threads_idx = parallel_get_thread_num();
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[threads_idx * buffer_size * srcDataSize]);
}
}
arg.dst = out_ptr_ + (OW * OH * C * b) * dstDataSize;
arg.weight_ptr[0] = reinterpret_cast<float*>(&auxTable[2]);
(*interpolateKernel)(&arg);
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_1d(ithr, nthr, B, b_loop);
});
}

Expand Down Expand Up @@ -3706,16 +3709,15 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint
// | |
// | |
// ----
parallel_for2d(B, C, [&](size_t b, size_t c) {
auto bc_loop = [&](size_t b, size_t c) {
const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * b + IW * IH * c) * srcDataSize;
uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * b + OW * OH * c) * dstDataSize;
uint8_t *xpass_out_ptr_nc = nullptr;
const uint8_t *ypass_in_ptr_nc = nullptr;
if (xPass && yPass) {
size_t threadsNum = parallel_get_num_threads();
size_t parallelNum = B * C;
size_t parallel_num = B * C;
// IH * OW buf needed
if (parallelNum < threadsNum) {
if (parallel_num < m_threads_num) {
xpass_out_ptr_nc = static_cast<uint8_t*>(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]);
ypass_in_ptr_nc = static_cast<const uint8_t*>(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]);
} else {
Expand Down Expand Up @@ -3770,23 +3772,27 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint
}
}
}
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_2d(ithr, nthr, B, C, bc_loop);
});
}

void Interpolate::InterpolateExecutorBase::create_pillow_working_buf(InterpolateLayoutType layout) {
if (srcDimPad5d[3] == dstDim5d[3] || srcDimPad5d[4] == dstDim5d[4])
return;
size_t bufSize = srcDimPad5d[3] * dstDim5d[4] * srcDataSize; // IH * OW
size_t threadsNum = parallel_get_max_threads();
m_threads_num = parallel_get_max_threads();
if (layout == InterpolateLayoutType::planar) {
// B and C execute in parallel, need separate buf
size_t parallelNum = srcDimPad5d[0] * srcDimPad5d[1];
bufSize *= std::min(threadsNum, parallelNum);
size_t parallel_num = srcDimPad5d[0] * srcDimPad5d[1];
bufSize *= std::min(m_threads_num, parallel_num);
} else {
bufSize *= srcDimPad5d[1]; // *C
// B execute in parallel, need separate buf
size_t parallelNum = srcDimPad5d[0];
bufSize *= std::min(threadsNum, parallelNum);
size_t parallel_num = srcDimPad5d[0];
bufSize *= std::min(m_threads_num, parallel_num);
}
pillow_working_buf.resize(bufSize);
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class Interpolate : public Node {
size_t dataRank;
std::vector<int> auxTable;
std::vector<uint8_t> pillow_working_buf;
size_t m_threads_num = 0lu;
};
std::shared_ptr<InterpolateExecutorBase> execPtr = nullptr;

Expand Down

0 comments on commit 00a1f74

Please sign in to comment.