From cb17368067d1c5d36ccc3cfdede4ed1deac01296 Mon Sep 17 00:00:00 2001 From: Nikolai Shchegolev Date: Mon, 21 Oct 2024 15:50:56 +0400 Subject: [PATCH] [CPU][OMP] Fix for Interpolate node --- .../intel_cpu/src/nodes/interpolate.cpp | 36 +++++++++++-------- src/plugins/intel_cpu/src/nodes/interpolate.h | 1 + 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.cpp b/src/plugins/intel_cpu/src/nodes/interpolate.cpp index ee6afa33827861..7eed5c1df9789b 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/interpolate.cpp @@ -2828,24 +2828,27 @@ void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t *in_ptr_ bool xPass = IW != OW; bool yPass = IH != OH; - parallel_for(B, [&](size_t b) { + auto b_loop = [&](size_t b) { auto arg = jit_interpolate_call_args(); arg.src_ptr[0] = in_ptr_ + (IW * IH * C * b) * srcDataSize; if (xPass && yPass) { - size_t threadsNum = parallel_get_num_threads(); - size_t parallelNum = B; + size_t parallel_num = B; // IH * OW * C buf needed size_t buffer_size = static_cast(OW * IH * C); - if (parallelNum < threadsNum) { + if (parallel_num < m_threads_num) { arg.src_ptr[1] = static_cast(&pillow_working_buf[b * buffer_size * srcDataSize]); } else { - size_t threadsIdx = parallel_get_thread_num(); - arg.src_ptr[1] = static_cast(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]); + size_t threads_idx = parallel_get_thread_num(); + arg.src_ptr[1] = static_cast(&pillow_working_buf[threads_idx * buffer_size * srcDataSize]); } } arg.dst = out_ptr_ + (OW * OH * C * b) * dstDataSize; arg.weight_ptr[0] = reinterpret_cast(&auxTable[2]); (*interpolateKernel)(&arg); + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_1d(ithr, nthr, B, b_loop); }); } @@ -3706,16 +3709,15 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint // | | // | | // ---- - parallel_for2d(B, C, [&](size_t b, size_t c) { + auto bc_loop = [&](size_t b, size_t c) { const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * b + IW * IH * c) * srcDataSize; uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * b + OW * OH * c) * dstDataSize; uint8_t *xpass_out_ptr_nc = nullptr; const uint8_t *ypass_in_ptr_nc = nullptr; if (xPass && yPass) { - size_t threadsNum = parallel_get_num_threads(); - size_t parallelNum = B * C; + size_t parallel_num = B * C; // IH * OW buf needed - if (parallelNum < threadsNum) { + if (parallel_num < m_threads_num) { xpass_out_ptr_nc = static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); ypass_in_ptr_nc = static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); } else { @@ -3770,6 +3772,10 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint } } } + }; + + parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { + for_2d(ithr, nthr, B, C, bc_loop); }); } @@ -3777,16 +3783,16 @@ void Interpolate::InterpolateExecutorBase::create_pillow_working_buf(Interpolate if (srcDimPad5d[3] == dstDim5d[3] || srcDimPad5d[4] == dstDim5d[4]) return; size_t bufSize = srcDimPad5d[3] * dstDim5d[4] * srcDataSize; // IH * OW - size_t threadsNum = parallel_get_max_threads(); + m_threads_num = parallel_get_max_threads(); if (layout == InterpolateLayoutType::planar) { // B and C execute in parallel, need separate buf - size_t parallelNum = srcDimPad5d[0] * srcDimPad5d[1]; - bufSize *= std::min(threadsNum, parallelNum); + size_t parallel_num = srcDimPad5d[0] * srcDimPad5d[1]; + bufSize *= std::min(m_threads_num, parallel_num); } else { bufSize *= srcDimPad5d[1]; // *C // B execute in parallel, need separate buf - size_t parallelNum = srcDimPad5d[0]; - bufSize *= std::min(threadsNum, parallelNum); + size_t parallel_num = srcDimPad5d[0]; + bufSize *= std::min(m_threads_num, parallel_num); } pillow_working_buf.resize(bufSize); } diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.h b/src/plugins/intel_cpu/src/nodes/interpolate.h index 11f0e3104e5085..a43b354aa0306a 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.h +++ b/src/plugins/intel_cpu/src/nodes/interpolate.h @@ -148,6 +148,7 @@ class Interpolate : public Node { size_t dataRank; std::vector auxTable; std::vector pillow_working_buf; + size_t m_threads_num = 0lu; }; std::shared_ptr execPtr = nullptr;