diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 0df8a5c58ea6cb..ac36de99c8a750 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -76,7 +76,7 @@ class LinearIR { const container& get_ops() const { return m_expressions; } const io_container& get_IO_ops() const { return m_io_expressions; } - Config get_config() const { return m_config; } + const Config& get_config() const { return m_config; } void set_loop_depth(size_t loop_depth) { m_config.m_loop_depth = loop_depth; } const ExpressionPtr& get_expr_by_node(const std::shared_ptr& n) const; diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index ce8d6dd0b156a8..c8a24021e2994c 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -42,15 +42,15 @@ class LinearIR::LoopManager { class LoopInfo { public: enum {UNDEFINED_DIM_IDX = std::numeric_limits::max()}; + // This enum is used for loop specific iterations handlers enumeration + enum {FIRST_ITER, MAIN_BODY, LAST_ITER}; LoopInfo() = default; LoopInfo(size_t work_amount, size_t increment, const std::vector& entries, - const std::vector& exits, - bool outer_splited_loop = false); + const std::vector& exits); LoopInfo(size_t work_amount, size_t increment, const std::vector& entries, - const std::vector& exits, - bool outer_splited_loop = false); + const std::vector& exits); std::shared_ptr clone_with_new_expr(const ExressionMap& expr_map) const; @@ -60,7 +60,7 @@ class LinearIR::LoopManager { size_t get_increment() const; const std::vector& get_entry_points() const; const std::vector& get_exit_points() const; - bool get_outer_splited_loop() const; + const std::vector& get_handlers() const; // Sets dim_idx to all entry and exit points void set_dim_idx(size_t dim_idx); @@ -68,10 +68,8 @@ class LinearIR::LoopManager { void set_increment(size_t increment); void set_entry_points(std::vector entry_points); void set_exit_points(std::vector exit_points); - void set_outer_splited_loop(bool outer_splited_loop); - - enum {FIRST_ITER, MAIN_BODY, LAST_ITER}; - std::vector handlers; + void set_handlers(std::vector handlers); + void set_default_handlers(); private: size_t m_work_amount = 0; @@ -82,8 +80,7 @@ class LinearIR::LoopManager { // Note: Scalars aren't entry expressions but can be before first entry expr in Linear IR std::vector m_entry_points = {}; std::vector m_exit_points = {}; - // True if this Loop is outer Loop for nested Loops that splits the same dimension - bool m_outer_splited_loop = false; + std::vector m_handlers = {}; }; using LoopInfoPtr = std::shared_ptr; @@ -112,16 +109,14 @@ class LinearIR::LoopManager { const std::vector& entries, const std::vector& exits, bool set_default_handlers = true) { - if (increment > work_amount) - increment = work_amount; - const auto loop_info = std::make_shared(work_amount, increment, entries, exits); + const auto loop_info = std::make_shared(work_amount, std::min(increment, work_amount), entries, exits); loop_info->set_dim_idx(dim_idx); const auto loop_id = this->add_loop_info(loop_info); for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) { insert_loop_id(*expr_it, loop_id); } if (set_default_handlers) { - set_default_loop_handlers(loop_info); + loop_info->set_default_handlers(); } return loop_id; } @@ -142,7 +137,7 @@ class LinearIR::LoopManager { insert_loop_id(*expr_it, loop_id); } if (set_default_handlers) { - set_default_loop_handlers(loop_info); + loop_info->set_default_handlers(); } return loop_id; } @@ -209,7 +204,6 @@ class LinearIR::LoopManager { size_t loop_id, bool loop_ops_inserted = false); LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id); - static void set_default_loop_handlers(const LoopInfoPtr& loop_info); private: static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos, @@ -221,8 +215,8 @@ class LinearIR::LoopManager { std::vector& entry_points, size_t loop_id); static std::vector fuse_loop_handlers( - std::vector& lhs, - std::vector& rhs); + const std::vector& lhs, + const std::vector& rhs); /* ===== The methods for work with Loop IDs of Expression ===== */ // Notes: diff --git a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp index 2111749b4b27d8..5575052e804955 100644 --- a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp @@ -6,12 +6,19 @@ #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/pass/pass.hpp" -#include "snippets/op/loop.hpp" namespace ov { namespace snippets { namespace lowered { namespace pass { +/** + * @interface UpdateMemoryAccessOps + * @brief The pass changes counts of all MemoryAccess ops in the Loop + * @attention The pass skips inner loops + * @attention The pass ignores memory access ports which have count == 1 + * @param m_count - count which must be set + * @ingroup snippets + */ class UpdateMemoryAccessOps : public pass::RangedPass { public: UpdateMemoryAccessOps(size_t count); @@ -22,6 +29,12 @@ class UpdateMemoryAccessOps : public pass::RangedPass { size_t m_count; }; +/** + * @interface SetFillOffset + * @brief The pass changes offset of all Fill ops in the Loop + * @param m_offset - offset which must be set + * @ingroup snippets + */ class SetFillOffset : public pass::RangedPass { public: SetFillOffset(size_t offset); @@ -32,6 +45,12 @@ class SetFillOffset : public pass::RangedPass { size_t m_offset; }; +/** + * @interface TransformInnerSplitLoop + * @brief The pass updates finalization offsets, work amount and increment of inner Loop basing on tail_size of the current Loop + * @param m_tail_size - tail_size of the current Loop + * @ingroup snippets + */ class TransformInnerSplitLoop : public pass::RangedPass { public: TransformInnerSplitLoop(size_t tail_size); diff --git a/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp b/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp index 0df580be2ada07..01ce37b6b3b8da 100644 --- a/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp @@ -11,7 +11,14 @@ namespace ov { namespace snippets { namespace lowered { namespace pass { - +/** + * @interface UpdateSubtensors + * @brief The pass updates subtensors of all operations in Loop basing on tail size. + * Firstly, the pass updates subtensors of all Loop entry points. + * After that, shape inference infrastructure is used to update subtensors of all ops in Loop body + * @param m_offset - offset which must be set + * @ingroup snippets + */ class UpdateSubtensors : public pass::RangedPass { public: UpdateSubtensors(size_t tail_size); diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index 6c1725d6504fba..67c6172511a9ab 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -42,31 +42,35 @@ std::shared_ptr LoopPort::clone_with_new_expr(const ExpressionPtr& new LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment, const std::vector& entries, - const std::vector& exits, - bool outer_splited_loop) + const std::vector& exits) : m_work_amount(work_amount), m_increment(increment), m_entry_points(entries), - m_exit_points(exits), - m_outer_splited_loop(outer_splited_loop) { - handlers.resize(3); + m_exit_points(exits) { + // Note: loop info always contain at least 3 set of handlers: + // 1. For first loop iteration + // 2. For main loop body + // 3. For last loop iteration + m_handlers.resize(3); } LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment, const std::vector& entries, - const std::vector& exits, - bool outer_splited_loop) + const std::vector& exits) : m_work_amount(work_amount), - m_increment(increment), - m_outer_splited_loop(outer_splited_loop) { + m_increment(increment) { m_entry_points.reserve(entries.size()); m_exit_points.reserve(exits.size()); for (const auto& port : entries) m_entry_points.emplace_back(port); for (const auto& port : exits) m_exit_points.emplace_back(port); - handlers.resize(3); + // Note: loop info always contain at least 3 set of handlers: + // 1. For first loop iteration + // 2. For main loop body + // 3. For last loop iteration + m_handlers.resize(3); } std::shared_ptr LoopInfo::clone_with_new_expr(const ExressionMap& expr_map) const { @@ -84,8 +88,8 @@ std::shared_ptr LoopInfo::clone_with_new_expr(const ExressionMap& expr const auto& new_entry_points = clone_loop_ports(m_entry_points); const auto& new_exit_points = clone_loop_ports(m_exit_points); - auto new_info = std::make_shared(m_work_amount, m_increment, new_entry_points, new_exit_points, m_outer_splited_loop); - new_info->handlers = handlers; + auto new_info = std::make_shared(m_work_amount, m_increment, new_entry_points, new_exit_points); + new_info->set_handlers(m_handlers); return new_info; } @@ -105,8 +109,8 @@ const std::vector& LoopInfo::get_exit_points() const { return m_exit_points; } -bool LoopInfo::get_outer_splited_loop() const { - return m_outer_splited_loop; +const std::vector& LoopInfo::get_handlers() const { + return m_handlers; } size_t LinearIR::LoopManager::LoopInfo::get_dim_idx() const { @@ -144,11 +148,19 @@ void LoopInfo::set_entry_points(std::vector entry_points) { } void LoopInfo::set_exit_points(std::vector exit_points) { - m_exit_points = std::move(exit_points);; + m_exit_points = std::move(exit_points); +} + +void LoopInfo::set_handlers(std::vector handlers) { + m_handlers = std::move(handlers); } -void LoopInfo::set_outer_splited_loop(bool outer_splited_loop) { - m_outer_splited_loop = outer_splited_loop; +void LoopInfo::set_default_handlers() { + const auto tail_size = get_work_amount() % get_increment(); + if (tail_size != 0) { + m_handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + m_handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + } } bool operator==(const LinearIR::LoopManager::LoopPort& lhs, const LinearIR::LoopManager::LoopPort& rhs) { @@ -287,14 +299,6 @@ LinearIR::LoopManager::LoopPort LinearIR::LoopManager::get_loop_port_by_expr_por : get_loop_port(loop_info->get_exit_points()); } -void LinearIR::LoopManager::set_default_loop_handlers(const LoopInfoPtr& loop_info) { - const auto tail_size = loop_info->get_work_amount() % loop_info->get_increment(); - if (tail_size != 0) { - loop_info->handlers[LoopInfo::LAST_ITER].register_pass(tail_size); - loop_info->handlers[LoopInfo::LAST_ITER].register_pass(tail_size); - } -} - void LinearIR::LoopManager::get_io_loop_ports(LinearIR::constExprIt loop_begin_pos, LinearIR::constExprIt loop_end_pos, std::vector &entries, @@ -444,14 +448,11 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, loop_info->set_entry_points(new_entries); loop_info->set_exit_points(new_exits); - loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers); + loop_info->set_handlers(fuse_loop_handlers(loop_info_upper->get_handlers(), loop_info_lower->get_handlers())); // Since fusion can be called for broadcastable loops (one of the loops has work_amount = increment = 1), // maximum value is set to the fused loop loop_info->set_work_amount(std::max(loop_info_upper->get_work_amount(), loop_info_lower->get_work_amount())); loop_info->set_increment(std::max(loop_info_upper->get_increment(), loop_info_lower->get_increment())); - // If one of the Loops is outer for nested loops that splits the same dimension, - // after fusion new common Loop saves this status - loop_info->set_outer_splited_loop(loop_info_upper->get_outer_splited_loop() || loop_info_lower->get_outer_splited_loop()); const auto& from = fuse_into_upper ? loop_id_lower : loop_id_upper; const auto& to = fuse_into_upper ? loop_id_upper : loop_id_lower; @@ -464,15 +465,13 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, } std::vector LinearIR::LoopManager::fuse_loop_handlers( - std::vector& from, - std::vector& to) { - const auto min_size = std::min(from.size(), to.size()); - std::vector merged_handlers; - merged_handlers.resize(min_size); - for (size_t i = 0; i < min_size; ++i) { - merged_handlers[i] = from[i]; + const std::vector& lhs, + const std::vector& rhs) { + OPENVINO_ASSERT(lhs.size() == rhs.size(), "fuse_loop_handlers supports only handlers vectors with equal sizes."); + auto merged_handlers = lhs; + for (size_t i = 0; i < lhs.size(); ++i) { const auto& res_passes = merged_handlers[i].get_passes(); - for (const auto& pass : to[i].get_passes()) { + for (const auto& pass : rhs[i].get_passes()) { auto pred = [&pass](const std::shared_ptr& p) { return p->get_type_info() == pass->get_type_info(); }; @@ -481,10 +480,6 @@ std::vector LinearIR::LoopManager::fuse_loop_handle } } } - auto& handlers_with_larger_size = from.size() > to.size() ? from : to; - for (size_t i = min_size; i < handlers_with_larger_size.size(); ++i) { - merged_handlers.emplace_back(std::move(handlers_with_larger_size[i])); - } return merged_handlers; } diff --git a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp index 39f636be5a138a..c7cf6b67abd8ea 100644 --- a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp @@ -78,7 +78,7 @@ bool AllocateBuffers::run(lowered::LinearIR& linear_ir, lowered::LinearIR::const pipeline.register_pass(); pipeline.run(linear_ir); } else { - InitBuffersDefault(m_buffer_scratchpad_size).run(linear_ir, linear_ir.begin(), linear_ir.end()); + InitBuffersDefault(m_buffer_scratchpad_size).run(linear_ir, linear_ir.cbegin(), linear_ir.cend()); } return m_buffer_scratchpad_size > 0; diff --git a/src/common/snippets/src/lowered/pass/fuse_loops.cpp b/src/common/snippets/src/lowered/pass/fuse_loops.cpp index aa856c5f38feba..da9a01cd168bbb 100644 --- a/src/common/snippets/src/lowered/pass/fuse_loops.cpp +++ b/src/common/snippets/src/lowered/pass/fuse_loops.cpp @@ -48,6 +48,8 @@ bool FuseLoops::can_be_fused(const LoopInfoPtr& loop_current, const LoopInfoPtr& const auto target_work_amount = loop_target->get_work_amount(); const auto current_increment = loop_current->get_increment(); const auto target_increment = loop_target->get_increment(); + const auto& current_handlers = loop_current->get_handlers(); + const auto& target_handlers = loop_target->get_handlers(); // Loop fusion is supported only if Loops have equal/broadcastable increments and work amounts. // Note: For example, Broadcastable work amounts are possible in the following case: // Relu_0 [16x1] Relu_1 [16x128] @@ -58,15 +60,16 @@ bool FuseLoops::can_be_fused(const LoopInfoPtr& loop_current, const LoopInfoPtr& // - Relu_1 and Add with work amount `128` and increment `vector size` // We can fuse them into one Loop with work amount `128` and increment `vector size` + const bool handlers_sizes_match = current_handlers.size() == target_handlers.size(); // WA: we can't fuse 2 loops if one of them has first iteration handler but second hasn't, // because in this case Main/Tail body handlers of the loop wo first iter handler must be reset with new parameters // (e.g. tail size). This logic is not implemented for now, so fusion for such loops is skipped. - const bool first_iter_handlers_match = loop_current->handlers[LoopManager::LoopInfo::FIRST_ITER].empty() == - loop_target->handlers[LoopManager::LoopInfo::FIRST_ITER].empty(); + const bool first_iter_handlers_match = current_handlers[LoopManager::LoopInfo::FIRST_ITER].empty() == + target_handlers[LoopManager::LoopInfo::FIRST_ITER].empty(); const bool equal_parameters = current_work_amount == target_work_amount && current_increment == target_increment; const bool current_bcastable = current_work_amount == 1 && current_increment == 1; const bool target_bcastable = target_work_amount == 1 && target_increment == 1; - return first_iter_handlers_match && (equal_parameters || current_bcastable || target_bcastable); + return handlers_sizes_match && first_iter_handlers_match && (equal_parameters || current_bcastable || target_bcastable); } void FuseLoops::move(LinearIR& linear_ir, const LinearIR::LoopManagerPtr& loop_manager, size_t loop_id, diff --git a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp index 6d3b8230758708..60472a629a625f 100644 --- a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp +++ b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp @@ -78,7 +78,7 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir, lowered::LinearIR::const const auto& loop_info = loop_manager->get_loop_info(loop_end->get_id()); const auto work_amount = loop_info->get_work_amount(); const auto increment = loop_info->get_increment(); - auto& handlers = loop_info->handlers; + const auto& handlers = loop_info->get_handlers(); const auto main_body_begin_it = linear_ir.find(linear_ir.get_expr_by_node(loop_end->get_loop_begin())); const auto main_body_end_it = linear_ir.find(linear_ir.get_expr_by_node(loop_end)); @@ -100,10 +100,11 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir, lowered::LinearIR::const auto copy_and_run_specific_handlers = [&](const PassPipeline& handlers) { const auto& cloned_body = copy_loop(linear_ir, loop_end->get_id()); - linear_ir.insert(main_body_begin_it, cloned_body.begin(), cloned_body.end()); - const auto& loop_end_it = std::prev(cloned_body.end()); - handlers.run(linear_ir, cloned_body.begin(), loop_end_it); - return ov::as_type_ptr(loop_end_it->get()->get_node()); + lowered::LinearIR::constExprIt start = linear_ir.insert(main_body_begin_it, cloned_body.begin(), cloned_body.end()); + const auto cloned_loop_end = *std::prev(cloned_body.end()); + auto end = linear_ir.find_after(start, cloned_loop_end); + handlers.run(linear_ir, start, end); + return ov::as_type_ptr(cloned_loop_end->get_node()); }; const bool specific_first_iteration = !handlers[LoopInfo::FIRST_ITER].empty(); diff --git a/src/common/snippets/src/lowered/pass/iter_handler.cpp b/src/common/snippets/src/lowered/pass/iter_handler.cpp index cc6351dba168cd..1f8440e4fa9e9f 100644 --- a/src/common/snippets/src/lowered/pass/iter_handler.cpp +++ b/src/common/snippets/src/lowered/pass/iter_handler.cpp @@ -94,7 +94,7 @@ bool TransformInnerSplitLoop::run(LinearIR& linear_ir, LinearIR::constExprIt beg const auto inner_loop_begin_it = std::find(begin, it, linear_ir.get_expr_by_node(inner_loop_begin)); const auto inner_loop_end_it = std::next(end); OPENVINO_ASSERT(inner_loop_begin_it != it, "LoopBegin has not been found!"); - const auto& last_iter_handlers = inner_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER]; + const auto& last_iter_handlers = inner_loop_info->get_handlers()[LinearIR::LoopManager::LoopInfo::LAST_ITER]; last_iter_handlers.run(linear_ir, inner_loop_begin_it, inner_loop_end_it); modified = true; } diff --git a/src/common/snippets/src/lowered/pass/pass.cpp b/src/common/snippets/src/lowered/pass/pass.cpp index 27588a03d431fa..7bc69732d34f56 100644 --- a/src/common/snippets/src/lowered/pass/pass.cpp +++ b/src/common/snippets/src/lowered/pass/pass.cpp @@ -27,7 +27,7 @@ void PassPipeline::register_pass(const std::shared_ptr& pass) { } void PassPipeline::run(LinearIR& linear_ir) const { - run(linear_ir, linear_ir.begin(), linear_ir.end()); + run(linear_ir, linear_ir.cbegin(), linear_ir.cend()); } void PassPipeline::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) const { diff --git a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp index ffcc9a0fbd5ebf..3f9c25a7b1e559 100644 --- a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp +++ b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp @@ -76,7 +76,9 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr const auto& reduce_max_loop_info = loop_manager->get_loop_info(reduce_max_loop_id); const auto tail_size = inner_work_amount % inner_increment; if (tail_size != 0) { - reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + auto handlers = reduce_max_loop_info->get_handlers(); + handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + reduce_max_loop_info->set_handlers(handlers); } const auto broadcast_horizon_max = push_node(std::make_shared(horizon_max.second, broadcasted_dim)); const auto vector_buffer_sum = push_node(std::make_shared()); @@ -100,7 +102,9 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr (*sum.first)->get_output_port(0)}); const auto& reduce_sum_loop_info = loop_manager->get_loop_info(reduce_sum_loop_id); if (tail_size != 0) { - reduce_sum_loop_info->handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + auto handlers = reduce_sum_loop_info->get_handlers(); + handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + reduce_sum_loop_info->set_handlers(handlers); } // Divide is expensive operation, so we decompose it into 1 / x * y, where 1 / x is executed outside loop diff --git a/src/common/snippets/src/lowered/pass/split_loops.cpp b/src/common/snippets/src/lowered/pass/split_loops.cpp index c3b06cf84c4a35..21718d801e77df 100644 --- a/src/common/snippets/src/lowered/pass/split_loops.cpp +++ b/src/common/snippets/src/lowered/pass/split_loops.cpp @@ -24,7 +24,7 @@ SplitLoops::SplitLoops() : RangedPass() {} bool SplitLoops::can_be_split(const LoopInfoPtr& loop_to_split, const LoopInfoPtr& loop_to_fuse) { const auto current_dim_idx = loop_to_split->get_dim_idx(); const auto parent_dim_idx = loop_to_fuse->get_dim_idx(); - const auto& handlers = loop_to_split->handlers; + const auto& handlers = loop_to_split->get_handlers(); const bool equal_dim_idxes = current_dim_idx != LoopInfo::UNDEFINED_DIM_IDX && current_dim_idx == parent_dim_idx; const bool only_main_body = handlers[LoopInfo::FIRST_ITER].empty() && handlers[LoopInfo::FIRST_ITER].empty(); return loop_to_split->get_work_amount() == loop_to_fuse->get_work_amount() && @@ -87,14 +87,14 @@ bool SplitLoops::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, loop_to_split->get_entry_points(), loop_to_split->get_exit_points()); const auto& new_loop_info = loop_manager->get_loop_info(split_loop_id); - new_loop_info->set_outer_splited_loop(true); - new_loop_info->handlers = loop_to_split->handlers; const auto work_amount = loop_to_fuse->get_work_amount(); const auto increment = loop_to_fuse->get_increment(); const auto tail_size = work_amount % increment; + auto new_handlers = loop_to_split->get_handlers(); if (tail_size != 0) { - new_loop_info->handlers[LoopInfo::LAST_ITER].register_pass(tail_size); + new_handlers[LoopInfo::LAST_ITER].register_pass(tail_size); } + new_loop_info->set_handlers(new_handlers); break; } } diff --git a/src/common/snippets/src/lowered/pass/validate_loops.cpp b/src/common/snippets/src/lowered/pass/validate_loops.cpp index 99698a6b4329bd..43afdc12e63551 100644 --- a/src/common/snippets/src/lowered/pass/validate_loops.cpp +++ b/src/common/snippets/src/lowered/pass/validate_loops.cpp @@ -63,8 +63,6 @@ bool ValidateLoops::run(LinearIR& linear_ir) { "Incorrect Loop ID configuration: the Loops with splitted dimension should be successively nested"); OPENVINO_ASSERT(loop_manager->get_loop_info(loop_ids[i - 1])->get_increment() == loop_manager->get_loop_info(id)->get_work_amount(), "Incorrect Loop ID configuration: the Loops with splitted dimension should be successively nested"); - OPENVINO_ASSERT(loop_manager->get_loop_info(loop_ids[i - 1])->get_outer_splited_loop(), - "Incorrect Loop ID configuration: the outer Loop with splitted dimension should have `outer_splited_loop=True`"); } dim_indexes.push_back(dim_idx); } diff --git a/src/common/snippets/tests/src/lowered/pass/loop.cpp b/src/common/snippets/tests/src/lowered/pass/loop.cpp index f5bcc910464841..345c1f0e67876c 100644 --- a/src/common/snippets/tests/src/lowered/pass/loop.cpp +++ b/src/common/snippets/tests/src/lowered/pass/loop.cpp @@ -48,10 +48,12 @@ static void init_linear_ir(const std::vector& in_shapes, Linea loop_manager->mark_loop(expr_it, std::next(expr_it), blocked_wa, blocked_inc, 1, loop_entry_points, loop_exit_points); const auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), outer_wa, outer_inc, 1, loop_entry_points, loop_exit_points); const auto& outer_loop_info = loop_manager->get_loop_info(loop_id); - outer_loop_info->set_outer_splited_loop(true); const auto outer_tail_size = outer_wa % outer_inc; - if (outer_tail_size != 0) - outer_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass(outer_tail_size); + if (outer_tail_size != 0) { + auto handlers = outer_loop_info->get_handlers(); + handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER].register_pass(outer_tail_size); + outer_loop_info->set_handlers(handlers); + } } static void apply_transformations(LinearIR& linear_ir, const std::shared_ptr& config) { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index a153a8842b2170..be17818ae02b04 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -152,7 +152,9 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits); const auto loop_info = loop_manager->get_loop_info(id); - loop_info->handlers[LoopInfo::FIRST_ITER].register_pass(0.f); + auto handlers = loop_info->get_handlers(); + handlers[LoopInfo::FIRST_ITER].register_pass(0.f); + loop_info->set_handlers(handlers); }; apply_k_blocking(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp index b7a17fa57d3464..7d29e6b5570a7f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp @@ -9,6 +9,12 @@ namespace ov { namespace intel_cpu { namespace pass { +/** + * @interface SetBrgemmBeta + * @brief The pass updates all CPUBrgemm nodes with a new beta value + * @param m_beta - beta which must be set + * @ingroup snippets + */ class SetBrgemmBeta : public snippets::lowered::pass::RangedPass { public: SetBrgemmBeta(float beta);