Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Add bf16 support for ROI pooling #5187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 140 additions & 38 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_roi_pooling_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
#include "mkldnn_roi_pooling_node.h"

#include <mkldnn.hpp>
#include <string>
#include <vector>
#include <math.h>
#include <mkldnn_extension_utils.h>
#include <cpu/x64/jit_generator.hpp>
#include "ie_parallel.hpp"
#include <mkldnn_selective_build.h>

#include <ngraph/opsets/opset2.hpp>

#include "ie_parallel.hpp"
#include "utils/bfloat16.hpp"
#include "emitters/jit_load_store_emitters.hpp"

#include <cpu/x64/jit_generator.hpp>

#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <cmath>

using namespace MKLDNNPlugin;
using namespace InferenceEngine;
using namespace mkldnn;
Expand All @@ -25,7 +34,7 @@ using namespace Xbyak;

template <cpu_isa_t isa>
struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32)
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_roi_pooling_kernel_f32);

explicit jit_uni_roi_pooling_kernel_f32(jit_roi_pooling_params jcp) : jit_uni_roi_pooling_kernel(jcp), jit_generator() {}

Expand All @@ -35,14 +44,16 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
};

void generate() override {
load_emitter.reset(new jit_load_emitter(this, isa, nullptr));
store_emitter.reset(new jit_store_emitter(this, isa, nullptr));

this->preamble();

Label exit_label;
Label tail_label;

mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);

mov(reg_bin_area, ptr[this->param1 + GET_OFF(bin_area)]);
mov(reg_c_blocks, ptr[this->param1 + GET_OFF(c_blocks)]);

Expand All @@ -56,6 +67,10 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
mov(reg_xoff, ptr[this->param1 + GET_OFF(xoff)]);
}

load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx())};

int nb_c_tail = jpp_.nb_c % jpp_.nb_c_blocking;
cmp(reg_c_blocks, jpp_.nb_c_blocking);
jne(nb_c_tail ? tail_label : exit_label, T_NEAR);
Expand All @@ -71,13 +86,18 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
L(exit_label);

this->postamble();

load_emitter->emit_data();
if (!mayiuse(avx512_core_bf16) && mayiuse(avx512_core) && store_emitter != nullptr && store_emitter->get_emu_vcvtneps2bf16() != nullptr)
store_emitter->get_emu_vcvtneps2bf16()->emit_data();
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
}

private:
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2,
Xbyak::Ymm, Xbyak::Zmm>::type;

const int vlen = cpu_isa_traits<isa>::vlen;
const int step = vlen / sizeof(float);

Vmm vmm_mask = Vmm(0);
Vmm vmm_zero = Vmm(0);
Expand All @@ -87,6 +107,13 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Xmm xmm_xf = Xmm(1);
Vmm vmm_xf = Vmm(1);

std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
std::vector<size_t> load_pool_gpr_idxs;

std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
std::vector<size_t> store_pool_gpr_idxs;
std::vector<size_t> store_pool_vec_idxs;

Vmm get_acc_reg(int idx) { return Vmm(2*idx + 1); }
Vmm get_src_reg(int idx) { return Vmm(2*idx + 2); }

Expand All @@ -102,8 +129,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
reg64_t reg_kh = r10;
reg64_t reg_kw = r11;

reg64_t h_iter = r14;
reg64_t w_iter = r15;
reg64_t h_iter = r13;
reg64_t w_iter = r14;

reg64_t reg_c_blocks = rbx;
reg64_t reg_bin_area = rdx;
Expand All @@ -114,15 +141,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
reg64_t reg_yoff = h_iter;
reg64_t reg_xoff = r12;

Xbyak::Reg64 reg_load_table = r15;
Xbyak::Reg64 reg_load_store_mask = abi_param1;

void roi_pool_max(int c_blocks) {
Label h_loop_label;
Label w_loop_label;

mov(aux_reg_input, reg_input);

const int src_c_off = jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_max = get_acc_reg(i);
uni_vmovups(vmm_max, ptr[reg_input + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]);

load_emitter->emit_code({static_cast<size_t>(reg_input.getIdx())}, {static_cast<size_t>(vmm_max.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
{}, load_pool_gpr_idxs);
}

xor_(h_iter, h_iter);
Expand All @@ -134,7 +168,10 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Vmm vmm_max = get_acc_reg(i);
Vmm vmm_src = get_src_reg(i);

uni_vmovups(vmm_src, ptr[aux_reg_input1 + i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float)]);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input1.getIdx())}, {static_cast<size_t>(vmm_src.getIdx())},
std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", i * src_c_off),
{}, load_pool_gpr_idxs);

if (isa == cpu::x64::sse41) {
movups(vmm_mask, vmm_max);
cmpps(vmm_mask, vmm_src, _cmp_lt_os);
Expand All @@ -148,23 +185,27 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
}
}

add(aux_reg_input1, jpp_.c_block * sizeof(float));
add(aux_reg_input1, jpp_.c_block * jpp_.src_data_size);

inc(w_iter);
cmp(w_iter, reg_kw);
jl(w_loop_label, T_NEAR);
}

add(aux_reg_input, jpp_.iw * jpp_.c_block * sizeof(float));
add(aux_reg_input, jpp_.iw * jpp_.c_block * jpp_.src_data_size);

inc(h_iter);
cmp(h_iter, reg_kh);
jl(h_loop_label, T_NEAR);
}

const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
for (int i = 0; i < c_blocks; i++) {
Vmm vmm_dst = get_acc_reg(i);
uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_dst);

store_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

Expand All @@ -180,17 +221,29 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
Vmm vmm_src11 = get_src_reg(3);

for (int i = 0; i < c_blocks; i++) {
int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * sizeof(float);
const int src_c_off = i * jpp_.ih * jpp_.iw * jpp_.c_block * jpp_.src_data_size;
const auto load_context = std::make_shared<load_emitter_context>(jpp_.src_prc, Precision::FP32, step, false, "zero", src_c_off);

mov(aux_reg_input, reg_input);
uni_vmovups(vmm_src00, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src00.getIdx())},
load_context,
{}, load_pool_gpr_idxs);
add(aux_reg_input, reg_xoff);
uni_vmovups(vmm_src01, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src01.getIdx())},
load_context,
{}, load_pool_gpr_idxs);

add(aux_reg_input, reg_yoff);
uni_vmovups(vmm_src11, ptr[aux_reg_input + src_c_off]);
load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src11.getIdx())},
load_context,
{}, load_pool_gpr_idxs);
sub(aux_reg_input, reg_xoff);
uni_vmovups(vmm_src10, ptr[aux_reg_input + src_c_off]);

load_emitter->emit_code({static_cast<size_t>(aux_reg_input.getIdx())}, {static_cast<size_t>(vmm_src10.getIdx())},
load_context,
{}, load_pool_gpr_idxs);

uni_vsubps(vmm_src01, vmm_src01, vmm_src00);
uni_vfmadd213ps(vmm_src01, vmm_xf, vmm_src00);
Expand All @@ -201,15 +254,22 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
uni_vsubps(vmm_src11, vmm_src11, vmm_src01);
uni_vfmadd213ps(vmm_src11, vmm_yf, vmm_src01);

int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float);
uni_vmovups(ptr[reg_output + dst_c_off], vmm_src11);
const int dst_c_off = i * jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;

store_emitter->emit_code({static_cast<size_t>(vmm_src11.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(Precision::FP32, jpp_.dst_prc, step, dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

void empty_roi(int c_blocks) {
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);

const int dst_c_off = jpp_.oh * jpp_.ow * jpp_.c_block * jpp_.dst_data_size;
for (int i = 0; i < c_blocks; i++) {
uni_vmovups(ptr[reg_output + i * jpp_.oh * jpp_.ow * jpp_.c_block * sizeof(float)], vmm_zero);
store_emitter->emit_code({static_cast<size_t>(vmm_zero.getIdx())}, {static_cast<size_t>(reg_output.getIdx())},
std::make_shared<store_emitter_context>(jpp_.src_prc, jpp_.dst_prc, step, i * dst_c_off),
store_pool_vec_idxs, store_pool_gpr_idxs);
}
}

Expand All @@ -226,8 +286,8 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
roi_pool_bilinear(c_blocks);

if (isa == cpu::x64::sse41) {
add(reg_input, 4 * sizeof(float));
add(reg_output, 4 * sizeof(float));
add(reg_input, 4 * jpp_.src_data_size);
add(reg_output, 4 * jpp_.dst_data_size);

if (jpp_.alg == Algorithm::ROIPoolingMax)
roi_pool_max(c_blocks);
Expand All @@ -239,7 +299,7 @@ struct jit_uni_roi_pooling_kernel_f32 : public jit_uni_roi_pooling_kernel, publi
L(empty_roi_label);
empty_roi(c_blocks);
if (isa == cpu::x64::sse41) {
add(reg_output, 4 * sizeof(float));
add(reg_output, 4 * jpp_.dst_data_size);
empty_roi(c_blocks);
}

Expand Down Expand Up @@ -317,6 +377,18 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

runtimePrecision = getOriginalInputPrecisionAtPort(0);

if (!mayiuse(avx512_core)) {
if (runtimePrecision == Precision::BF16)
runtimePrecision = Precision::FP32;
}

auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(runtimePrecision);

src_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType);
dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(dataType);

InferenceEngine::LayerConfig config;
config.dynBatchSupport = false;
config.inConfs.resize(2);
Expand All @@ -342,9 +414,9 @@ void MKLDNNROIPoolingNode::initSupportedPrimitiveDescriptors() {
impl_type = impl_desc_type::ref;
}

config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), memory::data_type::f32, format);
config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), memory::data_type::f32, memory::format_tag::nc);
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), memory::data_type::f32, format);
config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), dataType, format);
config.inConfs[1].desc = MKLDNNMemoryDesc(getParentEdgeAt(1)->getDims(), dataType, memory::format_tag::nc);
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), dataType, format);
supportedPrimitiveDescriptors.push_back({config, impl_type, format});
}

Expand Down Expand Up @@ -375,6 +447,12 @@ void MKLDNNROIPoolingNode::createPrimitive() {

jpp.nb_c_blocking = mayiuse(cpu::x64::avx512_common) ? 15 : 7;

auto selectedPD = getSelectedPrimitiveDescriptor();
jpp.src_prc = selectedPD->getConfig().inConfs[0].desc.getPrecision();
jpp.dst_prc = selectedPD->getConfig().outConfs[0].desc.getPrecision();
jpp.src_data_size = jpp.src_prc.size();
jpp.dst_data_size = jpp.dst_prc.size();

jpp.alg = getAlgorithm();

if (mayiuse(cpu::x64::avx512_common)) {
Expand All @@ -389,14 +467,15 @@ void MKLDNNROIPoolingNode::createPrimitive() {
roi_pooling_kernel->create_ker();
}

void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
template<typename T>
void MKLDNNROIPoolingNode::execute() {
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
auto &srcMemory0 = getParentEdgeAt(0)->getMemory();
auto &srcMemory1 = getParentEdgeAt(1)->getMemory();
auto &dstMemory = getChildEdgeAt(0)->getMemory();
auto &dstMemory = getChildEdgeAt(0)->getMemory();

const auto *src_data = reinterpret_cast<const float *>(srcMemory0.GetPtr());
const auto *src_roi = reinterpret_cast<const float *>(srcMemory1.GetPtr());
float *dst = reinterpret_cast<float *>(dstMemory.GetPtr());
const auto *src_data = reinterpret_cast<const T*>(srcMemory0.GetPtr());
const auto *src_roi = reinterpret_cast<const T*>(srcMemory1.GetPtr());
auto *dst = reinterpret_cast<T*>(dstMemory.GetPtr());
Comment on lines +476 to +478
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT. Just use auto only since the type is completely described in the reinterpret_cast template argument.


auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor();
if (!selectedPrimitiveDescriptor)
Expand All @@ -405,16 +484,16 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {

auto src_strides = config.inConfs[0].desc.getBlockingDesc().getStrides();
auto dst_strides = config.outConfs[0].desc.getBlockingDesc().getStrides();
size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0];

int cb_work = impl::utils::div_up(jpp.nb_c, jpp.nb_c_blocking);
int MB = jpp.mb;

size_t src_roi_step = config.inConfs[1].desc.getBlockingDesc().getStrides()[0];
int real_rois = 0;
for (; real_rois < MB; real_rois++) {
size_t roi_off = real_rois * src_roi_step;

const float *src_roi_ptr = &src_roi[roi_off];
const auto *src_roi_ptr = &src_roi[roi_off];
int roi_batch_ind = static_cast<int>(src_roi_ptr[0]);
if (roi_batch_ind == -1) {
break;
Expand Down Expand Up @@ -443,7 +522,7 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
(*roi_pooling_kernel)(&arg);
} else {
size_t roi_off = n * src_roi_step;
const float* src_roi_ptr = &src_roi[roi_off];
const auto *src_roi_ptr = &src_roi[roi_off];

int roi_batch_ind = static_cast<int>(src_roi_ptr[0]);

Expand Down Expand Up @@ -549,11 +628,12 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
arg.xf = in_x - left_x_index;
arg.yf = in_y - top_y_index;

arg.xoff = sizeof(float) * (right_x_index - left_x_index) * jpp.c_block;
arg.yoff = sizeof(float) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block;
arg.xoff = sizeof(T) * (right_x_index - left_x_index) * jpp.c_block;
arg.yoff = sizeof(T) * (bottom_y_index - top_y_index) * jpp.iw * jpp.c_block;

arg.src = &src_data[roi_batch_ind * src_strides[0] + cb * src_strides[1] +
top_y_index * src_strides[2] + left_x_index * src_strides[3]];

arg.bin_area = 1;
} else {
for (int c = 0; c < 1; c++) {
Expand Down Expand Up @@ -583,6 +663,28 @@ void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
});
}

namespace {
struct ROIPoolingContext {
MKLDNNROIPoolingNode &node;
};
}

template<typename T>
struct MKLDNNROIPoolingNode::ROIPoolingExecute {
void operator()(ROIPoolingContext & ctx) {
ctx.node.execute<T>();
}
};

void MKLDNNROIPoolingNode::execute(mkldnn::stream strm) {
ROIPoolingContext ctx = {
*this
};
// enable conditional compilation
OV_SWITCH(MKLDNNPlugin, ROIPoolingExecute, ctx, runtimePrecision,
OV_CASE(Precision::FP32, float),
OV_CASE(Precision::BF16, bfloat16_t))
}

bool MKLDNNROIPoolingNode::created() const {
return getType() == ROIPooling;
Expand Down
Loading