diff --git a/src/common/snippets/include/snippets/lowered/pass/brgemm_blocking.hpp b/src/common/snippets/include/snippets/lowered/pass/brgemm_blocking.hpp index b8214f22024522..258e6fa84dc686 100644 --- a/src/common/snippets/include/snippets/lowered/pass/brgemm_blocking.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/brgemm_blocking.hpp @@ -48,26 +48,30 @@ class BrgemmBlockingBase { static bool blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager, const ov::snippets::lowered::ExpressionPtr& brgemm_expr); - static void mark_m_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, - snippets::lowered::LinearIR::constExprIt loop_begin, - snippets::lowered::LinearIR::constExprIt loop_end, - const std::vector& entries, - const std::vector& exits, - size_t block_size_m); + void mark_m_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, + snippets::lowered::LinearIR::constExprIt loop_begin, + snippets::lowered::LinearIR::constExprIt loop_end, + const std::vector& entries, + const std::vector& exits, + size_t block_size_m); - static void mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, - snippets::lowered::LinearIR::constExprIt loop_begin, - snippets::lowered::LinearIR::constExprIt loop_end, - const std::vector& entries, - const std::vector& exits, - size_t block_size_n); + void mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, + snippets::lowered::LinearIR::constExprIt loop_begin, + snippets::lowered::LinearIR::constExprIt loop_end, + const std::vector& entries, + const std::vector& exits, + size_t block_size_n); - static void mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, - snippets::lowered::LinearIR::constExprIt loop_begin, - snippets::lowered::LinearIR::constExprIt loop_end, - const std::vector& entries, - const std::vector& exits, - size_t block_size_k); + void mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, + snippets::lowered::LinearIR::constExprIt loop_begin, + snippets::lowered::LinearIR::constExprIt loop_end, + const std::vector& entries, + const std::vector& exits, + size_t block_size_k); + + virtual SpecificIterationHandlers get_m_loop_handlers(size_t work_amount, size_t block_size) const; + virtual SpecificIterationHandlers get_n_loop_handlers(size_t work_amount, size_t block_size) const; + virtual SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const; }; /** diff --git a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp index 2c537a5aaa7165..b7eb4e7176f3c1 100644 --- a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp @@ -80,24 +80,6 @@ class SetEvaluateOnce : public snippets::lowered::pass::RangedPass { std::shared_ptr merge(const std::shared_ptr& other) override; }; -/** - * @interface SetBrgemmBeta - * @brief The pass updates all CPUBrgemm nodes with a new beta value - * @param m_beta - beta which must be set - * @ingroup snippets - */ -class SetBrgemmBeta : public snippets::lowered::pass::RangedPass { -public: - SetBrgemmBeta(float beta); - OPENVINO_RTTI("SetBrgemmBeta", "RangedPass") - bool run(snippets::lowered::LinearIR& linear_ir, - snippets::lowered::LinearIR::constExprIt begin, - snippets::lowered::LinearIR::constExprIt end) override; - std::shared_ptr merge(const std::shared_ptr& other) override; - -private: - float m_beta = 0; -}; } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index e48c2a482ed846..3bfdbddd817400 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -32,9 +32,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op { size_t get_offset_b() const { return get_input_offset(1); } size_t get_offset_c() const { return get_output_offset(0); } - float get_beta() const { return m_beta; } - void set_beta(float beta) { m_beta = beta; } - static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1); void validate_and_infer_types() override; @@ -48,7 +45,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op { std::vector get_planar_input_shapes(const std::vector>& inputs) const; ov::PartialShape infer_output_partial_shape(const std::vector& input_shapes) const; ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const; - float m_beta = 0.f; private: void custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c); diff --git a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp index d89439b9a479ef..a7336c14454319 100644 --- a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp +++ b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp @@ -55,7 +55,7 @@ void BrgemmBlockingBase::mark_m_blocking(const snippets::lowered::LoopManagerPtr const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[0].expr_port); const auto m = *++planar_dims.rbegin(); const auto id = loop_manager->mark_loop(loop_begin, loop_end, m, block_size_m, 1, entries, exits, false); - loop_manager->get_loop_info(id)->set_handlers(get_default_blocking_loop_handlers(m, block_size_m)); + loop_manager->get_loop_info(id)->set_handlers(get_m_loop_handlers(m, block_size_m)); } void BrgemmBlockingBase::mark_n_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, @@ -67,7 +67,7 @@ void BrgemmBlockingBase::mark_n_blocking(const snippets::lowered::LoopManagerPtr const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[1].expr_port); const auto n = *planar_dims.rbegin(); const auto id = loop_manager->mark_loop(loop_begin, loop_end, n, block_size_n, 0, entries, exits, false); - loop_manager->get_loop_info(id)->set_handlers(get_default_blocking_loop_handlers(n, block_size_n)); + loop_manager->get_loop_info(id)->set_handlers(get_n_loop_handlers(n, block_size_n)); } void BrgemmBlockingBase::mark_k_blocking(const snippets::lowered::LoopManagerPtr& loop_manager, @@ -79,10 +79,17 @@ void BrgemmBlockingBase::mark_k_blocking(const snippets::lowered::LoopManagerPtr const auto planar_dims = ov::snippets::utils::get_planar_vdims(*entries[0].expr_port); const auto k = *planar_dims.rbegin(); const auto id = loop_manager->mark_loop(loop_begin, loop_end, k, block_size_k, entries, exits, false); - auto handlers = get_default_blocking_loop_handlers(k, block_size_k); - handlers.register_pass(0.f); - loop_manager->get_loop_info(id)->set_handlers(handlers); + loop_manager->get_loop_info(id)->set_handlers(get_k_loop_handlers(k, block_size_k)); +} + +SpecificIterationHandlers BrgemmBlockingBase::get_m_loop_handlers(size_t work_amount, size_t block_size) const { + return get_default_blocking_loop_handlers(work_amount, block_size); +} +SpecificIterationHandlers BrgemmBlockingBase::get_n_loop_handlers(size_t work_amount, size_t block_size) const { + return get_default_blocking_loop_handlers(work_amount, block_size); +} +SpecificIterationHandlers BrgemmBlockingBase::get_k_loop_handlers(size_t work_amount, size_t block_size) const { + return get_default_blocking_loop_handlers(work_amount, block_size); } std::tuple BrgemmBlockingBase::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) { @@ -102,21 +109,19 @@ std::tuple BrgemmBlockingBase::get_blocking_params(const // Ticket: 113745 // TODO: extend block size selection heuristics auto get_block_size_m = [](const size_t M) -> size_t { - if (snippets::utils::is_dynamic_value(M)) - return 32; - return M <= 32 ? get_full_dim_value() : 32; + if (!snippets::utils::is_dynamic_value(M) && M <= 32) + return get_full_dim_value(); + return 32; }; auto get_block_size_n = [](const size_t N) -> size_t { - // N blocking is disabled in dynamism by default - if (ov::snippets::utils::is_dynamic_value(N) || N <= 64) + if (!snippets::utils::is_dynamic_value(N) && N <= 64) return get_full_dim_value(); return 64; }; auto get_block_size_k = [](const size_t K) -> size_t { - // K blocking is disabled in dynamism by default - if (ov::snippets::utils::is_dynamic_value(K) || K <= 512) - return get_full_dim_value(); - return K > 1024 ? 1024 : 512; + if (ov::snippets::utils::is_dynamic_value(K)) + return 512; + return K > 1024 ? 1024 : K > 512 ? 512 : get_full_dim_value(); }; return std::make_tuple(get_block_size_m(m), get_block_size_n(n), get_block_size_k(k)); } @@ -137,8 +142,6 @@ bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear LoopPort(brgemm_expr->get_input_port(1), true, 1)}; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; mark_k_blocking(loop_manager, brgemm_it, std::next(brgemm_it), entries, exits, k_block); - } else { - ov::as_type_ptr(brgemm_expr->get_node())->set_beta(0.f); } if (!ov::snippets::utils::is_full_dim_value(n_block)) { const std::vector entries{LoopPort(brgemm_expr->get_input_port(0), false), diff --git a/src/common/snippets/src/lowered/pass/iter_handler.cpp b/src/common/snippets/src/lowered/pass/iter_handler.cpp index 8462476b07c766..a3ee577338a691 100644 --- a/src/common/snippets/src/lowered/pass/iter_handler.cpp +++ b/src/common/snippets/src/lowered/pass/iter_handler.cpp @@ -154,28 +154,6 @@ std::shared_ptr SetEvaluateOnce::merge(const return !other || ov::is_type(other) ? std::make_shared() : nullptr; } -SetBrgemmBeta::SetBrgemmBeta(float beta) : snippets::lowered::pass::RangedPass(), m_beta(beta) {} - -bool SetBrgemmBeta::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { - for (auto expr_it = begin; expr_it != end; ++expr_it) { - const auto& expr = expr_it->get(); - if (const auto brgemm = ov::as_type_ptr(expr->get_node())) { - brgemm->set_beta(m_beta); - } - } - return true; -} - -std::shared_ptr SetBrgemmBeta::merge(const std::shared_ptr& other) { - const auto merged_pass = std::make_shared(m_beta); - if (other == nullptr) - return merged_pass; - const auto casted_pass = ov::as_type_ptr(other); - if (!casted_pass || m_beta != casted_pass->m_beta) - return nullptr; - return merged_pass; -} - } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 449457125292af..72fc692fff5d70 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -80,7 +80,6 @@ std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args } bool Brgemm::visit_attributes(AttributeVisitor& visitor) { - visitor.on_attribute("beta", m_beta); return MemoryAccess::visit_attributes(visitor); } diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 6f8945649c2b94..ec1db44f074766 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -177,6 +177,8 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRPtr& linear_ir // If the specific iteration is not needed, we skip loop evaluation - set zero as work amount is enough if (!lowered::pass::InsertSpecificIterations::is_decomposed_loop_needed(current_unified_loop_info, decomposed_loop_type, current_work_amount)) { expanded_loop_info->set_work_amount(0); + if (expanded_loop_info->is_evaluate_once()) + expanded_loop_info->set_increment(0); continue; } diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index a5bc6a96674ab0..b783da1cd0cb53 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -28,9 +28,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const auto& brg0Prc = brgemm_node->get_input_element_type(0); const auto& brg1Prc = brgemm_node->get_input_element_type(1); const auto brgemm_type = brgemm_node->get_type(); - BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, - brgemm_node->get_beta(), with_amx(brgemm_type), - with_compensations(brgemm_type), + BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_amx(brgemm_type), with_compensations(brgemm_type), brgemm_utils::get_primitive_isa(brg0Prc, with_amx(brgemm_type))); m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index 62e772c2d78d22..920f95f0c8bc37 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -9,6 +9,7 @@ #include "common/utils.hpp" #include "dnnl_extension_utils.h" #include "snippets/lowered/loop_manager.hpp" +#include "snippets/lowered/pass/insert_specific_iterations.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" @@ -20,13 +21,13 @@ using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; namespace { -size_t init_hash(dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, float beta, bool is_with_amx, +size_t init_hash(dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t isa) { size_t seed = 0; #define HASH(X) seed = hash_combine(seed, X) HASH(dt_in0); HASH(dt_in1); HASH(is_with_amx); HASH(is_with_comp); - HASH(beta); HASH(isa); + HASH(isa); #undef HASH return seed; } @@ -35,9 +36,9 @@ size_t init_hash(dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, float beta, b namespace ov { namespace intel_cpu { BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, - float beta, bool is_with_amx, bool is_with_comp, + bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) : - m_static_params(std::make_shared(in0_dtype, in1_dtype, beta, + m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_amx, is_with_comp, primitive_isa)) { m_hash = compute_hash(); @@ -49,28 +50,30 @@ bool BrgemmKernelConfig::is_completed() const { bool BrgemmKernelConfig::operator==(const BrgemmKernelConfig& rhs) const { #define EQ(X) X == rhs.X - return EQ(m_hash) && + return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && (EQ(m_static_params.get()) || *m_static_params == *(rhs.m_static_params)); #undef EQ } -void BrgemmKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC) { +void BrgemmKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta) { // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) // To process this case, we have to make this Config as empty (nullify runtime parameters) if (utils::one_of(0, M, N, K)) { m_M = 0; m_N = 0; m_K = 0; m_LDA = 0; m_LDB = 0; m_LDC = 0; + m_beta = 0; } else { m_M = M; m_N = N; m_K = K; m_LDA = LDA; m_LDB = LDB; m_LDC = LDC; + m_beta = beta; } m_hash = compute_hash(); } bool BrgemmKernelConfig::is_empty() const { - return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC); + return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta); } BrgemmKernelConfig::operator amx_tile_config_t() const { @@ -80,19 +83,17 @@ BrgemmKernelConfig::operator amx_tile_config_t() const { } BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, - float beta, bool is_with_amx, bool is_with_comp, + bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) : dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), - beta(beta), is_with_amx(is_with_amx), is_with_comp(is_with_comp), + is_with_amx(is_with_amx), is_with_comp(is_with_comp), isa(primitive_isa), - hash(init_hash(dt_in0, dt_in1, beta, is_with_amx, is_with_comp, isa)) { + hash(init_hash(dt_in0, dt_in1, is_with_amx, is_with_comp, isa)) { } bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { #define EQ(X) X == rhs.X - return EQ(hash) && - EQ(dt_in0) && EQ(dt_in1) && EQ(beta) && - EQ(is_with_amx) && EQ(is_with_comp) && EQ(isa); + return EQ(hash) && EQ(dt_in0) && EQ(dt_in1)&& EQ(is_with_amx) && EQ(is_with_comp) && EQ(isa); #undef EQ } size_t BrgemmKernelConfig::compute_hash() const { @@ -100,6 +101,7 @@ size_t BrgemmKernelConfig::compute_hash() const { #define HASH(X) seed = hash_combine(seed, X) HASH(m_M); HASH(m_N); HASH(m_K); HASH(m_LDA); HASH(m_LDB); HASH(m_LDC); + HASH(m_beta); #undef HASH return seed; } @@ -110,7 +112,7 @@ std::string BrgemmKernelConfig::StaticParams::to_string() const { std::stringstream ss; PRINT(dt_in0); PRINT(dt_in1); PRINT(is_with_amx); PRINT(is_with_comp); - PRINT(beta); PRINT(isa); + PRINT(isa); return ss.str(); } @@ -119,6 +121,7 @@ std::string BrgemmKernelConfig::to_string() const { ss << m_static_params->to_string() << "\n"; PRINT(m_M); PRINT(m_N); PRINT(m_K); PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC); + PRINT(m_beta); return ss.str(); } #undef PRINT @@ -156,6 +159,30 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const return compiled_kernel; } +float BrgemmKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, + const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { + // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. + // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). + // It means that previous executed Loops have Loop ID less the current Loop ID. + // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have `beta = 1`. + // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. + if (loop_id > 0) { + const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); + // Check the previous Loops + --loop_id; + while (loop_id >= 0) { + const auto& expanded_loop_info = loop_manager->get_loop_info(loop_id); + if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) + return 0; + if (expanded_loop_info->get_work_amount() > 0) { + // there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1` + return 1; + } + --loop_id; + } + } + return 0; +} void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRPtr& linear_ir, BrgemmKernelConfig& config) const { @@ -169,32 +196,80 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression auto in0_subtensor = input_pds[0]->get_subtensor(); auto in1_subtensor = input_pds[1]->get_subtensor(); + // Need to update M, K, N + // 1. If the original value in subtensor is `FULL_DIM`, it means that + // Brgemm block should process full tensor by this dim -> take dimension from shape + // 2. Otherwise, Brgemm block processes part of the tensor by this dim + // (there is blocking by this dimension) -> take from Loop increment + auto M = *++in0_subtensor.rbegin(); auto K = *in0_subtensor.rbegin(); auto N = *in1_subtensor.rbegin(); + size_t loop_idx = 0; + const auto& loop_ids = expr->get_loop_ids(); + const auto& loop_manager = linear_ir->get_loop_manager(); + auto get_loop_info = [&](){ + OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop by dimension M is missed"); + return loop_manager->get_loop_info(loop_ids[loop_idx++]); + }; + + /* ------- Dimension M ----------*/ if (ov::snippets::utils::is_full_dim_value(M)) { M = *++in0_shape.rbegin(); } else { - const auto& loop_ids = expr->get_loop_ids(); - OPENVINO_ASSERT(!loop_ids.empty(), "Loop by dimension M is missed"); - // TODO [146125]: Loop by M is first one in `loop_ids` - const auto& expanded_loop_info = linear_ir->get_loop_manager()->get_loop_info(loop_ids.front()); - M = expanded_loop_info->get_increment(); + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` + // to avoid extra checks, we validate only first input port + OPENVINO_ASSERT(in_ports.size() > 1 && in_ports.front().is_incremented && in_ports.front().dim_idx == 1 && + out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 1, + "Incorrect Loop by Brgemm dimension N"); + M = current_expanded_loop_info->get_increment(); input_pds[0]->set_subtensor_dim(1, M); output_pds[0]->set_subtensor_dim(1, M); } - if (ov::snippets::utils::is_full_dim_value(K)) { - K = *in0_shape.rbegin(); - } else if (ov::snippets::utils::is_dynamic_value(K)) { - OPENVINO_THROW("Dynamic K is not supported"); - } - + /* ------- Dimension N ----------*/ if (ov::snippets::utils::is_full_dim_value(N)) { N = *in1_shape.rbegin(); - } else if (ov::snippets::utils::is_dynamic_value(N)) { - OPENVINO_THROW("Dynamic N is not supported"); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && in_ports.back().is_incremented && in_ports.back().dim_idx == 0 && + out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 0, + "Incorrect Loop by Brgemm dimension N"); + N = current_expanded_loop_info->get_increment(); + input_pds[1]->set_subtensor_dim(0, N); + output_pds[0]->set_subtensor_dim(0, N); + } + + /* ------- Dimension K ----------*/ + // 1. If Brgemm block processes full dimension K -> `beta = 0` + // 2. If Brgemm block processes part of the dimension K (there is blocking), need to find + // the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0). + // First of them will have `beta = 0`, other - `beta = 1` + float beta = 0; + if (ov::snippets::utils::is_full_dim_value(K)) { + K = *in0_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().is_incremented && in_ports.front().dim_idx == 0 && + in_ports.back().is_incremented && in_ports.back().dim_idx == 1 && + out_ports.size() == 1 && !out_ports.front().is_incremented, + "Incorrect Loop by Brgemm dimension K"); + K = current_expanded_loop_info->get_increment(); + input_pds[0]->set_subtensor_dim(0, K); + input_pds[1]->set_subtensor_dim(1, K); + if (K > 0) + beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); } const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); @@ -206,7 +281,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression if (with_repacking(brgemm_node->get_type())) LDB = brgemm_utils::repacking::compute_out_leading_dim(N, brgemm_node->get_input_element_type(1)); - config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC); + config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 4dd52e21ca2dfd..b673c61d6d0aef 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -9,11 +9,14 @@ #include "emitters/snippets/cpu_kernel_executor_table.hpp" #include +#include "snippets/lowered/loop_manager.hpp" +#include "snippets/lowered/loop_info.hpp" + namespace ov { namespace intel_cpu { struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { public: - BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, float beta, + BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); BrgemmKernelConfig() = delete; bool is_completed() const override; @@ -23,7 +26,7 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { std::unique_ptr get_clone_ptr() const override { return std::unique_ptr( new BrgemmKernelConfig(*this)); } - void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC); + void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); bool is_empty() const; dnnl_data_type_t get_dt_in0() const { return m_static_params->dt_in0; } @@ -32,7 +35,7 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_static_params->isa; } bool is_with_amx() const {return m_static_params->is_with_amx; } bool is_with_comp() const { return m_static_params->is_with_comp; } - float get_beta() const { return m_static_params->beta; } + float get_beta() const { return m_beta; } dnnl_dim_t get_M() const { return m_M; } dnnl_dim_t get_N() const { return m_N; } @@ -53,10 +56,9 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { private: struct StaticParams { - StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, float beta, + StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_amx, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); const dnnl_data_type_t dt_in0 {dnnl_f32}, dt_in1 {dnnl_f32}; - const float beta {0}; const bool is_with_amx {false}; const bool is_with_comp {false}; const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; @@ -70,6 +72,7 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { size_t compute_hash() const; std::shared_ptr m_static_params; dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0}; + float m_beta {0}; size_t m_hash {SIZE_MAX}; }; @@ -99,6 +102,9 @@ class BrgemmKernelExecutor : public CPUKernelExecutor& A, const Output& B, BRGEMM_TYPE type, const size_t offset_a, const size_t offset_b, const size_t offset_c, - std::vector layout_a, std::vector layout_b, std::vector layout_c, - const float beta) + std::vector layout_a, std::vector layout_b, std::vector layout_c) : Brgemm(), m_type(type) { // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call set_arguments({A, B}); @@ -27,13 +26,11 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE t set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); - set_beta(beta); } BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, BRGEMM_TYPE type, const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c, - std::vector layout_a, std::vector layout_b, std::vector layout_c, - const float beta) + std::vector layout_a, std::vector layout_b, std::vector layout_c) : Brgemm(), m_type(type) { set_arguments({A, B, scratch}); set_output_size(1); @@ -43,33 +40,28 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output< set_output_port_descriptor({0, offset_c}, 0); set_input_port_descriptor({0, offset_scratch}, 2); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); - set_beta(beta); } BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c, - std::vector layout_a, std::vector layout_b, std::vector layout_c, - const float beta) + std::vector layout_a, std::vector layout_b, std::vector layout_c) : Brgemm(), m_type(type) { set_arguments({A, B}); set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}}; m_output_ports = {{0, desc_c}}; custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); - set_beta(beta); } BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, BRGEMM_TYPE type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c, - std::vector layout_a, std::vector layout_b, std::vector layout_c, - const float beta) + std::vector layout_a, std::vector layout_b, std::vector layout_c) : Brgemm(), m_type(type) { set_arguments({A, B, scratch}); set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_scratch}}; m_output_ports = {{0, desc_c}}; custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); - set_beta(beta); } void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c) { @@ -129,15 +121,13 @@ std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a get_input_port_descriptor(0), get_input_port_descriptor(1), get_output_port_descriptor(0), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), - snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(), - m_beta); + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } else { return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, get_input_port_descriptor(0), get_input_port_descriptor(1), get_input_port_descriptor(2), get_output_port_descriptor(0), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), - snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(), - m_beta); + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp index 9dc3a82bb2e1fa..a646ffc792fd6d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp @@ -26,20 +26,16 @@ class BrgemmCPU : public snippets::op::Brgemm { BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0, - std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const float beta = 1.f); + std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}); BrgemmCPU(const Output& A, const Output& B, const Output& scratch, BRGEMM_TYPE type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0, - std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const float beta = 1.f); + std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}); BrgemmCPU(const Output& A, const Output& B, BRGEMM_TYPE type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c, - std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const float beta = 1.f); + std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}); BrgemmCPU(const Output& A, const Output& B, const Output& scratch, BRGEMM_TYPE type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c, - std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const float beta = 1.f); + std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}); BrgemmCPU() = default; void validate_and_infer_types() override; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp index a473cb45853ab7..0aa14a2a47749d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp @@ -24,6 +24,13 @@ using namespace ov::intel_cpu::brgemm_utils; using namespace ov::snippets::lowered; using namespace ov::snippets::utils; +bool BrgemmCPUBlocking::DummyPass::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { + return true; +} +std::shared_ptr BrgemmCPUBlocking::DummyPass::merge(const std::shared_ptr& other) { + return !other || ov::is_type(other) ? std::make_shared() : nullptr; +} + LinearIR::constExprIt BrgemmCPUBlocking::move_new_memory_buffer(LinearIR& linear_ir, const LinearIR::constExprIt& brgemm_it) { const auto& brgemm_expr = brgemm_it->get(); const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr(); @@ -59,6 +66,12 @@ std::tuple BrgemmCPUBlocking::get_blocking_params(const return blocking_params; } +SpecificIterationHandlers BrgemmCPUBlocking::get_k_loop_handlers(size_t work_amount, size_t block_size) const { + SpecificIterationHandlers handlers = ov::snippets::lowered::pass::BrgemmBlockingBase::get_k_loop_handlers(work_amount, block_size); + handlers.register_pass(); + return handlers; +} + bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const LinearIR::constExprIt& brgemm_it, size_t m_block, @@ -92,8 +105,6 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, LoopPort(copy_b_expr->get_input_port(0), true, 1)}; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; mark_k_blocking(loop_manager, loop_begin, std::next(brgemm_it), entries, exits, k_block); - } else { - brgemm->set_beta(0.f); } if (!is_full_dim_value(n_block)) { const auto loop_begin = get_loop_begin_pos(linear_ir, brgemm_it, copy_b_expr); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp index 466ce5d8a76148..fe3c4e5727bf5a 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp @@ -20,6 +20,22 @@ class BrgemmCPUBlocking : public ov::snippets::lowered::pass::BrgemmBlocking merge(const std::shared_ptr& other) override; + }; + private: static snippets::lowered::LinearIR::constExprIt move_new_memory_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it); @@ -28,6 +44,8 @@ class BrgemmCPUBlocking : public ov::snippets::lowered::pass::BrgemmBlocking get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) override; bool mark_blocking_loops(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it, diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp index f009096289e716..d9f0bc947db958 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp @@ -50,6 +50,7 @@ std::shared_ptr BrgemmTPP::clone_with_new_inputs(const OutputVector& new_a } bool BrgemmTPP::visit_attributes(AttributeVisitor& visitor) { + visitor.on_attribute("beta", m_beta); TensorProcessingPrimitive::visit_attributes(visitor); return Brgemm::visit_attributes(visitor); } diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp index 0eb01d158e1c89..c9199c3c7f82df 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp @@ -35,6 +35,12 @@ class BrgemmTPP : virtual public modifier::TensorProcessingPrimitive, public sni std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool visit_attributes(AttributeVisitor& visitor) override; + + float get_beta() const { return m_beta; } + void set_beta(float beta) { m_beta = beta; } + +private: + float m_beta = 0.f; }; } // namespace op diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp index 857cca8d3728c7..7d6f095c28ab98 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp @@ -18,6 +18,20 @@ namespace tpp { namespace pass { using namespace ov::snippets::utils; +bool BrgemmTPPBlocking::SetBrgemmBeta::run(ov::snippets::lowered::LinearIR& linear_ir, + ov::snippets::lowered::LinearIR::constExprIt begin, + ov::snippets::lowered::LinearIR::constExprIt end) { + for (auto expr_it = begin; expr_it != end; ++expr_it) { + if (const auto brgemm = ov::as_type_ptr(expr_it->get()->get_node())) + brgemm->set_beta(0); + } + return true; +} + +std::shared_ptr BrgemmTPPBlocking::SetBrgemmBeta::merge(const std::shared_ptr& other) { + return !other || ov::is_type(other) ? std::make_shared() : nullptr; +} + std::tuple BrgemmTPPBlocking::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) { const auto& in_0_desc = brgemm_expr->get_input_port_descriptor(0); const auto& in_1_desc = brgemm_expr->get_input_port_descriptor(1); @@ -38,6 +52,12 @@ std::tuple BrgemmTPPBlocking::get_blocking_params(const const auto block_size_k = k > 1024 ? 1024 : k > 512 ? 512 : k; return std::make_tuple(block_size_m, block_size_n, block_size_k); } + +ov::snippets::lowered::SpecificIterationHandlers BrgemmTPPBlocking::get_k_loop_handlers(size_t work_amount, size_t block_size) const { + ov::snippets::lowered::SpecificIterationHandlers handlers = ov::snippets::lowered::pass::BrgemmBlockingBase::get_k_loop_handlers(work_amount, block_size); + handlers.register_pass(); + return handlers; +} } // namespace pass } // namespace tpp } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp index ecaf4602e83c37..cba6d5f88adc8a 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp @@ -21,8 +21,25 @@ class BrgemmTPPBlocking : public ov::snippets::lowered::pass::BrgemmBlocking merge(const std::shared_ptr& other) override; + }; + private: std::tuple get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) override; + ov::snippets::lowered::SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const override; }; } // namespace pass diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index db428abb8167d9..9249e5bbaa6bc7 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -89,16 +89,14 @@ std::vector> input_shapes_dynamic{ }, // Only K dimension is dynamic { - {PartialShape{2, 2, 70, -1}, {{2, 2, 70, 128}, {2, 2, 70, 10}, {2, 2, 70, 33}, - {2, 2, 70, 35}, {2, 2, 70, 100}}}, - {PartialShape{2, 2, -1, 70}, {{2, 2, 128, 70}, {2, 2, 10, 70}, {2, 2, 33, 70}, - {2, 2, 35, 70}, {2, 2, 100, 70}}} + {PartialShape{2, 2, 70, -1}, {{2, 2, 70, 512}, {2, 2, 70, 10}, {2, 2, 70, 33}, {2, 2, 70, 2000}, {2, 2, 70, 35}, {2, 2, 70, 600}}}, + {PartialShape{2, 2, -1, 70}, {{2, 2, 512, 70}, {2, 2, 10, 70}, {2, 2, 33, 70}, {2, 2, 2000, 70}, {2, 2, 35, 70}, {2, 2, 600, 70}}} }, // Only N dimension is dynamic { {PartialShape{}, {{2, 2, 65, 550}}}, {PartialShape{2, 2, 550, -1}, {{2, 2, 550, 70}, {2, 2, 550, 12}, {2, 2, 550, 70}, - {2, 2, 550, 12}, {2, 2, 550, 10}}} + {2, 2, 550, 12}, {2, 2, 550, 10}, {2, 2, 550, 64} }} }, }; diff --git a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp index 052065c7e127d6..dfae3be2a73946 100644 --- a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp @@ -25,6 +25,16 @@ using BRGEMM_TYPE = intel_cpu::brgemm_utils::BRGEMM_TYPE; namespace { +SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) { + auto handlers = BrgemmBlockingBase::get_default_blocking_loop_handlers(work_amount, block_size); +#ifdef SNIPPETS_LIBXSMM_TPP + handlers.register_pass(); +#else + handlers.register_pass(); +#endif + return handlers; +} + void create_brgemm_loop_infos(const LinearIRPtr& linear_ir, const ExpressionPtr& brgemm_expr, size_t m = 0, size_t m_blk = 0, @@ -39,8 +49,7 @@ void create_brgemm_loop_infos(const LinearIRPtr& linear_ir, std::vector{LoopPort(brgemm_expr->get_input_port(0)), LoopPort(brgemm_expr->get_input_port(1), true, 1)}, std::vector{LoopPort(brgemm_expr->get_output_port(0), false)}, - BrgemmBlockingBase::get_default_blocking_loop_handlers(k, k_block)); - loop_info->register_pass_to_handler(0.f); + get_k_loop_handlers(k, k_block)); linear_ir->get_loop_manager()->add_loop_info(loop_info); } if (n_block) { @@ -78,8 +87,7 @@ void create_brgemm_with_copy_b_loop_infos(const LinearIRPtr& linear_ir, std::vector{LoopPort(brgemm_expr->get_input_port(0)), LoopPort(copy_b_expr->get_input_port(0), true, 1)}, std::vector{LoopPort(brgemm_expr->get_output_port(0), false)}, - BrgemmBlockingBase::get_default_blocking_loop_handlers(k, k_block)); - loop_info->register_pass_to_handler(0.f); + get_k_loop_handlers(k, k_block)); linear_ir->get_loop_manager()->add_loop_info(loop_info); } if (n_block) { @@ -173,7 +181,6 @@ TEST_F(BrgemmCPUBlockingTest, BlockingIsNotNeeded) { auto data_a = linear_ir_ref->push_node(precision, input_shape_a); auto data_b = linear_ir_ref->push_node(precision, input_shape_b); auto brgemm = linear_ir_ref->push_node(data_a.second, data_b.second, BRGEMM_TYPE::STAND_ALONE); - brgemm.second->set_beta(0.f); const auto full_subtensor = VectorDims(2, ov::snippets::utils::get_full_dim_value()); init_expr_descriptors(*brgemm.first, std::vector(3, full_subtensor)); auto result = linear_ir_ref->push_node(brgemm.second); @@ -211,7 +218,6 @@ TEST_F(BrgemmCPUBlockingTest, WithDataRepacking) { init_expr_descriptors(brgemm_expr, {{m_blk, full_dim}, {full_dim, full_dim}, {m_blk, full_dim}}); create_brgemm_with_copy_b_loop_infos(linear_ir_ref, brgemm_expr, copy_b_expr, m, m_blk); brgemm_expr->set_loop_ids({0}); - brgemm.second->set_beta(0.f); auto result = linear_ir_ref->push_node(brgemm.second); } } @@ -247,7 +253,6 @@ TEST_F(BrgemmCPUBlockingTest, WithCompensations) { init_expr_descriptors(brgemm_expr, {{m_blk, full_dim}, {full_dim, full_dim}, {1, full_dim}, {m_blk, full_dim}}); create_brgemm_loop_infos(linear_ir_ref, brgemm_expr, m, m_blk); brgemm_expr->set_loop_ids({0}); - brgemm.second->set_beta(0.f); auto result = linear_ir_ref->push_node(brgemm.second); } } @@ -285,7 +290,6 @@ TEST_F(BrgemmCPUBlockingTest, AMX) { init_expr_descriptors(brgemm_expr, {{m_blk, full_dim}, {full_dim, full_dim}, get_default_subtensor(), {m_blk, full_dim}}); create_brgemm_with_copy_b_loop_infos(linear_ir_ref, brgemm_expr, copy_b_expr, m, m_blk); brgemm_expr->set_loop_ids({0}); - brgemm.second->set_beta(0.f); auto result = linear_ir_ref->push_node(brgemm.second); } }