diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index 32c769d9083da8..4ac6ea84940d55 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -145,6 +145,7 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ jitters[ov::intel_cpu::BrgemmWithCompensationsCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter); jitters[ov::intel_cpu::BrgemmAMXCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmWithScratchEmitter); jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter); + jitters[ov::intel_cpu::BrgemmCopyBWithCompensations::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBWithCompensationsEmitter); } size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const { @@ -168,7 +169,7 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const { } ngraph::snippets::TargetMachine::opRegType ov::intel_cpu::CPUTargetMachine::get_specific_op_reg_type(const std::shared_ptr& op) const { - if (std::dynamic_pointer_cast(op)) + if (std::dynamic_pointer_cast(op)) return gpr2gpr; else return vec2vec; diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 093db9f502f7d4..28765e5e263831 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -66,7 +66,7 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, if (std::dynamic_pointer_cast(emitter) || std::dynamic_pointer_cast(emitter) || std::dynamic_pointer_cast(emitter) || - std::dynamic_pointer_cast(emitter)) + std::dynamic_pointer_cast(emitter)) in_physical_regs = map_regs(in_abstract_regs, gpr_map_pool); else in_physical_regs = std::move(in_abstract_regs); @@ -178,7 +178,7 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: // Brgemm is a special case since it incorporates input and output (we use onednn kernel) // Just like Load & Store it requires offsets calculation const auto is_brgemm = std::dynamic_pointer_cast(emitter) || - std::dynamic_pointer_cast(emitter); + std::dynamic_pointer_cast(emitter); return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr || is_brgemm; }); // Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two @@ -1178,21 +1178,18 @@ void BrgemmWithScratchEmitter::kernel_call(const brgemm_kernel_t *brg_kernel, h->add(h->rsp, num_args_passed_on_stack * gpr_size); } -BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) +BrgemmCopyBBaseEmitter::BrgemmCopyBBaseEmitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const std::shared_ptr& n) : jit_emitter(h, isa, n) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; - const auto brgemm_repack = ov::as_type_ptr(n); + const auto brgemm_repack = ov::as_type_ptr(n); if (!brgemm_repack) - IE_THROW() << "BrgemmCopyBEmitters expects BrgemmCopyB node"; + IE_THROW() << "BrgemmCopyBBaseEmitter expects BrgemmCopyBBase node"; brgemm_prc_in0 = brgemm_repack->get_src_element_type(); brgemm_prc_in1 = brgemm_repack->get_input_element_type(0); brgemmVNNIFactor = 4 / brgemm_prc_in0.size(); - with_comp = brgemm_repack->is_with_comp(); - in_offset = brgemm_repack->get_offset_in(); - out_offset = brgemm_repack->get_offset_out(); - if (with_comp) - comp_offset = brgemm_repack->get_offset_comp(); auto layout = ngraph::snippets::utils::get_node_output_layout(brgemm_repack->get_input_node_shared_ptr(0)); const auto& original_shape = brgemm_repack->get_input_shape(0); @@ -1230,11 +1227,14 @@ BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, d const auto dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in0))); const auto dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in1))); init_brgemm_copy(kernel, leading_dimension, N_blk, N_tail, LDB, K - K_tail, use_amx, dt_in0, dt_in1); + + in_offset = brgemm_repack->get_offset_in(); + out_offset = brgemm_repack->get_offset_out(); } -void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& kernel, - size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, - bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { +void BrgemmCopyBBaseEmitter::init_brgemm_copy(std::unique_ptr& kernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { matmul::brgemm_matmul_conf_t brgCopyKernelConf; brgCopyKernelConf.src_dt = dt_in0; brgCopyKernelConf.wei_dt = dt_in1; @@ -1269,40 +1269,37 @@ void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& in, - const std::vector& out, - const std::vector& pool, - const std::vector& gpr, - const ov::intel_cpu::emitter_context *emit_context) const { - if (host_isa_ == cpu::x64::avx512_core) { - Reg64 src(static_cast(in[0])); - Reg64 dst(static_cast(out[0])); - Reg64 comp(static_cast(0)); // Compensations. Default reg idx is 0 if there aren't the compensations - if (with_comp) { - if (out.size() != 2) { - IE_THROW() << "BrgemmCopyBEmitter with compensations requires separate register for them"; - } - comp = Reg64(static_cast(out[1])); - } +void BrgemmCopyBBaseEmitter::emit_impl(const std::vector& in, + const std::vector& out, + const std::vector& pool, + const std::vector& gpr, + const ov::intel_cpu::emitter_context *emit_context) const { + if (host_isa_ != cpu::x64::avx512_core) { + IE_THROW() << "BrgemmCopyBBaseEmitter requires at least avx512_core instruction set"; + } + + const auto in_size = in.size(); + const auto out_size = out.size(); + std::vector regs(in_size + out_size); + for (size_t i = 0; i < in_size; ++i) + regs[i] = Reg64(static_cast(in[i])); + for (size_t i = 0; i < out_size; ++i) + regs[in_size + i] = Reg64(static_cast(out[i])); - const size_t data_size = brgemm_prc_in1.size(); - for (size_t nb = 0; nb < div_up(N, N_blk); nb++) { - const size_t offset_in = in_offset + nb * N_blk * data_size; - const size_t offset_out = out_offset + nb * N_blk * brgemmVNNIFactor * data_size; - const size_t offset_comp = with_comp ? comp_offset + nb * N_blk * sizeof(int32_t) : 0; + const size_t data_size = brgemm_prc_in1.size(); + for (size_t nb = 0; nb < div_up(N, N_blk); nb++) { + const auto offsets = init_kernel_offsets(nb, N_blk, brgemmVNNIFactor, data_size); - const bool is_N_tail = (N - nb * N_blk < N_blk); - const auto current_N_blk = is_N_tail ? N_tail : N_blk; + const bool is_N_tail = (N - nb * N_blk < N_blk); + const auto current_N_blk = is_N_tail ? N_tail : N_blk; - emit_kernel_call(kernel.get(), src, dst, comp, current_N_blk, K, offset_in, offset_out, offset_comp); - } - } else { - IE_THROW() << "BrgemmCopyBEmitter requires at least avx512_core instruction set"; + emit_kernel_call(kernel.get(), regs, offsets, current_N_blk, K); } } -void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, - size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const { +void BrgemmCopyBBaseEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, + const std::vector& regs, const std::vector& offsets, + size_t N, size_t K) const { size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; @@ -1331,18 +1328,74 @@ void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b for (size_t i = 0; i < get_max_vecs_count(); ++i) h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i)); - const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { - h->uni_vmovq(reg, xmm); - if (bytes_offset) h->add(reg, bytes_offset); - }; -#ifdef _WIN32 - const auto push_value = [&](size_t value, size_t index) { - // Firstly we need to move integer to GPR. Then we can move value from GPR to stack - h->mov(abi_not_param1, value); - h->mov(h->qword[h->rsp + index * gpr_size], abi_not_param1); - }; -#endif + kernel_call(kernel, regs, offsets); + // restore vector registers + for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { + h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); + } + h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); + + // restore k registers + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); + + // restore gpr registers + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); +} + +void BrgemmCopyBBaseEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src, + const void *dst, const void *comp, size_t N, size_t K) { + if (!kernel) + IE_THROW() << "Kernel for `brgemm_copy_b` hasn't been created"; + + auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); + ctx.current_N_blk = N; + ctx.src = src; + ctx.tr_src = dst; + ctx.compensation_ptr = comp; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K; + + (*kernel)(&ctx); +} + +void BrgemmCopyBBaseEmitter::data_ptr(Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) const { + h->uni_vmovq(reg, xmm); + if (bytes_offset) h->add(reg, bytes_offset); +} +void BrgemmCopyBBaseEmitter::push_value(size_t value, size_t index, size_t gpr_size) const { + // Firstly we need to move integer to GPR. Then we can move value from GPR to stack + h->mov(abi_not_param1, value); + h->mov(h->qword[h->rsp + index * gpr_size], abi_not_param1); +} + +BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) + : BrgemmCopyBBaseEmitter(h, isa, n) {} + +std::vector BrgemmCopyBEmitter::init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const { + const size_t offset_in = in_offset + nb * N_blk * data_size; + const size_t offset_out = out_offset + nb * N_blk * brgemmVNNIFactor * data_size; + return { offset_in, offset_out }; +} + +void BrgemmCopyBEmitter::kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, + const std::vector& regs, + const std::vector& offsets) const { + if (regs.size() != offsets.size() || regs.size() != 2) { + IE_THROW() << "BrgemmCopyBEmitter got unexpected register count and offset count: " << regs.size() << " and " << offsets.size(); + } + + size_t gpr_size = 8; size_t num_args_passed_on_stack = 0; // save function address in gpr to pass in call instruction const auto &kernel_overload = static_castuni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. // It's likely that a more efficient solution exists. - h->uni_vmovq(Xmm(0), src); - h->uni_vmovq(Xmm(1), dst); - if (with_comp) - h->uni_vmovq(Xmm(2), comp); + h->uni_vmovq(Xmm(0), regs[0]); + h->uni_vmovq(Xmm(1), regs[1]); // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. h->mov(abi_param1, reinterpret_cast(kernel)); - data_ptr(Xmm(0), abi_param2, offset_in); - data_ptr(Xmm(1), abi_param3, offset_out); - if (with_comp) { - data_ptr(Xmm(2), abi_param4, offset_comp); - } else { - h->mov(abi_param4, reinterpret_cast(nullptr)); - } + data_ptr(Xmm(0), abi_param2, offsets[0]); + data_ptr(Xmm(1), abi_param3, offsets[1]); + h->mov(abi_param4, reinterpret_cast(nullptr)); #ifdef _WIN32 // Before function call we should allocate stack area for @@ -1378,8 +1425,8 @@ void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]); h->sub(h->rsp, num_args_passed_on_stack * gpr_size); - push_value(N, abi_param_count + 0); - push_value(K, abi_param_count + 1); + push_value(N, abi_param_count + 0, gpr_size); + push_value(K, abi_param_count + 1, gpr_size); #else h->mov(abi_param5, N); h->mov(abi_param6, K); @@ -1395,43 +1442,80 @@ void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b h->add(h->rsp, h->rbx); if (num_args_passed_on_stack > 0) h->add(h->rsp, gpr_size * num_args_passed_on_stack); - // restore vector registers - for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { - h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); - } - h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); +} - // restore k registers - for (int i = n_k_regs_to_save - 1; i >= 0; --i) { - if (mayiuse(avx512_core)) - h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - else - h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - } - h->add(h->rsp, n_k_regs_to_save * k_mask_size); +BrgemmCopyBWithCompensationsEmitter::BrgemmCopyBWithCompensationsEmitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const std::shared_ptr& n) + : BrgemmCopyBBaseEmitter(h, isa, n) { + const auto brgemm_repack = ov::as_type_ptr(n); + if (!brgemm_repack) + IE_THROW() << "BrgemmCopyBWithCompensationsEmitter expects BrgemmCopyBWithCompensations node"; + comp_offset = brgemm_repack->get_offset_comp(); +} - // restore gpr registers - for (int i = n_gprs_to_save - 1; i >= 0; --i) - h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); - h->add(h->rsp, n_gprs_to_save * gpr_size); +std::vector BrgemmCopyBWithCompensationsEmitter::init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const { + const size_t offset_in = in_offset + nb * N_blk * data_size; + const size_t offset_out = out_offset + nb * N_blk * brgemmVNNIFactor * data_size; + const size_t offset_comp = comp_offset + nb * N_blk * sizeof(int32_t); + return { offset_in, offset_out, offset_comp }; } -void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src, - const void *dst, const void *comp, size_t N, size_t K) { - if (!kernel) - IE_THROW() << "Kernel for `brgemm_copy_b` hasn't been created"; +void BrgemmCopyBWithCompensationsEmitter::kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, + const std::vector& regs, + const std::vector& offsets) const { + if (regs.size() != offsets.size() || regs.size() != 3) { + IE_THROW() << "BrgemmCopyBEmitter got unexpected register count and offset count: " << regs.size() << " and " << offsets.size(); + } - auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); - ctx.current_N_blk = N; - ctx.src = src; - ctx.tr_src = dst; - ctx.compensation_ptr = comp; - ctx.zp_a_compensation_ptr = nullptr; - ctx.zp_a_neg_value_ptr = nullptr; - ctx.current_K_start = 0; - ctx.current_K_iters = K; + size_t gpr_size = 8; + size_t num_args_passed_on_stack = 0; + // save function address in gpr to pass in call instruction + const auto &kernel_overload = static_cast(execute); + h->mov(h->rbp, reinterpret_cast(kernel_overload)); + // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted + // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. + // It's likely that a more efficient solution exists. + h->uni_vmovq(Xmm(0), regs[0]); + h->uni_vmovq(Xmm(1), regs[1]); + h->uni_vmovq(Xmm(2), regs[2]); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. + h->mov(abi_param1, reinterpret_cast(kernel)); - (*kernel)(&ctx); + data_ptr(Xmm(0), abi_param2, offsets[0]); + data_ptr(Xmm(1), abi_param3, offsets[1]); + data_ptr(Xmm(2), abi_param4, offsets[2]); + +#ifdef _WIN32 + // Before function call we should allocate stack area for + // - register parameters - ABI parameters (shadow space) + // - stack parameters - remaining parameters + num_args_passed_on_stack = 6; // count of function kernel_overload() parameters + size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]); + + h->sub(h->rsp, num_args_passed_on_stack * gpr_size); + push_value(N, abi_param_count + 0, gpr_size); + push_value(K, abi_param_count + 1, gpr_size); +#else + h->mov(abi_param5, N); + h->mov(abi_param6, K); +#endif + // align stack on 16-byte as ABI requires + // note that RBX must not be changed by the callee + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); + + h->call(h->rbp); + + h->add(h->rsp, h->rbx); + if (num_args_passed_on_stack > 0) + h->add(h->rsp, gpr_size * num_args_passed_on_stack); } HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index 21661c88d24717..1d23aa28586b48 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -442,14 +442,14 @@ class BrgemmWithScratchEmitter : public BrgemmBaseEmitter { size_t load_offset_scratch = 0; }; - -class BrgemmCopyBEmitter : public jit_emitter { +// Base class for BrgemmCopyB emitters with common interface +class BrgemmCopyBBaseEmitter : public jit_emitter { public: - BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + BrgemmCopyBBaseEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); - size_t get_inputs_num() const override {return 2;} + size_t get_inputs_num() const override {return 1;} -private: +protected: void emit_impl(const std::vector& in, const std::vector& out, const std::vector& pool, @@ -459,11 +459,18 @@ class BrgemmCopyBEmitter : public jit_emitter { void init_brgemm_copy(std::unique_ptr& kernel, size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const; - void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, - size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const; + void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, + const std::vector& regs, const std::vector& offsets, + size_t N, size_t K) const; + + virtual void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector& regs, const std::vector& offsets) const = 0; + virtual std::vector init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const = 0; static void execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K); + inline void data_ptr(Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) const; + inline void push_value(size_t value, size_t index, size_t gpr_size) const; + std::unique_ptr kernel; ov::element::Type brgemm_prc_in0, brgemm_prc_in1; @@ -475,6 +482,26 @@ class BrgemmCopyBEmitter : public jit_emitter { size_t in_offset = 0lu; size_t out_offset = 0lu; +}; + +class BrgemmCopyBEmitter : public BrgemmCopyBBaseEmitter { +public: + BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + +protected: + void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector& regs, const std::vector& offsets) const override; + std::vector init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const override; +}; + +class BrgemmCopyBWithCompensationsEmitter : public BrgemmCopyBBaseEmitter { +public: + BrgemmCopyBWithCompensationsEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + +protected: + void kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, const std::vector& regs, const std::vector& offsets) const override; + std::vector init_kernel_offsets(size_t nb, size_t N_blk, size_t brgemmVNNIFactor, size_t data_size) const override; + +private: size_t comp_offset = 0lu; }; diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 1837587327aa68..63b287533ebdb5 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -61,6 +61,7 @@ std::map Extension::getOpSets() { NGRAPH_OP(BrgemmWithCompensationsCPU, ov::intel_cpu) NGRAPH_OP(BrgemmAMXCPU, ov::intel_cpu) NGRAPH_OP(BrgemmCopyB, ov::intel_cpu) + NGRAPH_OP(BrgemmCopyBWithCompensations, ov::intel_cpu) #undef NGRAPH_OP return opset; diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp index 9370ad642b31d2..e663a17bf4503a 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp @@ -46,7 +46,6 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto N = *dimsMatMulIn1.rbegin(); const auto element_type_a = brgemm->get_input_element_type(0); - const auto element_type_b = brgemm->get_input_element_type(1); const auto brgemmVNNIFactor = 4 / element_type_a.size(); const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_int8) || dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16); @@ -64,7 +63,12 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { offset_a, offset_b, offset_c); } else { const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr()); - const auto brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, with_comp, offset_b); + std::shared_ptr brgemmRepackIn1 = nullptr; + if (with_comp) { + brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, offset_b); + } else { + brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, offset_b); + } const auto buffer = std::make_shared(brgemmRepackIn1->output(0)); if (with_amx) { diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp index 5f4b54b110fbae..f6553cdbe34a59 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp @@ -12,41 +12,23 @@ using namespace std; using namespace ov; -intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp, - const size_t offset_in, const size_t offset_out0, const size_t offset_out1) - : ngraph::snippets::op::MemoryAccess({x}), m_with_comp(with_comp), m_src_type(src_type) { +intel_cpu::BrgemmCopyBBase::BrgemmCopyBBase(const Output& x, const element::Type src_type, + const size_t offset_in, const size_t offset_out) + : ngraph::snippets::op::MemoryAccess({x}), m_src_type(src_type) { set_input_port_descriptor({0, offset_in}, 0); - set_output_port_descriptor({0, offset_out0}, 0); - if (with_comp) { - set_output_port_descriptor({0, offset_out1}, 1); - set_output_size(2); - } else { - set_output_size(1); - } - constructor_validate_and_infer_types(); + set_output_port_descriptor({0, offset_out}, 0); } -bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { - INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); - MemoryAccess::visit_attributes(visitor); - visitor.on_attribute("with_comp", m_with_comp); - visitor.on_attribute("src_type", m_src_type); - return true; -} - -void intel_cpu::BrgemmCopyB::validate_and_infer_types() { - INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); +void intel_cpu::BrgemmCopyBBase::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmCopyBBase_validate_and_infer_types); const auto element_type = get_input_element_type(0); NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), - "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); + "BrgemmCopyBBase doesn't support element type" + element_type.get_type_name()); const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0)); if (pshape.is_dynamic()) { set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); - if (m_with_comp) { - set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); - } return; } @@ -58,16 +40,60 @@ void intel_cpu::BrgemmCopyB::validate_and_infer_types() { set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)), ov::Dimension(rnd_up(N, N_blk))}); - if (m_with_comp) { - set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))}); - } +} + +bool intel_cpu::BrgemmCopyBBase::visit_attributes(AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(BrgemmCopyBBase_visit_attributes); + MemoryAccess::visit_attributes(visitor); + visitor.on_attribute("src_type", m_src_type); + return true; +} + +intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, + const size_t offset_in, const size_t offset_out) + : BrgemmCopyBBase(x, src_type, offset_in, offset_out) { + set_output_size(1); + constructor_validate_and_infer_types(); } std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { - INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); + INTERNAL_OP_SCOPE(BrgemmCopyB_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_src_type, m_with_comp, + return std::make_shared(new_args.at(0), m_src_type, get_offset_in(), - get_offset_out(), - m_with_comp ? get_offset_comp() : 0); + get_offset_out()); +} + +intel_cpu::BrgemmCopyBWithCompensations::BrgemmCopyBWithCompensations(const Output& x, const element::Type src_type, + const size_t offset_in, const size_t offset_out0, const size_t offset_out1) + : BrgemmCopyBBase(x, src_type, offset_in, offset_out0) { + set_output_port_descriptor({0, offset_out1}, 1); + set_output_size(2); + constructor_validate_and_infer_types(); +} + +void intel_cpu::BrgemmCopyBWithCompensations::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmCopyBWithCompensations_validate_and_infer_types); + BrgemmCopyBBase::validate_and_infer_types(); + + const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0)); + if (pshape.is_dynamic()) { + set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); + return; + } + + const auto shape = pshape.get_shape(); + const auto N = *shape.rbegin(); + const auto N_blk = get_input_element_type(0) == element::bf16 ? 32 : 64; + + set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))}); +} + +std::shared_ptr intel_cpu::BrgemmCopyBWithCompensations::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(BrgemmCopyBWithCompensations_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), m_src_type, + get_offset_in(), + get_offset_out(), + get_offset_comp()); } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp index da118d5cd35cc9..6df10b7df5400d 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp @@ -10,32 +10,64 @@ namespace ov { namespace intel_cpu { /** -* @interface BrgemmCopyB -* @brief The operation for data repacking of Brgemm with input non-fp32 precisions +* @interface BrgemmCopyBBase +* @brief The base class with the common interface for data repacking of Brgemm with input non-fp32 precisions * @ingroup snippets */ -class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess { +class BrgemmCopyBBase : public ngraph::snippets::op::MemoryAccess { public: - OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess); - BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp = false, - const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); - BrgemmCopyB() = default; + OPENVINO_OP("BrgemmCopyBBase", "SnippetsOpset", MemoryAccess); + BrgemmCopyBBase() = default; size_t get_offset_in() const { return get_input_port_descriptor(0).m_offset; } size_t get_offset_out() const { return get_output_port_descriptor(0).m_offset; } - size_t get_offset_comp() const { return get_output_port_descriptor(1).m_offset; } element::Type get_src_element_type() const { return m_src_type; } - bool is_with_comp() const { return m_with_comp; } - bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; bool has_evaluate() const override { return false; } + +protected: + BrgemmCopyBBase(const Output& x, const element::Type src_type, + const size_t offset_in = 0lu, const size_t offset_out = 0lu); + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override { return nullptr; }; + + element::Type m_src_type; // src element type of the corresponding BRGEMM (first input) +}; + +/** +* @interface BrgemmCopyB +* @brief The operation for data repacking of Brgemm with input non-fp32 precisions without compensations (doesn't have 2nd output) +* @ingroup snippets +*/ +class BrgemmCopyB : public BrgemmCopyBBase { +public: + OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", BrgemmCopyBBase); + BrgemmCopyB(const Output& x, const element::Type src_type, + const size_t offset_in = 0lu, const size_t offset_out = 0lu); + BrgemmCopyB() = default; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; +}; -private: - bool m_with_comp = false; - element::Type m_src_type; // src element type of the corresponding BRGEMM +/** +* @interface BrgemmCopyBWithCompensations +* @brief The operation for data repacking of Brgemm with input non-fp32 precisions with compensations (has 2nd output) +* @ingroup snippets +*/ +class BrgemmCopyBWithCompensations : public BrgemmCopyBBase { +public: + OPENVINO_OP(" BrgemmCopyBWithCompensations", "SnippetsOpset", BrgemmCopyBBase); + BrgemmCopyBWithCompensations(const Output& x, const element::Type src_type, + const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); + BrgemmCopyBWithCompensations() = default; + + size_t get_offset_comp() const { return get_output_port_descriptor(1).m_offset; } + + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp index fdf71ea207d93a..381ec75f9bc600 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp @@ -30,9 +30,9 @@ std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a // ============================= BrgemmWithRepackingCPU ============================== -std::shared_ptr BrgemmWithRepackingCPU::get_brgemm_copy() const { +std::shared_ptr BrgemmWithRepackingCPU::get_brgemm_copy() const { if (const auto buffer = ov::as_type_ptr(get_input_node_shared_ptr(1))) { - return ov::as_type_ptr(buffer->get_input_node_shared_ptr(0)); + return ov::as_type_ptr(buffer->get_input_node_shared_ptr(0)); } return nullptr; } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp index a4340245f2cadc..c5e04ceec75ce5 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp @@ -58,7 +58,7 @@ class BrgemmWithRepackingCPU : public ngraph::snippets::op::Brgemm { OPENVINO_OP("BrgemmWithRepackingCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm); BrgemmWithRepackingCPU() = default; - std::shared_ptr get_brgemm_copy() const; + std::shared_ptr get_brgemm_copy() const; protected: void validate_output();