diff --git a/src/common/snippets/src/lowered/port_descriptor.cpp b/src/common/snippets/src/lowered/port_descriptor.cpp index c474fe4471fe8f..ca3ee158d50638 100644 --- a/src/common/snippets/src/lowered/port_descriptor.cpp +++ b/src/common/snippets/src/lowered/port_descriptor.cpp @@ -9,7 +9,8 @@ namespace ov { namespace snippets { namespace lowered { -size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX; +// To avoid SIZE_MAX - is dynamic value +size_t PortDescriptor::ServiceDimensions::FULL_DIM = SIZE_MAX - 1; PortDescriptor::PortDescriptor(const ov::Input& in, VectorDims subtensor_shape, std::vector layout) : PortDescriptor(ov::Input(in.get_node(), in.get_index()), std::move(subtensor_shape), std::move(layout)) {} diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index c4616707c367cf..5e4cb94d5bdd4f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -9,6 +9,7 @@ #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/loop_manager.hpp" #include "snippets/lowered/pass/pass.hpp" +#include "snippets/lowered/pass/propagate_subtensors.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/utils.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" @@ -107,12 +108,20 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea const auto block_size_n = snippets::utils::is_dynamic_value(n) ? brgemm->get_n_block_size() : std::min(brgemm->get_n_block_size(), n); const auto block_size_k = snippets::utils::is_dynamic_value(k) ? brgemm->get_k_block_size() : std::min(brgemm->get_k_block_size(), k); - *++in_0_subtensor.rbegin() = block_size_m; - *++out_subtensor.rbegin() = block_size_m; - *in_1_subtensor.rbegin() = block_size_n; - *out_subtensor.rbegin() = block_size_n; - *in_0_subtensor.rbegin() = block_size_k; - *++in_1_subtensor.rbegin() = block_size_k; + // If block_size is dynamic, it means that Brgemm will process full tensor: + // subtensor[i] = FULL_DIM as by default + if (!snippets::utils::is_dynamic_value(block_size_m)) { + *++in_0_subtensor.rbegin() = block_size_m; + *++out_subtensor.rbegin() = block_size_m; + } + if (!snippets::utils::is_dynamic_value(block_size_n)) { + *in_1_subtensor.rbegin() = block_size_n; + *out_subtensor.rbegin() = block_size_n; + } + if (!snippets::utils::is_dynamic_value(block_size_k)) { + *in_0_subtensor.rbegin() = block_size_k; + *++in_1_subtensor.rbegin() = block_size_k; + } brgemm_expr->get_input_port_descriptor(0)->set_subtensor(in_0_subtensor); brgemm_expr->get_input_port_descriptor(1)->set_subtensor(in_1_subtensor); @@ -141,6 +150,15 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea } } + auto get_default_handlers = [](size_t work_amount, size_t block_size) { + SpecificIterationHandlers handlers; + const auto tail_size = snippets::utils::is_dynamic_value(work_amount) ? snippets::utils::get_dynamic_value() : work_amount % block_size; + if (tail_size != 0) + handlers.register_pass(tail_size); + handlers.register_pass(true); + return handlers; + }; + auto mark_m_blocking = [&](bool include_repacking) { const auto loop_begin_it = get_loop_begin_pos(linear_ir, expr_it, include_repacking); const auto loop_end_it = std::next(expr_it); @@ -154,9 +172,8 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea entries.emplace_back(brgemm_expr->get_input_port(2), false); const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; - const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits); - const auto& loop_info = loop_manager->get_loop_info(id); - loop_info->register_pass_to_handler(true); + const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits, false); + loop_manager->get_loop_info(id)->set_handlers(get_default_handlers(m, block_size_m)); }; auto mark_n_blocking = [&]() { @@ -168,9 +185,8 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea LoopPort(brgemm_cpu && brgemm_cpu->is_with_data_repacking() ? copy_b_expr->get_input_port(0) : brgemm_expr->get_input_port(1), true)}; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; - const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits); - const auto& loop_info = loop_manager->get_loop_info(id); - loop_info->register_pass_to_handler(true); + const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits, false); + loop_manager->get_loop_info(id)->set_handlers(get_default_handlers(n, block_size_n)); }; auto mark_k_blocking = [&]() { @@ -182,10 +198,11 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea LoopPort(brgemm_cpu && brgemm_cpu->is_with_data_repacking() ? copy_b_expr->get_input_port(0) : brgemm_expr->get_input_port(1), true, 1)}; const std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; - const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits); - const auto& loop_info = loop_manager->get_loop_info(id); - loop_info->register_pass_to_handler(0.f); - loop_info->register_pass_to_handler(true); + auto handlers = get_default_handlers(k, block_size_k); + handlers.register_pass(0.f); + + const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits, false); + loop_manager->get_loop_info(id)->set_handlers(handlers); }; const bool k_blocking = block_size_k != k;