diff --git a/src/common/snippets/include/snippets/lowered/expression.hpp b/src/common/snippets/include/snippets/lowered/expression.hpp index 0b619370ab47a5..6731e369ae0921 100644 --- a/src/common/snippets/include/snippets/lowered/expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expression.hpp @@ -19,7 +19,7 @@ namespace lowered { class LinearIR; using ExpressionPtr = std::shared_ptr; -using ExressionMap = std::unordered_map; +using ExpressionMap = std::unordered_map; class Expression : public std::enable_shared_from_this { friend class LinearIR; friend class ExpressionPort; @@ -63,7 +63,7 @@ class Expression : public std::enable_shared_from_this { void set_loop_ids(const std::vector& loops); virtual ExpressionPtr clone_with_new_inputs(const std::vector& new_inputs, const std::shared_ptr& new_node) const; - ExpressionPtr clone_with_new_inputs(const ExressionMap& expr_map, const std::shared_ptr& new_node) const; + ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr& new_node) const; protected: Expression(const Expression& other); diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index f1649cd82aed3c..c9c5e6963a2924 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -70,7 +70,7 @@ class LinearIR { std::shared_ptr clone() const; static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin, LinearIR::container::const_iterator end, - ExressionMap& expression_map); + ExpressionMap& expression_map); const container& get_ops() const { return m_expressions; } const io_container& get_IO_ops() const { return m_io_expressions; } diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index 76dc002fc7b36e..87112de371a4a8 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -89,7 +89,7 @@ class LinearIR::LoopManager { const std::vector& exits, const SpecificIterationHandlers& handlers = SpecificIterationHandlers()); - std::shared_ptr clone_with_new_expr(const ExressionMap& expr_map) const; + std::shared_ptr clone_with_new_expr(const ExpressionMap& expr_map) const; // Returns dimension index if dimension indices for all entry and exit points are equal, and UNDEFINED_DIM_IDX otherwise size_t get_dim_idx() const; @@ -125,7 +125,7 @@ class LinearIR::LoopManager { }; using LoopInfoPtr = std::shared_ptr; - std::shared_ptr clone_with_new_expr(const ExressionMap& expr_map) const; + std::shared_ptr clone_with_new_expr(const ExpressionMap& expr_map) const; size_t add_loop_info(const LoopInfoPtr& loop); void remove_loop_info(size_t index); LoopInfoPtr get_loop_info(size_t index) const; diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_specific_iterations.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_specific_iterations.hpp index 842f54b0cf75cf..15d2703d3f8e6d 100644 --- a/src/common/snippets/include/snippets/lowered/pass/insert_specific_iterations.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/insert_specific_iterations.hpp @@ -22,7 +22,16 @@ class InsertSpecificIterations : public RangedPass { OPENVINO_RTTI("InsertSpecificIterations", "RangedPass") bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override; - static LinearIR::container copy_loop(const LinearIR& linear_ir, const size_t loop_id); + /** + * @brief Makes a copy of a loop body with id 'loop_id' and inserts it to the LinearIR before the 'insert_pos' position + * @param linear_ir LinearIR which should be modified + * @param loop_id id of the loop which should be copied + * @param insert_pos position before which the loop body copy should be inserted + * @return iterator which points on the LoopBegin copy + */ + static LinearIR::constExprIt insert_copy_loop(LinearIR& linear_ir, + const size_t loop_id, + const LinearIR::constExprIt& insert_pos); }; } // namespace pass diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index f33f3aeef95fc3..5c2a190dbf66a0 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -156,7 +156,7 @@ ExpressionPtr Expression::clone_with_new_inputs(const std::vector& new_node) const { std::vector new_inputs; new_inputs.reserve(m_input_port_connectors.size()); diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index a29d8d2045a6f1..b79292cd085299 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -47,7 +47,7 @@ std::shared_ptr LinearIR::clone() const { auto cloned = std::make_shared(); cloned->m_config = m_config; - ExressionMap expression_map; + ExpressionMap expression_map; cloned->m_expressions = deep_copy_range(m_expressions.cbegin(), m_expressions.cend(), expression_map); for (const auto& expr : cloned->m_expressions) { cloned->m_node2expression_map[expr->get_node()] = expr; @@ -106,7 +106,7 @@ ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr& m) { LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin, LinearIR::container::const_iterator end, - ExressionMap& expression_map) { + ExpressionMap& expression_map) { OPENVINO_ASSERT(expression_map.empty(), "deep_copy_range expects empty expression_map as an input"); LinearIR::container result; NodeVector original_nodes; diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index bab26cb6ee00c2..76dd2627019f48 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -102,7 +102,7 @@ LoopInfo::LoopInfo(size_t work_amount, m_exit_points.emplace_back(port); } -std::shared_ptr LoopInfo::clone_with_new_expr(const ExressionMap& expr_map) const { +std::shared_ptr LoopInfo::clone_with_new_expr(const ExpressionMap& expr_map) const { auto clone_loop_ports = [&expr_map](const std::vector& port_points) { std::vector cloned_port_points; cloned_port_points.reserve(port_points.size()); @@ -197,7 +197,7 @@ bool operator<(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopM (lhs.is_incremented == rhs.is_incremented && lhs.dim_idx < rhs.dim_idx))); } -std::shared_ptr LoopManager::clone_with_new_expr(const ExressionMap& expr_map) const { +std::shared_ptr LoopManager::clone_with_new_expr(const ExpressionMap& expr_map) const { auto new_loop_manager = std::make_shared(); for (const auto& id_info : m_map) new_loop_manager->m_map.insert({id_info.first, id_info.second->clone_with_new_expr(expr_map)}); diff --git a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp index ac0573e830c380..783569b97cac73 100644 --- a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp +++ b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp @@ -15,12 +15,14 @@ namespace snippets { namespace lowered { namespace pass { -LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_ir, const size_t loop_id) { +LinearIR::constExprIt InsertSpecificIterations::insert_copy_loop(LinearIR& linear_ir, const size_t loop_id, const LinearIR::constExprIt& insert_pos) { const auto& loop_manager = linear_ir.get_loop_manager(); LinearIR::constExprIt loop_begin_pos, loop_end_pos; loop_manager->get_loop_bounds(linear_ir, loop_id, loop_begin_pos, loop_end_pos, true); - ExressionMap expression_map; + ExpressionMap expression_map; const auto& loop_copy_range = LinearIR::deep_copy_range(loop_begin_pos, std::next(loop_end_pos), expression_map); + const auto new_loop_begin_pos = linear_ir.insert(insert_pos, loop_copy_range.begin(), loop_copy_range.end()); + const auto new_loop_end_pos = insert_pos; const auto original_loop_info = loop_manager->get_loop_info(loop_id); std::vector new_entry_points, new_exit_points; @@ -44,8 +46,6 @@ LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_i loop_manager->update_loops_port(outer_loop_ids, expr->get_output_port(i), {expr->get_output_port(i), new_expr->get_output_port(i)}, false); } - const auto new_loop_begin_pos = loop_copy_range.begin(); - const auto new_loop_end_pos = loop_copy_range.end(); const auto new_id = loop_manager->replace_with_new_loop(linear_ir, std::next(new_loop_begin_pos), std::prev(new_loop_end_pos), @@ -57,7 +57,7 @@ LinearIR::container InsertSpecificIterations::copy_loop(const LinearIR& linear_i const auto loop_end = ov::as_type_ptr(std::prev(new_loop_end_pos)->get()->get_node()); OPENVINO_ASSERT(loop_end, "Cloned Loop does not contain LoopEnd op at the expected place."); loop_end->set_id(new_id); - return loop_copy_range; + return new_loop_begin_pos; } using LoopInfo = LinearIR::LoopManager::LoopInfo; @@ -100,13 +100,16 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir, lowered::LinearIR::const }; auto copy_and_run_specific_handlers = [&](const PassPipeline& handlers) { - const auto& cloned_body = copy_loop(linear_ir, loop_end->get_id()); - lowered::LinearIR::constExprIt start = linear_ir.insert(main_loop_begin_it, cloned_body.begin(), cloned_body.end()); - const auto cloned_loop_end = *std::prev(cloned_body.end()); - auto end = linear_ir.find_after(start, cloned_loop_end); + const auto new_loop_begin_pos = insert_copy_loop(linear_ir, loop_end->get_id(), main_loop_begin_it); + const auto new_loop_begin = ov::as_type_ptr(new_loop_begin_pos->get()->get_node()); + OPENVINO_ASSERT(new_loop_begin, "Cloned Loop does not contain LoopBegin op at the expected place."); + const auto new_loop_end = new_loop_begin->get_loop_end(); + const auto new_loop_end_pos = linear_ir.find_after(new_loop_begin_pos, linear_ir.get_expr_by_node(new_loop_end)); + OPENVINO_ASSERT(new_loop_end, "Cloned Loop does not contain LoopEnd op at the expected place."); + // Note: handlers must be run on the range started with the first operation in the loop body. - handlers.run(linear_ir, std::next(start), end); - return ov::as_type_ptr(cloned_loop_end->get_node()); + handlers.run(linear_ir, std::next(new_loop_begin_pos), new_loop_end_pos); + return new_loop_end; }; const bool specific_first_iteration = !handlers.get_first_iter_handelrs().empty();