From 8d938067d9796f25a24626ae6ff7523493c63cf9 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 11 Sep 2023 13:20:21 +0200 Subject: [PATCH] [LIR] Added custom first iteration handler for K blocking --- .../include/snippets/lowered/loop_manager.hpp | 20 +++ .../lowered/pass/insert_tail_loop.hpp | 3 +- .../snippets/src/lowered/loop_manager.cpp | 4 + .../src/lowered/pass/insert_tail_loop.cpp | 124 +++++++++++------- src/common/snippets/src/op/loop.cpp | 1 + .../x64/pass/lowered/brgemm_blocking.cpp | 35 ++++- .../snippets/matmul.cpp | 1 + 7 files changed, 137 insertions(+), 51 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index c9e08cd6f3c177..4827a64a064709 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -53,6 +53,10 @@ class LinearIR::LoopManager { // Returns dimension index if dimension indices for all entry and exit points are equal, and SIZE_MAX otherwise size_t get_dim_idx() const; + using FirstIterHandler = std::function; + void set_first_iter_handler(FirstIterHandler handler); + FirstIterHandler fst_iter_handler = nullptr; + size_t work_amount = 0; size_t increment = 0; // The order of entry and exit expressions is important: @@ -115,6 +119,22 @@ class LinearIR::LoopManager { return loop_id; } + template + size_t mark_loop_with_old_loop_replacement(LinearIR::constExprIt loop_begin_pos, + LinearIR::constExprIt loop_end_pos, + size_t work_amount, + size_t increment, + const std::vector& entries, + const std::vector& exits, + const size_t old_id) { + const auto loop_info = std::make_shared(work_amount, increment, entries, exits); + const auto loop_id = this->add_loop_info(loop_info); + for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) { + replace_loop_id(*expr_it, old_id, loop_id); + } + return loop_id; + } + void fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true); void fuse_loops(LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true); diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_tail_loop.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_tail_loop.hpp index 8801d4c7130ec4..78b1db9cba524c 100644 --- a/src/common/snippets/include/snippets/lowered/pass/insert_tail_loop.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/insert_tail_loop.hpp @@ -23,6 +23,8 @@ class InsertTailLoop : public Pass { public: OPENVINO_RTTI("InsertTailLoop", "Pass") bool run(LinearIR& linear_ir) override; + static bool optimize_single_evaluation(const std::shared_ptr& loop); + static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id); private: static std::shared_ptr create_tail_loop(LinearIR& linear_ir, @@ -37,7 +39,6 @@ class InsertTailLoop : public Pass { LinearIR::constExprIt tail_begin, LinearIR::constExprIt tail_end, size_t tail_size); - static bool optimize_single_evaluation(const std::shared_ptr& loop); }; } // namespace pass diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index e7a14a677921e0..0183268e5f8720 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -41,6 +41,10 @@ size_t LinearIR::LoopManager::LoopInfo::get_dim_idx() const { } } +void LinearIR::LoopManager::LoopInfo::set_first_iter_handler(FirstIterHandler handler) { + fst_iter_handler = std::move(handler); +} + bool operator==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) { if (&lhs == &rhs) return true; diff --git a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp index 9112593962cd65..8e38e379971b48 100644 --- a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp +++ b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp @@ -14,21 +14,60 @@ namespace ov { namespace snippets { namespace lowered { namespace pass { +LinearIR::container InsertTailLoop::copy_loop(const LinearIR& linear_ir, const size_t loop_id) { + const auto& loop_manager = linear_ir.get_loop_manager(); + const auto original_loop_info = loop_manager->get_loop_info(loop_id); + auto new_entry_points = original_loop_info->entry_points; + auto new_exit_points = original_loop_info->exit_points; -namespace { -void replace_ports_with_tail_ports(const ExpressionPtr& expr, - const ExpressionPtr& tail_expr, - std::vector& ports) { - auto find_if_predicate = [&](const LinearIR::LoopManager::LoopPort& port) { - return port.expr_port->get_expr()->get_node() == expr->get_node(); + auto update_loop_ports = [](const ExpressionPtr& expr, + const ExpressionPtr& tail_expr, + std::vector& ports) { + auto find_if_predicate = [&](const LinearIR::LoopManager::LoopPort& port) { + return port.expr_port->get_expr()->get_node() == expr->get_node(); + }; + auto pos = std::find_if(ports.begin(), ports.end(), find_if_predicate); + while (pos != ports.end()) { + pos->expr_port = std::make_shared(tail_expr, pos->expr_port->get_type(), pos->expr_port->get_index()); + pos = std::find_if(pos, ports.end(), find_if_predicate); + } }; - auto pos = std::find_if(ports.begin(), ports.end(), find_if_predicate); - while (pos != ports.end()) { - pos->expr_port = std::make_shared(tail_expr, pos->expr_port->get_type(), pos->expr_port->get_index()); - pos = std::find_if(pos, ports.end(), find_if_predicate); - } + + auto update_loop_info = [&](const ExpressionPtr& expr, const ExpressionPtr& new_expr) { + const auto node = expr->get_node(); + // Loop begin/end ops can't be loop ports + if (ov::is_type(node)) + return; + // Clone loop ports from original loop info to tail loop info + update_loop_ports(expr, new_expr, new_entry_points); + update_loop_ports(expr, new_expr, new_exit_points); + + // Update loop info of all inner loops with new loop ports + const auto loop_ids = expr->get_loop_ids(); + auto cur_id_pos = std::find(loop_ids.begin(), loop_ids.end(), loop_id); + std::vector inner_loop_ids(loop_ids.begin(), cur_id_pos); + for (size_t i = 0; i < expr->get_input_count(); ++i) + loop_manager->update_loops_port(inner_loop_ids, expr->get_input_port(i), {expr->get_input_port(i), new_expr->get_input_port(i)}, true); + for (size_t i = 0; i < expr->get_output_count(); ++i) + loop_manager->update_loops_port(inner_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false); + }; + + LinearIR::constExprIt loop_begin_pos, loop_end_pos; + loop_manager->get_loop_bounds(linear_ir, loop_id, loop_begin_pos, loop_end_pos, true); + const auto loop_copy_range = LinearIR::deep_copy_range(loop_begin_pos, std::next(loop_end_pos), update_loop_info); + const auto new_loop_begin_pos = loop_copy_range.begin(); + const auto new_loop_end_pos = loop_copy_range.end(); + const auto new_id = loop_manager->mark_loop_with_old_loop_replacement(std::next(new_loop_begin_pos), + std::prev(new_loop_end_pos), + original_loop_info->work_amount, + original_loop_info->increment, + new_entry_points, + new_exit_points, + loop_id); + const auto loop_end = ov::as_type_ptr(std::prev(new_loop_end_pos)->get()->get_node()); + loop_end->set_id(new_id); + return loop_copy_range; } -} // namespace std::shared_ptr InsertTailLoop::create_tail_loop(LinearIR& linear_ir, LinearIR::constExprIt vector_begin, @@ -47,41 +86,16 @@ std::shared_ptr InsertTailLoop::create_tail_loop(LinearIR& linear_i const auto& original_loop_info = loop_manager->get_loop_info(original_loop_id); auto tail_loop_info = original_loop_info; if (need_vector_loop) { - auto tail_entry_points = original_loop_info->entry_points; - auto tail_exit_points = original_loop_info->exit_points; - - auto update_loop_info = [&](const ExpressionPtr& expr, const ExpressionPtr& new_expr) { - const auto node = expr->get_node(); - // Loop begin/end ops can't be loop ports - if (ov::is_type(node)) - return; - // Clone loop ports from original loop info to tail loop info - replace_ports_with_tail_ports(expr, new_expr, tail_entry_points); - replace_ports_with_tail_ports(expr, new_expr, tail_exit_points); + const auto new_loop_range = copy_loop(linear_ir, original_loop_id); + const auto loop_end = ov::as_type_ptr(std::prev(new_loop_range.end())->get()->get_node()); + loop_end->set_work_amount(tail_size); + loop_end->set_increment(tail_size); + tail_loop_info = loop_manager->get_loop_info(loop_end->get_id()); + tail_loop_info->work_amount = tail_size; + tail_loop_info->increment = tail_size; - // Update loop info of all inner loops with new loop ports - const auto loop_ids = expr->get_loop_ids(); - auto cur_id_pos = std::find(loop_ids.begin(), loop_ids.end(), original_loop_id); - std::vector inner_loop_ids(loop_ids.begin(), cur_id_pos); - for (size_t i = 0; i < expr->get_input_count(); ++i) - loop_manager->update_loops_port(inner_loop_ids, expr->get_input_port(i), {expr->get_input_port(i), new_expr->get_input_port(i)}, true); - for (size_t i = 0; i < expr->get_output_count(); ++i) - loop_manager->update_loops_port(inner_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false); - }; - auto vector_loop_deep_copy = LinearIR::deep_copy_range(vector_begin, vector_end, update_loop_info); - tail_begin = linear_ir.insert(vector_end, vector_loop_deep_copy.begin(), vector_loop_deep_copy.end()); + tail_begin = linear_ir.insert(vector_end, new_loop_range.begin(), new_loop_range.end()); tail_end = vector_end; - - const auto new_id = loop_manager->mark_loop(std::next(tail_begin), - std::prev(tail_end), - tail_size, - tail_size, - tail_entry_points, - tail_exit_points); - const auto loop_begin = ov::as_type_ptr(tail_begin->get()->get_node()); - const auto loop_end = loop_begin->get_loop_end(); - loop_end->set_id(new_id); - tail_loop_info = loop_manager->get_loop_info(new_id); } else { tail_begin = vector_begin; tail_end = vector_end; @@ -140,6 +154,9 @@ std::shared_ptr InsertTailLoop::create_tail_loop(LinearIR& linear_i tail_loop_end->set_work_amount(tail_size); tail_loop_end->set_finalization_offsets(tail_finalization_offsets); tail_loop_end->has_outer_loop = vector_loop_end->has_outer_loop; + const auto new_vector_loop_wa = original_loop_info->work_amount - tail_size; + original_loop_info->work_amount = new_vector_loop_wa; + vector_loop_end->set_work_amount(new_vector_loop_wa); return tail_loop_end; } @@ -171,21 +188,23 @@ void InsertTailLoop::tail_transformations(LinearIR& linear_ir, // correct math calculations for ReduceMax and ReduceSum in scalar case. // Note: We find Maximum and Add ops because HorizonMax and HorizonSum are outside Loop, // so they are missed in - auto op = (*expr_it)->get_node(); + const auto& expr = *expr_it; + const auto op = expr->get_node(); if (config.m_need_fill_tail_register && (ov::is_type(op) || ov::is_type(op))) { for (size_t i = 0; i < op->inputs().size(); ++i) { if (auto fill = insertFill(op->input(i))) { - const auto& input = expr_it->get()->get_input_port_connector(i); + const auto& input = expr->get_input_port_connector(i); const auto consumers = input->get_consumers(); auto fill_expr = linear_ir.create_expression(fill, {input}); linear_ir.insert(expr_it, fill_expr); linear_ir.replace_input(consumers, fill_expr->get_output_port_connector(0)); // in_reg == out_reg since we want to modify vector reg inplace - const auto reg = expr_it->get()->get_input_port_descriptor(0)->get_reg(); + const auto reg = expr->get_input_port_descriptor(0)->get_reg(); fill_expr->get_input_port_descriptor(0)->set_reg(reg); fill_expr->get_output_port_descriptor(0)->set_reg(reg); + fill_expr->set_loop_ids(expr->get_loop_ids()); } } } else if (const auto memory_access = std::dynamic_pointer_cast(op)) { @@ -239,9 +258,16 @@ bool InsertTailLoop::run(LinearIR& linear_ir) { if (!loop_end) continue; + const auto loop_info = loop_manager->get_loop_info(loop_end->get_id()); + if (loop_info->fst_iter_handler != nullptr) { + modified |= loop_info->fst_iter_handler(linear_ir, expr_it); + } + + if (loop_end->get_evaluate_once() == true) + continue; + const auto work_amount = loop_end->get_work_amount(); const auto increment = loop_end->get_increment(); - const auto loop_info = loop_manager->get_loop_info(loop_end->get_id()); const auto tail_size = work_amount % increment; const auto need_tail = tail_size != 0; const auto need_vector_loop = work_amount >= increment; diff --git a/src/common/snippets/src/op/loop.cpp b/src/common/snippets/src/op/loop.cpp index be42004275634d..ac31459eb55e6b 100644 --- a/src/common/snippets/src/op/loop.cpp +++ b/src/common/snippets/src/op/loop.cpp @@ -198,6 +198,7 @@ bool LoopEnd::visit_attributes(AttributeVisitor &visitor) { visitor.on_attribute("input_num", m_input_num); visitor.on_attribute("output_num", m_output_num); visitor.on_attribute("id", m_id); + visitor.on_attribute("evaluate_once", m_evaluate_once); return true; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index 8de4a0c09e1711..701e37ad4961c3 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -9,6 +9,7 @@ #include "snippets/itt.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/loop_manager.hpp" +#include "snippets/lowered/pass/insert_tail_loop.hpp" #include "snippets/snippets_isa.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" @@ -124,7 +125,39 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { entries.emplace_back(brgemm_expr->get_input_port(2), true, 1); std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), k, block_size_k, entries, exits); - loop_manager->get_loop_info(loop_id)->brgemm_k_blocking_loop = true; + const auto loop_info = loop_manager->get_loop_info(loop_id); + loop_info->brgemm_k_blocking_loop = true; + + auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt expr_it) { + const auto loop_end = ov::as_type_ptr(expr_it->get()->get_node()); + const auto loop_id = loop_end->get_id(); + const auto& loop_manager = linear_ir.get_loop_manager(); + const auto& loop_info = loop_manager->get_loop_info(loop_id); + const auto work_amount = loop_info->work_amount; + const auto increment = loop_info->increment; + if (!loop_info->brgemm_k_blocking_loop || work_amount <= increment) + return false; + + auto new_loop_range = snippets::lowered::pass::InsertTailLoop::copy_loop(linear_ir, loop_id); + const auto new_loop_end = ov::as_type_ptr(std::prev(new_loop_range.end())->get()->get_node()); + new_loop_end->set_work_amount(work_amount - increment); + auto new_loop_info = loop_manager->get_loop_info(new_loop_end->get_id()); + new_loop_info->work_amount = work_amount - increment; + for (const auto expr : new_loop_range) { + if (const auto brgemm = ov::as_type_ptr(expr->get_node())) { + brgemm->set_beta(1.f); + } + } + + linear_ir.insert(std::next(expr_it), new_loop_range.begin(), new_loop_range.end()); + + loop_info->work_amount = increment; + loop_end->set_work_amount(increment); + loop_end->set_finalization_offsets(std::vector(loop_end->get_finalization_offsets().size(), 0)); + snippets::lowered::pass::InsertTailLoop::optimize_single_evaluation(loop_end); + return true; + }; + loop_info->set_first_iter_handler(first_iter_handler); } }; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 630cf4374ac448..9dd40aafdf23c6 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -20,6 +20,7 @@ std::vector> input_shapes{ {{1, 1, 32, 23}, {1, 1, 23, 68}}, {{1, 16, 384, 64}, {1, 16, 64, 384}}, {{1, 1, 100, 700}, {1, 1, 700, 100}}, + {{1, 1, 100, 2500}, {1, 1, 2500, 100}}, }; static inline std::vector> quantized_precisions() {