Skip to content

Commit

Permalink
[Snippets] MHA: blocking by K and N for bf16/int8 precisions (#23292)
Browse files Browse the repository at this point in the history
### Details:
- *Added brgemm blocking support for bf16 and int8 precisions: in this
case blocking loops are shared between BrgemmCopyB and BrgemmCPU nodes*
- *Reduced allocation shapes of input brgemm buffers in case of low
precision*

### Tickets:
 - *CVS-115165*
  • Loading branch information
v-Golubev authored Apr 25, 2024
1 parent 1de6329 commit 6f0e530
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 207 deletions.
6 changes: 4 additions & 2 deletions src/common/snippets/src/lowered/pass/cleanup_loop_offsets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ bool CleanupLoopOffsets::run(lowered::LinearIR& linear_ir, lowered::LinearIR::co
}
if (auto outer_loop_end = as_type_ptr<op::LoopEndStatic>(next_node)) {
const auto& is_incremented = loop_end->get_is_incremented();
const auto& data_sizes = loop_end->get_element_type_sizes();
auto fin_offsets = loop_end->get_finalization_offsets();
std::unordered_map<PortConnectorPtr, size_t> per_port_connector_offset;
const auto& loop_inputs = expr_it->get()->get_input_port_connectors();
for (size_t i = 0; i < fin_offsets.size(); i++)
per_port_connector_offset[loop_inputs[i]] = i;

const auto outer_is_incremented = outer_loop_end->get_is_incremented();
const auto& outer_is_incremented = outer_loop_end->get_is_incremented();
const auto& outer_data_sizes = outer_loop_end->get_element_type_sizes();
const auto outer_increment = static_cast<int64_t>(outer_loop_end->get_increment());
auto outer_ptr_increments = outer_loop_end->get_ptr_increments();
const auto& outer_loop_inputs = next_expr_it->get()->get_input_port_connectors();
Expand All @@ -47,7 +49,7 @@ bool CleanupLoopOffsets::run(lowered::LinearIR& linear_ir, lowered::LinearIR::co
const auto& managed_connector = outer_loop_inputs[i];
const auto& found = per_port_connector_offset.find(managed_connector);
if (found != per_port_connector_offset.end()) {
if (!is_incremented[found->second])
if (!is_incremented[found->second] || outer_data_sizes[i] != data_sizes[found->second])
continue;
// Since data ptr is incremented on [ptr_increment x increment],
// we should guarantee proportionality of ptr shifts.
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/pass/fuse_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ bool FuseLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, l

const auto upper_loop_id = upper_loop_ids[loop_idx];
OPENVINO_ASSERT(current_loop_id != upper_loop_id,
"Loops cannot have parents of entry points with the same identifier");
"Loops cannot have parents of entry points with the same identifier (", upper_loop_id, ")");
if (fuse_upper_into_current(linear_ir, loop_manager, entry_point.expr_port, current_loop_id, upper_loop_id,
current_loop_begin_pos, current_loop_end_pos)) {
was_fusion_up = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
if (!brgemm_repack)
OV_CPU_JIT_EMITTER_THROW("expects BrgemmCopyB node");

m_brgemm_prc_in0 = brgemm_repack->get_src_element_type();
m_brgemm_prc_in1 = brgemm_repack->get_input_element_type(0);
m_brgemmVNNIFactor = 4 / m_brgemm_prc_in0.size();
m_with_comp = brgemm_repack->is_with_compensations();
m_in_offset = brgemm_repack->get_offset_in();
m_out_offset = brgemm_repack->get_offset_out();
Expand All @@ -49,36 +46,52 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
leading_dimension = jit_brgemm_emitter::get_in_leading_dim(original_shape, layout);
}

m_N = *(transposed_shape.rbegin());
m_K = *(transposed_shape.rbegin() + 1);
const auto& in_subtensor = in_desc->get_subtensor();
m_N_blk = *in_subtensor.rbegin();
m_K_blk = *++in_subtensor.rbegin();
OV_CPU_JIT_EMITTER_ASSERT(m_N_blk <= *transposed_shape.rbegin() && m_K_blk <= *++transposed_shape.rbegin(),
"BrgemmCopyB has incompatible subtensor dimensions");
m_inner_N_block = brgemm_repack->get_n_inner_block_size();
m_inner_N_tail = m_N_blk % m_inner_N_block;

m_N_blk = brgemm_repack->get_n_block_size();
m_K_blk = brgemm_repack->get_k_block_size();
OV_CPU_JIT_EMITTER_ASSERT(expr->get_output_port_descriptor(0)->get_subtensor() == in_subtensor, "output and input subtensors must be equal");
if (m_with_comp) {
const auto& compensations_subtensor = expr->get_output_port_descriptor(1)->get_subtensor();
OV_CPU_JIT_EMITTER_ASSERT(
*compensations_subtensor.rbegin() == m_N_blk && *++compensations_subtensor.rbegin() == 1,
"compensations subtensor must be {1, m_N_blk}");
}

OV_CPU_JIT_EMITTER_ASSERT(!one_of(m_brg_weight_etype, element::bf16, element::i8), "doesn't support precision ", m_brg_weight_etype);
const auto repacking_buffer_shape = brgemm_repack->get_repacking_buffer_shape();
OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty");
const auto& LDB = repacking_buffer_shape.back();

m_N_tail = m_N % m_N_blk;
m_K_tail = m_K % m_K_blk;
m_LDB = m_brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(m_N, m_N_blk);
const auto& brg_src_etype = brgemm_repack->get_src_element_type();
m_brg_weight_etype = brgemm_repack->get_input_element_type(0);
m_brgemmVNNIFactor = brgemm_repack->get_brgemm_vnni_factor();

const auto dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(m_brgemm_prc_in0));
const auto dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(m_brgemm_prc_in1));
const auto use_amx = mayiuse(avx512_core_amx) && brg_src_etype != ov::element::f32 &&
(m_K_blk % m_brgemmVNNIFactor == 0) && (m_N_blk % m_brgemmVNNIFactor == 0);

const bool isAMXSupported = mayiuse(avx512_core_amx);
const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0);
init_brgemm_copy(m_kernel, leading_dimension, m_N_blk, m_N_tail, m_LDB, m_K - m_K_tail, use_amx, dt_in0, dt_in1);
const auto src_dt = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(brg_src_etype));
const auto wei_dt = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(m_brg_weight_etype));

init_brgemm_copy(m_kernel, leading_dimension, m_inner_N_block, m_inner_N_tail, LDB, m_K_blk, use_amx, src_dt, wei_dt);
}

void jit_brgemm_copy_b_emitter::init_brgemm_copy(std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>& kernel,
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const {
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
bool is_with_amx, dnnl_data_type_t src_dt, dnnl_data_type_t wei_dt) const {
matmul::brgemm_matmul_conf_t brgCopyKernelConf;
brgCopyKernelConf.src_dt = dt_in0;
brgCopyKernelConf.wei_dt = dt_in1;
brgCopyKernelConf.src_dt = src_dt;
brgCopyKernelConf.wei_dt = wei_dt;
brgCopyKernelConf.wei_n_blk = static_cast<int>(N_blk);
brgCopyKernelConf.wei_tag = dnnl_abcd; // What's about other ranks?
brgCopyKernelConf.copy_B_wei_stride = 0;
brgCopyKernelConf.LDB = static_cast<dim_t>(LDB);
brgCopyKernelConf.N = static_cast<dim_t>(N);
brgCopyKernelConf.N_tail = static_cast<dim_t>(N_tail);
brgCopyKernelConf.N_tail = static_cast<dim_t>(N_tail);
brgCopyKernelConf.N_blk = static_cast<dim_t>(N_blk);
brgCopyKernelConf.K = static_cast<dim_t>(K);
brgCopyKernelConf.K_blk = static_cast<dim_t>(K);
Expand All @@ -91,50 +104,47 @@ void jit_brgemm_copy_b_emitter::init_brgemm_copy(std::unique_ptr<matmul::jit_brg
brgCopyKernelConf.isa = avx512_core_amx;
brgCopyKernelConf.s8s8_compensation_required = false;
} else {
brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni;
brgCopyKernelConf.s8s8_compensation_required = dt_in0 == dnnl_data_type_t::dnnl_s8;
brgCopyKernelConf.isa = src_dt == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni;
brgCopyKernelConf.s8s8_compensation_required = src_dt == dnnl_data_type_t::dnnl_s8;
}

brgCopyKernelConf.has_zero_point_a = false;
brgCopyKernelConf.has_zero_point_b = false;
brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none;

auto status = matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf);
if (status != dnnl_success)
OV_CPU_JIT_EMITTER_THROW("cannot create kernel due to invalid params");
OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "cannot create kernel due to invalid params");
}

void jit_brgemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in,
const std::vector<size_t>& out) const {
if (host_isa_ == cpu::x64::avx512_core) {
Xbyak::Reg64 src(static_cast<int>(in[0]));
Xbyak::Reg64 dst(static_cast<int>(out[0]));
Xbyak::Reg64 comp(static_cast<int>(0)); // Compensations. Default reg idx is 0 if there aren't the compensations
if (m_with_comp) {
if (out.size() != 2) {
OV_CPU_JIT_EMITTER_THROW("with compensations requires separate register for them");
}
comp = Xbyak::Reg64(static_cast<int>(out[1]));
}

const size_t data_size = m_brgemm_prc_in1.size();
for (size_t nb = 0; nb < div_up(m_N, m_N_blk); nb++) {
const size_t offset_in = m_in_offset + nb * m_N_blk * data_size;
const size_t offset_out = m_out_offset + nb * m_N_blk * m_brgemmVNNIFactor * data_size;
const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_N_blk * sizeof(int32_t) : 0;

const bool is_N_tail = (m_N - nb * m_N_blk < m_N_blk);
const auto current_N_blk = is_N_tail ? m_N_tail : m_N_blk;

emit_kernel_call(m_kernel.get(), src, dst, comp, current_N_blk, m_K, offset_in, offset_out, offset_comp);
}
} else {
OV_CPU_JIT_EMITTER_THROW("requires at least avx512_core instruction set");
void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1, "expects 1 input");
OV_CPU_JIT_EMITTER_ASSERT((m_with_comp && out.size() == 2) || (!m_with_comp && out.size() == 1),
"expects 2 outputs if there are compensations");
}

void jit_brgemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);
OV_CPU_JIT_EMITTER_ASSERT(host_isa_ == cpu::x64::avx512_core, "requires at least avx512_core instruction set");

Xbyak::Reg64 src(static_cast<int>(in[0]));
Xbyak::Reg64 dst(static_cast<int>(out[0]));
Xbyak::Reg64 comp(static_cast<int>(m_with_comp ? out[1] : 0));

const size_t data_size = m_brg_weight_etype.size();
for (size_t nb = 0; nb < div_up(m_N_blk, m_inner_N_block); nb++) {
const size_t offset_in = m_in_offset + nb * m_inner_N_block * data_size;
const size_t offset_out = m_out_offset + nb * m_inner_N_block * m_brgemmVNNIFactor * data_size;
const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_inner_N_block * sizeof(int32_t) : 0;

const bool is_N_tail = (m_N_blk - nb * m_inner_N_block < m_inner_N_block);
const auto current_N_blk = is_N_tail ? m_inner_N_tail : m_inner_N_block;

emit_kernel_call(m_kernel.get(), src, dst, comp, current_N_blk, m_K_blk, offset_in, offset_out, offset_comp);
}
}

void jit_brgemm_copy_b_emitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp,
size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const {
size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const {
const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) {
h->uni_vmovq(reg, xmm);
if (bytes_offset) h->add(reg, bytes_offset);
Expand Down Expand Up @@ -199,11 +209,12 @@ void jit_brgemm_copy_b_emitter::emit_kernel_call(const matmul::jit_brgemm_matmul
internal_call_postamble();
}

void jit_brgemm_copy_b_emitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src,
const void *dst, const void *comp, size_t N, size_t K) {
if (!kernel)
OV_CPU_JIT_EMITTER_THROW("Kernel hasn't been created");

void jit_brgemm_copy_b_emitter::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel,
const void* src,
const void* dst,
const void* comp,
size_t N,
size_t K) {
auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t();
ctx.current_N_blk = N;
ctx.src = src;
Expand All @@ -214,6 +225,7 @@ void jit_brgemm_copy_b_emitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kern
ctx.current_K_start = 0;
ctx.current_K_iters = K;

OV_CPU_JIT_EMITTER_ASSERT(kernel, "Kernel hasn't been created");
(*kernel)(&ctx);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
}

private:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

void init_brgemm_copy(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& kernel,
Expand All @@ -36,18 +37,24 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
const void* src, const void* dst, const void* comp, size_t N, size_t K);

std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> m_kernel;
ov::element::Type m_brg_weight_etype;

ov::element::Type m_brgemm_prc_in0, m_brgemm_prc_in1;
size_t m_N, m_N_blk, m_N_tail;
size_t m_K, m_K_blk, m_K_tail;
size_t m_LDB;
size_t m_brgemmVNNIFactor;
bool m_with_comp = false;
// Block size which is set by snippets: it is usually shared between brgemm and brgemm_copy_b nodes
size_t m_N_blk = 0lu;
// Block size which is used by the internal OneDNN implementation.
// It is used in snippets emitter to iterate through input/output data and call OneDNN kernel
size_t m_inner_N_block = 0lu;
size_t m_inner_N_tail = 0lu;

size_t m_K_blk = 0lu;
size_t m_brgemmVNNIFactor = 0lu;

size_t m_in_offset = 0lu;
size_t m_out_offset = 0lu;
size_t m_comp_offset = 0lu;

bool m_with_comp = false;

#ifdef SNIPPETS_DEBUG_CAPS
friend std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const ov

init_in_scheduling_params(input_0_desc);
if (brgemm_node->is_with_data_repacking()) {
const auto& brgemm_copy = brgemm_node->get_brgemm_copy();
const auto& allocated_shape = brgemm_copy->get_data_repacking_shape(input_1_desc->get_shape());
leading_dimensions.push_back(*allocated_shape.rbegin());
const auto repacking_buffer_shape = brgemm_node->get_brgemm_copy()->get_repacking_buffer_shape();
OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty");
leading_dimensions.push_back(repacking_buffer_shape.back());
} else {
init_in_scheduling_params(input_1_desc);
}
Expand All @@ -89,7 +89,6 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const ov
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);

m_with_comp = brgemm_node->is_with_compensations();
m_with_scratch = brgemm_node->is_with_scratchpad();

const auto& output_subtensor = output_desc->get_subtensor();
Expand All @@ -113,6 +112,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const ov
m_ctx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(brg1Prc));
m_ctx.beta = brgemm_node->get_beta();
m_ctx.is_with_amx = brgemm_node->is_amx();
m_ctx.is_with_comp = brgemm_node->is_with_compensations();

init_brgemm_kernel(m_ctx, m_kernel);

Expand Down Expand Up @@ -155,8 +155,6 @@ void jit_brgemm_emitter::init_brgemm_kernel(brgemmCtx& ctx, std::unique_ptr<brge

status = brgemm_init_tiles(desc, ctx.palette);

ctx.is_with_comp = ctx.dt_in0 == dnnl_data_type_t::dnnl_s8 && !ctx.is_with_amx;

brgemm_kernel_t* kernel_ = nullptr;
status = brgemm_kernel_create(&kernel_, desc);
if (status != dnnl_success)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class jit_brgemm_emitter : public jit_emitter {
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> m_kernel = nullptr;

bool m_with_scratch = false;
bool m_with_comp = false;

size_t m_load_offset_a = 0lu;
size_t m_load_offset_b = 0lu;
Expand Down
19 changes: 7 additions & 12 deletions src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,24 @@ std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) {
<< " m_load_offset_b:" << emitter->m_load_offset_b
<< " m_load_offset_scratch:" << emitter->m_load_offset_scratch
<< " m_store_offset_c:" << emitter->m_store_offset_c
<< " m_with_scratch:" << emitter->m_with_scratch
<< " m_with_comp:" << emitter->m_with_comp;
<< " m_with_scratch:" << emitter->m_with_scratch;

return ss.str();
}

std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter) {
std::stringstream ss;
ss << "Emitter_type_name:jit_brgemm_copy_b_emitter"
<< " m_LDB:" << emitter->m_LDB
<< " m_K:" << emitter->m_K
<< " m_K_blk:" << emitter->m_K_blk
<< " m_K_tail:" << emitter->m_K_tail
<< " m_N:" << emitter->m_N
<< " m_brg_weight_etype:" << emitter->m_brg_weight_etype
<< " m_N_blk:" << emitter->m_N_blk
<< " m_N_tail:" << emitter->m_N_tail
<< " m_brgemm_prc_in0:" << emitter->m_brgemm_prc_in0
<< " m_brgemm_prc_in1:" << emitter->m_brgemm_prc_in1
<< " m_inner_N_block:" << emitter->m_inner_N_block
<< " m_inner_N_tail:" << emitter->m_inner_N_tail
<< " m_K_blk:" << emitter->m_K_blk
<< " m_brgemmVNNIFactor:" << emitter->m_brgemmVNNIFactor
<< " m_with_comp:" << emitter->m_with_comp
<< " m_in_offset:" << emitter->m_in_offset
<< " m_out_offset:" << emitter->m_out_offset
<< ",m_comp_offset:" << emitter->m_comp_offset;
<< " m_comp_offset:" << emitter->m_comp_offset
<< " m_with_comp:" << emitter->m_with_comp;

return ss.str();
}
Expand Down
Loading

0 comments on commit 6f0e530

Please sign in to comment.