diff --git a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp index af1e2b60e1d70a..9af247cd52ecab 100644 --- a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp @@ -24,9 +24,10 @@ namespace pass { class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { public: MHAParallelWAOptimizer() = default; - MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator); + MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator); bool run(const lowered::LinearIR& linear_ir) override; + bool applicable() const override { return !m_loops_to_split.empty(); } private: static std::unordered_set find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir); diff --git a/src/common/snippets/include/snippets/lowered/pass/runtime_optimizer.hpp b/src/common/snippets/include/snippets/lowered/pass/runtime_optimizer.hpp index 99522628e23c07..ed37a1c6c58bca 100644 --- a/src/common/snippets/include/snippets/lowered/pass/runtime_optimizer.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/runtime_optimizer.hpp @@ -20,9 +20,30 @@ namespace pass { class RuntimeOptimizer : public ConstPass { public: RuntimeOptimizer() = default; - RuntimeOptimizer(RuntimeConfigurator* configurator) : m_configurator(configurator) {} + RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) { + OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr"); + } + /** + * @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped. + */ + virtual bool applicable() const = 0; + + /** + * @brief Creates an instance of the specified pass type and checks if it is applicable. + * If the pass is applicable, it is registered in the provided pipeline. + * @param pipeline The pipeline in which the pass should be registered. + * @param args The arguments to be forwarded to the pass constructor. + */ + template ::value>> + static void register_if_applicable(PassPipeline& pipeline, Args&&... args) { + auto pass = std::make_shared(std::forward(args)...); + if (pass->applicable()) { + pipeline.register_pass(pass); + } + } + protected: - RuntimeConfigurator* m_configurator = nullptr; + const RuntimeConfigurator* m_configurator = nullptr; }; } // namespace pass diff --git a/src/common/snippets/include/snippets/runtime_configurator.hpp b/src/common/snippets/include/snippets/runtime_configurator.hpp index 7edb916d8154b0..866e98843fcd50 100644 --- a/src/common/snippets/include/snippets/runtime_configurator.hpp +++ b/src/common/snippets/include/snippets/runtime_configurator.hpp @@ -133,7 +133,7 @@ class RuntimeConfigurator { * @brief Update tensor rank based on master shape * @param master_shape Master shape */ - virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape); + virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const; protected: /** diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index 7c4c3085679d6b..2f57d6422cf11d 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -19,7 +19,7 @@ using namespace ov::snippets::pass; const size_t MHAParallelWAOptimizer::m_dim_M_idx = 1; -MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator) +MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator) : lowered::pass::RuntimeOptimizer(configurator) { if (linear_ir->get_config().m_enable_domain_optimization || !linear_ir->is_dynamic()) return; @@ -47,9 +47,6 @@ MHAParallelWAOptimizer::MHAParallelWAOptimizer(const lowered::LinearIRCPtr& line bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::MHAParallelWAOptimizer") - if (m_loops_to_split.empty()) - return false; - const auto& config = m_configurator->get_config(); size_t new_batch_dim, new_kernel_dim; if (!SplitDimensionM::split(config->master_shape, m_concurrency, new_batch_dim, new_kernel_dim)) diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 5a19c61767a22c..41cfdd7d6df381 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -17,6 +17,7 @@ namespace snippets { using namespace ov::snippets::pass; using namespace ov::snippets::lowered; +using namespace ov::snippets::lowered::pass; #ifdef SNIPPETS_DEBUG_CAPS std::string RuntimeConfig::to_string() const { @@ -65,7 +66,7 @@ void RuntimeConfigurator::initialization(const lowered::LinearIRCPtr& linear_ir) m_config->tile_rank = linear_ir->get_config().m_loop_depth; if (linear_ir->is_dynamic()) - m_intermediate_optimizers.register_pass(linear_ir, this); + RuntimeOptimizer::register_if_applicable(m_intermediate_optimizers, linear_ir, this); } void RuntimeConfigurator::update(const lowered::LinearIRCPtr& linear_ir) { @@ -86,7 +87,7 @@ void RuntimeConfigurator::update(const lowered::LinearIRCPtr& linear_ir) { m_config->latest_shapes = std::move(m_config->io_shapes); } -void RuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) { +void RuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) const { m_config->tensor_rank = master_shape.size(); } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index d4be7235131ead..283b5bf621b85f 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -13,6 +13,7 @@ #endif namespace ov { namespace intel_cpu { +using namespace ov::snippets::lowered::pass; const size_t CPURuntimeConfigurator::rank6D = 6; @@ -44,8 +45,8 @@ void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearI RuntimeConfigurator::initialization(linear_ir); #ifndef OPENVINO_ARCH_ARM64 if (linear_ir->is_dynamic()) - m_intermediate_optimizers.register_pass(linear_ir, this); - m_final_optimizers.register_pass(linear_ir, this); + RuntimeOptimizer::register_if_applicable(m_intermediate_optimizers, linear_ir, this); + RuntimeOptimizer::register_if_applicable(m_final_optimizers, linear_ir, this); #endif } @@ -72,7 +73,7 @@ void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIRCPtr& l m_config->latest_shapes = std::move(m_config->io_shapes); } -void CPURuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) { +void CPURuntimeConfigurator::update_tensor_rank(const ov::snippets::VectorDims& master_shape) const { m_config->tensor_rank = std::max(master_shape.size(), rank6D); } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp index f36c3b28de1fe1..42ce35a3c66c2b 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp @@ -36,7 +36,7 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const; protected: void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; - void update_tensor_rank(const ov::snippets::VectorDims& master_shape) override; + void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override; void init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const override; void initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index aaa12d303bb232..eb58b58c545acb 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -650,7 +650,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision, ov::intel_cpu::pass::BrgemmToBrgemmCPU); SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::pass::BrgemmToBrgemmCPU, - ov::intel_cpu::pass::MoveBrgemmRepackingOut); + ov::intel_cpu::pass::EliminateBrgemmCopyB); SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineEnd, ov::intel_cpu::pass::RemoveConverts); SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineEnd, ov::intel_cpu::pass::MulAddToFMA); @@ -992,14 +992,17 @@ void Subgraph::SubgraphExecutor::parallel_forNd(const std::function& inMemPtrs, std::vector& outMemPtrs) { - if (m_in_requested_descs.empty()) +void Subgraph::SubgraphExecutor::execute(dnnl::stream strm, const std::vector& inMemPtrs, const std::vector& outMemPtrs) { + if (!m_in_requested_descs.empty()) { + auto reorderedInMemPtrs = exec_in_reorders(strm, inMemPtrs); + exec_impl(reorderedInMemPtrs, outMemPtrs); + } else { exec_impl(inMemPtrs, outMemPtrs); - else - reorder_execute(strm, inMemPtrs, outMemPtrs); + } } -void Subgraph::SubgraphExecutor::reorder_execute(dnnl::stream strm, std::vector inMemPtrs, const std::vector& outMemPtrs) { +std::vector Subgraph::SubgraphExecutor::exec_in_reorders(dnnl::stream strm, const std::vector& inMemPtrs) { + auto reordered_in_ptrs = inMemPtrs; size_t offset = m_internal_buffer_size; for (const auto& requested_descs_elem : m_in_requested_descs) { const auto in_idx = requested_descs_elem.first; @@ -1007,11 +1010,11 @@ void Subgraph::SubgraphExecutor::reorder_execute(dnnl::stream strm, std::vector< const void* data_ptr = m_buffer_scratchpad->getDataAs() + offset; const auto scratch_mem = std::make_shared(strm.get_engine(), requested_desc, data_ptr, false); - scratch_mem->load(*inMemPtrs[in_idx]); - inMemPtrs[in_idx] = scratch_mem; + scratch_mem->load(*reordered_in_ptrs[in_idx]); + reordered_in_ptrs[in_idx] = scratch_mem; offset += requested_desc->getCurrentMemSize(); } - exec_impl(inMemPtrs, outMemPtrs); + return reordered_in_ptrs; } } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.h b/src/plugins/intel_cpu/src/nodes/subgraph.h index 0cc5258f3d18e7..cf907349bda25b 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.h +++ b/src/plugins/intel_cpu/src/nodes/subgraph.h @@ -129,7 +129,7 @@ class Subgraph::SubgraphExecutor { const BufferScratchpadAllocator& allocator); virtual ~SubgraphExecutor() = default; - void execute(dnnl::stream strm, std::vector& inMemPtrs, std::vector& outMemPtrs); + void execute(dnnl::stream strm, const std::vector& inMemPtrs, const std::vector& outMemPtrs); protected: virtual void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) = 0; @@ -169,7 +169,7 @@ class Subgraph::SubgraphExecutor { #endif private: - void reorder_execute(dnnl::stream strm, std::vector inMemPtrs, const std::vector& outMemPtrs); + std::vector exec_in_reorders(dnnl::stream strm, const std::vector& inMemPtrs); std::unordered_map m_in_requested_descs = {}; }; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index 2982fd7767486f..6a4fc83d409355 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -95,7 +95,7 @@ ov::snippets::lowered::ExpressionPtr get_copy_b_expr(const ov::snippets::lowered } else if (ov::is_type(b_input_expr)) { OPENVINO_ASSERT(b_input_expr->get_input_count() >= 1, "BufferExpression on brgemm's B input must have at least one input"); const auto input_buffer_expr = b_input_expr->get_input_port_connector(0)->get_source().get_expr(); - if (ov::is_type(b_input_expr->get_node())) { + if (ov::is_type(input_buffer_expr->get_node())) { return input_buffer_expr; } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index d15a76c5e4f15d..0d8e3f5fb6fc9b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -18,7 +18,7 @@ enum class BRGEMM_TYPE { STAND_ALONE, // No extra requirements, used for f32|f32 WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations - REPACKING_ONLY, // low precision or some specific f32 cases - needs BrgemmCopyB on second input for data repacking + REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on second input for data repacking }; dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp index 8d734e288514bf..7dfe711a5a5c67 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp @@ -65,17 +65,12 @@ bool pass::AdjustBrgemmCopyBLoopPorts::run(const snippets::lowered::LinearIR& li bool modified = false; - auto get_repacking_loop_idces = [](const snippets::lowered::ExpressionPtr& parent_expr) { + auto get_repacking_loop_idces = [](const snippets::lowered::ExpressionPtr& brgemm_expr) { // Repacking may be extracted outside the snippets kernel. In this case, brgemm parent expression is a parameter. - if (is_type(parent_expr->get_node())) + if (is_type(brgemm_expr->get_input_port_connector(1)->get_source().get_expr()->get_node())) return std::vector{}; - - OPENVINO_ASSERT(is_type(parent_expr), - "In case of repacking brgemm expr must have BufferExpression on B input"); - const auto buffer_parent_ports = parent_expr->get_input_port(0).get_connected_ports(); - OPENVINO_ASSERT(buffer_parent_ports.size() == 1, - "Parent of brgemm repacking buffer must be connected only to the buffer"); - const auto& repacking_expr = buffer_parent_ports.begin()->get_expr(); + const auto repacking_expr = brgemm_utils::repacking::get_copy_b_expr(brgemm_expr); + OPENVINO_ASSERT(repacking_expr, "BrgemmCopyB expression is not found"); return repacking_expr->get_loop_ids(); }; @@ -83,30 +78,22 @@ bool pass::AdjustBrgemmCopyBLoopPorts::run(const snippets::lowered::LinearIR& li const auto brgemm = ov::as_type_ptr(expr->get_node()); if (!brgemm || !brgemm_utils::with_repacking(brgemm->get_type())) continue; - const auto& parent_expr = expr->get_input_port_connector(1)->get_source().get_expr(); - const auto& repacking_loop_ids = get_repacking_loop_idces(parent_expr); - for (const auto& target_port : parent_expr->get_output_port(0).get_connected_ports()) { - const auto& port_node = target_port.get_expr()->get_node(); - if (!is_type(port_node)) { - OPENVINO_ASSERT(is_type(port_node), - "Invalid grandchild of BrgemmCopyB"); - continue; - } - const auto &brgemm_loop_ids = target_port.get_expr()->get_loop_ids(); - // Continue if there is no blocking loop - if (brgemm_loop_ids.empty() && repacking_loop_ids.empty()) - continue; - OPENVINO_ASSERT(brgemm_loop_ids.size() > repacking_loop_ids.size(), "Invalid BrgemmCopyB loop configuration"); - const auto &loop_manager = linear_ir.get_loop_manager(); - for (auto i = repacking_loop_ids.size(); i < brgemm_loop_ids.size(); i++) { - const auto &loop = loop_manager->get_loop_info(brgemm_loop_ids[i]); - auto uni_loop = ov::as_type_ptr(loop); - if (!uni_loop) - uni_loop = ov::as_type_ptr(loop)->get_unified_loop_info(); - if (!m_affected_loops.count(uni_loop) && update_loop_info(uni_loop)) { - m_affected_loops.insert(uni_loop); - modified = true; - } + const auto& brgemm_loop_ids = expr->get_loop_ids(); + const auto& repacking_loop_ids = get_repacking_loop_idces(expr); + // Continue if there is no blocking loop + if (brgemm_loop_ids.empty() && repacking_loop_ids.empty()) + continue; + + OPENVINO_ASSERT(brgemm_loop_ids.size() > repacking_loop_ids.size(), "Invalid BrgemmCopyB loop configuration"); + const auto &loop_manager = linear_ir.get_loop_manager(); + for (auto i = repacking_loop_ids.size(); i < brgemm_loop_ids.size(); i++) { + const auto &loop = loop_manager->get_loop_info(brgemm_loop_ids[i]); + auto uni_loop = ov::as_type_ptr(loop); + if (!uni_loop) + uni_loop = ov::as_type_ptr(loop)->get_unified_loop_info(); + if (!m_affected_loops.count(uni_loop) && update_loop_info(uni_loop)) { + m_affected_loops.insert(uni_loop); + modified = true; } } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.cpp index 089d91aba809fb..509f9ecf149c8e 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.cpp @@ -12,7 +12,7 @@ namespace ov { namespace intel_cpu { BrgemmCopyBLoopPortsAdjuster::BrgemmCopyBLoopPortsAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, - CPURuntimeConfigurator* configurator) + const CPURuntimeConfigurator* configurator) : ov::snippets::lowered::pass::RuntimeOptimizer(configurator) { const auto& pass = std::make_shared(); pass->run(*linear_ir); @@ -29,9 +29,6 @@ BrgemmCopyBLoopPortsAdjuster::BrgemmCopyBLoopPortsAdjuster(const ov::snippets::l bool BrgemmCopyBLoopPortsAdjuster::run(const snippets::lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmCopyBLoopPortsAdjuster") - if (m_affected_uni2exp_map.empty()) - return false; - for (const auto& p : m_affected_uni2exp_map) { const auto& uni_loop = p.first; const auto& exp_loops = p.second; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp index c33cb0d502f19f..7b9f30ac96e4b1 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp @@ -19,9 +19,10 @@ namespace intel_cpu { class BrgemmCopyBLoopPortsAdjuster : public ov::snippets::lowered::pass::RuntimeOptimizer { public: BrgemmCopyBLoopPortsAdjuster() = default; - BrgemmCopyBLoopPortsAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, CPURuntimeConfigurator* configurator); + BrgemmCopyBLoopPortsAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, const CPURuntimeConfigurator* configurator); bool run(const snippets::lowered::LinearIR& linear_ir) override; + bool applicable() const override { return !m_affected_uni2exp_map.empty(); } private: std::unordered_mapget_input_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()}); - copy_b_expr->get_output_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()}); - if (with_compensations(type)) { - const ov::snippets::VectorDims compensations_subtensor{1, get_full_dim_value()}; - OPENVINO_ASSERT(brgemm_expr->get_input_count() == 3, "Brgemm must have 3 inputs in case of compensations."); - brgemm_expr->get_input_port_descriptor(2)->set_subtensor(compensations_subtensor); - copy_b_expr->get_output_port_descriptor(1)->set_subtensor(compensations_subtensor); - } + const ov::snippets::VectorDims full_subtensor(2, get_full_dim_value()); + copy_b_expr->get_input_port_descriptor(0)->set_subtensor(full_subtensor); + copy_b_expr->get_output_port_descriptor(0)->set_subtensor(full_subtensor); } if (with_amx(type)) { move_new_memory_buffer(linear_ir, brgemm_it); @@ -102,8 +97,12 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const auto& loop_manager = linear_ir.get_loop_manager(); if (with_compensations(type)) { + const ov::snippets::VectorDims compensations_subtensor{1, get_full_dim_value()}; OPENVINO_ASSERT(brgemm_expr->get_input_count() == 3, "Brgemm must have 3 inputs in case of compensations."); + OPENVINO_ASSERT(copy_b_expr, "BrgemmCopyB must be present in case of compensations."); const auto& compens_port = brgemm_expr->get_input_port(2); + compens_port.get_descriptor_ptr()->set_subtensor(compensations_subtensor); + copy_b_expr->get_output_port_descriptor(1)->set_subtensor(compensations_subtensor); const auto& loop_ids = brgemm_expr->get_loop_ids(); size_t i = 0; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp index 327d82761ad566..e98c8ebbecf49b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp @@ -15,7 +15,7 @@ namespace ov { namespace intel_cpu { BrgemmExternalRepackingAdjuster::BrgemmExternalRepackingAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, - CPURuntimeConfigurator* configurator) + const CPURuntimeConfigurator* configurator) : snippets::lowered::pass::RuntimeOptimizer(configurator) { const auto& params = linear_ir->get_parameters(); for (size_t i = 0; i < params.size(); ++i) { @@ -37,9 +37,6 @@ BrgemmExternalRepackingAdjuster::BrgemmExternalRepackingAdjuster(const ov::snipp bool BrgemmExternalRepackingAdjuster::run(const snippets::lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmExternalRepackingAdjuster") - if (m_param_idces_with_external_repacking.empty()) - return false; - const auto& cpu_config = ov::as_type_ptr(m_configurator->get_config()); auto& optimal_descs = cpu_config->m_in_requested_descs; for (const auto& i : m_param_idces_with_external_repacking) { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp index fb22beaca63ae1..f102af8f23fe5b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp @@ -19,9 +19,10 @@ namespace intel_cpu { class BrgemmExternalRepackingAdjuster : public ov::snippets::lowered::pass::RuntimeOptimizer { public: BrgemmExternalRepackingAdjuster() = default; - BrgemmExternalRepackingAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, CPURuntimeConfigurator* configurator); + BrgemmExternalRepackingAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, const CPURuntimeConfigurator* configurator); bool run(const snippets::lowered::LinearIR& linear_ir) override; + bool applicable() const override { return !m_param_idces_with_external_repacking.empty(); } private: std::set m_param_idces_with_external_repacking; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.cpp index a6973492f7d95c..003ca67722dd3c 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.cpp @@ -15,18 +15,18 @@ namespace ov { namespace intel_cpu { -pass::MoveBrgemmRepackingOut::MoveBrgemmRepackingOut() { - MATCHER_SCOPE(MoveBrgemmRepackingOut); +pass::EliminateBrgemmCopyB::EliminateBrgemmCopyB() { + MATCHER_SCOPE(EliminateBrgemmCopyB); auto m_param = ov::pass::pattern::wrap_type(); auto m_rank_norm = ov::pass::pattern::optional(m_param); auto m_copy_b = ov::pass::pattern::wrap_type({m_param}); auto callback = [=](ov::pass::pattern::Matcher& m) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::MoveBrgemmRepackingOut") + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::EliminateBrgemmCopyB") const auto& pattern_map = m.get_pattern_value_map(); const auto& copy_b_out = pattern_map.at(m_copy_b); const auto copy_b_node = ov::as_type_ptr(copy_b_out.get_node_shared_ptr()); - OPENVINO_ASSERT(copy_b_node, "BrgemmCopyB node is null in MoveBrgemmRepackingOut transformation"); + OPENVINO_ASSERT(copy_b_node, "BrgemmCopyB node is null in EliminateBrgemmCopyB transformation"); const auto& in_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(copy_b_node->input(0)); const auto& layout = in_desc->get_layout(); @@ -34,7 +34,7 @@ pass::MoveBrgemmRepackingOut::MoveBrgemmRepackingOut() { // 1. Ticket 157340: support external repacking for copyB with compensations // 2. Ticket 157339: support external repacking for non-planar layout if (!ov::snippets::utils::is_planar_layout(layout) || - copy_b_node->get_src_element_type() == ov::element::i8 || transformation_callback(copy_b_node)) + brgemm_utils::with_compensations(copy_b_node->get_type()) || transformation_callback(copy_b_node)) return false; return ov::replace_output_update_name(copy_b_out, copy_b_node->input_value(0)); }; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.hpp index c82193c93f1d4b..2cdeae53fab026 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/move_brgemm_repacking_out.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -10,10 +10,17 @@ namespace ov { namespace intel_cpu { namespace pass { -class MoveBrgemmRepackingOut: public ov::pass::MatcherPass { +/** + * @interface EliminateBrgemmCopyB + * @brief EliminateBrgemmCopyB identifies BrgemmCopyB nodes which can be inferred outside the Subgraph. + * If this is possible, CopyB node is removed, and the external repacking is configured on the further pipeline stages in RuntimeConfigurator. + * + * @ingroup snippets + */ +class EliminateBrgemmCopyB: public ov::pass::MatcherPass { public: - OPENVINO_RTTI("MoveBrgemmRepackingOut", "0"); - MoveBrgemmRepackingOut(); + OPENVINO_RTTI("EliminateBrgemmCopyB", "0"); + EliminateBrgemmCopyB(); }; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 05dfb6a377ec91..e67fbc238a8e10 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -954,8 +954,9 @@ void Transformations::MainSnippets(void) { // [122706] Some 3D MHA Patterns have perf regressions when Transpose op is tokenized std::set mha_supported_transpose_ranks = { 4 }; - // Note: this is a temporary WA, avoiding matmul B input tokenization in the cases when CPU . - // It will be removed when plugin specific SubgraphPass will be implemented. + // If preliminary repacking is needed, it is executed outside the snippets kernel for performance reasons, + // so tokenization of ops sequences on matmul's B input is disabled + // Ticket 157743: This logic should be placed in CPU specific SubgraphPass. auto mha_tokenize_mm_b_input_callback = [this](const std::shared_ptr& node) { const auto& input_type_0 = node->get_input_element_type(0); const auto& input_type_1 = node->get_input_element_type(1);