From efe9dbd653756553d361b2a1e07b03844fa9cd1e Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 13 Jun 2023 10:22:30 +0400 Subject: [PATCH 1/4] [Snippets] Fixed Buffer insertion position --- .../snippets/src/lowered/pass/insert_buffers.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 0eb75a33749abf..1b1d65a504cb86 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -225,12 +225,26 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt } } + // potential_consumers is unsorted by linear IR set. + // We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers + OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least"); + auto consumer_expr = potential_consumers.begin()->get_expr(); + if (potential_consumers.size() > 1) { + std::set consumers; + for (const auto& port : potential_consumers) + consumers.insert(port.get_expr()); + const auto it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), + [&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; }); + OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR"); + consumer_expr = *it; + } + // We should insert Buffer between first different Loops. // Example: Current expr Loop identifies: 3, 2, 1 // Target consumers Loop identifies: 3, 4, 6 // Need to insert after 2nd Loops // Note: All potential consumers must have the same count of first equal Loop identifies and the same count of different last identifies - const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr()); + const auto pos = insertion_position(linear_ir, loop_manager, expr, consumer_expr); const auto allocation_shape = compute_allocation_shape(loop_manager, buffer_loop_ids, From 179dc4c94b78426e872ea5d91519a9ed47e22e0f Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 13 Jun 2023 10:26:26 +0400 Subject: [PATCH 2/4] [Snippets] Fixed unsafe insertion iterators: added asserts --- src/common/snippets/src/lowered/pass/insert_buffers.cpp | 4 +++- src/common/snippets/src/lowered/pass/insert_load_store.cpp | 5 ++++- .../src/lowered/pass/load_movebroadcast_to_broadcastload.cpp | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 1b1d65a504cb86..2de97447dea436 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -221,7 +221,9 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt const auto buffer_consumers_inputs = buffer_out->get_consumers(); linear_ir.replace_input(buffer_consumers_inputs, output_connector); potential_consumers.insert(buffer_consumers_inputs.begin(), buffer_consumers_inputs.end()); - linear_ir.erase(std::find(linear_ir.begin(), linear_ir.end(), buffer)); + const auto buffer_pos = std::find(linear_ir.cbegin(), linear_ir.cend(), buffer); + OPENVINO_ASSERT(buffer_pos != linear_ir.cend(), "Buffer has not been found in Linear IR"); + linear_ir.erase(buffer_pos); } } diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index 4899cc13c3635d..ab4b362484a080 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -49,7 +49,9 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr const auto load = std::make_shared(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0))); PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone()); const auto load_expr = linear_ir.create_expression(load, {output_connector}); - linear_ir.insert(std::find(data_expr_it, linear_ir.cend(), consumer_expr), load_expr); + const auto insertion_pos = std::find(data_expr_it, linear_ir.cend(), consumer_expr); + OPENVINO_ASSERT(insertion_pos != linear_ir.cend(), "Consumer should be after data producer in Linear IR"); + linear_ir.insert(insertion_pos, load_expr); linear_ir.replace_input(consumer_input, load_expr->get_output_port_connector(0)); // Copy Loop identifies load_expr->set_loop_ids(loop_ids); @@ -81,6 +83,7 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone()); const auto store_expr = linear_ir.create_expression(store, {input_connector}); const auto& reverse_insertion_pos = std::find(std::reverse_iterator(data_expr_it), linear_ir.crend(), parent_expr); + OPENVINO_ASSERT(reverse_insertion_pos != linear_ir.crend(), "Consumer should be after data producer in Linear IR"); const auto& insertion_pos = reverse_insertion_pos.base(); linear_ir.insert(insertion_pos, store_expr); linear_ir.replace_input(data_expr->get_input_port(0), store_expr->get_output_port_connector(0)); diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index f0384911ac8f3e..f3977e8ed8ed1b 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -56,7 +56,9 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) { const auto mv_expr_it = expr_it; const auto insertion_pos = std::next(expr_it); expr_it = linear_ir.insert(insertion_pos, broadcastload_expr); - linear_ir.erase(std::find(linear_ir.begin(), mv_expr_it, parent_expr)); + const auto load_it = std::find(linear_ir.begin(), mv_expr_it, parent_expr); + OPENVINO_ASSERT(load_it != mv_expr_it, "Failed fuse Load and MoveBroadcast: Load should be before MoveBroadcast in Linear IR"); + linear_ir.erase(load_it); linear_ir.erase(mv_expr_it); linear_ir.replace_input(move_consumers, broadcastload_expr->get_output_port_connector(0)); modified |= true; From d455a6cf9198b60efded52a4f6e36f2755350c7c Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 21 Jun 2023 10:14:45 +0400 Subject: [PATCH 3/4] [Snippets] Applied Ivan comments --- .../include/snippets/lowered/linear_ir.hpp | 37 ++++++++++-- .../snippets/lowered/pass/insert_buffers.hpp | 2 +- src/common/snippets/src/lowered/linear_ir.cpp | 56 +++++++++++++++++++ .../snippets/src/lowered/loop_manager.cpp | 8 +-- .../src/lowered/pass/insert_buffers.cpp | 24 ++++---- .../src/lowered/pass/insert_load_store.cpp | 8 +-- .../load_movebroadcast_to_broadcastload.cpp | 4 +- 7 files changed, 106 insertions(+), 33 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 3ea2464829fa0c..60000a95fc520c 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -31,6 +31,8 @@ class LinearIR { using io_container = std::list>; using exprIt = container::iterator; using constExprIt = container::const_iterator; + using exprReverseIt = container::reverse_iterator; + using constExprReverseIt = container::const_reverse_iterator; LinearIR() = default; explicit LinearIR(const std::shared_ptr& m, Config config = {}); @@ -69,10 +71,10 @@ class LinearIR { constExprIt end() const noexcept {return cend();} constExprIt cbegin() const noexcept {return m_expressions.cbegin();} constExprIt cend() const noexcept {return m_expressions.cend();} - container::reverse_iterator rbegin() noexcept {return m_expressions.rbegin();} - container::reverse_iterator rend() noexcept {return m_expressions.rend();} - container::const_reverse_iterator crbegin() const noexcept {return m_expressions.crbegin();} - container::const_reverse_iterator crend() const noexcept {return m_expressions.crend();} + exprReverseIt rbegin() noexcept {return m_expressions.rbegin();} + exprReverseIt rend() noexcept {return m_expressions.rend();} + constExprReverseIt crbegin() const noexcept {return m_expressions.crbegin();} + constExprReverseIt crend() const noexcept {return m_expressions.crend();} exprIt insert(constExprIt pos, const ov::NodeVector& nodes); exprIt insert(constExprIt pos, const std::shared_ptr& n); @@ -84,6 +86,21 @@ class LinearIR { exprIt erase(exprIt pos); exprIt erase(constExprIt pos); + template + iterator find(iterator begin, iterator end, const ExpressionPtr& target); + template + const_iterator find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const; + exprIt find(const ExpressionPtr& target); + constExprIt find(const ExpressionPtr& target) const; + template + iterator find_before(iterator it, const ExpressionPtr& target); + template + const_iterator find_before(const_iterator it, const ExpressionPtr& target) const; + template + iterator find_after(iterator it, const ExpressionPtr& target); + template + const_iterator find_after(const_iterator it, const ExpressionPtr& target) const; + void init_emitters(const std::shared_ptr& target); void serialize(const std::string& xml, const std::string& bin); @@ -107,6 +124,18 @@ class LinearIR { LoopManagerPtr m_loop_manager = nullptr; }; +template +iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) { + auto found = std::find(begin, end, target); + OPENVINO_ASSERT(found != end, "Expression has not been found"); + return found; +} +template +const_iterator LinearIR::find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const { + auto found = std::find(begin, end, target); + OPENVINO_ASSERT(found != end, "Expression has not been found"); + return found; +} } // namespace lowered } // namespace snippets } // namespace ov diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp index 2e4b8b4489d80c..004ea711288ab2 100644 --- a/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/insert_buffers.hpp @@ -28,7 +28,7 @@ class InsertBuffers : public Pass { bool run(LinearIR& linear_ir) override; private: - void insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager, + void insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager, const std::vector& loop_entries, const std::vector& loop_exits); diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 245238f328460f..f90ae20e2f90fa 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -260,6 +260,62 @@ void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) { m_expressions.splice(to, m_expressions, from); } +LinearIR::exprIt LinearIR::find(const ExpressionPtr& target) { + return find(begin(), end(), target); +} +LinearIR::constExprIt LinearIR::find(const ExpressionPtr& target) const { + return find(cbegin(), cend(), target); +} +template<> +LinearIR::exprIt LinearIR::find_before(LinearIR::exprIt it, const ExpressionPtr& target) { + return find(begin(), it, target); +} +template<> +LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) { + return find(cbegin(), it, target); +} +template<> +LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) const { + return find(cbegin(), it, target); +} +template<> +LinearIR::exprReverseIt LinearIR::find_before(LinearIR::exprReverseIt it, const ExpressionPtr& target) { + return find(rbegin(), it, target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { + return find(crbegin(), it, target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { + return find(crbegin(), it, target); +} +template<> +LinearIR::exprIt LinearIR::find_after(LinearIR::exprIt it, const ExpressionPtr& target) { + return find(it, end(), target); +} +template<> +LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) { + return find(it, cend(), target); +} +template<> +LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) const { + return find(it, cend(), target); +} +template<> +LinearIR::exprReverseIt LinearIR::find_after(LinearIR::exprReverseIt it, const ExpressionPtr& target) { + return find(it, rend(), target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { + return find(it, crend(), target); +} +template<> +LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { + return find(it, crend(), target); +} + + }// namespace lowered }// namespace snippets }// namespace ov diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index 928b846742f290..90c003b33d8173 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -87,8 +87,7 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, OPENVINO_ASSERT(!entries.empty(), "Loop must have entry points"); OPENVINO_ASSERT(!exits.empty(), "Loop must have entry points"); const auto& entry_expr = entries.front().expr_port->get_expr(); - loop_begin_pos = std::find(linear_ir.begin(), linear_ir.end(), entry_expr); - OPENVINO_ASSERT(loop_begin_pos != linear_ir.end(), "Loop begin hasn't been found!"); + loop_begin_pos = linear_ir.find(entry_expr); // Some operations in Loop can be before first entry points: Scalars, VectorBuffer. // We should iterate by them till the expr is in the corresponding Loop @@ -104,12 +103,11 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir, const auto loop_end = loop_begin->get_loop_end(); OPENVINO_ASSERT(loop_end->get_id() == loop_id, "Failed explicit loop bounds getting: Loop bounds with correct ID have not been found"); loop_begin_pos = std::prev(loop_begin_pos); - loop_end_pos = std::find(loop_begin_pos, linear_ir.end(), linear_ir.get_expr_by_node(loop_end)); + loop_end_pos = linear_ir.find_after(loop_begin_pos, linear_ir.get_expr_by_node(loop_end)); } else { // At the moment all Loops must have exit points const auto& exit_expr = exits.back().expr_port->get_expr(); - loop_end_pos = std::next(std::find(loop_begin_pos, linear_ir.end(), exit_expr)); - OPENVINO_ASSERT(loop_end_pos != linear_ir.end(), "Loop end hasn't been found!"); + loop_end_pos = std::next(linear_ir.find_after(loop_begin_pos, exit_expr)); } } diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 2de97447dea436..708f4520654bf6 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -83,15 +83,11 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i const auto down_loops = down_expr->get_loop_ids(); // If upper expression is out of Loop, we can insert Buffer implicitly after him if (up_loops.empty()) { - const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), up_expr); - OPENVINO_ASSERT(it != linear_ir.cend(), "Upper expression hasn't been found to insert Buffer after him!"); - return std::next(it); + return std::next(linear_ir.find(up_expr)); } // If lower expression is out of Loop, we can insert Buffer implicitly before him if (down_loops.empty()) { - const auto it = std::find(linear_ir.cbegin(), linear_ir.cend(), down_expr); - OPENVINO_ASSERT(it != linear_ir.cend(), "Lower expression hasn't been found to insert Buffer after him!"); - return it; + return linear_ir.find(down_expr); } const auto up_loop_count = up_loops.size(); @@ -120,7 +116,7 @@ LinearIR::constExprIt InsertBuffers::insertion_position(const LinearIR& linear_i OPENVINO_THROW("Incorrect configuration for Buffer insertion!"); } -void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager, +void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::constExprIt& expr_it, const LinearIR::LoopManagerPtr& loop_manager, const std::vector& loop_entries, const std::vector& loop_exits) { for (const auto& entry_point : loop_entries) { @@ -221,21 +217,20 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt const auto buffer_consumers_inputs = buffer_out->get_consumers(); linear_ir.replace_input(buffer_consumers_inputs, output_connector); potential_consumers.insert(buffer_consumers_inputs.begin(), buffer_consumers_inputs.end()); - const auto buffer_pos = std::find(linear_ir.cbegin(), linear_ir.cend(), buffer); - OPENVINO_ASSERT(buffer_pos != linear_ir.cend(), "Buffer has not been found in Linear IR"); - linear_ir.erase(buffer_pos); + linear_ir.erase(linear_ir.find_after(expr_it, buffer)); } } // potential_consumers is unsorted by linear IR set. // We have to find first expr in Linear IR from the set to insert Buffer before *all* consumers + // [113536]: Remove this logic with `std::find` using, when expression numeration will be supported OPENVINO_ASSERT(!potential_consumers.empty(), "Buffer should have one consumer at least"); auto consumer_expr = potential_consumers.begin()->get_expr(); if (potential_consumers.size() > 1) { std::set consumers; for (const auto& port : potential_consumers) consumers.insert(port.get_expr()); - const auto it = std::find_if(linear_ir.cbegin(), linear_ir.cend(), + const auto it = std::find_if(expr_it, linear_ir.cend(), [&consumers](const ExpressionPtr& expr) { return consumers.count(expr) > 0; }); OPENVINO_ASSERT(it != linear_ir.cend(), "Consumer of Buffer has not been found in Linear IR"); consumer_expr = *it; @@ -282,10 +277,11 @@ bool InsertBuffers::run(LinearIR& linear_ir) { const auto loop_info = loop_data.second; const auto loop_entries = loop_info->entry_points; const auto loop_exits = loop_info->exit_points; - insertion(linear_ir, loop_manager, loop_entries, loop_exits); + // using begin() as expr_it because we work with LoopInfo, not expressions in Linear IR + insertion(linear_ir, linear_ir.cbegin(), loop_manager, loop_entries, loop_exits); } - for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { + for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { const auto expr = *expr_it; const auto node = (*expr_it)->get_node(); const auto ma = ov::as_type_ptr(node); @@ -302,7 +298,7 @@ bool InsertBuffers::run(LinearIR& linear_ir) { loop_exits[p.first] = expr->get_output_port(p.first); } - insertion(linear_ir, loop_manager, loop_entries, loop_exits); + insertion(linear_ir, expr_it, loop_manager, loop_entries, loop_exits); } return true; diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index ab4b362484a080..40f802a649a9e9 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -49,9 +49,7 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr const auto load = std::make_shared(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0))); PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone()); const auto load_expr = linear_ir.create_expression(load, {output_connector}); - const auto insertion_pos = std::find(data_expr_it, linear_ir.cend(), consumer_expr); - OPENVINO_ASSERT(insertion_pos != linear_ir.cend(), "Consumer should be after data producer in Linear IR"); - linear_ir.insert(insertion_pos, load_expr); + linear_ir.insert(linear_ir.find_after(data_expr_it, consumer_expr), load_expr); linear_ir.replace_input(consumer_input, load_expr->get_output_port_connector(0)); // Copy Loop identifies load_expr->set_loop_ids(loop_ids); @@ -82,9 +80,7 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp const auto store = std::make_shared(parent->output(port), get_count(data_expr->get_input_port_descriptor(0))); PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone()); const auto store_expr = linear_ir.create_expression(store, {input_connector}); - const auto& reverse_insertion_pos = std::find(std::reverse_iterator(data_expr_it), linear_ir.crend(), parent_expr); - OPENVINO_ASSERT(reverse_insertion_pos != linear_ir.crend(), "Consumer should be after data producer in Linear IR"); - const auto& insertion_pos = reverse_insertion_pos.base(); + const auto& insertion_pos = linear_ir.find_after(std::reverse_iterator(data_expr_it), parent_expr).base(); linear_ir.insert(insertion_pos, store_expr); linear_ir.replace_input(data_expr->get_input_port(0), store_expr->get_output_port_connector(0)); // Copy Loop identifies diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index f3977e8ed8ed1b..df156f5775e698 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -56,9 +56,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) { const auto mv_expr_it = expr_it; const auto insertion_pos = std::next(expr_it); expr_it = linear_ir.insert(insertion_pos, broadcastload_expr); - const auto load_it = std::find(linear_ir.begin(), mv_expr_it, parent_expr); - OPENVINO_ASSERT(load_it != mv_expr_it, "Failed fuse Load and MoveBroadcast: Load should be before MoveBroadcast in Linear IR"); - linear_ir.erase(load_it); + linear_ir.erase(linear_ir.find_before(mv_expr_it, parent_expr)); linear_ir.erase(mv_expr_it); linear_ir.replace_input(move_consumers, broadcastload_expr->get_output_port_connector(0)); modified |= true; From 32cd75cf38535955469e4b46fb19c6ba7717b04e Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 4 Jul 2023 09:13:07 +0400 Subject: [PATCH 4/4] [Snippets] Left only const iterator methods --- .../include/snippets/lowered/linear_ir.hpp | 23 +++--------- src/common/snippets/src/lowered/linear_ir.cpp | 35 ------------------- .../src/lowered/pass/insert_tail_loop.cpp | 5 ++- .../load_movebroadcast_to_broadcastload.cpp | 2 +- 4 files changed, 8 insertions(+), 57 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 60000a95fc520c..511894a030eeb3 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -86,20 +86,13 @@ class LinearIR { exprIt erase(exprIt pos); exprIt erase(constExprIt pos); - template - iterator find(iterator begin, iterator end, const ExpressionPtr& target); - template - const_iterator find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const; - exprIt find(const ExpressionPtr& target); constExprIt find(const ExpressionPtr& target) const; template - iterator find_before(iterator it, const ExpressionPtr& target); - template - const_iterator find_before(const_iterator it, const ExpressionPtr& target) const; + iterator find(iterator begin, iterator end, const ExpressionPtr& target) const; + template + iterator find_before(iterator it, const ExpressionPtr& target) const; template - iterator find_after(iterator it, const ExpressionPtr& target); - template - const_iterator find_after(const_iterator it, const ExpressionPtr& target) const; + iterator find_after(iterator it, const ExpressionPtr& target) const; void init_emitters(const std::shared_ptr& target); void serialize(const std::string& xml, const std::string& bin); @@ -125,13 +118,7 @@ class LinearIR { }; template -iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) { - auto found = std::find(begin, end, target); - OPENVINO_ASSERT(found != end, "Expression has not been found"); - return found; -} -template -const_iterator LinearIR::find(const_iterator begin, const_iterator end, const ExpressionPtr& target) const { +iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) const { auto found = std::find(begin, end, target); OPENVINO_ASSERT(found != end, "Expression has not been found"); return found; diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index f90ae20e2f90fa..6246ddef8838a4 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -260,57 +260,22 @@ void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) { m_expressions.splice(to, m_expressions, from); } -LinearIR::exprIt LinearIR::find(const ExpressionPtr& target) { - return find(begin(), end(), target); -} LinearIR::constExprIt LinearIR::find(const ExpressionPtr& target) const { return find(cbegin(), cend(), target); } template<> -LinearIR::exprIt LinearIR::find_before(LinearIR::exprIt it, const ExpressionPtr& target) { - return find(begin(), it, target); -} -template<> -LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) { - return find(cbegin(), it, target); -} -template<> LinearIR::constExprIt LinearIR::find_before(LinearIR::constExprIt it, const ExpressionPtr& target) const { return find(cbegin(), it, target); } template<> -LinearIR::exprReverseIt LinearIR::find_before(LinearIR::exprReverseIt it, const ExpressionPtr& target) { - return find(rbegin(), it, target); -} -template<> -LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { - return find(crbegin(), it, target); -} -template<> LinearIR::constExprReverseIt LinearIR::find_before(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { return find(crbegin(), it, target); } template<> -LinearIR::exprIt LinearIR::find_after(LinearIR::exprIt it, const ExpressionPtr& target) { - return find(it, end(), target); -} -template<> -LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) { - return find(it, cend(), target); -} -template<> LinearIR::constExprIt LinearIR::find_after(LinearIR::constExprIt it, const ExpressionPtr& target) const { return find(it, cend(), target); } template<> -LinearIR::exprReverseIt LinearIR::find_after(LinearIR::exprReverseIt it, const ExpressionPtr& target) { - return find(it, rend(), target); -} -template<> -LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) { - return find(it, crend(), target); -} -template<> LinearIR::constExprReverseIt LinearIR::find_after(LinearIR::constExprReverseIt it, const ExpressionPtr& target) const { return find(it, crend(), target); } diff --git a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp index da5ec911a7a650..d71d5ba3c24364 100644 --- a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp +++ b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp @@ -98,8 +98,7 @@ void InsertTailLoop::tail_transformations(LinearIR& linear_ir, // Skip inner Loops const auto loop_begin = ov::as_type_ptr(expr_it->get()->get_node()); if (loop_begin) { - expr_it = std::find(expr_it, tail_end, linear_ir.get_expr_by_node(loop_begin->get_loop_end())); - OPENVINO_ASSERT(expr_it != tail_end, "LoopEnd has not been found"); + expr_it = linear_ir.find(expr_it, tail_end, linear_ir.get_expr_by_node(loop_begin->get_loop_end())); continue; } // We should fill vector regs by float_min and zero to have @@ -198,7 +197,7 @@ bool InsertTailLoop::run(LinearIR& linear_ir) { // finalization offsets which are supported by LoopEnd. if (need_tail) { const auto loop_begin = loop_end->get_loop_begin(); - const auto begin_it = std::find(linear_ir.begin(), linear_ir.end(), linear_ir.get_expr_by_node(loop_begin)); + const auto begin_it = linear_ir.find(linear_ir.get_expr_by_node(loop_begin)); LinearIR::constExprIt tail_begin, tail_end; const auto tail_loop_end = create_tail_loop(linear_ir, begin_it, std::next(expr_it), tail_begin, tail_end, loop_end, need_vector_loop, tail_size, tail_finalization_offsets); diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index df156f5775e698..cd4d57cfd2c941 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -19,7 +19,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) { const auto& loop_manager = linear_ir.get_loop_manager(); bool modified = false; - for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { + for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { const auto& expr = *expr_it; const auto& op = expr->get_node(); // Match on MoveBroadcast because MoveBroadcast is rare node in bodies