diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.cpp index f3189ba6332201..73fe0bab6eb802 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.cpp @@ -7,7 +7,9 @@ #include #include #include +#include #include "ie_parallel.hpp" +#include "utils/bfloat16.hpp" using namespace mkldnn; @@ -133,6 +135,12 @@ void MKLDNNTransposeNode::createPrimitive() { if (getSelectedPrimitiveDescriptor() == nullptr) IE_THROW() << "Preferable primitive descriptor is not set."; + if (getParentEdgeAt(0)->getMemory().GetDesc().isPlainFormat() && + std::find(optimizedOrders.begin(), optimizedOrders.end(), order) != optimizedOrders.end()) { + isOptimized = true; + return; + } + PermuteParams params; params.data_size = getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc.getPrecision().size(); params.order = order; @@ -148,508 +156,123 @@ void MKLDNNTransposeNode::createPrimitive() { permuteKernel = std::unique_ptr(new PermuteKernel(params)); } -static void transpose_to_0231(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - // Supports only NCHW to NHWC - int block_size = 1; - if (!srcMemPtr->GetDesc().isPlainFormat()) { - const auto &blk_desc = srcMemPtr->GetDescriptor().data.format_desc.blocking; - auto found = std::find(blk_desc.inner_idxs, blk_desc.inner_idxs + blk_desc.inner_nblks, 1); - auto pos = std::distance(found, blk_desc.inner_idxs); - block_size = blk_desc.inner_blks[pos]; - } - - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - // NHWC - const int src_stride = H * W * block_size; - - parallel_for3d(MB, H, W, [&](int n, int h, int w) { - int src_off = n * C * H * W + (h * W + w) * block_size; - int dst_off = n * H * W * C + h * W * C + w * C; - - for (int c = 0; c < C; c += block_size) { - for (int b = 0; b < block_size; b++) { - dst_data[dst_off] = src_data[src_off + b]; - dst_off++; - } - - src_off += src_stride; - } - }); -} - -static void transpose_to_0213(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - int block_size = 1; - if (!srcMemPtr->GetDesc().isPlainFormat()) { - const auto &blk_desc = srcMemPtr->GetDescriptor().data.format_desc.blocking; - auto found = std::find(blk_desc.inner_idxs, blk_desc.inner_idxs + blk_desc.inner_nblks, 1); - auto pos = std::distance(found, blk_desc.inner_idxs); - block_size = blk_desc.inner_blks[pos]; - } - - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - parallel_for3d(MB, C/block_size, H, [&](int n, int c, int h) { - for (int w = 0; w < W; w++) { - int src_off = n*C*H*W + (c*H*W + h*W + w)*block_size; - int dst_off = n*C*H*W + (h*C*W + w + c*W)*block_size; - for (int b = 0; b < block_size; b++) { - dst_data[dst_off + b] = src_data[src_off + b]; - } - } - }); -} - -static void transpose_to_0312(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - parallel_for3d(MB, C, H, [&](int n, int c, int h) { - for (int w = 0; w < W; w++) { - int src_off = n*C*H*W + c*H*W + h*W + w; - int dst_off = n*W*C*H + w*C*H + c*H + h; - dst_data[dst_off] = src_data[src_off]; - } - }); -} - -template -static void transpose_to_014253(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int CH = scale_H > 0 ? static_cast(scale_H) : srcMemPtr->GetDims()[2]; - const int CW = scale_W > 0 ? static_cast(scale_W) : srcMemPtr->GetDims()[3]; - const int H = srcMemPtr->GetDims()[4]; - const int W = srcMemPtr->GetDims()[5]; - - int src_off = 0; - int dst_off = 0; - - for (int n = 0; n < MB; n++) { - for (int c = 0; c < C; c++) { - for (int h = 0; h < H; h++) { - for (int ch = 0; ch < CH; ch++) { - for (int w = 0; w < W; w++) { - for (int cw = 0; cw < CW; cw++) { - src_off = n * C * CH * CW * H * W + - c * CH * CW * H * W + - ch * CW * H * W + - cw * H * W + - h * W + - w; - - dst_data[dst_off] = src_data[src_off]; - dst_off++; - } - } - } - } - } - } -} - -static void transpose_to_3012(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - int src_off = 0; - int dst_off = 0; - - for (int w = 0; w < W; w++) { - for (int n = 0; n < MB; n++) { - for (int c = 0; c < C; c++) { - for (int h = 0; h < H; h++) { - src_off = n * C * H * W + - c * H * W + - h * W + - w; - - dst_data[dst_off] = src_data[src_off]; - dst_off++; - } - } - } - } -} - -static void transpose_to_021(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int S = srcMemPtr->GetDims()[2]; - - parallel_for2d(MB, S, [&](int n, int s) { - int src_off = 0; - int dst_off = 0; - - for (int c = 0; c < C; c++) { - src_off = n * C * S + - c * S + - s; - dst_off = n * S * C + - s * C + - c; - - dst_data[dst_off] = src_data[src_off]; - } - }); -} - -static void transpose_to_034152(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); +template +static void transpose_to_0312(const int MB, const MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + const auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); + auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); const int DIM1 = srcMemPtr->GetDims()[1]; const int DIM2 = srcMemPtr->GetDims()[2]; const int DIM3 = srcMemPtr->GetDims()[3]; - const int DIM4 = srcMemPtr->GetDims()[4]; - const int DIM5 = srcMemPtr->GetDims()[5]; - - int src_off = 0; - int dst_off = 0; - - for (int n = 0; n < MB; n++) { - for (int dim3 = 0; dim3 < DIM3; dim3++) { - for (int dim4 = 0; dim4 < DIM4; dim4++) { - for (int dim1 = 0; dim1 < DIM1; dim1++) { - for (int dim5 = 0; dim5 < DIM5; dim5++) { - for (int dim2 = 0; dim2 < DIM2; dim2++) { - src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 + - dim1 * DIM2 * DIM3 * DIM4 * DIM5 + - dim2 * DIM3 * DIM4 * DIM5 + - dim3 * DIM4 * DIM5 + - dim4 * DIM5 + - dim5; - - dst_data[dst_off] = src_data[src_off]; - dst_off++; - } - } - } - } - } - } -} - -static void transpose_to_0132(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - int src_block_size = 1; - if (!srcMemPtr->GetDesc().isPlainFormat()) { - const auto &blk_desc = srcMemPtr->GetDescriptor().data.format_desc.blocking; - auto found = std::find(blk_desc.inner_idxs, blk_desc.inner_idxs + blk_desc.inner_nblks, 1); - auto pos = std::distance(found, blk_desc.inner_idxs); - src_block_size = blk_desc.inner_blks[pos]; - } - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - parallel_for3d(MB, C/src_block_size, H, [&](int n, int c, int h) { - for (int w = 0; w < W; w++) { - int src_off = n*C*H*W + (c*H*W + h*W + w)*src_block_size; - int dst_off = n*C*H*W + c*H*W*src_block_size + w*H + h; - for (int b = 0; b < src_block_size; b++) { - dst_data[dst_off + b*H*W] = src_data[src_off + b]; - } - } - }); -} - -static void transpose_to_03142(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); + parallel_for3d(MB, DIM1, DIM2, [&](const int n, const int dim1, const int dim2) { + for (int dim3 = 0; dim3 < DIM3; ++dim3) { + const int src_off = n * DIM1 * DIM2 * DIM3 + + dim1 * DIM2 * DIM3 + + dim2 * DIM3 + + dim3; + const int dst_off = n * DIM1 * DIM2 * DIM3 + + dim3 * DIM1 * DIM2 + + dim1 * DIM2 + + dim2; - const int DIM1 = srcMemPtr->GetDims()[1]; - const int DIM2 = srcMemPtr->GetDims()[2]; - const int DIM3 = srcMemPtr->GetDims()[3]; - const int DIM4 = srcMemPtr->GetDims()[4]; - - int src_off = 0; - int dst_off = 0; - - for (int n = 0; n < MB; n++) { - for (int dim3 = 0; dim3 < DIM3; dim3++) { - for (int dim1 = 0; dim1 < DIM1; dim1++) { - for (int dim4 = 0; dim4 < DIM4; dim4++) { - for (int dim2 = 0; dim2 < DIM2; dim2++) { - src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - - dst_data[dst_off] = src_data[src_off]; - dst_off++; - } - } - } - } - } -} - -static void transpose_to_1203(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int H = srcMemPtr->GetDims()[2]; - const int W = srcMemPtr->GetDims()[3]; - - parallel_for3d(MB, C, H, [&](int n, int c, int h) { - for (int w = 0; w < W; w++) { - int src_off = n * C * H * W + c * H * W + h * W + w; - int dst_off = c * H * MB * W + h * MB * W + n * W + w; dst_data[dst_off] = src_data[src_off]; } }); } -static void transpose_to_02134(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); +template +static void transpose_to_04123(const int MB, const MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + const auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); + auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); const int DIM1 = srcMemPtr->GetDims()[1]; const int DIM2 = srcMemPtr->GetDims()[2]; const int DIM3 = srcMemPtr->GetDims()[3]; const int DIM4 = srcMemPtr->GetDims()[4]; - parallel_for4d(MB, DIM2, DIM1, DIM3, [&](int n, int dim2, int dim1, int dim3) { - for (int dim4 = 0; dim4 < DIM4; dim4++) { - int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - int dst_off = n * DIM2 * DIM1 * DIM3 * DIM4 + - dim2 * DIM1 * DIM3 * DIM4 + - dim1 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; + parallel_for4d(MB, DIM1, DIM2, DIM3, [&](const int n, const int dim1, const int dim2, const int dim3) { + for (int dim4 = 0; dim4 < DIM4; ++dim4) { + const int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + + dim1 * DIM2 * DIM3 * DIM4 + + dim2 * DIM3 * DIM4 + + dim3 * DIM4 + + dim4; + const int dst_off = n * DIM1 * DIM2 * DIM3 * DIM4 + + dim4 * DIM1 * DIM2 * DIM3 + + dim1 * DIM2 * DIM3 + + dim2 * DIM3 + + dim3; dst_data[dst_off] = src_data[src_off]; } }); } -static void transpose_to_02431(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int DIM1 = srcMemPtr->GetDims()[1]; - const int DIM2 = srcMemPtr->GetDims()[2]; - const int DIM3 = srcMemPtr->GetDims()[3]; - const int DIM4 = srcMemPtr->GetDims()[4]; - - parallel_for4d(MB, DIM2, DIM4, DIM3, [&](int n, int dim2, int dim4, int dim3) { - for (int dim1 = 0; dim1 < DIM1; dim1++) { - int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - int dst_off = n * DIM2 * DIM4 * DIM3 * DIM1 + - dim2 * DIM4 * DIM3 * DIM1 + - dim4 * DIM3 * DIM1 + - dim3 * DIM1 + - dim1; - - dst_data[dst_off] = src_data[src_off]; - } - }); -} - -static void transpose_to_04231(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int DIM1 = srcMemPtr->GetDims()[1]; - const int DIM2 = srcMemPtr->GetDims()[2]; - const int DIM3 = srcMemPtr->GetDims()[3]; - const int DIM4 = srcMemPtr->GetDims()[4]; - - parallel_for4d(MB, DIM4, DIM2, DIM3, [&](int n, int dim4, int dim2, int dim3) { - for (int dim1 = 0; dim1 < DIM1; dim1++) { - int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - int dst_off = n * DIM4 * DIM2 * DIM3 * DIM1 + - dim4 * DIM2 * DIM3 * DIM1 + - dim2 * DIM3 * DIM1 + - dim3 * DIM1 + - dim1; - - dst_data[dst_off] = src_data[src_off]; - } - }); -} - -static void transpose_to_102(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int C = srcMemPtr->GetDims()[1]; - const int S = srcMemPtr->GetDims()[2]; - - parallel_for2d(MB, S, [&](int n, int s) { - int src_off = 0; - int dst_off = 0; - - for (int c = 0; c < C; c++) { - src_off = n * C * S + - c * S + - s; - dst_off = c * MB * S + - n * S + - s; - - dst_data[dst_off] = src_data[src_off]; - } - }); -} - -static void transpose_to_02341(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); +template +static void transpose_to_051234(const int MB, const MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + const auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); + auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); const int DIM1 = srcMemPtr->GetDims()[1]; const int DIM2 = srcMemPtr->GetDims()[2]; const int DIM3 = srcMemPtr->GetDims()[3]; const int DIM4 = srcMemPtr->GetDims()[4]; + const int DIM5 = srcMemPtr->GetDims()[5]; - parallel_for4d(MB, DIM2, DIM3, DIM4, [&](int n, int dim2, int dim3, int dim4) { - for (int dim1 = 0; dim1 < DIM1; dim1++) { - int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - int dst_off = n * DIM2 * DIM3 * DIM4 * DIM1 + - dim2 * DIM3 * DIM4 * DIM1 + - dim3 * DIM4 * DIM1 + - dim4 * DIM1 + - dim1; + parallel_for5d(MB, DIM1, DIM2, DIM3, DIM4, [&](const int n, const int dim1, const int dim2, const int dim3, const int dim4) { + for (int dim5 = 0; dim5 < DIM5; ++dim5) { + const int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 + + dim1 * DIM2 * DIM3 * DIM4 * DIM5 + + dim2 * DIM3 * DIM4 * DIM5 + + dim3 * DIM4 * DIM5 + + dim4 * DIM5 + + dim5; + const int dst_off = n * DIM5 * DIM1 * DIM2 * DIM3 * DIM4 + + dim5 * DIM1 * DIM2 * DIM3 * DIM4 + + dim1 * DIM2 * DIM3 * DIM4 + + dim2 * DIM3 * DIM4 + + dim3 * DIM4 + + dim4; dst_data[dst_off] = src_data[src_off]; } }); } -static void transpose_to_04123(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - auto src_data = reinterpret_cast(srcMemPtr->GetPtr()); - auto dst_data = reinterpret_cast(dstMemPtr->GetPtr()); - - const int DIM1 = srcMemPtr->GetDims()[1]; - const int DIM2 = srcMemPtr->GetDims()[2]; - const int DIM3 = srcMemPtr->GetDims()[3]; - const int DIM4 = srcMemPtr->GetDims()[4]; - - parallel_for4d(MB, DIM4, DIM1, DIM2, [&](int n, int dim4, int dim1, int dim2) { - for (int dim3 = 0; dim3 < DIM3; dim3++) { - int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - int dst_off = n * DIM4 * DIM1 * DIM2 * DIM3 + - dim4 * DIM1 * DIM2 * DIM3 + - dim1 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; - - dst_data[dst_off] = src_data[src_off]; - } - }); +template +void MKLDNNTransposeNode::optimizedExecute(const int MB, const MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { + switch (srcMemPtr->GetDims().size()) { + case 4: + transpose_to_0312(MB, srcMemPtr, dstMemPtr); + break; + case 5: + transpose_to_04123(MB, srcMemPtr, dstMemPtr); + break; + case 6: + transpose_to_051234(MB, srcMemPtr, dstMemPtr); + break; + default: + IE_THROW() << "Transpose '" << getName() << "' supports optimized execution with only 4D, 5D and 6D shapes"; + } } -const std::multimap MKLDNNTransposeNode::OptimizedCases = { - {{0, 2, 3, 1}, MKLDNNTransposeNode::TransposeImpl(transpose_to_0231, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return true; - })}, // NCHW -> NHWC case - {{0, 1, 4, 2, 5, 3}, MKLDNNTransposeNode::TransposeImpl(transpose_to_014253<2, 2>, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat() && srcMemPtr->GetDims()[2] == 2 && srcMemPtr->GetDims()[3] == 2; - })}, // Dense upsample convolution case (scale = 2) - {{0, 1, 4, 2, 5, 3}, MKLDNNTransposeNode::TransposeImpl(transpose_to_014253<0, 0>, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, // Dense upsample convolution case (generic) - {{3, 0, 1, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_3012, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat() && MB == srcMemPtr->GetDims()[0]; - })}, // LPR case - {{0, 2, 1, 3}, MKLDNNTransposeNode::TransposeImpl(transpose_to_0213, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, // shufflenet - {{0, 2, 1}, MKLDNNTransposeNode::TransposeImpl(transpose_to_021, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, // self attention block - {{0, 3, 4, 1, 5, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_034152, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, // learning-to-see-in-the-dark-sony - {{0, 1, 3, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_0132, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return true; - })}, - {{0, 3, 1, 4, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_03142, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{1, 2, 0, 3}, MKLDNNTransposeNode::TransposeImpl(transpose_to_1203, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat() && MB == srcMemPtr->GetDims()[0]; - })}, - {{0, 2, 1, 3, 4}, MKLDNNTransposeNode::TransposeImpl(transpose_to_02134, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{0, 2, 4, 3, 1}, MKLDNNTransposeNode::TransposeImpl(transpose_to_02431, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{0, 4, 2, 3, 1}, MKLDNNTransposeNode::TransposeImpl(transpose_to_04231, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{0, 3, 1, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_0312, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{1, 0, 2}, MKLDNNTransposeNode::TransposeImpl(transpose_to_102, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat() && MB == srcMemPtr->GetDims()[0]; - })}, - {{0, 2, 3, 4, 1}, MKLDNNTransposeNode::TransposeImpl(transpose_to_02341, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, - {{0, 4, 1, 2, 3}, MKLDNNTransposeNode::TransposeImpl(transpose_to_04123, [](int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) { - return srcMemPtr->GetDesc().isPlainFormat(); - })}, -}; - void MKLDNNTransposeNode::execute(mkldnn::stream strm) { auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); int MB = batchToProcess(); - if (prec == Precision::FP32 && !getParentEdgeAt(0)->getMemory().GetDesc().isTailCFormat()) { - for (const auto &impl : OptimizedCases) { - if (impl.first == order && impl.second.isValidParams(MB, srcMemPtr, dstMemPtr)) { - impl.second.execute(MB, srcMemPtr, dstMemPtr); - return; - } - } + if (isOptimized) { + const auto precision = getParentEdgeAt(0)->getDesc().getPrecision(); + TransposeContext ctx = {this, srcMemPtr, dstMemPtr, MB}; + OV_SWITCH(MKLDNNPlugin, TransposeOptimizedEmitter, ctx, precision, + OV_CASE(InferenceEngine::Precision::FP32, float), + OV_CASE(InferenceEngine::Precision::I32, int32_t), + OV_CASE(InferenceEngine::Precision::BF16, bfloat16_t), + OV_CASE(InferenceEngine::Precision::I8, int8_t), + OV_CASE(InferenceEngine::Precision::U8, uint8_t)); + + return; } const uint8_t* srcData = reinterpret_cast(srcMemPtr->GetPtr()); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.h index 7ba21e5ba829c7..6de2cefa5d8cab 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_transpose_node.h @@ -34,20 +34,33 @@ class MKLDNNTransposeNode : public MKLDNNNode { } private: + template void optimizedExecute(const int MB, const MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr); + InferenceEngine::SizeVector order; InferenceEngine::Precision prec; + bool isOptimized = false; - typedef std::function transposeImpl; - typedef std::function isApplicable; - struct TransposeImpl { - TransposeImpl(transposeImpl f0, isApplicable f1): execute(std::move(f0)), isValidParams(std::move(f1)) {} - - transposeImpl execute; - isApplicable isValidParams; + const std::vector> optimizedOrders = { + std::vector{0, 3, 1, 2}, + std::vector{0, 4, 1, 2, 3}, + std::vector{0, 5, 1, 2, 3, 4}, }; - static const std::multimap OptimizedCases; std::unique_ptr permuteKernel; + + struct TransposeContext { + MKLDNNTransposeNode* nodePtr; + MKLDNNMemoryPtr srcMemPtr; + MKLDNNMemoryPtr dstMemPtr; + int MB; + }; + + template + struct TransposeOptimizedEmitter { + void operator()(TransposeContext& ctx) { + ctx.nodePtr->optimizedExecute(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); + } + }; }; } // namespace MKLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/transpose.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/transpose.cpp index ee508a38ca7005..d7f10c16572d84 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/transpose.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/transpose.cpp @@ -56,7 +56,7 @@ INSTANTIATE_TEST_CASE_P(smoke_Transpose4D, TransposeLayerTest, std::vector> inputShape5D = {{2, 2, 2, 2, 2}, {1, 10, 2, 3, 4}, {2, 3, 4, 5, 6}}; std::vector> order5D = { {}, {0, 1, 2, 3, 4}, {1, 0, 2, 3, 4}, {4, 3, 2, 1, 0}, {0, 2, 3, 4, 1}, - {1, 4, 2, 3, 0}, {2, 4, 1, 0, 3}, {3, 0, 2, 1, 4}, {4, 1, 0, 3, 2} + {1, 4, 2, 3, 0}, {2, 4, 1, 0, 3}, {3, 0, 2, 1, 4}, {4, 1, 0, 3, 2}, {0, 4, 1, 2, 3}, }; INSTANTIATE_TEST_CASE_P(smoke_Transpose5D, TransposeLayerTest, @@ -70,4 +70,23 @@ INSTANTIATE_TEST_CASE_P(smoke_Transpose5D, TransposeLayerTest, ::testing::ValuesIn(inputShape5D), ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeLayerTest::getTestCaseName); + +std::vector> inputShape6D = {{2, 2, 2, 2, 2, 2}, {1, 10, 2, 3, 4, 5}, {2, 3, 4, 5, 6, 7}}; +std::vector> order6D = { + {}, {0, 1, 2, 3, 4, 5}, {1, 0, 2, 3, 4, 5}, {5, 4, 3, 2, 1, 0}, {0, 2, 3, 4, 5, 1}, + {1, 5, 4, 2, 3, 0}, {2, 5, 4, 1, 0, 3}, {3, 0, 2, 1, 4, 5}, {5, 1, 0, 4, 3, 2}, {0, 5, 1, 2, 3, 4}, +}; + +INSTANTIATE_TEST_CASE_P(smoke_Transpose6D, TransposeLayerTest, + ::testing::Combine( + ::testing::ValuesIn(order6D), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Precision::UNSPECIFIED), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::Values(InferenceEngine::Layout::ANY), + ::testing::ValuesIn(inputShape6D), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeLayerTest::getTestCaseName); + } // namespace