From e228757b35a940a83c44c8701f44731f1801b7d8 Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Fri, 2 Apr 2021 13:32:27 +0300 Subject: [PATCH] Add bf16 support for ROI pooling --- .../nodes/mkldnn_roi_pooling_node.cpp | 197 +++++++++++++----- .../nodes/mkldnn_roi_pooling_node.h | 16 +- .../cpu/single_layer_tests/roi_pooling.cpp | 162 ++++++++++++++ 3 files changed, 322 insertions(+), 53 deletions(-) create mode 100644 inference-engine/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp index c0033280eeb7eb..38c19b783ea784 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp @@ -4,14 +4,20 @@ #include "mkldnn_roi_pooling_node.h" -#include #include +#include +#include + +#include +#include "ie_parallel.hpp" +#include "utils/bfloat16.hpp" +#include "emitters/jit_load_store_emitters.hpp" + +#include + #include #include #include -#include -#include -#include "ie_parallel.hpp" using namespace MKLDNNPlugin; using namespace InferenceEngine; @@ -25,7 +31,7 @@ using namespace Xbyak; template struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32) + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32); explicit jit_uni_roi_pooling_kernel_f32(jit_roi_pooling_params jcp) : jit_uni_roi_pooling_kernel(jcp), jit_generator() {} @@ -35,6 +41,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi }; void generate() override { + load_emitter.reset(new jit_load_emitter(this, isa, nullptr)); + store_emitter.reset(new jit_store_emitter(this, isa, nullptr)); + this->preamble(); Label exit_label; @@ -42,7 +51,6 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi mov(reg_input, ptr[this->param1 + GET_OFF(src)]); mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_bin_area, ptr[this->param1 + GET_OFF(bin_area)]); mov(reg_c_blocks, ptr[this->param1 + GET_OFF(c_blocks)]); @@ -71,6 +79,9 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi L(exit_label); this->postamble(); + + load_emitter->emit_data(); + store_emitter->emit_data(); } private: @@ -78,6 +89,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Xbyak::Ymm, Xbyak::Zmm>::type; const int vlen = cpu_isa_traits::vlen; + const int step = vlen / sizeof(float); Vmm vmm_mask = Vmm(0); Vmm vmm_zero = Vmm(0); @@ -87,6 +99,12 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Xmm xmm_xf = Xmm(1); Vmm vmm_xf = Vmm(1); + std::unique_ptr load_emitter = nullptr; + std::unique_ptr store_emitter = nullptr; + + std::vector store_pool_gpr_idxs; + std::vector store_pool_vec_idxs; + Vmm get_acc_reg(int idx) { return Vmm(2*idx + 1); } Vmm get_src_reg(int idx) { return Vmm(2*idx + 2); } @@ -114,15 +132,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi reg64_t reg_yoff = h_iter; reg64_t reg_xoff = r12; + Xbyak::Reg64 reg_load_table = r13; + Xbyak::Reg64 reg_load_store_mask = rcx; + void roi_pool_max(int c_blocks) { Label h_loop_label; Label w_loop_label; mov(aux_reg_input, reg_input); + int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size; for (int i = 0; i < c_blocks; i++) { Vmm vmm_max = get_acc_reg(i); - uni_vmovups(vmm_max, ptr[reg_input + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]); + + load_emitter->emit_code({static_cast(reg_input.getIdx())}, {static_cast(vmm_max.getIdx())}, + std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off), + {}, {}); } xor_(h_iter, h_iter); @@ -134,7 +159,10 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Vmm vmm_max = get_acc_reg(i); Vmm vmm_src = get_src_reg(i); - uni_vmovups(vmm_src, ptr[aux_reg_input1 + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]); + load_emitter->emit_code({static_cast(aux_reg_input1.getIdx())}, {static_cast(vmm_src.getIdx())}, + std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off), + {}, {}); + if (isa == cpu::x64::sse41) { movups(vmm_mask, vmm_max); cmpps(vmm_mask, vmm_src, _cmp_lt_os); @@ -148,23 +176,27 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi } } - add(aux_reg_input1, jpp_.c_block * sizeof(float)); + add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size); inc(w_iter); cmp(w_iter, reg_kw); jl(w_loop_label, T_NEAR); } - add(aux_reg_input, jpp_.iw * jpp_.c_block * sizeof(float)); + add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size); inc(h_iter); cmp(h_iter, reg_kh); jl(h_loop_label, T_NEAR); } + int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; for (int i = 0; i < c_blocks; i++) { Vmm vmm_dst = get_acc_reg(i); - uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_dst); + + store_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_output.getIdx())}, + std::make_shared(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off), + {}, {}); } } @@ -180,17 +212,29 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi Vmm vmm_src11 = get_src_reg(3); for (int i = 0; i < c_blocks; i++) { - int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float); + int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size; + auto load_context = std::make_shared(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off); mov(aux_reg_input, reg_input); - uni_vmovups(vmm_src00, ptr[aux_reg_input + src_c_off]); + + load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src00.getIdx())}, + load_context, + {}, {}); add(aux_reg_input, reg_xoff); - uni_vmovups(vmm_src01, ptr[aux_reg_input + src_c_off]); + + load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src01.getIdx())}, + load_context, + {}, {}); add(aux_reg_input, reg_yoff); - uni_vmovups(vmm_src11, ptr[aux_reg_input + src_c_off]); + load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src11.getIdx())}, + load_context, + {}, {}); sub(aux_reg_input, reg_xoff); - uni_vmovups(vmm_src10, ptr[aux_reg_input + src_c_off]); + + load_emitter->emit_code({static_cast(aux_reg_input.getIdx())}, {static_cast(vmm_src10.getIdx())}, + load_context, + {}, {}); uni_vsubps(vmm_src01, vmm_src01, vmm_src00); uni_vfmadd213ps(vmm_src01, vmm_xf, vmm_src00); @@ -201,15 +245,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi uni_vsubps(vmm_src11, vmm_src11, vmm_src01); uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01); - int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float); - uni_vmovups(ptr[reg_output + dst_c_off], vmm_src11); + int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; + + store_emitter->emit_code({static_cast(vmm_src11.getIdx())}, {static_cast(reg_output.getIdx())}, + std::make_shared(Precision::FP32, jpp_.dst_prc, step, dst_c_off), + {}, {}); } } void empty_roi(int c_blocks) { uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size; for (int i = 0; i < c_blocks; i++) { - uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_zero); + store_emitter->emit_code({static_cast(vmm_zero.getIdx())}, {static_cast(reg_output.getIdx())}, + std::make_shared(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off), + {store_pool_vec_idxs}, {store_pool_gpr_idxs}); } } @@ -226,8 +277,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi roi_pool_bilinear(c_blocks); if (isa == cpu::x64::sse41) { - add(reg_input, 4 * sizeof(float)); - add(reg_output, 4 * sizeof(float)); + add(reg_input, 4 * jpp_.src_data_size); + add(reg_output, 4 * jpp_.dst_data_size); if (jpp_.alg == ROIPoolingOpType::Max) roi_pool_max(c_blocks); @@ -239,7 +290,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi L(empty_roi_label); empty_roi(c_blocks); if (isa == cpu::x64::sse41) { - add(reg_output, 4 * sizeof(float)); + add(reg_output, 4 * jpp_.dst_data_size); empty_roi(c_blocks); } @@ -300,6 +351,18 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; + Precision runtimePrecision = getCnnLayer()->insData[0].lock()->getPrecision(); + + if (!mayiuse(avx512_core)) { + if (runtimePrecision == Precision::BF16) + runtimePrecision = Precision::FP32; + } + + auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(runtimePrecision); + + src_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType); + dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType); + InferenceEngine::LayerConfig config; config.dynBatchSupport = false; config.inConfs.resize(2); @@ -325,9 +388,9 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() { impl_type = impl_desc_type::ref; } - config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, format); - config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, memory::format_tag::nc); - config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), memory::data_type::f32, format); + config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), dataType, format); + config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), dataType, memory::format_tag::nc); + config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), dataType, format); supportedPrimitiveDescriptors.push_back({config, impl_type, format}); } @@ -358,6 +421,12 @@ void MKLDNNROIPoolingNode::createPrimitive() { jpp.nb_c_blocking = mayiuse(cpu::x64::avx512_common) ? 15 : 7; + auto selectedPD = getSelectedPrimitiveDescriptor(); + jpp.src_prc = selectedPD->getConfig().inConfs[0].desc.getPrecision(); + jpp.dst_prc = selectedPD->getConfig().outConfs[0].desc.getPrecision(); + jpp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jpp.src_prc)); + jpp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jpp.dst_prc)); + jpp.alg = opType; if (mayiuse(cpu::x64::avx512_common)) { @@ -372,14 +441,15 @@ void MKLDNNROIPoolingNode::createPrimitive() { roi_pooling_kernel->create_ker(); } -void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { +template +void MKLDNNROIPoolingNode::execute() { auto &srcMemory0 = getParentEdgeAt(0)->getMemory(); auto &srcMemory1 = getParentEdgeAt(1)->getMemory(); - auto &dstMemory = getChildEdgeAt(0)->getMemory(); + auto &dstMemory = getChildEdgeAt(0)->getMemory(); - const auto *src_data = reinterpret_cast(srcMemory0.GetPtr()); - const auto *src_roi = reinterpret_cast(srcMemory1.GetPtr()); - float *dst = reinterpret_cast(dstMemory.GetPtr()); + const auto *src_data = reinterpret_cast(srcMemory0.GetPtr()); + const auto *src_roi = reinterpret_cast(srcMemory1.GetPtr()); + auto *dst = reinterpret_cast(dstMemory.GetPtr()); auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor(); if (!selectedPrimitiveDescriptor) @@ -388,16 +458,16 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { auto src_strides = config.inConfs[0].desc.getBlockingDesc().getStrides(); auto dst_strides = config.outConfs[0].desc.getBlockingDesc().getStrides(); + size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0]; int cb_work = impl::utils::div_up(jpp.nb_c, jpp.nb_c_blocking); int MB = jpp.mb; - size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0]; int real_rois = 0; for (; real_rois < MB; real_rois++) { size_t roi_off = real_rois * src_roi_step; - const float *src_roi_ptr = &src_roi[roi_off]; + const T* src_roi_ptr = &src_roi[roi_off]; int roi_batch_ind = static_cast(src_roi_ptr[0]); if (roi_batch_ind == -1) { break; @@ -426,7 +496,7 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { (*roi_pooling_kernel)(&arg); } else { size_t roi_off = n * src_roi_step; - const float* src_roi_ptr = &src_roi[roi_off]; + const T* src_roi_ptr = &src_roi[roi_off]; int roi_batch_ind = static_cast(src_roi_ptr[0]); @@ -480,7 +550,7 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { } else { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - float batch_data = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + + T batch_data = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + h * src_strides[2] + w * src_strides[3] + c]; if (batch_data > dst[pool_index]) { @@ -492,17 +562,17 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { } } } else { - float roi_start_w_ = src_roi_ptr[1]; - float roi_start_h_ = src_roi_ptr[2]; - float roi_end_w_ = src_roi_ptr[3]; - float roi_end_h_ = src_roi_ptr[4]; + T roi_start_w_ = src_roi_ptr[1]; + T roi_start_h_ = src_roi_ptr[2]; + T roi_end_w_ = src_roi_ptr[3]; + T roi_end_h_ = src_roi_ptr[4]; - float height_scale = (jpp.pooled_h > 1 ? ((roi_end_h_ - roi_start_h_) * (jpp.ih - 1)) / (jpp.pooled_h - 1) : 0); - float width_scale = (jpp.pooled_w > 1 ? ((roi_end_w_ - roi_start_w_) * (jpp.iw - 1)) / (jpp.pooled_w - 1) : 0); + T height_scale = (jpp.pooled_h > 1 ? ((roi_end_h_ - roi_start_h_) * (jpp.ih - 1)) / (jpp.pooled_h - 1) : 0); + T width_scale = (jpp.pooled_w > 1 ? ((roi_end_w_ - roi_start_w_) * (jpp.iw - 1)) / (jpp.pooled_w - 1) : 0); - float in_y = (jpp.pooled_h > 1 ? (oh * height_scale + roi_start_h_ * (jpp.ih - 1)) : + T in_y = (jpp.pooled_h > 1 ? (oh * height_scale + roi_start_h_ * (jpp.ih - 1)) : 0.5 * (roi_start_h_ + roi_end_h_) * (jpp.ih - 1)); - float in_x = (jpp.pooled_w > 1 ? (ow * width_scale + roi_start_w_ * (jpp.iw - 1)) : + T in_x = (jpp.pooled_w > 1 ? (ow * width_scale + roi_start_w_ * (jpp.iw - 1)) : 0.5 * (roi_start_w_ + roi_end_w_) * (jpp.iw - 1)); if (in_y < 0 || in_y > jpp.ih - 1 || in_x < 0 || in_x > jpp.iw - 1) { @@ -532,28 +602,29 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { arg.xf = in_x - left_x_index; arg.yf = in_y - top_y_index; - arg.xoff = sizeof(float) * (right_x_index - left_x_index) * jpp.c_block; - arg.yoff = sizeof(float) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block; + arg.xoff = sizeof(T) * (right_x_index - left_x_index) * jpp.c_block; + arg.yoff = sizeof(T) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block; arg.src = &src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + top_y_index * src_strides[2] + left_x_index * src_strides[3]]; + arg.bin_area = 1; } else { for (int c = 0; c < 1; c++) { - const float top_left = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + + const T top_left = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + top_y_index * src_strides[2] + left_x_index * src_strides[3] + c]; - const float top_right = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + + const T top_right = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + top_y_index * src_strides[2] + right_x_index * src_strides[3] + c]; - const float bottom_left = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + + const T bottom_left = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + bottom_y_index * src_strides[2] + left_x_index * src_strides[3] + c]; - const float bottom_right = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + + const T bottom_right = src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] + bottom_y_index * src_strides[2] + right_x_index * src_strides[3] + c]; - const float top = top_left + (top_right - top_left) * (in_x - left_x_index); - const float bottom = bottom_left + (bottom_right - bottom_left) * (in_x - left_x_index); + const T top = top_left + (top_right - top_left) * (in_x - left_x_index); + const T bottom = bottom_left + (bottom_right - bottom_left) * (in_x - left_x_index); dst[n * dst_strides[0] + cb * dst_strides[1] + oh * dst_strides[2] + ow * dst_strides[3] + c] = - top + (bottom - top) * (in_y - top_y_index); + top + (bottom - top) * (in_y - top_y_index); } } } @@ -566,6 +637,30 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { }); } +namespace { +struct ROIPoolingContext { + MKLDNNROIPoolingNode &node; +}; +} + +template +struct MKLDNNROIPoolingNode::ROIPoolingExecute { + // using dataT = typename T::type; + + void operator()(ROIPoolingContext & ctx) { + ctx.node.execute(); + } +}; + +void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) { + ROIPoolingContext ctx = { + *this + }; + // enable conditional compilation + OV_SWITCH(MKLDNNPlugin, ROIPoolingExecute, ctx, runtimePrecision, + OV_CASE(Precision::FP32, float), + OV_CASE(Precision::BF16, bfloat16_t)) +} bool MKLDNNROIPoolingNode::created() const { return getType() == ROIPooling; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.h index f3b19aa2328aa2..5be21ce6b851d1 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.h @@ -27,12 +27,17 @@ struct jit_roi_pooling_params { int pooled_h; int pooled_w; + InferenceEngine::Precision src_prc; + InferenceEngine::Precision dst_prc; + int src_data_size; + int dst_data_size; + ROIPoolingOpType alg; }; struct jit_roi_pooling_call_args { - const float *src; - float *dst; + const void *src; + void *dst; size_t kh; size_t kw; @@ -73,8 +78,15 @@ class MKLDNNROIPoolingNode : public MKLDNNNode { void createPrimitive() override; void execute(mkldnn::stream strm) override; bool created() const override; + template void execute(); + template struct ROIPoolingExecute; private: + size_t src_data_size; + size_t dst_data_size; + + InferenceEngine::Precision runtimePrecision; + int pooled_h = 0; int pooled_w = 0; float spatial_scale = 0; diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp new file mode 100644 index 00000000000000..366697ecad6257 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/roi_pooling.cpp @@ -0,0 +1,162 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include +#include "test_utils/cpu_test_utils.hpp" + +using namespace InferenceEngine; +using namespace CPUTestUtils; + +namespace CPULayerTestsDefinitions { + +typedef std::tuple ROIPoolingCPUTestParamsSet; + +class ROIPoolingCPULayerTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon, + public CPUTestsBase { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + LayerTestsDefinitions::roiPoolingParamsTuple basicParamsSet; + CPUSpecificParams cpuParams; + std::vector inputShape; + std::vector coordsShape; + std::vector poolShape; + float spatial_scale; + ngraph::helpers::ROIPoolingTypes pool_method; + InferenceEngine::Precision netPrecision; + std::string targetDevice; + + std::tie(basicParamsSet, cpuParams) = obj.param; + std::tie(inputShape, coordsShape, poolShape, spatial_scale, pool_method, netPrecision, targetDevice) = basicParamsSet; + + std::ostringstream result; + + result << LayerTestsDefinitions::ROIPoolingLayerTest::getTestCaseName( + testing::TestParamInfo(basicParamsSet, 0)); + result << CPUTestsBase::getTestCaseName(cpuParams); + + return result.str(); + } + +protected: + void GenerateInputs() { + auto feat_map_shape = cnnNetwork.getInputShapes().begin()->second; + + const auto is_roi_max_mode = (pool_method == ngraph::helpers::ROIPoolingTypes::ROI_MAX); + + const int height = is_roi_max_mode ? feat_map_shape[2] / spatial_scale : 1; + const int width = is_roi_max_mode ? feat_map_shape[3] / spatial_scale : 1; + + size_t it = 0; + for (const auto &input : cnnNetwork.getInputsInfo()) { + const auto &info = input.second; + InferenceEngine::Blob::Ptr blob; + + if (it == 1) { + blob = make_blob_with_precision(info->getTensorDesc()); + blob->allocate(); + CommonTestUtils::fill_data_roi(blob->buffer(), blob->size(), feat_map_shape[0] - 1, height, width, 1.0f, is_roi_max_mode); + } else { + blob = GenerateInput(*info); + } + inputs.push_back(blob); + it++; + } + } + + void SetUp() { + LayerTestsDefinitions::roiPoolingParamsTuple basicParamsSet; + CPUSpecificParams cpuParams; + InferenceEngine::SizeVector inputShape; + InferenceEngine::SizeVector coordsShape; + InferenceEngine::SizeVector poolShape; + InferenceEngine::Precision netPrecision; + + // threshold = 0.08f; + + std::tie(basicParamsSet, cpuParams) = this->GetParam(); + std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; + std::tie(inputShape, coordsShape, poolShape, spatial_scale, pool_method, netPrecision, targetDevice) = basicParamsSet; + + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShape, coordsShape}); + auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)); + + std::shared_ptr roi_pooling = ngraph::builder::makeROIPooling(paramOuts[0], paramOuts[1], poolShape, spatial_scale, pool_method); + ngraph::ResultVector results{std::make_shared(roi_pooling)}; + + function = makeNgraphFunction(ngPrc, params, roi_pooling, "roi_pooling"); + + selectedType = getPrimitiveType() + "_FP32"; + } + +private: + ngraph::helpers::ROIPoolingTypes pool_method; + float spatial_scale; +}; + +TEST_P(ROIPoolingCPULayerTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + Run(); + CheckPluginRelatedResults(executableNetwork, "ROIPooling"); +} + +namespace { + +std::vector filterCPUInfoForDevice() { + std::vector resCPUParams; + if (with_cpu_x86_avx512f()) { + resCPUParams.push_back(CPUSpecificParams{{nChw16c, nc}, {nChw16c}, {"jit_avx512"}, "jit_avx512"}); + } else if (with_cpu_x86_avx2()) { + resCPUParams.push_back(CPUSpecificParams{{nChw8c, nc}, {nChw8c}, {"jit_avx2"}, "jit_avx2"}); + } else if (with_cpu_x86_sse42()) { + resCPUParams.push_back(CPUSpecificParams{{nChw8c, nc}, {nChw8c}, {"jit_sse42"}, "jit_sse42"}); + } else { + resCPUParams.push_back(CPUSpecificParams{{nChw8c, nc}, {nChw8c}, {"ref"}, "ref"}); + } + return resCPUParams; +} + +const std::vector> inShapes = {{1, 3, 8, 8}, {3, 4, 50, 50}}; + +const std::vector> pooledShapes_max = {{1, 1}, {2, 2}, {3, 3}, {6, 6}}; + +const std::vector> pooledShapes_bilinear = {{1, 1}, {2, 2}, {3, 3}, {6, 6}}; + +const std::vector> coordShapes = {{1, 5}}; + // {3, 5}, + // {5, 5}}; + +const std::vector netPRCs = {InferenceEngine::Precision::BF16, InferenceEngine::Precision::FP32}; + +const std::vector spatial_scales = {0.625f, 1.f}; + +const auto test_ROIPooling_max = ::testing::Combine(::testing::ValuesIn(inShapes), + ::testing::ValuesIn(coordShapes), + ::testing::ValuesIn(pooledShapes_max), + ::testing::ValuesIn(spatial_scales), + ::testing::Values(ngraph::helpers::ROIPoolingTypes::ROI_MAX), + ::testing::ValuesIn(netPRCs), + ::testing::Values(CommonTestUtils::DEVICE_CPU)); + +const auto test_ROIPooling_bilinear = ::testing::Combine(::testing::ValuesIn(inShapes), + ::testing::ValuesIn(coordShapes), + ::testing::ValuesIn(pooledShapes_bilinear), + ::testing::Values(spatial_scales[1]), + ::testing::Values(ngraph::helpers::ROIPoolingTypes::ROI_BILINEAR), + ::testing::ValuesIn(netPRCs), + ::testing::Values(CommonTestUtils::DEVICE_CPU)); + +INSTANTIATE_TEST_CASE_P(smoke_ROIPoolingCPU_max, + ROIPoolingCPULayerTest, + ::testing::Combine(test_ROIPooling_max, + ::testing::ValuesIn(filterCPUInfoForDevice())), + ROIPoolingCPULayerTest::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_ROIPoolingCPU_bilinear, + ROIPoolingCPULayerTest, + ::testing::Combine(test_ROIPooling_bilinear, + ::testing::ValuesIn(filterCPUInfoForDevice())), + ROIPoolingCPULayerTest::getTestCaseName); +} // namespace +} // namespace CPULayerTestsDefinitions