Skip to content

Commit

Permalink
[CPU] Add bf16 support for ROI pooling (#5187)
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky authored May 25, 2021
1 parent c3ca8d0 commit bae7f4b
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 44 deletions.
178 changes: 140 additions & 38 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
#include "mkldnn_roi_pooling_node.h"

#include <mkldnn.hpp>
#include <string>
#include <vector>
#include <math.h>
#include <mkldnn_extension_utils.h>
#include <cpu/x64/jit_generator.hpp>
#include "ie_parallel.hpp"
#include <mkldnn_selective_build.h>

#include <ngraph/opsets/opset2.hpp>

#include "ie_parallel.hpp"
#include "utils/bfloat16.hpp"
#include "emitters/jit_load_store_emitters.hpp"

#include <cpu/x64/jit_generator.hpp>

#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <cmath>

using namespace MKLDNNPlugin;
using namespace InferenceEngine;
using namespace mkldnn;
Expand All @@ -25,7 +34,7 @@ using namespace Xbyak;

template <cpu_isa_t isa>
struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32)
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32);

explicit jit_uni_roi_pooling_kernel_f32(jit_roi_pooling_params jcp) : jit_uni_roi_pooling_kernel(jcp), jit_generator() {}

Expand All @@ -35,14 +44,16 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
};

void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa, nullptr));
store_emitter.reset(new jit_store_emitter(this, isa, nullptr));

this->preamble();

Label exit_label;
Label tail_label;

mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);

mov(reg_bin_area, ptr[this->param1 + GET_OFF(bin_area)]);
mov(reg_c_blocks, ptr[this->param1 + GET_OFF(c_blocks)]);

Expand All @@ -56,6 +67,10 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
mov(reg_xoff, ptr[this->param1 + GET_OFF(xoff)]);
}

load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx())};

int nb_c_tail = jpp_.nb_c % jpp_.nb_c_blocking;
cmp(reg_c_blocks, jpp_.nb_c_blocking);
jne(nb_c_tail ? tail_label : exit_label, T_NEAR);
Expand All @@ -71,13 +86,18 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
L(exit_label);

this->postamble();

load_emitter->emit_data();
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
}

private:
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2,
Xbyak::Ymm, Xbyak::Zmm>::type;

const int vlen = cpu_isa_traits<isa>::vlen;
const int step = vlen / sizeof(float);

Vmm vmm_mask = Vmm(0);
Vmm vmm_zero = Vmm(0);
Expand All @@ -87,6 +107,13 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Xmm xmm_xf = Xmm(1);
Vmm vmm_xf = Vmm(1);

std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::vector<size_t> load_pool_gpr_idxs;

std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;

Vmm get_acc_reg(int idx) { return Vmm(2*idx + 1); }
Vmm get_src_reg(int idx) { return Vmm(2*idx + 2); }

Expand All @@ -102,8 +129,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
reg64_t reg_kh = r10;
reg64_t reg_kw = r11;

reg64_t h_iter = r14;
reg64_t w_iter = r15;
reg64_t h_iter = r13;
reg64_t w_iter = r14;

reg64_t reg_c_blocks = rbx;
reg64_t reg_bin_area = rdx;
Expand All @@ -114,15 +141,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
reg64_t reg_yoff = h_iter;
reg64_t reg_xoff = r12;

Xbyak::Reg64 reg_load_table = r15;
Xbyak::Reg64 reg_load_store_mask = abi_param1;

void roi_pool_max(int c_blocks) {
Label h_loop_label;
Label w_loop_label;

mov(aux_reg_input, reg_input);

const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_max = get_acc_reg(i);
uni_vmovups(vmm_max, ptr[reg_input + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]);

load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
{}, load_pool_gpr_idxs);
}

xor_(h_iter, h_iter);
Expand All @@ -134,7 +168,10 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Vmm vmm_max = get_acc_reg(i);
Vmm vmm_src = get_src_reg(i);

uni_vmovups(vmm_src, ptr[aux_reg_input1 + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
{}, load_pool_gpr_idxs);

if (isa == cpu::x64::sse41) {
movups(vmm_mask, vmm_max);
cmpps(vmm_mask, vmm_src, _cmp_lt_os);
Expand All @@ -148,23 +185,27 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
}
}

add(aux_reg_input1, jpp_.c_block * sizeof(float));
add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size);

inc(w_iter);
cmp(w_iter, reg_kw);
jl(w_loop_label, T_NEAR);
}

add(aux_reg_input, jpp_.iw * jpp_.c_block * sizeof(float));
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size);

inc(h_iter);
cmp(h_iter, reg_kh);
jl(h_loop_label, T_NEAR);
}

const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_dst = get_acc_reg(i);
uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_dst);

store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

Expand All @@ -180,17 +221,29 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Vmm vmm_src11 = get_src_reg(3);

for (int i = 0; i < c_blocks; i++) {
int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float);
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off);

mov(aux_reg_input, reg_input);
uni_vmovups(vmm_src00, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src00.getIdx())},
load_context,
{}, load_pool_gpr_idxs);
add(aux_reg_input, reg_xoff);
uni_vmovups(vmm_src01, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src01.getIdx())},
load_context,
{}, load_pool_gpr_idxs);

add(aux_reg_input, reg_yoff);
uni_vmovups(vmm_src11, ptr[aux_reg_input + src_c_off]);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src11.getIdx())},
load_context,
{}, load_pool_gpr_idxs);
sub(aux_reg_input, reg_xoff);
uni_vmovups(vmm_src10, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src10.getIdx())},
load_context,
{}, load_pool_gpr_idxs);

uni_vsubps(vmm_src01, vmm_src01, vmm_src00);
uni_vfmadd213ps(vmm_src01, vmm_xf, vmm_src00);
Expand All @@ -201,15 +254,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);

int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float);
uni_vmovups(ptr[reg_output + dst_c_off], vmm_src11);
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;

store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

void empty_roi(int c_blocks) {
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);

const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
for (int i = 0; i < c_blocks; i++) {
uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_zero);
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

Expand All @@ -226,8 +286,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
roi_pool_bilinear(c_blocks);

if (isa == cpu::x64::sse41) {
add(reg_input, 4 * sizeof(float));
add(reg_output, 4 * sizeof(float));
add(reg_input, 4 * jpp_.src_data_size);
add(reg_output, 4 * jpp_.dst_data_size);

if (jpp_.alg == Algorithm::ROIPoolingMax)
roi_pool_max(c_blocks);
Expand All @@ -239,7 +299,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
L(empty_roi_label);
empty_roi(c_blocks);
if (isa == cpu::x64::sse41) {
add(reg_output, 4 * sizeof(float));
add(reg_output, 4 * jpp_.dst_data_size);
empty_roi(c_blocks);
}

Expand Down Expand Up @@ -317,6 +377,18 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

runtimePrecision = getOriginalInputPrecisionAtPort(0);

if (!mayiuse(avx512_core)) {
if (runtimePrecision == Precision::BF16)
runtimePrecision = Precision::FP32;
}

auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(runtimePrecision);

src_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType);
dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType);

InferenceEngine::LayerConfig config;
config.dynBatchSupport = false;
config.inConfs.resize(2);
Expand All @@ -342,9 +414,9 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
impl_type = impl_desc_type::ref;
}

config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, format);
config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, memory::format_tag::nc);
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), memory::data_type::f32, format);
config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), dataType, format);
config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), dataType, memory::format_tag::nc);
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), dataType, format);
supportedPrimitiveDescriptors.push_back({config, impl_type, format});
}

Expand Down Expand Up @@ -375,6 +447,12 @@ void MKLDNNROIPoolingNode::createPrimitive() {

jpp.nb_c_blocking = mayiuse(cpu::x64::avx512_common) ? 15 : 7;

auto selectedPD = getSelectedPrimitiveDescriptor();
jpp.src_prc = selectedPD->getConfig().inConfs[0].desc.getPrecision();
jpp.dst_prc = selectedPD->getConfig().outConfs[0].desc.getPrecision();
jpp.src_data_size = jpp.src_prc.size();
jpp.dst_data_size = jpp.dst_prc.size();

jpp.alg = getAlgorithm();

if (mayiuse(cpu::x64::avx512_common)) {
Expand All @@ -389,14 +467,15 @@ void MKLDNNROIPoolingNode::createPrimitive() {
roi_pooling_kernel->create_ker();
}

void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
template<typename T>
void MKLDNNROIPoolingNode::execute() {
auto &srcMemory0 = getParentEdgeAt(0)->getMemory();
auto &srcMemory1 = getParentEdgeAt(1)->getMemory();
auto &dstMemory = getChildEdgeAt(0)->getMemory();
auto &dstMemory = getChildEdgeAt(0)->getMemory();

const auto *src_data = reinterpret_cast<const float *>(srcMemory0.GetPtr());
const auto *src_roi = reinterpret_cast<const float *>(srcMemory1.GetPtr());
float *dst = reinterpret_cast<float *>(dstMemory.GetPtr());
const auto *src_data = reinterpret_cast<const T*>(srcMemory0.GetPtr());
const auto *src_roi = reinterpret_cast<const T*>(srcMemory1.GetPtr());
auto *dst = reinterpret_cast<T*>(dstMemory.GetPtr());

auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor();
if (!selectedPrimitiveDescriptor)
Expand All @@ -405,16 +484,16 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {

auto src_strides = config.inConfs[0].desc.getBlockingDesc().getStrides();
auto dst_strides = config.outConfs[0].desc.getBlockingDesc().getStrides();
size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0];

int cb_work = impl::utils::div_up(jpp.nb_c, jpp.nb_c_blocking);
int MB = jpp.mb;

size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0];
int real_rois = 0;
for (; real_rois < MB; real_rois++) {
size_t roi_off = real_rois * src_roi_step;

const float *src_roi_ptr = &src_roi[roi_off];
const auto *src_roi_ptr = &src_roi[roi_off];
int roi_batch_ind = static_cast<int>(src_roi_ptr[0]);
if (roi_batch_ind == -1) {
break;
Expand Down Expand Up @@ -443,7 +522,7 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
(*roi_pooling_kernel)(&arg);
} else {
size_t roi_off = n * src_roi_step;
const float* src_roi_ptr = &src_roi[roi_off];
const auto *src_roi_ptr = &src_roi[roi_off];

int roi_batch_ind = static_cast<int>(src_roi_ptr[0]);

Expand Down Expand Up @@ -549,11 +628,12 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
arg.xf = in_x - left_x_index;
arg.yf = in_y - top_y_index;

arg.xoff = sizeof(float) * (right_x_index - left_x_index) * jpp.c_block;
arg.yoff = sizeof(float) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block;
arg.xoff = sizeof(T) * (right_x_index - left_x_index) * jpp.c_block;
arg.yoff = sizeof(T) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block;

arg.src = &src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] +
top_y_index * src_strides[2] + left_x_index * src_strides[3]];

arg.bin_area = 1;
} else {
for (int c = 0; c < 1; c++) {
Expand Down Expand Up @@ -583,6 +663,28 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
});
}

namespace {
struct ROIPoolingContext {
MKLDNNROIPoolingNode &node;
};
}

template<typename T>
struct MKLDNNROIPoolingNode::ROIPoolingExecute {
void operator()(ROIPoolingContext & ctx) {
ctx.node.execute<T>();
}
};

void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
ROIPoolingContext ctx = {
*this
};
// enable conditional compilation
OV_SWITCH(MKLDNNPlugin, ROIPoolingExecute, ctx, runtimePrecision,
OV_CASE(Precision::FP32, float),
OV_CASE(Precision::BF16, bfloat16_t))
}

bool MKLDNNROIPoolingNode::created() const {
return getType() == ROIPooling;
Expand Down
Loading

0 comments on commit bae7f4b

Please sign in to comment.