From 7b0504c46285b3b0319555a9d78b80cba86b364a Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 25 Jul 2024 08:38:10 +0400 Subject: [PATCH] [Snippets] Moved UpdateBrgemms to KernelExecutor::update_config --- .../snippets/kernel_executor_table.hpp | 12 ++-- .../include/snippets/runtime_configurator.hpp | 16 ++--- src/common/snippets/src/op/subgraph.cpp | 4 +- .../snippets/src/runtime_configurator.cpp | 48 +++++++------ .../snippets/cpu_runtime_configurator.cpp | 70 +++---------------- .../snippets/cpu_runtime_configurator.hpp | 20 ++---- .../snippets/x64/kernel_executors/brgemm.cpp | 64 ++++++++++------- .../snippets/x64/kernel_executors/brgemm.hpp | 4 +- 8 files changed, 97 insertions(+), 141 deletions(-) diff --git a/src/common/snippets/include/snippets/kernel_executor_table.hpp b/src/common/snippets/include/snippets/kernel_executor_table.hpp index 46f9cd04b923ba..cf6a6b2174c5c1 100644 --- a/src/common/snippets/include/snippets/kernel_executor_table.hpp +++ b/src/common/snippets/include/snippets/kernel_executor_table.hpp @@ -43,7 +43,7 @@ class KernelExecutorBase { * @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary. * This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr */ - virtual void update_by_expression(const lowered::ExpressionPtr& expr) = 0; + virtual void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIR& linear_ir) = 0; /** * @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary. * This method should be called to restore a saved state of the executor, that was configured using update_by_expression(). @@ -70,8 +70,8 @@ class KernelExecutor : public KernelExecutorBase { explicit KernelExecutor(Conf c) : KernelExecutorBase(), m_config{std::move(c)} {} // Note: override when final is redundant, but needed to avoid warnings on some compilers - void update_by_expression(const lowered::ExpressionPtr& expr) override final { // NOLINT - update_config(expr, m_config); + void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIR& linear_ir) override final { // NOLINT + update_config(expr, linear_ir, m_config); OPENVINO_ASSERT(m_config.is_completed(), "Failed to update kernel config in update_by_expression"); update_kernel(m_config, m_kernel); OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor"); @@ -103,7 +103,7 @@ class KernelExecutor : public KernelExecutorBase { protected: /*** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */ - virtual void update_config(const lowered::ExpressionPtr& expr, Conf& config) const = 0; + virtual void update_config(const lowered::ExpressionPtr& expr, const lowered::LinearIR& linear_ir, Conf& config) const = 0; /*** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is * performed if necessary. */ virtual void update_kernel(const Conf& c, std::shared_ptr& kernel) const = 0; @@ -130,9 +130,9 @@ class KernelExecutorTable { return m_table.at(expr); } /*** Updates every registered KernelExecutor in accordance with the corresponding expression */ - void update_state() const { + void update_state(const lowered::LinearIR& linear_ir) const { for (const auto& record : m_table) - record.second->update_by_expression(record.first); + record.second->update_by_expression(record.first, linear_ir); } /*** Returns lambda function that contains current state of the table, and restores this state when called */ diff --git a/src/common/snippets/include/snippets/runtime_configurator.hpp b/src/common/snippets/include/snippets/runtime_configurator.hpp index 059771d961df82..ed56088f16dd2a 100644 --- a/src/common/snippets/include/snippets/runtime_configurator.hpp +++ b/src/common/snippets/include/snippets/runtime_configurator.hpp @@ -61,7 +61,7 @@ class RuntimeConfigurator { * @param linear_ir LinearIR * @return updated config */ - const std::shared_ptr& get_updated_config(const std::shared_ptr& linear_ir); + const std::shared_ptr& get_updated_config(const lowered::LinearIR& linear_ir); /*** Returns pointer to KernelExecutorTable owned by the config */ const std::shared_ptr& get_kernel_executor_table() const { return m_config->kernel_executor_table; } @@ -70,19 +70,19 @@ class RuntimeConfigurator { * @brief Update RuntimeConfig based on LinearIR * @param linear_ir LinearIR */ - virtual void update(const std::shared_ptr& linear_ir); + virtual void update(const lowered::LinearIR& linear_ir); /** * @brief Allocate and intialize fields in RuntimeConfig and RuntimeConfigurator * @param linear_ir LinearIR */ - virtual void initialization(const std::shared_ptr& linear_ir); + virtual void initialization(const lowered::LinearIR& linear_ir); /** * @brief Initializes input and data information of LinearIR: * descriptors (that contains shapes and layouts) and data_sizes * @param linear_ir LinearIR */ - void init_data_info(const std::shared_ptr& linear_ir); + void init_data_info(const lowered::LinearIR& linear_ir); /** * @brief Initializes information of buffers: * - static buffer_scratchpad_size @@ -90,23 +90,23 @@ class RuntimeConfigurator { * - clusters with dynamic buffers (`m_dynamic_buffer_clusters`) for the quick access in `update()` * @param linear_ir LinearIR */ - void init_buffer_info(const std::shared_ptr& linear_ir); + void init_buffer_info(const lowered::LinearIR& linear_ir); /** * @brief Initializes tensor rank of config * @param linear_ir LinearIR */ - virtual void init_tensor_rank(const std::shared_ptr& linear_ir) const; + virtual void init_tensor_rank(const lowered::LinearIR& linear_ir) const; /** * @brief Update Loop informations in LinearIR: Unified and ExpandedLoopInfo * @param linear_ir LinearIR */ - void update_loop_info(const std::shared_ptr& linear_ir) const; + void update_loop_info(const lowered::LinearIR& linear_ir) const; /** * @brief Update Buffer scratchpad size and offsets if needed * Note: `update_loop_info` must be called before * @param linear_ir LinearIR */ - void update_buffer_scratchpad_size(const std::shared_ptr& linear_ir) const; + void update_buffer_scratchpad_size(const lowered::LinearIR& linear_ir) const; /** * @brief Calculate data offsets of LinearIR and update these values in RuntimeConfig */ diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index a33d478ee3929d..598ca30b7cd077 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -552,14 +552,14 @@ snippets::Schedule Subgraph::generate(const void* compile_params) const { exec_table->replace_key_expression(expression_map.at(expr.get()), expr); // Some kernel executors might've been registered during code emission. // We need to update them, so appropriate kernels will be compiled. - exec_table->update_state(); + exec_table->update_state(*m_linear_ir); return {std::move(lowering_result)}; } const std::shared_ptr& Subgraph::update_runtime_config() const { OPENVINO_ASSERT(m_generator, "Generator has not been inited!"); OPENVINO_ASSERT(m_linear_ir, "LoweredLinearIR has not been inited!"); - return m_generator->get_target_machine()->get_runtime_configurator()->get_updated_config(m_linear_ir); + return m_generator->get_target_machine()->get_runtime_configurator()->get_updated_config(*m_linear_ir); } void Subgraph::print() const { diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 8a1eb1bfa65f78..7c4c7bf1c7bd18 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -35,7 +35,7 @@ RuntimeConfigurator::RuntimeConfigurator(std::shared_ptr c) : OPENVINO_ASSERT(m_config, "Runtime config is nullptr!"); } -const std::shared_ptr& RuntimeConfigurator::get_updated_config(const std::shared_ptr& linear_ir) { +const std::shared_ptr& RuntimeConfigurator::get_updated_config(const lowered::LinearIR& linear_ir) { // First initialization if (m_io_num == 0) initialization(linear_ir); @@ -44,7 +44,7 @@ const std::shared_ptr& RuntimeConfigurator::get_updated_config(co return m_config; } -void RuntimeConfigurator::initialization(const std::shared_ptr& linear_ir) { +void RuntimeConfigurator::initialization(const lowered::LinearIR& linear_ir) { init_data_info(linear_ir); init_tensor_rank(linear_ir); init_buffer_info(linear_ir); @@ -52,28 +52,28 @@ void RuntimeConfigurator::initialization(const std::shared_ptr 0, "LinearIR must have parameters and results"); m_latest_shapes.resize(m_io_num); m_config->io_data_offsets.resize(m_io_num); - m_config->tile_rank = linear_ir->get_config().m_loop_depth; + m_config->tile_rank = linear_ir.get_config().m_loop_depth; } -void RuntimeConfigurator::update(const std::shared_ptr& linear_ir) { - if (linear_ir->is_dynamic()) { +void RuntimeConfigurator::update(const lowered::LinearIR& linear_ir) { + if (linear_ir.is_dynamic()) { update_loop_info(linear_ir); update_buffer_scratchpad_size(linear_ir); } - m_config->master_shape = linear_ir->get_master_shape(); + m_config->master_shape = linear_ir.get_master_shape(); update_data_offsets(); update_latest_shapes(); } -void RuntimeConfigurator::init_tensor_rank(const std::shared_ptr& linear_ir) const { - m_config->tensor_rank = linear_ir->get_master_shape().size(); +void RuntimeConfigurator::init_tensor_rank(const lowered::LinearIR& linear_ir) const { + m_config->tensor_rank = linear_ir.get_master_shape().size(); } -void RuntimeConfigurator::init_data_info(const std::shared_ptr& linear_ir) { - const auto& parameters = linear_ir->get_parameters(); - const auto& results = linear_ir->get_results(); +void RuntimeConfigurator::init_data_info(const lowered::LinearIR& linear_ir) { + const auto& parameters = linear_ir.get_parameters(); + const auto& results = linear_ir.get_results(); m_in_num = parameters.size(); m_io_num = m_in_num + results.size(); m_io_descs.reserve(m_io_num); @@ -113,11 +113,11 @@ void RuntimeConfigurator::init_data_info(const std::shared_ptr& linear_ir) { +void RuntimeConfigurator::init_buffer_info(const lowered::LinearIR& linear_ir) { std::map> dynamic_buffer_clusters, static_buffer_clusters; // All needed checks are in Validate pass - const auto& buffer_expressions = linear_ir->get_buffers(); + const auto& buffer_expressions = linear_ir.get_buffers(); for (const auto& buffer_expr : buffer_expressions) { const auto buffer = ov::as_type_ptr(buffer_expr->get_node()); OPENVINO_ASSERT(buffer, "Expected Buffer ops in Buffer expressions of LinearIR"); @@ -128,7 +128,7 @@ void RuntimeConfigurator::init_buffer_info(const std::shared_ptrbuffer_scratchpad_size = linear_ir->get_static_buffer_scratchpad_size(); + m_config->buffer_scratchpad_size = linear_ir.get_static_buffer_scratchpad_size(); m_config->buffer_cluster_offsets.resize(cluster_count, utils::get_dynamic_value()); for (const auto& p : static_buffer_clusters) { @@ -143,7 +143,7 @@ void RuntimeConfigurator::init_buffer_info(const std::shared_ptr& linear_ir) const { +void RuntimeConfigurator::update_loop_info(const lowered::LinearIR& linear_ir) const { // Initialized UnifiedLoopInfo struct CurrentUnifiedLoopInfo { size_t current_work_amount = 0; @@ -152,7 +152,7 @@ void RuntimeConfigurator::update_loop_info(const std::shared_ptr initializated_info_map; - const auto& loop_map = linear_ir->get_loop_manager()->get_map(); + const auto& loop_map = linear_ir.get_loop_manager()->get_map(); for (const auto& p : loop_map) { const auto& expanded_loop_info = ov::as_type_ptr(p.second); OPENVINO_ASSERT(expanded_loop_info, "UpdateLoopInfo expects ExpandedLoopInfo in LoopManager"); @@ -180,17 +180,19 @@ void RuntimeConfigurator::update_loop_info(const std::shared_ptrset_work_amount( - lowered::pass::InsertSpecificIterations::get_decomposed_loop_work_amount(current_unified_loop_info, decomposed_loop_type, current_work_amount)); + const auto work_amount = + lowered::pass::InsertSpecificIterations::get_decomposed_loop_work_amount(current_unified_loop_info, decomposed_loop_type, current_work_amount); + expanded_loop_info->set_work_amount(work_amount); // Update remaining Loop work amount - current_work_amount -= expanded_loop_info->get_work_amount(); + current_work_amount -= work_amount; // Update only `finalization offsets`. `Ptr increments` are always zeroed in this case auto updated_finalization_offsets = current_work_amount > 0 ? std::vector(finalization_offsets.size(), 0) : finalization_offsets; if (expanded_loop_info->is_evaluate_once()) { + expanded_loop_info->set_increment(work_amount); // work_amount is equal to increment in cases with `evaluate_once` for (size_t i = 0; i < updated_finalization_offsets.size(); ++i) - updated_finalization_offsets[i] += ptr_increments[i] * expanded_loop_info->get_work_amount(); + updated_finalization_offsets[i] += ptr_increments[i] * work_amount; } else { expanded_loop_info->update_ptr_increments(ptr_increments); } @@ -198,9 +200,9 @@ void RuntimeConfigurator::update_loop_info(const std::shared_ptr& linear_ir) const { - const auto& loop_manager = linear_ir->get_loop_manager(); - m_config->buffer_scratchpad_size = linear_ir->get_static_buffer_scratchpad_size(); +void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIR& linear_ir) const { + const auto& loop_manager = linear_ir.get_loop_manager(); + m_config->buffer_scratchpad_size = linear_ir.get_static_buffer_scratchpad_size(); for (const auto& p : m_dynamic_buffer_clusters) { const auto& cluster_id = p.first; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index 14d21652010a5e..89fd1ffffaf571 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -15,63 +15,24 @@ namespace intel_cpu { CPURuntimeConfigurator::CPURuntimeConfigurator() : ov::snippets::RuntimeConfigurator(std::make_shared()) { } -void CPURuntimeConfigurator::update(const std::shared_ptr& linear_ir) { - if (linear_ir->is_dynamic()) { - const auto& loop_manager = linear_ir->get_loop_manager(); - update_loop_info(linear_ir); - update_loop_args(loop_manager); - // Update Brgemm should be before `update_buffer_scratchpad_size` - // because `ComputeAllocationSize` depends on subtensors which are updated in `update_brgemms` - update_brgemms(loop_manager); - update_buffer_scratchpad_size(linear_ir); - get_kernel_executor_table()->update_state(); - } - - m_config->master_shape = linear_ir->get_master_shape(); - - update_data_offsets(); - update_latest_shapes(); -} - -void CPURuntimeConfigurator::initialization(const std::shared_ptr& linear_ir) { - RuntimeConfigurator::initialization(linear_ir); - - for (const auto& expr : *linear_ir) { - if (ov::is_type(expr->get_node())) { - const auto& in0_desc = expr->get_input_port_descriptor(0); - const auto& in1_desc = expr->get_input_port_descriptor(1); - const auto& out_desc = expr->get_output_port_descriptor(0); - - const auto& in0_subtensor = in0_desc->get_subtensor(); - const auto& in1_subtensor = in1_desc->get_subtensor(); - const auto& out_subtensor = out_desc->get_subtensor(); +void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIR& linear_ir) { + RuntimeConfigurator::update(linear_ir); - // TODO [146125]: At the moment only blocking by dynamic M is supported - // So we save Brgemm with only dynamic M - // If there are other dynamic dimensions, throw exception for now - OPENVINO_ASSERT(!snippets::utils::is_dynamic_value(*in0_subtensor.crbegin()) && - !snippets::utils::is_dynamic_value(*in1_subtensor.crbegin()) && - !snippets::utils::is_dynamic_value(*(++in1_subtensor.crbegin())) && - !snippets::utils::is_dynamic_value(*out_subtensor.crbegin()), - "CPURuntimeConfigurator supports only dynamic M in Brgemm subtensors"); - OPENVINO_ASSERT(*(++in0_subtensor.crbegin()) == *(++out_subtensor.crbegin()), - "Incorrect values in subtensors of BrgemmCPU"); - - if (snippets::utils::is_dynamic_value(*(++in0_subtensor.crbegin()))) - m_dynamic_brgemms.insert(expr); - } + if (linear_ir.is_dynamic()) { + get_kernel_executor_table()->update_state(linear_ir); + update_loop_args(linear_ir); } } -void CPURuntimeConfigurator::init_tensor_rank(const std::shared_ptr& linear_ir) const { - m_config->tensor_rank = std::max(linear_ir->get_master_shape().size(), rank6D); +void CPURuntimeConfigurator::init_tensor_rank(const ov::snippets::lowered::LinearIR& linear_ir) const { + m_config->tensor_rank = std::max(linear_ir.get_master_shape().size(), rank6D); } -void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::LoopManagerPtr& loop_manager) const { +void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::LinearIR& linear_ir) const { const auto& cpu_config = ov::as_type_ptr(m_config); OPENVINO_ASSERT(cpu_config, "CPURuntimeConfigurator expects CPURuntimeConfig"); - const auto& loop_map = loop_manager->get_map(); + const auto& loop_map = linear_ir.get_loop_manager()->get_map(); cpu_config->loop_args.resize(loop_map.size()); for (const auto& loop : loop_map) { const auto& idx = loop.first; @@ -90,18 +51,5 @@ void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::LoopM } } -void CPURuntimeConfigurator::update_brgemms(const ov::snippets::lowered::LoopManagerPtr& loop_manager) const { - for (const auto& brgemm_expr : m_dynamic_brgemms) { - const auto& loop_ids = brgemm_expr->get_loop_ids(); - OPENVINO_ASSERT(!loop_ids.empty(), "Dynamic Brgemm must be in loops"); - // TODO [146125]: Loop by M is first one in `loop_ids` - const auto& expanded_loop_info = loop_manager->get_loop_info(loop_ids.front()); - const auto& block_size_m = expanded_loop_info->get_work_amount(); - - brgemm_expr->get_input_port_descriptor(0)->set_subtensor_dim(1, block_size_m); - brgemm_expr->get_output_port_descriptor(0)->set_subtensor_dim(1, block_size_m); - } -} - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp index 39ab1977f878d1..22c524ffdb3cad 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp @@ -29,31 +29,19 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { * @brief Update RuntimeConfig based on LinearIR * @param linear_ir LinearIR */ - void update(const std::shared_ptr& linear_ir) override; - /** - * @brief Allocate and intialize fields in RuntimeConfig and RuntimeConfigurator - * @param linear_ir LinearIR - */ - void initialization(const std::shared_ptr& linear_ir) override; + void update(const ov::snippets::lowered::LinearIR& linear_ir) override; /** * @brief Initializes tensor rank of config * @param linear_ir LinearIR */ - void init_tensor_rank(const std::shared_ptr& linear_ir) const override; + void init_tensor_rank(const ov::snippets::lowered::LinearIR& linear_ir) const override; /** * @brief Calculate Loop parameters of Loop emitters and update these values in CPURuntimeConfig - * @param loop_manager Loop Manager - */ - void update_loop_args(const ov::snippets::lowered::LoopManagerPtr& loop_manager) const; - /** - * @brief Update subtensors of Brgemms - * @param loop_manager Loop Manager + * @param linear_ir LinearIR */ - void update_brgemms(const ov::snippets::lowered::LoopManagerPtr& loop_manager) const; + void update_loop_args(const ov::snippets::lowered::LinearIR& linear_ir) const; const size_t rank6D = 6; - // Brgemm expressions with subtensors with dynamic values - std::unordered_set m_dynamic_brgemms = {}; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index fb15ada10c504f..0fbca49c7b09f6 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -4,6 +4,8 @@ #include "brgemm.hpp" +#include "snippets/lowered/loop_manager.hpp" + #include #include "common/utils.hpp" #include "dnnl_extension_utils.h" @@ -153,31 +155,47 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const return compiled_kernel; } -void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, BrgemmKernelConfig& config) const { - auto get_projected_input_subtensor = [](const snippets::lowered::PortDescriptorPtr& desc) { - // Note: for output shape you will need get_preordered_vdims() - auto shape = snippets::utils::get_planar_vdims(desc->get_shape(), desc->get_layout()); - auto subtensor = desc->get_subtensor(); - OV_CPU_JIT_EMITTER_ASSERT(subtensor.size() <= shape.size() && subtensor.size() == 2, - "Invalid subtensor + shape combination"); - auto shape_it = shape.rbegin(); - for (auto sub_it = subtensor.rbegin(); sub_it != subtensor.rend(); sub_it++, shape_it++) { - *sub_it = std::min(*sub_it, *shape_it); - } - return subtensor; - }; +void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIR& linear_ir, + BrgemmKernelConfig& config) const { const auto& input_pds = expr->get_input_port_descriptors(); const auto& output_pds = expr->get_output_port_descriptors(); OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, "Invalid number of in/out port descriptors"); - // Update runtime-defined config fields: - // Matrix A (first input) + + const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout()); + const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout()); + auto in0_subtensor = input_pds[0]->get_subtensor(); + auto in1_subtensor = input_pds[1]->get_subtensor(); + + auto M = *++in0_subtensor.rbegin(); + auto K = *in0_subtensor.rbegin(); + auto N = *in1_subtensor.rbegin(); + + if (ov::snippets::utils::is_full_dim_value(M)) { + M = *++in0_shape.rbegin(); + } else if (ov::snippets::utils::is_dynamic_value(M)) { + const auto& loop_ids = expr->get_loop_ids(); + OPENVINO_ASSERT(!loop_ids.empty(), "Loop by dimension M is missed"); + // TODO [146125]: Loop by M is first one in `loop_ids` + const auto& expanded_loop_info = linear_ir.get_loop_manager()->get_loop_info(loop_ids.front()); + M = expanded_loop_info->get_increment(); + } + + if (ov::snippets::utils::is_full_dim_value(K)) { + K = *in0_shape.rbegin(); + } else if (ov::snippets::utils::is_dynamic_value(K)) { + OPENVINO_THROW("Dynamic K is not supported"); + } + + if (ov::snippets::utils::is_full_dim_value(N)) { + N = *in1_shape.rbegin(); + } else if (ov::snippets::utils::is_dynamic_value(N)) { + OPENVINO_THROW("Dynamic N is not supported"); + } + const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); - const auto& in0_subtensor = get_projected_input_subtensor(input_pds[0]); - const auto K = DIM_CAST(*in0_subtensor.rbegin()); - const auto M = DIM_CAST(*++in0_subtensor.rbegin()); - // Matrix B (second input) - // Non float input 1 => with data repacking + const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); const auto& brgemm_node = as_type_ptr(expr->get_node()); @@ -187,10 +205,8 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty"); LDB = DIM_CAST(repacking_buffer_shape.back()); } - const auto N = DIM_CAST(*get_projected_input_subtensor(input_pds[1]).rbegin()); - // Matrix C (output) - const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); - config.update(M, N, K, LDA, LDB, LDC); + + config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index c87a7e93f3b3f7..9f66f40962e581 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -96,7 +96,9 @@ class BrgemmKernelExecutor : public CPUKernelExecutor compile_kernel(const BrgemmKernelConfig& c) const override; - void update_config(const ov::snippets::lowered::ExpressionPtr& expr, BrgemmKernelConfig& config) const override; + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIR& linear_ir, + BrgemmKernelConfig& config) const override; }; #define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field)