From 4e960a3d1880dc552d645bf2cbed39f37d829972 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 21 Jun 2023 10:14:45 +0400 Subject: [PATCH] [Snippets] Applied Ivan comments --- .../include/snippets/lowered/linear_ir.hpp | 37 ++++++++++-- .../snippets/lowered/pass/insert_buffers.hpp | 2 +- src/common/snippets/src/lowered/linear_ir.cpp | 56 +++++++++++++++++++ .../snippets/src/lowered/loop_manager.cpp | 8 +-- .../src/lowered/pass/insert_buffers.cpp | 24 ++++---- .../src/lowered/pass/insert_load_store.cpp | 8 +-- .../load_movebroadcast_to_broadcastload.cpp | 4 +- 7 files changed, 106 insertions(+), 33 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 3ea2464829fa0c..60000a95fc520c 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -31,6 +31,8 @@ class LinearIR { using io_container = std::list>; using exprIt = container::iterator; using constExprIt = container::const_iterator; + using exprReverseIt = container::reverse_iterator; + using constExprReverseIt = container::const_reverse_iterator; LinearIR() = default; explicit LinearIR(const std::shared_ptr& m, Config config = {}); @@ -69,10 +71,10 @@ class LinearIR { constExprIt end() const noexcept {return cend();} constExprIt cbegin() const noexcept {return m_expressions.cbegin();} constExprIt cend() const noexcept {return m_expressions.cend();} - container::reverse_iterator rbegin() noexcept {return m_expressions.rbegin();} - container::reverse_iterator rend() noexcept {return m_expressions.rend();} - container::const_reverse_iterator crbegin() const noexcept {return m_expressions.crbegin();} - container::const_reverse_iterator crend() const noexcept {return m_expressions.crend();} + exprReverseIt rbegin() noexcept {return m_expressions.rbegin();} + exprReverseIt rend() noexcept {return m_expressions.rend();} + constExprReverseIt crbegin() const noexcept {return m_expressions.crbegin();} + constExprReverseIt crend() const noexcept {return m_expressions.crend();} exprIt insert(constExprIt pos, const ov::NodeVector& nodes); exprIt insert(constExprIt pos, const std::shared_ptr& n); @@ -84,6 +86,21 @@ class LinearIR { exprIt erase(exprIt pos); exprIt erase(constExprIt pos); + template + iterator find(iterator begin, iterator end, const ExpressionPtr& target); + template + const_iterator find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const; + exprIt find(const ExpressionPtr& target); + constExprIt find(const ExpressionPtr& target) const; + template + iterator find_before(iterator it, const ExpressionPtr& target); + template + const_iterator find_before(const_iterator it, const ExpressionPtr& target) const; + template + iterator find_after(iterator it, const ExpressionPtr& target); + template + const_iterator find_after(const_iterator it, const ExpressionPtr& target) const; + void init_emitters(const std::shared_ptr& target); void serialize(const std::string& xml, const std::string& bin); @@ -107,6 +124,18 @@ class LinearIR { LoopManagerPtr m_loop_manager = nullptr; }; +template +iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) { + auto found = std::find(begin, end, target); + OPENVINO_ASSERT(found != end, "Expression has not been found"); + return found; +} +template +const_iterator LinearIR::find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const { + auto found = std::find(begin, end, target); + OPENVINO_ASSERT(found != end, "Expression has not been found"); + return found; +} } // namespace lowered } // namespace snippets } // namespace ov diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp index 2e4b8b4489d80c..004ea711288ab2 100644 --- a/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp @@ -28,7 +28,7 @@ class InsertBuffers : public Pass { bool run(LinearIR& linear_ir) override; private: - void insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager, + void insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager, const std::vector& loop_entries, const std::vector& loop_exits); diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 245238f328460f..f90ae20e2f90fa 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -260,6 +260,62 @@ void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) { m_expressions.splice(to, m_expressions, from); } +LinearIR::exprIt LinearIR::find(const ExpressionPtr& target) { + return find(begin(), end(), target); +} +LinearIR::constExprIt LinearIR::find(const ExpressionPtr& target) const { + return find(cbegin(), cend(), target); +} +template<> +LinearIR::exprIt LinearIR::find_before(LinearIR::exprIt it, const ExpressionPtr& target) { + return find(begin(), it, target); +} +template<> +LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) { + return find(cbegin(), it, target); +} +template<> +LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) const { + return find(cbegin(), it, target); +} +template<> +LinearIR::exprReverseIt LinearIR::find_before(LinearIR::exprReverseIt it, const ExpressionPtr& target) { + return find(rbegin(), it, target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { + return find(crbegin(), it, target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { + return find(crbegin(), it, target); +} +template<> +LinearIR::exprIt LinearIR::find_after(LinearIR::exprIt it, const ExpressionPtr& target) { + return find(it, end(), target); +} +template<> +LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) { + return find(it, cend(), target); +} +template<> +LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) const { + return find(it, cend(), target); +} +template<> +LinearIR::exprReverseIt LinearIR::find_after(LinearIR::exprReverseIt it, const ExpressionPtr& target) { + return find(it, rend(), target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { + return find(it, crend(), target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { + return find(it, crend(), target); +} + + }// namespace lowered }// namespace snippets }// namespace ov diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index 928b846742f290..90c003b33d8173 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -87,8 +87,7 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points"); OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points"); const auto& entry_expr = entries.front().expr_port->get_expr(); - loop_begin_pos = std::find(linear_ir.begin(), linear_ir.end(), entry_expr); - OPENVINO_ASSERT(loop_begin_pos != linear_ir.end(), "Loop begin hasn't been found!"); + loop_begin_pos = linear_ir.find(entry_expr); // Some operations in Loop can be before first entry points: Scalars, VectorBuffer. // We should iterate by them till the expr is in the corresponding Loop @@ -104,12 +103,11 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, const auto loop_end = loop_begin->get_loop_end(); OPENVINO_ASSERT(loop_end->get_id() == loop_id, "Failed explicit loop bounds getting: Loop bounds with correct ID have not been found"); loop_begin_pos = std::prev(loop_begin_pos); - loop_end_pos = std::find(loop_begin_pos, linear_ir.end(), linear_ir.get_expr_by_node(loop_end)); + loop_end_pos = linear_ir.find_after(loop_begin_pos, linear_ir.get_expr_by_node(loop_end)); } else { // At the moment all Loops must have exit points const auto& exit_expr = exits.back().expr_port->get_expr(); - loop_end_pos = std::next(std::find(loop_begin_pos, linear_ir.end(), exit_expr)); - OPENVINO_ASSERT(loop_end_pos != linear_ir.end(), "Loop end hasn't been found!"); + loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr)); } } diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 2de97447dea436..708f4520654bf6 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -83,15 +83,11 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i const auto down_loops = down_expr->get_loop_ids(); // If upper expression is out of Loop, we can insert Buffer implicitly after him if (up_loops.empty()) { - const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), up_expr); - OPENVINO_ASSERT(it != linear_ir.cend(), "Upper expression hasn't been found to insert Buffer after him!"); - return std::next(it); + return std::next(linear_ir.find(up_expr)); } // If lower expression is out of Loop, we can insert Buffer implicitly before him if (down_loops.empty()) { - const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), down_expr); - OPENVINO_ASSERT(it != linear_ir.cend(), "Lower expression hasn't been found to insert Buffer after him!"); - return it; + return linear_ir.find(down_expr); } const auto up_loop_count = up_loops.size(); @@ -120,7 +116,7 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i OPENVINO_THROW("Incorrect configuration for Buffer insertion!"); } -void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager, +void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager, const std::vector& loop_entries, const std::vector& loop_exits) { for (const auto& entry_point : loop_entries) { @@ -221,21 +217,20 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt const auto buffer_consumers_inputs = buffer_out->get_consumers(); linear_ir.replace_input(buffer_consumers_inputs, output_connector); potential_consumers.insert(buffer_consumers_inputs.begin(), buffer_consumers_inputs.end()); - const auto buffer_pos = std::find(linear_ir.cbegin(), linear_ir.cend(), buffer); - OPENVINO_ASSERT(buffer_pos != linear_ir.cend(), "Buffer has not been found in Linear IR"); - linear_ir.erase(buffer_pos); + linear_ir.erase(linear_ir.find_after(expr_it, buffer)); } } // potential_consumers is unsorted by linear IR set. // We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers + // [113536]: Remove this logic with `std::find` using, when expression numeration will be supported OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least"); auto consumer_expr = potential_consumers.begin()->get_expr(); if (potential_consumers.size() > 1) { std::set consumers; for (const auto& port : potential_consumers) consumers.insert(port.get_expr()); - const auto it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), + const auto it = std::find_if(expr_it, linear_ir.cend(), [&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; }); OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR"); consumer_expr = *it; @@ -282,10 +277,11 @@ bool InsertBuffers::run(LinearIR& linear_ir) { const auto loop_info = loop_data.second; const auto loop_entries = loop_info->entry_points; const auto loop_exits = loop_info->exit_points; - insertion(linear_ir, loop_manager, loop_entries, loop_exits); + // using begin() as expr_it because we work with LoopInfo, not expressions in Linear IR + insertion(linear_ir, linear_ir.cbegin(), loop_manager, loop_entries, loop_exits); } - for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { + for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { const auto expr = *expr_it; const auto node = (*expr_it)->get_node(); const auto ma = ov::as_type_ptr(node); @@ -302,7 +298,7 @@ bool InsertBuffers::run(LinearIR& linear_ir) { loop_exits[p.first] = expr->get_output_port(p.first); } - insertion(linear_ir, loop_manager, loop_entries, loop_exits); + insertion(linear_ir, expr_it, loop_manager, loop_entries, loop_exits); } return true; diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index ab4b362484a080..40f802a649a9e9 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -49,9 +49,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr const auto load = std::make_shared(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0))); PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone()); const auto load_expr = linear_ir.create_expression(load, {output_connector}); - const auto insertion_pos = std::find(data_expr_it, linear_ir.cend(), consumer_expr); - OPENVINO_ASSERT(insertion_pos != linear_ir.cend(), "Consumer should be after data producer in Linear IR"); - linear_ir.insert(insertion_pos, load_expr); + linear_ir.insert(linear_ir.find_after(data_expr_it, consumer_expr), load_expr); linear_ir.replace_input(consumer_input, load_expr->get_output_port_connector(0)); // Copy Loop identifies load_expr->set_loop_ids(loop_ids); @@ -82,9 +80,7 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp const auto store = std::make_shared(parent->output(port), get_count(data_expr->get_input_port_descriptor(0))); PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone()); const auto store_expr = linear_ir.create_expression(store, {input_connector}); - const auto& reverse_insertion_pos = std::find(std::reverse_iterator(data_expr_it), linear_ir.crend(), parent_expr); - OPENVINO_ASSERT(reverse_insertion_pos != linear_ir.crend(), "Consumer should be after data producer in Linear IR"); - const auto& insertion_pos = reverse_insertion_pos.base(); + const auto& insertion_pos = linear_ir.find_after(std::reverse_iterator(data_expr_it), parent_expr).base(); linear_ir.insert(insertion_pos, store_expr); linear_ir.replace_input(data_expr->get_input_port(0), store_expr->get_output_port_connector(0)); // Copy Loop identifies diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index f3977e8ed8ed1b..df156f5775e698 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -56,9 +56,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) { const auto mv_expr_it = expr_it; const auto insertion_pos = std::next(expr_it); expr_it = linear_ir.insert(insertion_pos, broadcastload_expr); - const auto load_it = std::find(linear_ir.begin(), mv_expr_it, parent_expr); - OPENVINO_ASSERT(load_it != mv_expr_it, "Failed fuse Load and MoveBroadcast: Load should be before MoveBroadcast in Linear IR"); - linear_ir.erase(load_it); + linear_ir.erase(linear_ir.find_before(mv_expr_it, parent_expr)); linear_ir.erase(mv_expr_it); linear_ir.replace_input(move_consumers, broadcastload_expr->get_output_port_connector(0)); modified |= true;