diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index ee34b2f9076abe..1fc393252638d5 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -24,6 +24,18 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; BrgemmBlocking::BrgemmBlocking() : Pass() {} +void BrgemmBlocking::move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) { + const auto& brgemm_expr = brgemm_it->get(); + const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr(); + const auto wsp_buffer = ov::as_type_ptr(wsp_expr->get_node()); + OPENVINO_ASSERT(wsp_buffer && wsp_buffer->is_new_memory(), "Incorrect Scratchpad buffer for Brgemm AMX"); + // If scratchpad with temp memory is not explicitly before Brgemm, need to move to there. + if (wsp_expr != *std::prev(brgemm_it)) { + const auto wsp_it = linear_ir.find(wsp_expr); + linear_ir.move(wsp_it, brgemm_it); + } +} + bool BrgemmBlocking::run(LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking") if (linear_ir.empty()) @@ -75,12 +87,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { *(in_0_subtensor.rbegin() + 1) = block_size_m; *(out_subtensor.rbegin() + 1) = block_size_m; + auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it); std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), LoopPort(brgemm_expr->get_input_port(1), false)}; - if (brgemm->is_with_scratchpad()) + if (brgemm->is_with_compensations()) { entries.emplace_back(brgemm_expr->get_input_port(2), false); + } else if (brgemm->is_amx()) { + move_amx_scratchpad_buffer(linear_ir, expr_it); + loop_begin_it = std::prev(expr_it); + } std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; - loop_manager->mark_loop(expr_it, std::next(expr_it), m, block_size_m, 1, entries, exits); + loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits); } }; @@ -94,15 +111,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { *in_1_subtensor.rbegin() = block_size_n; *out_subtensor.rbegin() = block_size_n; + auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it); std::vector entries{LoopPort(brgemm_expr->get_input_port(0), false), LoopPort(brgemm_expr->get_input_port(1), true)}; - if (brgemm->is_with_scratchpad()) { - // The second input of Brgemm for AMX case is scratch buffer so it mustn't be incremented - const bool is_incremented = brgemm->is_with_compensations() ? true : false; - entries.emplace_back(brgemm_expr->get_input_port(2), is_incremented); + if (brgemm->is_with_compensations()) { + entries.emplace_back(brgemm_expr->get_input_port(2), true); + } else if (brgemm->is_amx()) { + move_amx_scratchpad_buffer(linear_ir, expr_it); + loop_begin_it = std::prev(expr_it); } std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; - loop_manager->mark_loop(expr_it, std::next(expr_it), n, block_size_n, 0, entries, exits); + loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits); } }; @@ -117,12 +136,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { *in_0_subtensor.rbegin() = block_size_k; *(in_1_subtensor.rbegin() + 1) = block_size_k; + auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it); std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true, 0), LoopPort(brgemm_expr->get_input_port(1), true, 1)}; - if (brgemm->is_with_scratchpad()) + if (brgemm->is_with_compensations()) { entries.emplace_back(brgemm_expr->get_input_port(2), false, 1); + } else if (brgemm->is_amx()) { + move_amx_scratchpad_buffer(linear_ir, expr_it); + loop_begin_it = std::prev(expr_it); + } std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; - auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), k, block_size_k, entries, exits); + auto loop_id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits); const auto loop_info = loop_manager->get_loop_info(loop_id); auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt loop_end_it) { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp index a720664076d9db..d6a5bd90ba961b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp @@ -21,6 +21,9 @@ class BrgemmBlocking : public snippets::lowered::pass::Pass { OPENVINO_RTTI("BrgemmBlocking", "Pass") BrgemmBlocking(); bool run(snippets::lowered::LinearIR& linear_ir) override; + +private: + static void move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it); }; } // namespace pass