Skip to content

Commit

Permalink
[CPU][OMP] Fix for Interpolate node
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego committed Oct 22, 2024
1 parent adeb3d2 commit cb17368
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
36 changes: 21 additions & 15 deletions src/plugins/intel_cpu/src/nodes/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2828,24 +2828,27 @@ void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t *in_ptr_
bool xPass = IW != OW;
bool yPass = IH != OH;

parallel_for(B, [&](size_t b) {
auto b_loop = [&](size_t b) {
auto arg = jit_interpolate_call_args();
arg.src_ptr[0] = in_ptr_ + (IW * IH * C * b) * srcDataSize;
if (xPass && yPass) {
size_t threadsNum = parallel_get_num_threads();
size_t parallelNum = B;
size_t parallel_num = B;
// IH * OW * C buf needed
size_t buffer_size = static_cast<size_t>(OW * IH * C);
if (parallelNum < threadsNum) {
if (parallel_num < m_threads_num) {
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[b * buffer_size * srcDataSize]);
} else {
size_t threadsIdx = parallel_get_thread_num();
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]);
size_t threads_idx = parallel_get_thread_num();
arg.src_ptr[1] = static_cast<uint8_t*>(&pillow_working_buf[threads_idx * buffer_size * srcDataSize]);
}
}
arg.dst = out_ptr_ + (OW * OH * C * b) * dstDataSize;
arg.weight_ptr[0] = reinterpret_cast<float*>(&auxTable[2]);
(*interpolateKernel)(&arg);
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_1d(ithr, nthr, B, b_loop);
});
}

Expand Down Expand Up @@ -3706,16 +3709,15 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint
// | |
// | |
// ----
parallel_for2d(B, C, [&](size_t b, size_t c) {
auto bc_loop = [&](size_t b, size_t c) {
const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * b + IW * IH * c) * srcDataSize;
uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * b + OW * OH * c) * dstDataSize;
uint8_t *xpass_out_ptr_nc = nullptr;
const uint8_t *ypass_in_ptr_nc = nullptr;
if (xPass && yPass) {
size_t threadsNum = parallel_get_num_threads();
size_t parallelNum = B * C;
size_t parallel_num = B * C;
// IH * OW buf needed
if (parallelNum < threadsNum) {
if (parallel_num < m_threads_num) {
xpass_out_ptr_nc = static_cast<uint8_t*>(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]);
ypass_in_ptr_nc = static_cast<const uint8_t*>(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]);
} else {
Expand Down Expand Up @@ -3770,23 +3772,27 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint
}
}
}
};

parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) {
for_2d(ithr, nthr, B, C, bc_loop);
});
}

void Interpolate::InterpolateExecutorBase::create_pillow_working_buf(InterpolateLayoutType layout) {
if (srcDimPad5d[3] == dstDim5d[3] || srcDimPad5d[4] == dstDim5d[4])
return;
size_t bufSize = srcDimPad5d[3] * dstDim5d[4] * srcDataSize; // IH * OW
size_t threadsNum = parallel_get_max_threads();
m_threads_num = parallel_get_max_threads();
if (layout == InterpolateLayoutType::planar) {
// B and C execute in parallel, need separate buf
size_t parallelNum = srcDimPad5d[0] * srcDimPad5d[1];
bufSize *= std::min(threadsNum, parallelNum);
size_t parallel_num = srcDimPad5d[0] * srcDimPad5d[1];
bufSize *= std::min(m_threads_num, parallel_num);
} else {
bufSize *= srcDimPad5d[1]; // *C
// B execute in parallel, need separate buf
size_t parallelNum = srcDimPad5d[0];
bufSize *= std::min(threadsNum, parallelNum);
size_t parallel_num = srcDimPad5d[0];
bufSize *= std::min(m_threads_num, parallel_num);
}
pillow_working_buf.resize(bufSize);
}
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class Interpolate : public Node {
size_t dataRank;
std::vector<int> auxTable;
std::vector<uint8_t> pillow_working_buf;
size_t m_threads_num = 0lu;
};
std::shared_ptr<InterpolateExecutorBase> execPtr = nullptr;

Expand Down

0 comments on commit cb17368

Please sign in to comment.