diff --git a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py index a1c6aecc45d421..60a48c275d6681 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py @@ -177,15 +177,3 @@ def unpatch_model(model): log.warning("Exception raised during GPTQ model unpatching. " "Depending on the exact issue it may lead to broken " "original model.\n%s", error) - - -def detect_gptq_model_raw(model): - return (model and getattr(model, 'config', None) and - getattr(model.config, 'quantization_config', None) and - model.config.quantization_config.quant_method == 'gptq') - - -def detect_gptq_model(model): - return (detect_gptq_model_raw(model) or - getattr(model, 'model', None) and - detect_gptq_model_raw(model.model)) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/quantized.py b/src/bindings/python/src/openvino/frontend/pytorch/quantized.py new file mode 100644 index 00000000000000..310e95cb9985d7 --- /dev/null +++ b/src/bindings/python/src/openvino/frontend/pytorch/quantized.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional +import torch +from openvino.frontend.pytorch import ModuleExtension, gptq +from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model + + +def detect_quantized_model(model: torch.nn.Module) -> Optional[str]: + """Detects the quantization method used in a given PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model to check for quantization. + + Returns: + str: The quantization method if available, otherwise None. + """ + if (model and getattr(model, "config", None) + and getattr(model.config, "quantization_config", None)): + return model.config.quantization_config.quant_method + if getattr(model, "model", None): + return detect_quantized_model(model.model) + return None + + +def patch_quantized(model: torch.nn.Module) -> None: + """Patches a model based on its quantization type ("awq" or "gptq"). + + Args: + model (torch.nn.Module): The model to patch. + + Raises: + RuntimeError: If the quantization type is unknown. + """ + quant_type = detect_quantized_model(model) + if quant_type == "awq": + extensions = {} + try: + from awq.modules.linear import WQLinear_GEMM + extensions[WQLinear_GEMM] = ModuleExtension( + WQLinear_GEMM, "ov_ext::awq_gemm", + convert=lambda module, target_op, *args, **kwargs: target_op( + args[0], module.qweight, module.qzeros, module.scales, + torch.tensor(module.group_size), + torch.tensor(module.w_bit), module.bias), + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.out_features], 0.5, + dtype=torch.float32)) # type: ignore + except ImportError: + pass + patch_model(model, extensions, + "_openvino_quantized_patch_orig_forward") # type: ignore + elif quant_type == "gptq": + model._openvino_gptq_patched = True + gptq.patch_model(model) # type: ignore + else: + raise RuntimeError(f"Unknown quantization type: {quant_type}.") + + +def unpatch_quantized(model: torch.nn.Module) -> None: + """Reverts the patching applied to a quantized PyTorch model. + + Args: + model (torch.nn.Module): The model to unpatch. + """ + if getattr(model, "_openvino_gptq_patched", False): + gptq.unpatch_model(model) # type: ignore + del model._openvino_gptq_patched + else: + unpatch_model(model, + "_openvino_quantized_patch_orig_forward") # type: ignore diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index eb32a0a93c669b..6d8fdb1658793e 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -16,7 +16,7 @@ graph_has_ops, ) from openvino.runtime import opset11 as ops -from openvino.frontend.pytorch import gptq, patch_model +from openvino.frontend.pytorch import quantized, patch_model from openvino.frontend.pytorch.module_extension import ModuleExtension import inspect @@ -141,27 +141,25 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False) patch_model.patch_model( pt_module, self.module_extensions, orig_forward_name) - gptq_patched = False - if gptq.detect_gptq_model(pt_module): + patched = False + if quantized.detect_quantized_model(pt_module) is not None: try: - gptq.patch_model(pt_module) - gptq_patched = True + quantized.patch_quantized(pt_module) + patched = True except Exception as error: log.warning( - "Failed patching of AutoGPTQ model. Error message:\n%s" - "\nTracing of the model will likely be unsuccessful or incorrect", - error) - gptq.unpatch_model(pt_module) - gptq_patched = False + "Failed patching of AutoGPTQ model. Error message:\n" + "Tracing of the model will likely be unsuccessful or incorrect", + exc_info=error) + quantized.unpatch_quantized(pt_module) + patched = False try: scripted = torch.jit.trace( pt_module, **input_parameters, strict=False) finally: - if gptq_patched: - gptq.unpatch_model(pt_module) - if self.module_extensions: - patch_model.unpatch_model(pt_module, orig_forward_name) + if patched: + quantized.unpatch_quantized(pt_module) have_to_freeze_ops = ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"] diff --git a/src/common/snippets/include/snippets/lowered/pass/solve_buffer_memory.hpp b/src/common/snippets/include/snippets/lowered/pass/solve_buffer_memory.hpp index 71b5f4ba6c6f96..4d3c9f95350f4b 100644 --- a/src/common/snippets/include/snippets/lowered/pass/solve_buffer_memory.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/solve_buffer_memory.hpp @@ -34,6 +34,10 @@ class SolveBufferMemory : public Pass { */ bool run(lowered::LinearIR& linear_ir) override; + // For the better performance data should be aligned with cache line size. + // The majority of CPUs have cache line size `64` bytes. + constexpr static size_t byte_alignment = 64; + private: using Buffers = std::vector; /** @@ -64,8 +68,6 @@ class SolveBufferMemory : public Pass { void set_dynamic_buffer_offset(const Buffers& dynamic_buffer_expressions); size_t& m_static_buffer_scratchpad_size; - - constexpr static size_t m_alignment = 32; // 32 bytes for data alignment in allocated memory }; } // namespace pass diff --git a/src/common/snippets/src/lowered/pass/solve_buffer_memory.cpp b/src/common/snippets/src/lowered/pass/solve_buffer_memory.cpp index ca85cefd369099..ec7ab6c95eb89a 100644 --- a/src/common/snippets/src/lowered/pass/solve_buffer_memory.cpp +++ b/src/common/snippets/src/lowered/pass/solve_buffer_memory.cpp @@ -102,9 +102,8 @@ std::vector SolveBufferMemory::init_boxes(const Buffers& boxes.reserve(map_boxes.size()); for (auto& p : map_boxes) { auto& box = p.second; - // We use data alignment to put data in the line cache - // TODO [143395] : Please check if alignment is really needed here - box.size = utils::div_up(box.size, m_alignment); + // Align with cache line size. The experiments show that it affects performance. + box.size = utils::div_up(box.size, byte_alignment); boxes.push_back(box); } @@ -116,12 +115,12 @@ void SolveBufferMemory::solve_static_buffer_memory(const Buffers& static_buffer_ const auto boxes = init_boxes(static_buffer_expressions, linear_ir); ov::MemorySolver memSolver(boxes); - m_static_buffer_scratchpad_size = static_cast(memSolver.solve()) * m_alignment; // alignment in byte + m_static_buffer_scratchpad_size = static_cast(memSolver.solve()) * byte_alignment; // alignment in byte // Set offsets for Buffers for (const auto& buffer_expr : static_buffer_expressions) { const auto offset = static_cast(memSolver.get_offset(static_cast(buffer_expr->get_cluster_id()))); - buffer_expr->set_offset(offset * m_alignment); // alignment in byte + buffer_expr->set_offset(offset * byte_alignment); // alignment in byte } } diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 96d13074d042ba..06beb8db94ae3d 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -8,6 +8,8 @@ #include "snippets/lowered/pass/init_loops.hpp" #include "snippets/lowered/pass/insert_specific_iterations.hpp" #include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp" +#include "snippets/lowered/pass/solve_buffer_memory.hpp" +#include "snippets/pass/split_dimension_m.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/utils/loop_utils.hpp" #include "snippets/utils/utils.hpp" @@ -228,7 +230,8 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_i void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIRCPtr& linear_ir) const { const auto& loop_manager = linear_ir->get_loop_manager(); - m_config->buffer_scratchpad_size = linear_ir->get_static_buffer_scratchpad_size(); + // Align initial buffer scratchpad size with cache line size + m_config->buffer_scratchpad_size = utils::rnd_up(linear_ir->get_static_buffer_scratchpad_size(), lowered::pass::SolveBufferMemory::byte_alignment); auto is_not_executed = [&loop_manager](const lowered::ExpressionPtr& buffer_expr) { const auto& loop_ids = buffer_expr->get_loop_ids(); @@ -254,6 +257,9 @@ void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIRC additional_size = std::max(allocation_size * buffer_expr->get_node()->get_element_type().size(), additional_size); } + // Align with cache line size. The experiments shows that it affects performance. + additional_size = utils::rnd_up(additional_size, lowered::pass::SolveBufferMemory::byte_alignment); + cluster_offset = m_config->buffer_scratchpad_size; OPENVINO_ASSERT(!utils::is_dynamic_value(cluster_offset), "Offset of the cluster must be defined!"); m_config->buffer_scratchpad_size += additional_size; diff --git a/src/frontends/pytorch/src/op/linear.cpp b/src/frontends/pytorch/src/op/linear.cpp index 2d01dee84c151b..4a5ad4a6b0e73b 100644 --- a/src/frontends/pytorch/src/op/linear.cpp +++ b/src/frontends/pytorch/src/op/linear.cpp @@ -5,6 +5,10 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/add.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" #include "utils.hpp" namespace ov { @@ -12,6 +16,8 @@ namespace frontend { namespace pytorch { namespace op { +using namespace ov::op; + OutputVector translate_linear(const NodeContext& context) { // schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor num_inputs_check(context, 2, 3); @@ -20,17 +26,91 @@ OutputVector translate_linear(const NodeContext& context) { if (weight.get_element_type() == element::f16 || weight.get_element_type() == element::bf16) { // In case of patched linear it can have mixed fp16/bf16 and fp32 input type. // In other cases these conversion is not required. - weight = context.mark_node(std::make_shared(weight, x)); + weight = context.mark_node(std::make_shared(weight, x)); } - auto matmul = context.mark_node(std::make_shared(x, weight, false, true)); + auto matmul = context.mark_node(std::make_shared(x, weight, false, true)); if (!context.input_is_none(2)) { auto bias = context.get_input(2); if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) { // Same reason as for weight. - bias = context.mark_node(std::make_shared(bias, x)); + bias = context.mark_node(std::make_shared(bias, x)); + } + matmul = context.mark_node(std::make_shared(matmul, bias)); + } + return {matmul}; +}; + +namespace { +uint32_t rearrange_awq_bits(uint32_t num) { + uint32_t result = 0; + uint32_t mask = 0xF; + + // Rearrange each 4-bit part in accordance with the AWQ i32->u4 unpacking schema + result |= (num & (mask << 0)) << 0; + result |= (num & (mask << 16)) >> 12; + result |= (num & (mask << 4)) << 4; + result |= (num & (mask << 20)) >> 8; + result |= (num & (mask << 8)) << 8; + result |= (num & (mask << 24)) >> 4; + result |= (num & (mask << 12)) << 12; + result |= (num & (mask << 28)) >> 0; + + return result; +} + +Output rearrange_constant(const Output& c, uint32_t groups) { + auto constant = std::dynamic_pointer_cast(c.get_node_shared_ptr()); + FRONT_END_OP_CONVERSION_CHECK(constant, "weight must be Constant."); + auto src = constant->get_data_ptr(); + auto initial_shape = constant->get_shape(); + FRONT_END_OP_CONVERSION_CHECK(initial_shape.size() == 2, "Only 2D constants are supported."); + auto new_shape = Shape{initial_shape[0] / groups, groups, initial_shape[1] * 8}; + auto new_qweight = std::make_shared(element::u4, new_shape); + auto dst = const_cast(reinterpret_cast(new_qweight->get_data_ptr())); + for (size_t i = 0; i < shape_size(constant->get_shape()); i++) { + dst[i] = rearrange_awq_bits(src[i]); + } + return new_qweight; +} +} // namespace + +OutputVector translate_linear_awq(const NodeContext& context) { + num_inputs_check(context, 4, 7); + auto x = context.get_input(0); + auto qweight = context.get_input(1); + auto qzeros = context.get_input(2); + auto scales = context.get_input(3); + auto groups = context.const_input(4); + auto bits = context.const_input(5); + + FRONT_END_OP_CONVERSION_CHECK(bits == 4, "Only 4 bit AWQ is supported."); + + auto new_qweight = rearrange_constant(qweight, static_cast(groups)); + auto new_qzeros = rearrange_constant(qzeros, 1); + new_qweight = context.mark_node(std::make_shared(new_qweight, scales.get_element_type())); + new_qzeros = context.mark_node(std::make_shared(new_qzeros, scales.get_element_type())); + + auto w_s = context.mark_node(std::make_shared(new_qweight, new_qzeros)); + FRONT_END_OP_CONVERSION_CHECK(scales.get_partial_shape().is_static(), "Scales must be constant."); + auto scales_shape = scales.get_shape(); + auto new_scales_shape = + v0::Constant::create(element::i32, {3}, std::vector{scales_shape[0], 1, scales_shape[1]}); + scales = context.mark_node(std::make_shared(scales, new_scales_shape, false)); + auto weight = context.mark_node(std::make_shared(w_s, scales)); + auto out_shape = + v0::Constant::create(element::i32, {2}, std::vector{static_cast(qweight.get_shape()[0]), -1}); + weight = context.mark_node(std::make_shared(weight, out_shape, false)); + weight = context.mark_node(std::make_shared(weight, x)); + + auto matmul = context.mark_node(std::make_shared(x, weight, false, false)); + if (!context.input_is_none(6)) { + auto bias = context.get_input(6); + + if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) { + bias = context.mark_node(std::make_shared(bias, x)); } - matmul = context.mark_node(std::make_shared(matmul, bias)); + matmul = context.mark_node(std::make_shared(matmul, bias)); } return {matmul}; }; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 7307833430411f..ed375fd742d7ed 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -61,7 +61,6 @@ OP_CONVERTER(translate_clamp); OP_CONVERTER(translate_col2im); OP_CONVERTER(translate_constant); OP_CONVERTER(translate_conv_transposend); -OP_CONVERTER(translate_conv1d_ext); OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); @@ -77,7 +76,6 @@ OP_CONVERTER(translate_dot); OP_CONVERTER(translate_elu); OP_CONVERTER(translate_embedding); OP_CONVERTER(translate_embedding_bag); -OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_empty); OP_CONVERTER(translate_empty_like); OP_CONVERTER(translate_erf); @@ -325,6 +323,10 @@ OP_CONVERTER(translate_unbind_int_fx); OP_CONVERTER(translate_unique2); OP_CONVERTER(translate_zeros_fx); OP_CONVERTER(translate_zeros_like_fx); +// Extensions +OP_CONVERTER(translate_conv1d_ext); +OP_CONVERTER(translate_embedding_ext); +OP_CONVERTER(translate_linear_awq); } // namespace op @@ -699,6 +701,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::zero", op::translate_zeros_like}, {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, + {"ov_ext::awq_gemm", op::translate_linear_awq}, {"ov_ext::embedding", op::translate_embedding_ext}, {"ov_ext::conv1d", op::translate_conv1d_ext}, {"ov_ext::linear", op::translate_linear}, diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 752b9accb71d01..5cc7ec21f30911 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -42,7 +42,11 @@ using namespace ov::op; void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) { auto num_inputs = context.get_input_size(); - FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, + "Got less inputs ", + num_inputs, + " than expected ", + min_inputs); for (auto i = max_inputs; i < num_inputs; i++) { FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected."); } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index e68ab224407c7b..53d8fea05a8adf 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -48,14 +48,11 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()), "Jit emitter is called when the shapes are unknown"); - const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0)); - const auto K_blk = *++in_subtensor.rbegin(); - const auto& src_prc = brgemm_repack->get_src_element_type(); const auto& wei_prc = brgemm_repack->get_input_element_type(0); const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc); const auto is_transposed = get_is_transposed(expr); - const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed); + const auto brgemm_type = get_brgemm_type(src_prc, is_transposed); const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type)); m_with_comp = with_compensations(brgemm_type); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 057a3687ab8d16..6e70cbf2e8fe81 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -5,10 +5,13 @@ #include "jit_brgemm_emitter.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "snippets/utils/utils.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp" #include "emitters/plugin/x64/utils.hpp" + +#include "snippets/utils/utils.hpp" #include "utils.hpp" -#include "transformations/snippets/x64/op/brgemm_utils.hpp" using namespace Xbyak; using namespace dnnl::impl; @@ -27,11 +30,14 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const auto& brg0Prc = brgemm_node->get_input_element_type(0); const auto& brg1Prc = brgemm_node->get_input_element_type(1); const auto brgemm_type = brgemm_node->get_type(); - BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_amx(brgemm_type), with_compensations(brgemm_type), - brgemm_utils::get_primitive_isa(brg0Prc, with_amx(brgemm_type))); - m_kernel_executor = kernel_table->register_kernel(expr, - compiled_kernel_cache, - kernel_config); + m_is_with_amx = brgemm_utils::with_amx(brgemm_type); + if (m_is_with_amx) { + BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true)); + m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); + } else { + BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_compensations(brgemm_type), brgemm_utils::get_primitive_isa(brg0Prc, false)); + m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); + } // Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update() // are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually // for both static and the 1st dynamic shapes. @@ -82,18 +88,32 @@ void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vec if (in.size() > 2) mem_ptrs_idxs.emplace_back(in[2]); + if (std::dynamic_pointer_cast(m_kernel_executor)) + emit_call(mem_ptrs_idxs); + else if (std::dynamic_pointer_cast(m_kernel_executor)) + emit_call(mem_ptrs_idxs); + else + OV_CPU_JIT_EMITTER_THROW("uknown execuor type"); +} + +template::value, bool>::type> +void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) const { EmitABIRegSpills spill(h); spill.preamble(); - h->mov(h->rbp, reinterpret_cast(BrgemmKernelExecutor::execute)); - auto reserved_stack_size = sizeof(BrgemmKernelExecutor::call_args); + h->mov(h->rbp, reinterpret_cast(T::execute)); + auto reserved_stack_size = sizeof(typename T::call_args); // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); - const std::vector brgemm_args_offsets {GET_OFF_BRGEMM_ARGS(A), GET_OFF_BRGEMM_ARGS(B), GET_OFF_BRGEMM_ARGS(C), GET_OFF_BRGEMM_ARGS(scratch)}; +#define GET_OFF_CALL_ARGS(field) offsetof(typename T::call_args, field) + const std::vector brgemm_args_offsets = { GET_OFF_CALL_ARGS(A), GET_OFF_CALL_ARGS(B), GET_OFF_CALL_ARGS(C), GET_OFF_CALL_ARGS(scratch) }; +#undef GET_OFF_CALL_ARGS + const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs); for (size_t i = 0; i < mem_ptrs.size(); i++) { if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) @@ -108,8 +128,10 @@ void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vec h->mov(h->qword[h->rsp + brgemm_args_offsets.back()], reinterpret_cast(nullptr)); // abi_param1 always contains jit_snippets_call_args which has amx tile config for each thread - h->lea(h->r10, h->ptr[abi_param1 + GET_OFF(amx_tile_config)]); - h->mov(h->qword[h->rsp + GET_OFF_BRGEMM_ARGS(amx_tile_config)], h->r10); + if (std::is_same()) { + h->lea(h->r10, h->ptr[abi_param1 + GET_OFF(amx_tile_config)]); + h->mov(h->qword[h->rsp + GET_OFF_BRGEMM_AMX_ARGS(amx_tile_config)], h->r10); + } h->mov(abi_param1, reinterpret_cast(m_kernel_executor.get())); h->mov(abi_param2, h->rsp); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index baa6ed95473034..ccec1b68b18b20 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -5,7 +5,7 @@ #pragma once #include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/x64/kernel_executors/brgemm.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" namespace ov { namespace intel_cpu { @@ -24,15 +24,21 @@ class jit_brgemm_emitter : public jit_emitter { void validate_arguments(const std::vector &in, const std::vector &out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; + template ::value, bool>::type = true> + void emit_call(const std::vector& mem_ptrs_idxs) const; + // Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in runtime std::vector m_memory_offsets{}; // Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer std::vector m_buffer_ids{}; - std::shared_ptr m_kernel_executor = nullptr; + std::shared_ptr m_kernel_executor = nullptr; #ifdef SNIPPETS_DEBUG_CAPS friend std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter); #endif + + bool m_is_with_amx {false}; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index fad1be5a5d1289..c57824526d6e20 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -1,136 +1,55 @@ -// Copyright (C) 2020-2023 Intel Corporation +// Copyright (C) 2020-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include "brgemm.hpp" -#include - #include "common/utils.hpp" #include "dnnl_extension_utils.h" -#include "snippets/lowered/loop_manager.hpp" + #include "snippets/lowered/pass/insert_specific_iterations.hpp" + #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" -#define DIM_CAST(X) static_cast(X) -#define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; -namespace { -size_t init_hash(dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, bool is_with_amx, - bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t isa) { - size_t seed = 0; -#define HASH(X) seed = hash_combine(seed, X) - HASH(dt_in0); HASH(dt_in1); - HASH(is_with_amx); HASH(is_with_comp); - HASH(isa); -#undef HASH - return seed; -} -} // namespace - namespace ov { namespace intel_cpu { -BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_amx, bool is_with_comp, - dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) : - m_static_params(std::make_shared(in0_dtype, in1_dtype, - is_with_amx, is_with_comp, - primitive_isa)) { - m_hash = compute_hash(); -} - -bool BrgemmKernelConfig::is_completed() const { - return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty(); -} - -bool BrgemmKernelConfig::operator==(const BrgemmKernelConfig& rhs) const { -#define EQ(X) X == rhs.X - return EQ(m_hash) && EQ(m_beta) && - EQ(m_M) && EQ(m_N) && EQ(m_K) && - EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && - (EQ(m_static_params.get()) || *m_static_params == *(rhs.m_static_params)); -#undef EQ -} -void BrgemmKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta) { - // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) - // To process this case, we have to make this Config as empty (nullify runtime parameters) - if (utils::one_of(0, M, N, K)) { - m_M = 0; m_N = 0; m_K = 0; - m_LDA = 0; m_LDB = 0; m_LDC = 0; - m_beta = 0; - } else { - m_M = M; m_N = N; m_K = K; - m_LDA = LDA; m_LDB = LDB; m_LDC = LDC; - m_beta = beta; - } +BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, + bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : BrgemmBaseKernelConfig(), m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) { m_hash = compute_hash(); } -bool BrgemmKernelConfig::is_empty() const { - return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta); -} - -BrgemmKernelConfig::operator amx_tile_config_t() const { - amx_tile_config_t res; - res.M = m_M; res.N = m_N; res.K = m_K; - return res; -} - BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_amx, bool is_with_comp, - dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) : - dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), - is_with_amx(is_with_amx), is_with_comp(is_with_comp), - isa(primitive_isa), - hash(init_hash(dt_in0, dt_in1, is_with_amx, is_with_comp, isa)) { -} + bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(is_with_comp)), is_with_comp(is_with_comp) {} bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { -#define EQ(X) X == rhs.X - return EQ(hash) && EQ(dt_in0) && EQ(dt_in1)&& EQ(is_with_amx) && EQ(is_with_comp) && EQ(isa); -#undef EQ + return StaticBaseParams::operator==(rhs) && is_with_comp == rhs.is_with_comp; } -size_t BrgemmKernelConfig::compute_hash() const { - size_t seed = m_static_params->hash; -#define HASH(X) seed = hash_combine(seed, X) - HASH(m_M); HASH(m_N); HASH(m_K); - HASH(m_LDA); HASH(m_LDB); HASH(m_LDC); - HASH(m_beta); -#undef HASH - return seed; + +size_t BrgemmKernelConfig::StaticParams::compute_hash(bool is_with_comp) { + return hash_combine(0, is_with_comp); } #ifdef SNIPPETS_DEBUG_CAPS -#define PRINT(X) ss << #X << " = " << X << "\n" std::string BrgemmKernelConfig::StaticParams::to_string() const { std::stringstream ss; - PRINT(dt_in0); PRINT(dt_in1); - PRINT(is_with_amx); PRINT(is_with_comp); - PRINT(isa); - return ss.str(); -} - -std::string BrgemmKernelConfig::to_string() const { - std::stringstream ss; - ss << m_static_params->to_string() << "\n"; - PRINT(m_M); PRINT(m_N); PRINT(m_K); - PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC); - PRINT(m_beta); + ss << StaticBaseParams::to_string(); + ss << "is_with_comp = " << is_with_comp << "\n"; return ss.str(); } -#undef PRINT #endif BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) { } - std::shared_ptr BrgemmKernelExecutor::compile_kernel(const BrgemmKernelConfig& config) const { std::shared_ptr compiled_kernel = std::make_shared(); @@ -138,203 +57,42 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const if (config.is_empty()) return compiled_kernel; - cpu::x64::brgemm_desc_t desc; - auto status = brgemm_desc_init(&desc, config.get_isa(), cpu::x64::brgemm_strd, - config.get_dt_in0(), config.get_dt_in1(), - false, false, cpu::x64::brgemm_row_major, 1.f, - config.get_beta(), - config.get_LDA(), config.get_LDB(), config.get_LDC(), - config.get_M(), config.get_N(), config.get_K(), nullptr); - OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "Cannot initialize brgemm descriptor due to invalid params"); - - if (config.is_with_amx()) { - status = brgemm_init_tiles(desc, compiled_kernel->palette); - OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "Cannot initialize brgemm tiles due to invalid params"); - } - - cpu::x64::brgemm_kernel_t* kernel_ = nullptr; - status = brgemm_kernel_create(&kernel_, desc); - OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "Cannot create brgemm kernel due to invalid params"); - compiled_kernel->compiled_kernel = std::unique_ptr(kernel_); + create_brgemm_kernel(compiled_kernel->brgemm_kernel, config.get_dt_in0(), config.get_dt_in1(), config.get_isa(), + config.get_M(), config.get_N(), config.get_K(), config.get_LDA(), config.get_LDB(), config.get_LDC(), config.get_beta()); return compiled_kernel; } -float BrgemmKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, - const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { - // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. - // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). - // It means that previous executed Loops have Loop ID less the current Loop ID. - // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have `beta = 1`. - // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. - if (loop_id > 0) { - const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); - // Check the previous Loops - --loop_id; - while (loop_id >= 0) { - const auto& expanded_loop_info = loop_manager->get_loop_info(loop_id); - if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) - return 0; - if (expanded_loop_info->get_work_amount() > 0) { - // there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1` - return 1; - } - --loop_id; - } - } - return 0; -} + void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const { - const auto& input_pds = expr->get_input_port_descriptors(); - const auto& output_pds = expr->get_output_port_descriptors(); - OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, - "Invalid number of in/out port descriptors"); - - const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout()); - const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout()); - auto in0_subtensor = input_pds[0]->get_subtensor(); - auto in1_subtensor = input_pds[1]->get_subtensor(); - - // Need to update M, K, N - // 1. If the original value in subtensor is `FULL_DIM`, it means that - // Brgemm block should process full tensor by this dim -> take dimension from shape - // 2. Otherwise, Brgemm block processes part of the tensor by this dim - // (there is blocking by this dimension) -> take from Loop increment - - auto M = *++in0_subtensor.rbegin(); - auto K = *in0_subtensor.rbegin(); - auto N = *in1_subtensor.rbegin(); - - size_t loop_idx = 0; - const auto& loop_ids = expr->get_loop_ids(); - const auto& loop_manager = linear_ir->get_loop_manager(); - auto get_loop_info = [&](){ - OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); - return loop_manager->get_loop_info(loop_ids[loop_idx++]); - }; - - /* ------- Dimension M ----------*/ - if (ov::snippets::utils::is_full_dim_value(M)) { - M = *++in0_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` - // to avoid extra checks, we validate only first input port - // Note: We check `is_incremented` attribute only for not incremented ports because - // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 1; }; - OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && - out_ports.size() == 1 && check_port(out_ports.back()), - "Incorrect Loop by Brgemm dimension M"); - M = current_expanded_loop_info->get_increment(); - input_pds[0]->set_subtensor_dim(1, M); - output_pds[0]->set_subtensor_dim(1, M); - } - - /* ------- Dimension N ----------*/ - if (ov::snippets::utils::is_full_dim_value(N)) { - N = *in1_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - // Note: We check `is_incremented` attribute only for not incremented ports because - // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 0; }; - OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && - out_ports.size() == 1 && check_port(out_ports.back()), - "Incorrect Loop by Brgemm dimension N"); - N = current_expanded_loop_info->get_increment(); - input_pds[1]->set_subtensor_dim(0, N); - output_pds[0]->set_subtensor_dim(0, N); - } - - /* ------- Dimension K ----------*/ - // 1. If Brgemm block processes full dimension K -> `beta = 0` - // 2. If Brgemm block processes part of the dimension K (there is blocking), need to find - // the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0). - // First of them will have `beta = 0`, other - `beta = 1` - float beta = 0; - if (ov::snippets::utils::is_full_dim_value(K)) { - K = *in0_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - // Note: We check `is_incremented` attribute only for not incremented ports because - // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization - OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 && - out_ports.size() == 1 && !out_ports.front().is_incremented, - "Incorrect Loop by Brgemm dimension K"); - K = current_expanded_loop_info->get_increment(); - input_pds[0]->set_subtensor_dim(0, K); - input_pds[1]->set_subtensor_dim(1, K); - if (K > 0) - beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); - } - - const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); - const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); - auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); - const auto& brgemm_node = as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); - // In case of data repacking LDB is chosen in accordance with repacking buffer size - if (with_repacking(brgemm_node->get_type())) - LDB = brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)); - - config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); + return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { + OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); auto kernel = executor->get_kernel(); const auto& config = static_cast(executor->get_config()); OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr compiler kernel or invalid config"); - const auto tile_config = args->amx_tile_config; - if (config.is_with_amx() && tile_config && !config.compatible(tile_config)) { - *tile_config = static_cast(config); - cpu::x64::amx_tile_configure(kernel->palette); - } - - cpu::x64::brgemm_kernel_params_t brgemm_p; // Note: compensations should be applied only once, so we do it only on the first iteration, when beta == 0 - size_t is_with_comp = config.get_beta() == 0 && config.is_with_comp(); - - brgemm_p.batch = nullptr; // default value - brgemm_p.ptr_A = args->A; - brgemm_p.ptr_B = args->B; - brgemm_p.ptr_C = args->C; - brgemm_p.ptr_D = args->C; - brgemm_p.ptr_buf = args->scratch; - brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = is_with_comp; - brgemm_p.do_apply_comp = is_with_comp; - brgemm_p.skip_accm = 0; - brgemm_p.BS = 1; // default value - OV_CPU_JIT_EMITTER_ASSERT(kernel->compiled_kernel, "has nullptr kernel"); - (*kernel->compiled_kernel)(&brgemm_p); + const auto is_with_comp = config.get_beta() == 0 && config.is_with_comp(); + execute_brgemm_kernel(kernel->brgemm_kernel, args->A, args->B, args->C, args->scratch, is_with_comp); } #ifdef SNIPPETS_DEBUG_CAPS BrgemmKernelReferenceExecutor::BrgemmKernelReferenceExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) : - BrgemmKernelExecutor(std::move(kernel_cache), std::move(config)) { -} + BrgemmKernelExecutor(std::move(kernel_cache), std::move(config)) {} std::shared_ptr BrgemmKernelReferenceExecutor::compile_kernel(const BrgemmKernelConfig& c) const { const auto& res = std::make_shared(); - res->compiled_kernel.reset(new brgemm_ref_kernel(c)); + res->brgemm_kernel.reset(new brgemm_ref_kernel(c)); return res; } brgemm_ref_kernel::brgemm_ref_kernel(BrgemmKernelConfig c) : m_config(std::move(c)) { - OV_CPU_JIT_EMITTER_ASSERT(!m_config.is_with_comp() && !m_config.is_with_amx(), - "brgemm_ref_kernel doesn't currently support compensations or amx"); + OV_CPU_JIT_EMITTER_ASSERT(!m_config.is_with_comp(), + "brgemm_ref_kernel doesn't currently support compensations"); OV_CPU_JIT_EMITTER_ASSERT(m_config.get_dt_in0() == m_config.get_dt_in1() && m_config.get_dt_in0() == dnnl_data_type_t::dnnl_f32, "brgemm_ref_kernel currently supports only fp32 inputs"); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 2549580c1a176c..1c3d1e18872aea 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -1,96 +1,61 @@ -// Copyright (C) 2020-2023 Intel Corporation +// Copyright (C) 2020-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" -#include - -#include "snippets/lowered/loop_manager.hpp" -#include "snippets/lowered/loop_info.hpp" +#include "brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { + +struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { public: BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); BrgemmKernelConfig() = delete; - bool is_completed() const override; - size_t hash() const override { return m_hash; } - bool operator==(const BrgemmKernelConfig& rhs) const; - bool operator!=(const BrgemmKernelConfig& rhs) const {return !(*this == rhs);} - std::unique_ptr get_clone_ptr() const override { - return std::unique_ptr( new BrgemmKernelConfig(*this)); - } - void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); - bool is_empty() const; - dnnl_data_type_t get_dt_in0() const { return m_static_params->dt_in0; } - dnnl_data_type_t get_dt_in1() const { return m_static_params->dt_in1; } - - dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_static_params->isa; } - bool is_with_amx() const {return m_static_params->is_with_amx; } - bool is_with_comp() const { return m_static_params->is_with_comp; } - float get_beta() const { return m_beta; } - - dnnl_dim_t get_M() const { return m_M; } - dnnl_dim_t get_N() const { return m_N; } - dnnl_dim_t get_K() const { return m_K; } - - dnnl_dim_t get_LDA() const { return m_LDA; } - dnnl_dim_t get_LDB() const { return m_LDB; } - dnnl_dim_t get_LDC() const { return m_LDC; } - - explicit operator amx_tile_config_t() const; - inline bool compatible(amx_tile_config_t* rhs) const { - return rhs && rhs->M == m_M && rhs->N == m_N && rhs->K == m_K; + std::unique_ptr get_clone_ptr() const override { + return std::unique_ptr(new BrgemmKernelConfig(*this)); } -#ifdef SNIPPETS_DEBUG_CAPS - std::string to_string() const override; -#endif + bool is_with_comp() const { return m_static_params->is_with_comp; } private: - struct StaticParams { - StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); - const dnnl_data_type_t dt_in0 {dnnl_f32}, dt_in1 {dnnl_f32}; - const bool is_with_amx {false}; + struct StaticParams : StaticBaseParams { + StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + const bool is_with_comp {false}; - const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; - const size_t hash {0}; + bool operator==(const StaticParams& rhs) const; bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const; #endif + private: + static size_t compute_hash(bool is_with_comp); }; - size_t compute_hash() const; - std::shared_ptr m_static_params; - dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0}; - float m_beta {0}; - size_t m_hash {SIZE_MAX}; + + std::shared_ptr get_static_params() const override { return m_static_params; } + + std::shared_ptr m_static_params {nullptr}; }; +// The `update_kernel` method verifies that a compiled kernel is not nullptr. +// However, the compiled kernel might be empty in cases if nothing is to be compiled (`Config.is_empty() == true`). +// To cover this case, we wrap the `brgemm_kernel_t` in the separate structure which may contain empty `brgemm_kernel_t` struct BrgemmCompiledKernel { - std::unique_ptr compiled_kernel = nullptr; - // Note: Palette is treated as a part of a kernel because it is initialized during the kernel compilation stage. - // Each kernel need to store the pallet it was compiled with. - char palette[64] = {}; + std::shared_ptr brgemm_kernel = nullptr; }; -class BrgemmKernelExecutor : public CPUKernelExecutor { +class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor, + public CPUKernelExecutor { public: struct call_args { const void* A = nullptr; const void* B = nullptr; void* C = nullptr; void* scratch = nullptr; - amx_tile_config_t* amx_tile_config = nullptr; }; BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config); @@ -99,12 +64,10 @@ class BrgemmKernelExecutor : public CPUKernelExecutor compile_kernel(const BrgemmKernelConfig& c) const override; + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const override; - - static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, - const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info); }; #define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) @@ -116,6 +79,7 @@ class BrgemmKernelReferenceExecutor : public BrgemmKernelExecutor { protected: std::shared_ptr compile_kernel(const BrgemmKernelConfig& c) const override; }; + struct brgemm_ref_kernel : public dnnl::impl::cpu::x64::brgemm_kernel_t { brgemm_ref_kernel(BrgemmKernelConfig c); void operator()(dnnl::impl::cpu::x64::brgemm_kernel_params_t *) const override; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp new file mode 100644 index 00000000000000..62c7236735f70e --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp @@ -0,0 +1,249 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_amx.hpp" + +#include "transformations/snippets/x64/op/brgemm_utils.hpp" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" + +#include + + +#define INNER_K_BLK(dtype) static_cast((brgemm_utils::repacking::compute_inner_k_block(in0_dtype))) +#define VNNI_FACTOR(dtype) static_cast((brgemm_utils::compute_vnni_factor(in0_dtype))) +#define EQ(X) X == rhs.X +#define HASH(X) seed = hash_combine(seed, X) + + +using namespace Xbyak; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + + +namespace ov { +namespace intel_cpu { + +BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : BrgemmBaseKernelConfig(), m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { + m_hash = compute_hash(); +} + +BrgemmAMXKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(INNER_K_BLK(in0_dtype), VNNI_FACTOR(in0_dtype))), + inner_k_blk(INNER_K_BLK(in0_dtype)), vnni_factor(VNNI_FACTOR(in0_dtype)) {} + +bool BrgemmAMXKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { + return StaticBaseParams::operator==(rhs) && EQ(inner_k_blk) && EQ(vnni_factor); +} + +size_t BrgemmAMXKernelConfig::StaticParams::compute_hash(dnnl_dim_t inner_k_blk, dnnl_dim_t vnni_factor) { + size_t seed = 0; + HASH(inner_k_blk); HASH(vnni_factor); + return seed; +} + +bool BrgemmAMXKernelConfig::need_copy_a(dnnl_dim_t K) const { + return K % get_vnni_factor() > 0; +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmAMXKernelConfig::StaticParams::to_string() const { + std::stringstream ss; + ss << StaticBaseParams::to_string(); + ss << "inner_k_blk = " << inner_k_blk << "\n"; + ss << "vnni_factor = " << vnni_factor << "\n"; + return ss.str(); +} +#endif + +BrgemmAMXKernelExecutor::BrgemmAMXKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmAMXKernelConfig config) : + CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} + +namespace { +struct BrgemmCopyAKey { + BrgemmCopyAKey(cpu_isa_t isa, dnnl_data_type_t dt, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA) + : isa(isa), dt(dt), K{K}, K_blk{K_blk}, K_tail{K_tail}, src_stride{src_stride}, LDA{LDA} {} + + size_t hash() const { + size_t seed = 0; + HASH(isa); HASH(dt); HASH(K); HASH(K_blk); HASH(K_tail); HASH(src_stride); HASH(LDA); + return seed; + } + bool operator==(const BrgemmCopyAKey& rhs) const { + return EQ(isa) && EQ(dt) && EQ(K) && EQ(K_blk) && EQ(K_tail) && EQ(src_stride) && EQ(LDA); + } + + cpu_isa_t isa {cpu_isa_t::isa_undef}; + dnnl_data_type_t dt {dnnl_data_type_t::dnnl_data_type_undef}; + dnnl_dim_t K {0}, K_blk {0}, K_tail {0}, src_stride {0}, LDA {0}; +}; +} // namespace + +std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel(const BrgemmAMXKernelConfig& config) const { + std::shared_ptr compiled_kernel = std::make_shared(); + + // Brgemm is not executable - nothing to compile + if (config.is_empty()) + return compiled_kernel; + + const auto& cache = m_kernel_cache.lock(); + OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in BrgemmAMXKernelExecutor::compile_kernel()"); + + auto brgemm_key = [&config](dnnl_dim_t K, dnnl_dim_t LDA, float beta) { + auto key = config; + key.update(config.get_M(), config.get_N(), K, LDA, config.get_LDB(), config.get_LDC(), beta); + return key; + }; + + auto brgemm_builder = [](const BrgemmAMXKernelConfig& k) { + std::shared_ptr ker = std::make_shared(); + create_brgemm_kernel(ker->brgemm_kernel, k.get_dt_in0(), k.get_dt_in1(), k.get_isa(), k.get_M(), k.get_N(), k.get_K(), + k.get_LDA(), k.get_LDB(), k.get_LDC(), k.get_beta(), true, ker->palette); + return ker; + }; + + auto brgemm_copy_a_builder = [](const BrgemmCopyAKey& k) { + std::shared_ptr ker {nullptr}; + create_brgemm_copy_a_kernel(ker, k.isa, k.dt, k.K, k.K_blk, k.K_tail, k.src_stride, k.LDA); + return ker; + }; + + auto K_tail = config.get_K() % config.get_inner_K_blk(); + auto K_body = config.get_K() - K_tail; + + float beta = config.get_beta(); + + // Brgemm Kernel for K_body + if (K_body != 0) { + const auto result = cache->getOrCreate(brgemm_key(K_body, config.get_LDA(), beta), brgemm_builder); + compiled_kernel->K_body_kernel = result.first; + beta = 1; + } + + // Brgemm Kernel for K_tail with BrgemmCopyA if needed + if (K_tail != 0) { + auto LDA = config.get_LDA(); + if (config.need_copy_a(K_tail)) { + const auto copy_A_src_stride = LDA * dnnl_data_type_size(config.get_dt_in0()); + K_tail = ov::snippets::utils::rnd_up(K_tail, config.get_vnni_factor()); + LDA = K_tail; + + const auto key = BrgemmCopyAKey(config.get_isa(), config.get_dt_in0(), config.get_K(), config.get_inner_K_blk(), K_tail, copy_A_src_stride, LDA); + const auto result = cache->getOrCreate(key, brgemm_copy_a_builder); + compiled_kernel->brgemm_copy_a_kernel = result.first; + } + + const auto result = cache->getOrCreate(brgemm_key(K_tail, LDA, beta), brgemm_builder); + compiled_kernel->K_tail_kernel = result.first; + } + + return compiled_kernel; +} + +void BrgemmAMXKernelExecutor::create_brgemm_copy_a_kernel(std::shared_ptr& kernel, + dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_data_type_t dt, + dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA) { + matmul::brgemm_matmul_conf_t conf_; + conf_.src_tag = dnnl_abcd; // unused + conf_.K = K; + conf_.K_tail = K_tail; + conf_.K_blk = K_blk; + conf_.use_buffer_a_tail_only = false; + conf_.LDA = LDA; + conf_.has_zero_point_b = false; + conf_.s8s8_compensation_required = false; + conf_.wei_zp_type = dnnl::impl::cpu::x64::none; + conf_.src_zp_type = dnnl::impl::cpu::x64::none; + conf_.src_dt = dt; + conf_.copy_A_src_stride = src_stride; + conf_.a_dt_sz = dnnl_data_type_size(conf_.src_dt); + // copied A has the same precision of original + conf_.tr_a_dt_sz = dnnl_data_type_size(conf_.src_dt); + conf_.transposed_A = false; + conf_.isa = isa; + + std::unique_ptr brgemm_matmul_copy_a = nullptr; + OV_CPU_JIT_EMITTER_ASSERT(create_brgemm_matmul_copy_a(brgemm_matmul_copy_a, &conf_) == dnnl_success, + "Cannot create brgemm copy a kernel due to invalid params"); + kernel = std::move(brgemm_matmul_copy_a); +} + +void BrgemmAMXKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmAMXKernelConfig& config) const { + return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); +} + +void BrgemmAMXKernelExecutor::configure_tiles_if_needed(amx_tile_config_t* config, const char* palette, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K) { + auto compatible = [&](amx_tile_config_t* rhs) { + return rhs && rhs->M == M && rhs->N == N && rhs->K == K; + }; + if (config && !compatible(config)) { + config->M = M; config->N = N; config->K = K; + cpu::x64::amx_tile_configure(palette); + } +} + +void BrgemmAMXKernelExecutor::execute_brgemm_copy_a_kernel(const std::shared_ptr& kernel, + const void* src, const void* tr_src, dnnl_dim_t M, dnnl_dim_t K) { + auto ctx = matmul::jit_brgemm_matmul_copy_a_t::ctx_t(); + + ctx.current_M_blk = M; + ctx.zp_b_compensation_buffer_ptr = nullptr; + ctx.zp_a_compensation_result_ptr = nullptr; + ctx.zp_b_neg_value_ptr = nullptr; + ctx.zp_ab_comp_ptr = nullptr; + ctx.src = src; + ctx.tr_src = tr_src; + ctx.current_K_start = 0; + ctx.current_K_blk = K; + + OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr brgemm_copy_a_kernel"); + (*kernel)(&ctx); +} + +void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, call_args* args) { + OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); + auto kernel = executor->get_kernel(); + const auto& config = static_cast(executor->get_config()); + OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr compiler kernel or invalid config"); + + const auto* src_ptr = args->A; + const auto* wei_ptr = args->B; + auto* scratch = args->scratch; + + const auto K_tail = config.get_K() % config.get_inner_K_blk(); + const auto K_body = config.get_K() - K_tail; + + if (K_body != 0) { + const auto& K_body_kernel = kernel->K_body_kernel; + configure_tiles_if_needed(args->amx_tile_config, K_body_kernel->palette, config.get_M(), config.get_N(), K_body); + execute_brgemm_kernel(K_body_kernel->brgemm_kernel, src_ptr, wei_ptr, args->C, scratch, false); + + src_ptr = src_ptr + K_body * dnnl_data_type_size(config.get_dt_in0()); + wei_ptr = wei_ptr + (K_body * config.get_LDB()) * dnnl_data_type_size(config.get_dt_in1()); + } + + if (K_tail != 0) { + if (config.need_copy_a(K_tail)) { + auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE; + + execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail); + src_ptr = tr_src; + } + + const auto& K_tail_kernel = kernel->K_tail_kernel; + configure_tiles_if_needed(args->amx_tile_config, K_tail_kernel->palette, config.get_M(), config.get_N(), K_tail); + execute_brgemm_kernel(K_tail_kernel->brgemm_kernel, src_ptr, wei_ptr, args->C, scratch, false); + } +} + +#undef INNER_K_BLK +#undef VNNI_FACTOR +#undef EQ +#undef HASH + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp new file mode 100644 index 00000000000000..a8544e5343b0ce --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp @@ -0,0 +1,102 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "brgemm_base.hpp" + +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" + +#include +#include + + +namespace ov { +namespace intel_cpu { + +struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig { +public: + BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + BrgemmAMXKernelConfig() = delete; + + std::unique_ptr get_clone_ptr() const override { + return std::unique_ptr(new BrgemmAMXKernelConfig(*this)); + } + + dnnl_dim_t get_inner_K_blk() const { return m_static_params->inner_k_blk; } + dnnl_dim_t get_vnni_factor() const { return m_static_params->vnni_factor; } + + bool need_copy_a(dnnl_dim_t K) const; + +private: + struct StaticParams : StaticBaseParams { + StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + + const dnnl_dim_t inner_k_blk {0}; + const dnnl_dim_t vnni_factor {0}; + + bool operator==(const StaticParams& rhs) const; + bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + private: + static size_t compute_hash(dnnl_dim_t inner_k_blk, dnnl_dim_t vnni_factor); + }; + + std::shared_ptr get_static_params() const override { return m_static_params; } + + std::shared_ptr m_static_params {nullptr}; +}; + +struct BrgemmAMXCompiledKernel { + struct BrgemmKernel { + std::shared_ptr brgemm_kernel {nullptr}; + // Note: Palette is treated as a part of a kernel because it is initialized during the kernel compilation stage. + // Each kernel need to store the pallet it was compiled with. + char palette[64] = {}; + }; + + std::shared_ptr K_body_kernel {nullptr}; + std::shared_ptr K_tail_kernel {nullptr}; + std::shared_ptr brgemm_copy_a_kernel {nullptr}; +}; + +class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor, + public CPUKernelExecutor { +public: + struct call_args { + const uint8_t* A = nullptr; + const uint8_t* B = nullptr; + void* C = nullptr; + uint8_t* scratch = nullptr; + amx_tile_config_t* amx_tile_config = nullptr; + }; + BrgemmAMXKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmAMXKernelConfig config); + + /** Function that will be called in runtime to execute the kernel */ + static void execute(const BrgemmAMXKernelExecutor* executor, call_args* args); + +protected: + std::shared_ptr compile_kernel(const BrgemmAMXKernelConfig& c) const override; + + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmAMXKernelConfig& config) const override; + + static void configure_tiles_if_needed(amx_tile_config_t* config, const char* palette, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K); + + static void create_brgemm_copy_a_kernel(std::shared_ptr& kernel, + dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_data_type_t dt, + dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA); + + static void execute_brgemm_copy_a_kernel(const std::shared_ptr& kernel, + const void* src, const void* tr_src, dnnl_dim_t M, dnnl_dim_t K); +}; +#define GET_OFF_BRGEMM_AMX_ARGS(field) offsetof(BrgemmAMXKernelExecutor::call_args, field) + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp new file mode 100644 index 00000000000000..17b1f0e053b577 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp @@ -0,0 +1,273 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_base.hpp" + +#include "common/utils.hpp" +#include "dnnl_extension_utils.h" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" + +#define DIM_CAST(X) static_cast(X) +#define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) +#define PRINT(X) ss << #X << " = " << X << "\n" +#define EQ(X) X == rhs.X +#define HASH(X) seed = hash_combine(seed, X) + +using namespace Xbyak; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { + +bool BrgemmBaseKernelConfig::is_completed() const { + return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty(); +} + +bool BrgemmBaseKernelConfig::is_empty() const { + return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta); +} + +bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const { + return EQ(m_hash) && EQ(m_beta) && + EQ(m_M) && EQ(m_N) && EQ(m_K) && + EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && + (EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params())); +} + +void BrgemmBaseKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta) { + // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) + // To process this case, we have to make this Config as empty (nullify runtime parameters) + if (utils::one_of(0, M, N, K)) { + m_M = 0; m_N = 0; m_K = 0; + m_LDA = 0; m_LDB = 0; m_LDC = 0; + m_beta = 0; + } else { + m_M = M; m_N = N; m_K = K; + m_LDA = LDA; m_LDB = LDB; m_LDC = LDC; + m_beta = beta; + } + m_hash = compute_hash(); +} + +size_t BrgemmBaseKernelConfig::compute_hash() const { + size_t seed = get_static_params()->hash(); + HASH(m_M); HASH(m_N); HASH(m_K); + HASH(m_LDA); HASH(m_LDB); HASH(m_LDC); + HASH(m_beta); + return seed; +} + +BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, const element::Type& in1_dtype, + cpu_isa_t primitive_isa, size_t hash_seed) + : dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), isa(primitive_isa), m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} + +bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { + return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa); +} + +size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, cpu_isa_t isa) { + size_t seed = hash_seed; + HASH(dt_in0); HASH(dt_in1); HASH(isa); + return seed; +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { + std::stringstream ss; + PRINT(dt_in0); PRINT(dt_in1); + PRINT(isa); + return ss.str(); +} + +std::string BrgemmBaseKernelConfig::to_string() const { + std::stringstream ss; + ss << get_static_params()->to_string() << "\n"; + PRINT(m_M); PRINT(m_N); PRINT(m_K); + PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC); + PRINT(m_beta); + return ss.str(); +} +#endif + +float BrgemmBaseKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, + const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { + // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. + // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). + // It means that previous executed Loops have Loop ID less the current Loop ID. + // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have `beta = 1`. + // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. + if (loop_id > 0) { + const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); + // Check the previous Loops + --loop_id; + while (loop_id >= 0) { + const auto& expanded_loop_info = loop_manager->get_loop_info(loop_id); + if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) + return 0; + if (expanded_loop_info->get_work_amount() > 0) { + // there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1` + return 1; + } + --loop_id; + } + } + return 0; +} + +void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig& config) { + const auto& input_pds = expr->get_input_port_descriptors(); + const auto& output_pds = expr->get_output_port_descriptors(); + OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, + "Invalid number of in/out port descriptors"); + + const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout()); + const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout()); + auto in0_subtensor = input_pds[0]->get_subtensor(); + auto in1_subtensor = input_pds[1]->get_subtensor(); + + // Need to update M, K, N + // 1. If the original value in subtensor is `FULL_DIM`, it means that + // Brgemm block should process full tensor by this dim -> take dimension from shape + // 2. Otherwise, Brgemm block processes part of the tensor by this dim + // (there is blocking by this dimension) -> take from Loop increment + + auto M = *++in0_subtensor.rbegin(); + auto K = *in0_subtensor.rbegin(); + auto N = *in1_subtensor.rbegin(); + + size_t loop_idx = 0; + const auto& loop_ids = expr->get_loop_ids(); + const auto& loop_manager = linear_ir->get_loop_manager(); + auto get_loop_info = [&](){ + OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); + return loop_manager->get_loop_info(loop_ids[loop_idx++]); + }; + + /* ------- Dimension M ----------*/ + if (ov::snippets::utils::is_full_dim_value(M)) { + M = *++in0_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` + // to avoid extra checks, we validate only first input port + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 1; }; + OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && + out_ports.size() == 1 && check_port(out_ports.back()), + "Incorrect Loop by Brgemm dimension M"); + M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[0]->set_subtensor_dim(1, M); + output_pds[0]->set_subtensor_dim(1, M); + } + + /* ------- Dimension N ----------*/ + if (ov::snippets::utils::is_full_dim_value(N)) { + N = *in1_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 0; }; + OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && + out_ports.size() == 1 && check_port(out_ports.back()), + "Incorrect Loop by Brgemm dimension N"); + N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[1]->set_subtensor_dim(0, N); + output_pds[0]->set_subtensor_dim(0, N); + } + + /* ------- Dimension K ----------*/ + // 1. If Brgemm block processes full dimension K -> `beta = 0` + // 2. If Brgemm block processes part of the dimension K (there is blocking), need to find + // the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0). + // First of them will have `beta = 0`, other - `beta = 1` + float beta = 0; + if (ov::snippets::utils::is_full_dim_value(K)) { + K = *in0_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 && + out_ports.size() == 1 && !out_ports.front().is_incremented, + "Incorrect Loop by Brgemm dimension K"); + K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[0]->set_subtensor_dim(0, K); + input_pds[1]->set_subtensor_dim(1, K); + if (K > 0) + beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); + } + + const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); + const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); + auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); + + const auto& brgemm_node = as_type_ptr(expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); + // In case of data repacking LDB is chosen in accordance with repacking buffer size + if (with_repacking(brgemm_node->get_type())) + LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1))); + + config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); +} + +void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr& kernel, dnnl_data_type_t dt0, dnnl_data_type_t dt1, + cpu_isa_t isa, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, + dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta, bool with_amx, char* palette) { + cpu::x64::brgemm_desc_t desc; + OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, isa, cpu::x64::brgemm_strd, dt0, dt1, + false, false, cpu::x64::brgemm_row_major, 1.f, + beta, LDA, LDB, LDC, M, N, K, nullptr) == dnnl_success, + "Cannot initialize brgemm descriptor due to invalid params"); + + if (with_amx) { + OV_CPU_JIT_EMITTER_ASSERT(palette && brgemm_init_tiles(desc, palette) == dnnl_success, + "Cannot initialize brgemm tiles due to invalid params"); + } + + cpu::x64::brgemm_kernel_t* kernel_ = nullptr; + OV_CPU_JIT_EMITTER_ASSERT(brgemm_kernel_create(&kernel_, desc) == dnnl_success, "Cannot create brgemm kernel due to invalid params"); + kernel = std::unique_ptr(kernel_); +} + +void BrgemmBaseKernelExecutor::execute_brgemm_kernel(const std::shared_ptr& kernel, + const void* src, const void* wei, void* dst, void* scratch, bool with_comp) { + cpu::x64::brgemm_kernel_params_t brgemm_p; + brgemm_p.batch = nullptr; // default value + brgemm_p.ptr_A = src; + brgemm_p.ptr_B = wei; + brgemm_p.ptr_C = dst; + brgemm_p.ptr_D = dst; + brgemm_p.ptr_buf = scratch; + brgemm_p.ptr_bias = nullptr; + brgemm_p.do_post_ops = with_comp; + brgemm_p.do_apply_comp = with_comp; + brgemm_p.skip_accm = 0; + brgemm_p.BS = 1; // default value + OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel"); + (*kernel)(&brgemm_p); +} + +#undef DIM_CAST +#undef DTYPE_CAST +#undef PRINT +#undef EQ +#undef HASH + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp new file mode 100644 index 00000000000000..74a5c2b76daf65 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp @@ -0,0 +1,102 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/type/element_type.hpp" + +#include "cpu/x64/cpu_isa_traits.hpp" + +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include + +#include "snippets/lowered/loop_manager.hpp" +#include "snippets/lowered/loop_info.hpp" + +namespace ov { +namespace intel_cpu { + +struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConfig { +public: + BrgemmBaseKernelConfig() = default; + + bool is_completed() const override; + size_t hash() const override { return m_hash; } + + bool is_empty() const; + void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); + + bool operator==(const BrgemmBaseKernelConfig& rhs) const; + bool operator!=(const BrgemmBaseKernelConfig& rhs) const {return !(*this == rhs);} + + dnnl_data_type_t get_dt_in0() const { return get_static_params()->dt_in0; } + dnnl_data_type_t get_dt_in1() const { return get_static_params()->dt_in1; } + + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return get_static_params()->isa; } + float get_beta() const { return m_beta; } + + dnnl_dim_t get_M() const { return m_M; } + dnnl_dim_t get_N() const { return m_N; } + dnnl_dim_t get_K() const { return m_K; } + + dnnl_dim_t get_LDA() const { return m_LDA; } + dnnl_dim_t get_LDB() const { return m_LDB; } + dnnl_dim_t get_LDC() const { return m_LDC; } + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override; +#endif + +protected: + struct StaticBaseParams { + StaticBaseParams(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, size_t hash_seed); + virtual ~StaticBaseParams() = default; + + const dnnl_data_type_t dt_in0 {dnnl_f32}, dt_in1 {dnnl_f32}; + const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; + + size_t hash() const { return m_hash; } + + bool operator==(const StaticBaseParams& rhs) const; + bool operator!=(const StaticBaseParams& rhs) const { return !(*this == rhs); } +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + protected: + static size_t compute_hash(size_t hash_seed, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, dnnl::impl::cpu::x64::cpu_isa_t isa); + + const size_t m_hash {0}; + }; + + virtual std::shared_ptr get_static_params() const = 0; + size_t compute_hash() const; + + dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0}; + float m_beta {0}; + size_t m_hash {SIZE_MAX}; +}; + +class BrgemmBaseKernelExecutor { +public: + virtual ~BrgemmBaseKernelExecutor() = default; +protected: + static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, + const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info); + + static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig& config); + + static void create_brgemm_kernel(std::shared_ptr& kernel, dnnl_data_type_t dt0, dnnl_data_type_t dt1, + dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, + dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta, bool with_amx = false, char* palette = nullptr); + + static void execute_brgemm_kernel(const std::shared_ptr& kernel, const void* src, const void* wei, + void* dst, void* scratch, bool with_comp); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp index 78563bc00aa228..269212edf1ab9b 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp @@ -11,6 +11,9 @@ #include "jit_brgemm_copy_b_emitter.hpp" #include "jit_kernel_emitter.hpp" #include "jit_snippets_emitters.hpp" +#include "kernel_executors/brgemm.hpp" +#include "kernel_executors/brgemm_amx.hpp" + #ifndef _WIN32 #include @@ -86,9 +89,12 @@ static std::string init_info_jit_store_memory_emitter(const jit_store_memory_emi std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) { std::stringstream ss; - ss << "Emitter_type_name:jit_brgemm_emitter" - << emitter->m_kernel_executor->to_string() - << " m_memory_offset:" << vector_to_string(emitter->m_memory_offsets) + ss << "Emitter_type_name:jit_brgemm_emitter"; + if (const auto& common = std::dynamic_pointer_cast(emitter->m_kernel_executor)) + ss << common->to_string(); + if (const auto& amx = std::dynamic_pointer_cast(emitter->m_kernel_executor)) + ss << amx->to_string(); + ss << " m_memory_offset:" << vector_to_string(emitter->m_memory_offsets) << " m_buffer_ids:" << vector_to_string(emitter->m_buffer_ids); return ss.str(); diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index a23835d398cbe7..d5579fea23b6b1 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -32,7 +32,7 @@ #include "emitters/snippets/x64/cpu_generator.hpp" #include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp" #include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp" -#include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp" +#include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp" #include "transformations/snippets/x64/pass/remove_converts.hpp" #include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp" #include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp" @@ -694,7 +694,7 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const { SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::InsertLoops, ov::intel_cpu::pass::FuseLoadStoreConvert); SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, ov::snippets::lowered::pass::InsertBuffers, - ov::intel_cpu::pass::InsertBrgemmCopyBBuffers); + ov::intel_cpu::pass::InsertBrgemmCopyBuffers); #ifdef SNIPPETS_LIBXSMM_TPP SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, ov::intel_cpu::pass::BrgemmCPUBlocking, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index 6a4fc83d409355..9088ced9c18649 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -44,18 +44,17 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) { #undef SUPPORT } -BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimension& K_dim, bool transpose_b) { +BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transpose_b) { if (element_type_a == element::f32) return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE; OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16), "BF16 precision is not supported on this hardware"); - const auto brgemmVNNIFactor = 4 / element_type_a.size(); if (one_of(element_type_a, element::u8, element::i8, element::bf16) && - dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) && - K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0) + dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) return BRGEMM_TYPE::WITH_AMX; + // Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the // backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp if (element_type_a == ov::element::i8) @@ -87,6 +86,10 @@ size_t compute_inner_n_block(const ov::element::Type& precision) { } } +size_t compute_inner_k_block(const ov::element::Type& precision) { + return brgemm_utils::get_elems_in_vec(precision); +} + ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) { OPENVINO_ASSERT(ov::is_type(brgemm_expr->get_node()), "get_copy_b_expr must be called only for BrgemmCPU node"); const auto b_input_expr = brgemm_expr->get_input_port_connector(1)->get_source().get_expr(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index 0d8e3f5fb6fc9b..672b67888eef9b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -23,7 +23,7 @@ enum class BRGEMM_TYPE { dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx); -BRGEMM_TYPE get_brgemm_type(const element::Type& element_type_a, const Dimension& K_dim, bool transpose_b); +BRGEMM_TYPE get_brgemm_type(const element::Type& element_type_a, bool transpose_b); inline bool stand_alone(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::STAND_ALONE; } @@ -45,6 +45,8 @@ size_t get_elems_in_vec(const ov::element::Type& precision); namespace repacking { /// \brief Computes inner N block size used by OneDNN implementation. Depends on tensor precision size_t compute_inner_n_block(const ov::element::Type& precision); +/// \brief Computes inner K block size used by OneDNN implementation. Depends on tensor precision +size_t compute_inner_k_block(const ov::element::Type& precision); /** * @brief Computes leading dimension (LDB) which must be used in brgemm and brgemm_copy_b emitters * @param n_block N block size shared between BrgemmCPU and BrgemmCopyB node diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index abb6147bac3588..50182765856777 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -60,15 +60,13 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input(0)); const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input(1)); - const auto K = *dimsMatMulIn0.rbegin(); - const auto& layout_a = brgemm_in0_desc->get_layout(); const auto& layout_b = brgemm_in1_desc->get_layout(); const auto& layout_c = brgemm_out_desc->get_layout(); const auto element_type_a = brgemm->get_input_element_type(0); const bool transpose_b = !layout_b.empty() && layout_b.back() != layout_b.size() - 1; - const auto brgemm_type = brgemm_utils::get_brgemm_type(element_type_a, K, transpose_b); + const auto brgemm_type = brgemm_utils::get_brgemm_type(element_type_a, transpose_b); const auto offset_a = brgemm->get_offset_a(); const auto offset_b = brgemm->get_offset_b(); const auto offset_c = brgemm->get_offset_c(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.cpp deleted file mode 100644 index bd8dd12bd39256..00000000000000 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (C) 2018-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "insert_brgemm_copy_b_buffers.hpp" - -#include "snippets/lowered/loop_manager.hpp" -#include "snippets/itt.hpp" - -#include "transformations/snippets/x64/op/brgemm_copy_b.hpp" -#include "expressions/brgemm_copy_b_buffer_expressions.hpp" - - -using namespace ov::intel_cpu::brgemm_utils::repacking; -using namespace ov::snippets::lowered; - -namespace ov { -namespace intel_cpu { -namespace pass { - -bool InsertBrgemmCopyBBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InsertBrgemmCopyBBuffers") - - const auto& factory = linear_ir.get_expr_factory(); - - auto insert_buffer = [&](const ExpressionPtr& copy_b_expr, size_t out_port, LinearIR::constExprIt insertion_pos) { - const auto& copy_b = ov::as_type_ptr(copy_b_expr->get_node()); - const auto& copy_b_out = copy_b_expr->get_output_port_connector(out_port); - const auto copy_b_consumers = copy_b_out->get_consumers(); - OPENVINO_ASSERT(copy_b_consumers.size() == 1, "BufferCopyB must have only one consumer on each out port - Brgemm"); - const auto& buffer_op = std::make_shared(copy_b->output(out_port)); - BufferExpressionPtr buffer_expr = nullptr; - if (out_port == 0) { - buffer_expr = factory->build(buffer_op, {copy_b_out}); - } else if (out_port == 1 && with_compensations(copy_b->get_type())) { - buffer_expr = factory->build(buffer_op, {copy_b_out}); - } else { - OPENVINO_THROW("BrgemmCopyB has incorrect output ports"); - } - return linear_ir.insert_expr(buffer_expr, LoopManager::get_common_outer_loops(copy_b_expr, copy_b_consumers.begin()->get_expr()), - true, insertion_pos, {copy_b_consumers}); - }; - - bool modified = false; - for (auto expr_it = begin; expr_it != end; ++expr_it) { - const auto expr = *expr_it; - if (auto copy_b = ov::as_type_ptr(expr->get_node())) { - for (size_t i = 0; i < expr->get_output_count(); ++i) { - expr_it = insert_buffer(expr, i, std::next(expr_it)); - } - modified = true; - } - } - return modified; -} - -} // namespace pass -} // namespace intel_cpu -} // namespace ov - diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp deleted file mode 100644 index a08bc507aa60da..00000000000000 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (C) 2018-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "snippets/lowered/pass/pass.hpp" - -namespace ov { -namespace intel_cpu { -namespace pass { - -/** - * @interface InsertBrgemmCopyBBuffers - * @brief Insert Buffers after BrgemmCopyB with algorithm of allocation size calculation which - * distinguishes with common algorithm - * @ingroup snippets - */ -class InsertBrgemmCopyBBuffers: public snippets::lowered::pass::RangedPass { -public: - InsertBrgemmCopyBBuffers() = default; - OPENVINO_RTTI("InsertBrgemmCopyBBuffers", "Pass"); - bool run(snippets::lowered::LinearIR& linear_ir, snippets::lowered::LinearIR::constExprIt begin, snippets::lowered::LinearIR::constExprIt end) override; -}; - -} // namespace pass -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp new file mode 100644 index 00000000000000..14134b1cd0980f --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "insert_brgemm_copy_buffers.hpp" + +#include "snippets/lowered/loop_manager.hpp" +#include "snippets/itt.hpp" + +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_copy_b.hpp" +#include "expressions/brgemm_copy_b_buffer_expressions.hpp" + + +using namespace ov::intel_cpu::brgemm_utils::repacking; +using namespace ov::snippets::lowered; + +namespace ov { +namespace intel_cpu { +namespace pass { + +bool InsertBrgemmCopyBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InsertBrgemmCopyBuffers") + + const auto& factory = linear_ir.get_expr_factory(); + + auto insert_copy_b_buffer = [&](const ExpressionPtr& copy_b_expr, size_t out_port, LinearIR::constExprIt insertion_pos) { + const auto& copy_b = ov::as_type_ptr(copy_b_expr->get_node()); + const auto& copy_b_out = copy_b_expr->get_output_port_connector(out_port); + const auto copy_b_consumers = copy_b_out->get_consumers(); + OPENVINO_ASSERT(copy_b_consumers.size() == 1, "BufferCopyB must have only one consumer on each out port - Brgemm"); + const auto& buffer_op = std::make_shared(copy_b->output(out_port)); + BufferExpressionPtr buffer_expr = nullptr; + if (out_port == 0) { + buffer_expr = factory->build(buffer_op, {copy_b_out}); + } else if (out_port == 1 && with_compensations(copy_b->get_type())) { + buffer_expr = factory->build(buffer_op, {copy_b_out}); + } else { + OPENVINO_THROW("BrgemmCopyB has incorrect output ports"); + } + return linear_ir.insert_expr(buffer_expr, LoopManager::get_common_outer_loops(copy_b_expr, copy_b_consumers.begin()->get_expr()), + true, insertion_pos, {copy_b_consumers}); + }; + + auto update_scratchpad = [](const ExpressionPtr& brgemm_expr, const BufferExpressionPtr& scratch_expr) { + OPENVINO_ASSERT(scratch_expr && scratch_expr->is_independent_memory(), "Incorrect Scratchpad buffer for Brgemm AMX"); + const auto src_dt = brgemm_expr->get_node()->get_input_element_type(0); + const auto in_subtensor = ov::snippets::utils::get_projected_subtensor(brgemm_expr->get_input_port(0)); + const auto shape0 = ov::snippets::utils::get_planar_vdims(brgemm_expr->get_input_port(0)); + const auto K_dim = shape0.back(); + const auto M_blk = *++in_subtensor.rbegin(); + OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(M_blk), "M blk cannot be dynamic!"); + + const auto vnni_factor = brgemm_utils::compute_vnni_factor(src_dt); + const auto inner_k_blk = brgemm_utils::repacking::compute_inner_k_block(src_dt); + const auto tile_scratch_size = BrgemmCPU::SCRATCH_BYTE_SIZE; + const auto current_scratch_size = scratch_expr->get_byte_size(); + OPENVINO_ASSERT(current_scratch_size == tile_scratch_size, + "Tile scratchpad for BrgemmAMX should have byte size ", tile_scratch_size); + size_t inner_k_size = 0; + if (ov::snippets::utils::is_dynamic_value(K_dim)) { + // In dynamic case we don't know exactly if we need repacking of MatMul first input. + // Because of that, we allocate maximum possible size for repacked data in compilation stage. + inner_k_size = inner_k_blk; + } else { + // In static case, we allocate buffer for repacked data only if we have to repack MatMul first input: + // only if `K_dim % inner_k_blk > 0` + const auto inner_k_tail = K_dim % inner_k_blk; + inner_k_size = inner_k_tail % vnni_factor > 0 ? ov::snippets::utils::rnd_up(inner_k_tail, vnni_factor) : 0; + } + const auto repacked_in0_size = M_blk * inner_k_size * src_dt.size(); + scratch_expr->set_allocation_size(tile_scratch_size + repacked_in0_size); + }; + + bool modified = false; + for (auto expr_it = begin; expr_it != end; ++expr_it) { + const auto brgemm_expr = *expr_it; + if (const auto brgemm_cpu = ov::as_type_ptr(brgemm_expr->get_node())) { + if (brgemm_utils::with_repacking(brgemm_cpu->get_type())) { + // BrgemmCopyB might be extracted from the body + if (const auto copy_b_expr = brgemm_utils::repacking::get_copy_b_expr(brgemm_expr)) { + auto insertion_it = std::next(linear_ir.find_before(expr_it, copy_b_expr)); + for (size_t i = 0; i < copy_b_expr->get_output_count(); ++i) { + insertion_it = std::next(insert_copy_b_buffer(copy_b_expr, i, insertion_it)); + } + modified = true; + } + } + + if (brgemm_utils::with_amx(brgemm_cpu->get_type())) { + const auto& scratch_expr = + ov::as_type_ptr(brgemm_expr->get_input_port_connector(2)->get_source().get_expr()); + update_scratchpad(brgemm_expr, scratch_expr); + modified = true; + } + } + } + return modified; +} + +} // namespace pass +} // namespace intel_cpu +} // namespace ov + diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp new file mode 100644 index 00000000000000..feca42ca3b8496 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/lowered/pass/pass.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface InsertBrgemmCopyBuffers + * @brief Insert Brgemm-specific buffers: + * - after BrgemmCopyB with algorithm of allocation size calculation which distinguishes with common algorithm + * - update size of `NewMemory` Buffer - add allocation byte size for repacked data from first input of Brgemm in AMX scenario + * @ingroup snippets + */ +class InsertBrgemmCopyBuffers: public snippets::lowered::pass::RangedPass { +public: + InsertBrgemmCopyBuffers() = default; + OPENVINO_RTTI("InsertBrgemmCopyBuffers", "0", snippets::lowered::pass::RangedPass); + bool run(snippets::lowered::LinearIR& linear_ir, snippets::lowered::LinearIR::constExprIt begin, snippets::lowered::LinearIR::constExprIt end) override; +}; + +} // namespace pass +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 764133d52a7fdd..b675a7c2da7d42 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -563,18 +563,6 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_Snippets_MHA.*EnforceBF16.*)"); retVector.emplace_back(R"(.*ConcatSDPTest.*bf16.*)"); } - // [150842] Need to support dynamic K dimension of BF16|INT8 MatMul on AMX systems - if (ov::with_cpu_x86_avx512_core_amx()) { - retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[2.2.70.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); - retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); - retVector.emplace_back(R"(.*smoke_Snippets_MatMulTransposeB.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); - retVector.emplace_back(R"(.*smoke_Snippets_MatMulBias.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); - - retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16_3D.*IS\[1\]=\[2.64.\?\].*)"); - retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[(\?|1).(\?|4).(\?|12).(\?|64)\].*)"); - retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[\?.\?.\?\].*)"); - retVector.emplace_back(R"(.*smoke_Snippets_(MHAINT8MatMul|MHAQuantMatMul0|MHAFQAfterMatMul_4D|smoke_Snippets_MHAFQ).*IS\[0\]=\[\?.\?.\?\.\?].*)"); - } #ifdef SNIPPETS_LIBXSMM_TPP // GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234) retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)"); diff --git a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/buffer_allocation.cpp b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/buffer_allocation.cpp index e31a8bebb95758..9ace85b3038afa 100644 --- a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/buffer_allocation.cpp +++ b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/buffer_allocation.cpp @@ -3,6 +3,7 @@ // #include "openvino/opsets/opset.hpp" +#include "openvino/runtime/system_conf.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/pass/mark_loops.hpp" @@ -17,7 +18,7 @@ #include "transformations/snippets/x64/shape_inference.hpp" #include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp" -#include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_b_buffers.hpp" +#include "transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" @@ -90,7 +91,7 @@ class BufferAllocationCPUTest : public testing::TestWithParam(m_vector_size); pipeline.register_pass(); pipeline.register_pass(); - pipeline.register_pass(); + pipeline.register_pass(); pipeline.register_pass(); pipeline.register_pass(m_vector_size); pipeline.register_pass(); @@ -255,6 +256,11 @@ TEST_P(MHAFP32BufferAllocationTest, BufferAllocationCPU) { } TEST_P(MHABF16AMXBufferAllocationTest, BufferAllocationCPU) { + // Scratchpad memory for AMX with CopyA (dynamic case) has allocation size which depends on element count in vector register. + // So the current `expected_allocation_size` in the test is targeted on real AVX512 platforms with vector registers with 512 bits. + // If the test infrastructure has AVX2, the allocation size will not be matched. + if (!with_cpu_x86_avx512_core()) + GTEST_SKIP(); Validate(); } @@ -363,7 +369,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BufferAllocation_MHABF16AMXOptimizedWSpl ::testing::Values(dynamic_shapes), ::testing::Values(true), ::testing::Values(true), - ::testing::Values(32768), // only WSP buffers + ::testing::Values(34816), // only WSP buffers ::testing::Values(3), ::testing::Values(7)), BufferAllocationCPUTest::getTestCaseName); diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp b/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp index 767128a5be2950..c4c27161b89fe4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp @@ -186,8 +186,15 @@ struct gemm_onednn : typed_primitive_onednn_impl { if (ret) { tag = convert_data_format(transposed_format); dnnl::memory::dims original_dims = dims; - for (size_t i = 0; i < original_dims.size(); ++i) { - dims[i] = original_dims[order[i]]; + if (is_input) { + for (size_t i = 0; i < original_dims.size(); ++i) { + dims[i] = original_dims[order[i]]; + } + } else { + // Get non-transposed dims for output dims + for (size_t i = 0; i < original_dims.size(); ++i) { + dims[order[i]] = original_dims[i]; + } } } else { std::ostringstream ostream; diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp index b55c9e00bdab64..cc28dbab3660b9 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/transpose_matmul_fusion.cpp @@ -96,3 +96,95 @@ TEST_P(TransposeMatMulFusionOnGPU, CompareWithRefs){ }; } // namespace + + +//================================================================================= +// Transpose + MatMul + Transpose pattern fusion (TransposeMatMulTransposeMatcher) +//================================================================================= +namespace ov { +namespace test { + +using MatMulTransposeFusionParams = std::tuple; // input C shapes +class MatMulTransposeFusionOnGPU: public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + ov::PartialShape input0; + ov::PartialShape input1; + ov::PartialShape input2; + + std::tie(input0, input1, input2) = obj.param; + + std::ostringstream result; + result << "device=(" << std::string(utils::DEVICE_GPU) << ")_"; + result << ov::test::utils::partialShape2str({input0}) << "_"; + result << ov::test::utils::partialShape2str({input1}) << "_"; + result << ov::test::utils::partialShape2str({input2}) << "_"; + return result.str(); + } +protected: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_GPU; + + ov::PartialShape shape1; + ov::PartialShape shape2; + ov::PartialShape shape3; + + std::tie(shape1, shape2, shape3) = GetParam(); + + InputShape input_shape1 = {shape1, {shape1.get_shape()}}; + InputShape input_shape2 = {shape2, {shape2.get_shape()}}; + InputShape input_shape3 = {shape3, {shape3.get_shape()}}; + init_input_shapes({input_shape1, input_shape2, input_shape3}); + + const auto param1 = std::make_shared(ov::element::f16, shape1); + const auto param2 = std::make_shared(ov::element::f16, shape2); + const auto param3 = std::make_shared(ov::element::f16, shape3); + + auto input2_shape = shape2.get_shape(); + + //input0 + const auto input0_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {1, 0, 2, 3}); + const auto input0_transpose = std::make_shared(param1, input0_order); + const auto input0_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape); + const auto input0_reshape = std::make_shared(input0_transpose, input0_shape_pattern, false); + + //input1 + const auto input1_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {0, 1, 3, 2}); + const auto input1_transpose = std::make_shared(param2, input1_order); + + // matmul & softmax + const auto matmul1 = std::make_shared(input0_reshape, input1_transpose, false, false); + const auto softmax = std::make_shared(matmul1, -1); + + // input3 + const auto input3_transpose = std::make_shared(param3, input0_order); + const auto input3_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape); + const auto input3_reshape = std::make_shared(input3_transpose, input3_shape_pattern, false); + + // target matmul + const auto matmul2 = std::make_shared(softmax, input3_reshape, false, false); + const auto order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {2, 0, 1, 3}); + const auto transpose = std::make_shared(matmul2, order); + + function = std::make_shared(transpose, ov::ParameterVector{param1, param2, param3}); + } +}; + + +} // namespace test +} // namespace ov + + +namespace { +INSTANTIATE_TEST_SUITE_P(smoke_MatMulTransposeFusion, MatMulTransposeFusionOnGPU, + ::testing::Values( + MatMulTransposeFusionParams({3, 8, 16, 1}, {2, 4, 3, 16}, {3, 8, 16, 1})), + MatMulTransposeFusionOnGPU::getTestCaseName); + +TEST_P(MatMulTransposeFusionOnGPU, CompareWithRefs){ + run(); +}; +} // namespace diff --git a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py index 4de252e40442c4..5814c1a5427b9b 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py @@ -54,17 +54,13 @@ def create_add_placeholder_const_net(self, x_shape, y_shape, op_type): 'Maximum': tf.raw_ops.Maximum, 'Minimum': tf.raw_ops.Minimum, 'Mod': tf.raw_ops.Mod, - 'LogicalAnd': tf.raw_ops.LogicalAnd, - 'LogicalOr': tf.raw_ops.LogicalOr, 'FloorMod': tf.raw_ops.FloorMod, 'FloorDiv': tf.raw_ops.FloorDiv, 'Xdivy': tf.raw_ops.Xdivy, } input_type = np.float32 - if op_type in ["LogicalAnd", "LogicalOr", "LogicalXor"]: - input_type = bool - elif op_type in ['Pow']: + if op_type in ['Pow']: input_type = np.int32 self.input_type = input_type @@ -89,8 +85,7 @@ def create_add_placeholder_const_net(self, x_shape, y_shape, op_type): @pytest.mark.parametrize('y_shape', [[4], [2, 3, 4]]) @pytest.mark.parametrize("op_type", ['Add', 'AddV2', 'Sub', 'Mul', 'Div', 'RealDiv', 'SquaredDifference', 'Pow', - 'Maximum', 'Minimum', 'Mod', 'LogicalAnd', 'LogicalOr', 'FloorMod', - 'FloorDiv', 'Xdivy']) + 'Maximum', 'Minimum', 'Mod', 'FloorMod', 'FloorDiv', 'Xdivy']) @pytest.mark.nightly @pytest.mark.precommit @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', diff --git a/tests/layer_tests/tensorflow_tests/test_tf_LogicalBinaryOps.py b/tests/layer_tests/tensorflow_tests/test_tf_LogicalBinaryOps.py new file mode 100644 index 00000000000000..e89dc96fedc7c6 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_LogicalBinaryOps.py @@ -0,0 +1,54 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + +rng = np.random.default_rng(23345) + + +class TestLogicalBinaryOps(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x:0' in inputs_info, "Test error: inputs_info must contain `x`" + assert 'y:0' in inputs_info, "Test error: inputs_info must contain `y`" + x_shape = inputs_info['x:0'] + y_shape = inputs_info['y:0'] + + inputs_data = {} + inputs_data['x:0'] = rng.choice([True, False], x_shape).astype(bool) + inputs_data['y:0'] = rng.choice([True, False], y_shape).astype(bool) + return inputs_data + + def create_logical_binary_ops_net(self, x_shape, y_shape, op_type): + op_type_map = { + 'LogicalAnd': tf.raw_ops.LogicalAnd, + 'LogicalOr': tf.raw_ops.LogicalOr, + } + + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(bool, x_shape, 'x') + y = tf.compat.v1.placeholder(bool, y_shape, 'y') + op_type_map[op_type](x=x, y=y, name=op_type) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + @pytest.mark.parametrize('x_shape', [[], [4], [3, 4], [2, 3, 4]]) + @pytest.mark.parametrize('y_shape', [[2, 3, 4]]) + @pytest.mark.parametrize("op_type", ['LogicalAnd', 'LogicalOr']) + @pytest.mark.nightly + @pytest.mark.precommit + def test_logical_binary_op(self, x_shape, y_shape, op_type, + ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): + self._test(*self.create_logical_binary_ops_net(x_shape=x_shape, y_shape=y_shape, op_type=op_type), + ie_device, precision, ir_version, + temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) diff --git a/tests/model_hub_tests/pytorch/detectron2_precommit b/tests/model_hub_tests/pytorch/detectron2_precommit index 155e4d2a359779..f98e44ad21871f 100644 --- a/tests/model_hub_tests/pytorch/detectron2_precommit +++ b/tests/model_hub_tests/pytorch/detectron2_precommit @@ -1,13 +1,8 @@ -COCO-Detection/faster_rcnn_R_50_C4_1x,none -COCO-Detection/faster_rcnn_R_50_DC5_3x,none COCO-Detection/faster_rcnn_R_50_FPN_1x,none COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x,none COCO-Detection/retinanet_R_50_FPN_1x,none COCO-Detection/rpn_R_50_C4_1x,none -COCO-Detection/rpn_R_50_FPN_1x,none COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none -COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x,none -COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x,none COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x,none COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x,none COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x,none @@ -19,8 +14,6 @@ LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x,none Misc/cascade_mask_rcnn_R_50_FPN_3x,none Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv,none -Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5,none -Misc/mask_rcnn_R_50_FPN_3x_gn,none Misc/mask_rcnn_R_50_FPN_3x_syncbn,none Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn,none PascalVOC-Detection/faster_rcnn_R_50_C4,none diff --git a/tests/model_hub_tests/pytorch/test_llm.py b/tests/model_hub_tests/pytorch/test_llm.py index e444f93db9d7ec..ba48634a070e39 100644 --- a/tests/model_hub_tests/pytorch/test_llm.py +++ b/tests/model_hub_tests/pytorch/test_llm.py @@ -15,10 +15,10 @@ from torch_utils import TestTorchConvertModel -def is_gptq_model(config): +def is_quantized_model(config): config_dict = config.to_dict() if not isinstance(config, dict) else config quantization_config = config_dict.get("quantization_config", None) - return quantization_config and quantization_config["quant_method"] == "gptq" + return quantization_config and quantization_config["quant_method"] in ["gptq", "awq"] def patch_gptq(): @@ -26,35 +26,83 @@ def patch_gptq(): orig_cuda_is_bf16_supported = torch.cuda.is_bf16_supported orig_cuda_get_device_capability = torch.cuda.get_device_capability orig_post_init_model = None + orig_gemm_forward = None torch.set_default_dtype(torch.float32) torch.cuda.is_available = lambda: True torch.cuda.is_bf16_supported = lambda: False torch.cuda.get_device_capability = lambda n: (9, 1) - from optimum.gptq import GPTQQuantizer + try: + from optimum.gptq import GPTQQuantizer - orig_post_init_model = GPTQQuantizer.post_init_model + orig_post_init_model = GPTQQuantizer.post_init_model - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length - class StoreAttr(object): - pass + class StoreAttr(object): + pass - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + GPTQQuantizer.post_init_model = post_init_model + except ImportError: + pass + + try: + # patch GEMM module to work without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) - GPTQQuantizer.post_init_model = post_init_model - return (orig_cuda_is_available, orig_cuda_is_bf16_supported, orig_cuda_get_device_capability), orig_post_init_model + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + out = out + bias if bias is not None else out + out = out.reshape(out_shape) -def unpatch_gptq(orig_cuda_check, orig_post_init_model): - from optimum.gptq import GPTQQuantizer + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + except ImportError: + pass + return (orig_cuda_is_available, orig_cuda_is_bf16_supported, orig_cuda_get_device_capability), orig_post_init_model, orig_gemm_forward + + +def unpatch_gptq(orig_cuda_check, orig_post_init_model, orig_gemm_forward): torch.cuda.is_available, torch.cuda.is_bf16_supported, torch.cuda.get_device_capability = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model + try: + from optimum.gptq import GPTQQuantizer + GPTQQuantizer.post_init_model = orig_post_init_model + except ImportError: + pass + try: + from awq.modules.linear.gemm import WQLinearMMFunction + WQLinearMMFunction.forward = orig_gemm_forward + except ImportError: + pass def to_numpy(t): @@ -88,7 +136,7 @@ def flattenize_outputs(outputs): class TestLLMModel(TestTorchConvertModel): def setup_class(self): self.infer_timeout = 1800 - self.cuda_available, self.gptq_postinit = None, None + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = None, None, None @retry(3, exceptions=(OSError,), delay=1) def load_model(self, name, type): @@ -99,11 +147,12 @@ def load_model(self, name, type): except Exception: config = {} model_kwargs = {"torchscript": True, "trust_remote_code": True} - is_gptq = is_gptq_model(config) + is_quant = is_quantized_model(config) is_gpt2 = name == "openai-community/gpt2" - if is_gptq: - self.cuda_available, self.gptq_postinit = patch_gptq() + if is_quant: + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = patch_gptq() + model_kwargs["torch_dtype"] = "auto" model_kwargs["torch_dtype"] = torch.float32 self.ov_config = {"DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"} elif is_gpt2: @@ -113,7 +162,7 @@ def load_model(self, name, type): t = AutoTokenizer.from_pretrained(name, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) - if is_gptq: + if is_quant: model = self.model else: assert self.model.config.torch_dtype in [ @@ -175,8 +224,8 @@ def convert_model_impl(self, model_obj): def teardown_method(self): # restore after gptq patching if self.cuda_available is not None: - unpatch_gptq(self.cuda_available, self.gptq_postinit) - self.cuda_available, self.gptq_postinit = None, None + unpatch_gptq(self.cuda_available, self.gptq_postinit, self.orig_gemm_forward) + self.cuda_available, self.gptq_postinit, self.orig_gemm_forward = None, None, None super().teardown_method() @staticmethod @@ -191,7 +240,8 @@ def get_pkv(model, tokenizer): @pytest.mark.parametrize("type,name", [ ("opt_gptq", "katuni4ka/opt-125m-gptq"), ("llama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"), - ("gpt2", "openai-community/gpt2") + ("gpt2", "openai-community/gpt2"), + ("llama_awq", "casperhansen/tinyllama-1b-awq") ]) @pytest.mark.precommit @pytest.mark.nightly @@ -210,6 +260,7 @@ def test_convert_model_precommit(self, name, type, ie_device): ("bloom_gptq", "sbolouki/bloom-1b7-gptq"), ("cohere_gptq", "shuyuej/aya-23-8B-GPTQ"), ("mbart_gptq", "Shivam098/opt-translation"), + ("llama_awq", "TheBloke/open-llama-3b-v2-wizard-evol-instuct-v2-196k-AWQ") ]) @pytest.mark.nightly def test_convert_model_nightly(self, name, type, ie_device): @@ -236,6 +287,8 @@ def test_convert_model_nightly(self, name, type, ie_device): marks=pytest.mark.xfail(reason="GPTQ QUANT_TYPE=cuda is not supported")), pytest.param("llama3_gptq", "TechxGenus/Meta-Llama-3-8B-GPTQ", marks=pytest.mark.xfail(reason="GPTQ QUANT_TYPE=cuda is not supported")), + ("qwen2_awq", "Qwen/Qwen2.5-Coder-32B-Instruct-AWQ"), + ("mixstral_awq", "TheBloke/SauerkrautLM-Mixtral-8x7B-AWQ"), ]) def test_convert_model_very_large(self, name, type, ie_device): self.run(model_name=name, model_link=type, ie_device=ie_device) diff --git a/tests/requirements_pytorch b/tests/requirements_pytorch index 56446beba12600..be304155e2afc0 100644 --- a/tests/requirements_pytorch +++ b/tests/requirements_pytorch @@ -19,6 +19,7 @@ pytest-html==4.1.1 pytest-xdist[psutil]==3.6.1 defusedxml==0.7.1 +autoawq==0.2.7; platform_system == "Linux" and platform_machine == "x86_64" auto-gptq==0.7.1; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" av==13.0.0 basicsr==1.4.2; python_version < "3.12"