diff --git a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp index c4f3296adf165c..1a4fe60ce5f9ea 100644 --- a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp @@ -38,6 +38,9 @@ class DefaultTailLoopHandler : public SpecificIterHandler { bool copy_body_modification(LinearIR::constExprIt begin, LinearIR::constExprIt end) const override; bool need_to_modify_main_loop(const std::shared_ptr& loop_end) const override; bool need_to_copy_loop(const std::shared_ptr& loop_end) const override; + +private: + void update_memory_access_ops(LinearIR::constExprIt begin, LinearIR::constExprIt end, size_t tail_size) const; }; } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index f264fe35c9a94c..01709c30f7f761 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -27,16 +27,22 @@ void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, c std::function& op)> reg_type_mapper = [&](const std::shared_ptr& op) -> opRegType { return get_op_reg_type(op); }; + lowered::pass::PassPipeline pre_pipeline; + pre_pipeline.register_pass(reg_type_mapper); + pre_pipeline.run(linear_ir); + + // auto clone = *linear_ir.clone(); + // lowered::pass::PassPipeline reference_pipeline; + // reference_pipeline.register_pass(); + // reference_pipeline.run(clone); + // clone.serialize("/home/vgolubev/models/specific_iteration_reference.xml", ""); + lowered::pass::PassPipeline lowered_pipeline; - lowered_pipeline.register_pass(reg_type_mapper); - // lowered_pipeline.register_pass(); lowered_pipeline.register_pass(); - lowered_pipeline.register_pass(); - lowered_pipeline.register_pass(); - std::cout << "before\n"; + // lowered_pipeline.register_pass(); + // lowered_pipeline.register_pass(); lowered_pipeline.run(linear_ir); - std::cout << "after\n"; - linear_ir.serialize("specific_iteration.xml", ""); + linear_ir.serialize("/home/vgolubev/models/specific_iteration.xml", ""); linear_ir.init_emitters(target); OV_ITT_TASK_NEXT(GENERATE, "::EmitCode") diff --git a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp index 96172e3d0c75a7..e0f46886778df4 100644 --- a/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp +++ b/src/common/snippets/src/lowered/pass/insert_specific_iterations.cpp @@ -79,7 +79,7 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir) { // We can't skip tail loop because we need to handle inner loops of the tail loop auto call_handlers = [&](const std::vector& handlers, SpecificIterHandler::Mode mode, - const bool need_to_copy) { + bool need_to_copy) { LinearIR::container copied_body; if (need_to_copy) { copied_body = copy_loop(linear_ir, loop_end->get_id()); @@ -100,11 +100,11 @@ bool InsertSpecificIterations::run(LinearIR& linear_ir) { return handler->need_to_copy_loop(loop_end); }; // TODO: handlers must be taken from loop info - std::vector first_iter_handlers{}; + // std::vector first_iter_handlers{}; std::vector last_iter_handlers{std::make_shared(linear_ir)}; - const bool copy_first_iter = std::any_of(first_iter_handlers.begin(), first_iter_handlers.end(), need_copy); + // const bool copy_first_iter = std::any_of(first_iter_handlers.begin(), first_iter_handlers.end(), need_copy); const bool copy_last_iter = std::any_of(last_iter_handlers.begin(), last_iter_handlers.end(), need_copy); - call_handlers(first_iter_handlers, SpecificIterHandler::FIRST_ITERATION, copy_first_iter); + // call_handlers(first_iter_handlers, SpecificIterHandler::FIRST_ITERATION, copy_first_iter); call_handlers(last_iter_handlers, SpecificIterHandler::LAST_ITERATION, copy_last_iter); std::cout << "Specific iter handlers called for node " << loop_end << std::endl; } diff --git a/src/common/snippets/src/lowered/pass/iter_handler.cpp b/src/common/snippets/src/lowered/pass/iter_handler.cpp index 579800aa1aa329..0b386e03d75d31 100644 --- a/src/common/snippets/src/lowered/pass/iter_handler.cpp +++ b/src/common/snippets/src/lowered/pass/iter_handler.cpp @@ -18,13 +18,13 @@ namespace lowered { bool DefaultTailLoopHandler::need_to_modify_main_loop(const std::shared_ptr& loop_end) const { const auto work_amount = loop_end->get_work_amount(); const auto increment = loop_end->get_increment(); - return work_amount >= increment && work_amount % increment != 0; + return work_amount % increment != 0; } bool DefaultTailLoopHandler::need_to_copy_loop(const std::shared_ptr& loop_end) const { const auto work_amount = loop_end->get_work_amount(); const auto increment = loop_end->get_increment(); - return work_amount >= increment; + return work_amount > increment && work_amount % increment != 0; } bool DefaultTailLoopHandler::main_body_modification(LinearIR::constExprIt begin, LinearIR::constExprIt end) const { @@ -38,7 +38,7 @@ bool DefaultTailLoopHandler::main_body_modification(LinearIR::constExprIt begin, const auto& loop_manager = get_linear_ir().get_loop_manager(); const auto& loop_info = loop_manager->get_loop_info(loop_end->get_id()); - if (work_amount >= increment) { + if (work_amount > increment) { loop_end->set_work_amount(work_amount - tail_size); loop_end->set_finalization_offsets(std::vector(loop_end->get_finalization_offsets().size(), 0)); loop_info->set_work_amount(work_amount - tail_size); @@ -47,7 +47,9 @@ bool DefaultTailLoopHandler::main_body_modification(LinearIR::constExprIt begin, loop_end->set_increment(tail_size); loop_info->set_work_amount(tail_size); loop_info->set_increment(tail_size); + update_memory_access_ops(begin, end, tail_size); } + return true; } bool DefaultTailLoopHandler::copy_body_modification(LinearIR::constExprIt begin, LinearIR::constExprIt end) const { @@ -66,6 +68,11 @@ bool DefaultTailLoopHandler::copy_body_modification(LinearIR::constExprIt begin, loop_info->set_work_amount(tail_size); loop_info->set_increment(tail_size); + update_memory_access_ops(begin, end, tail_size); + return true; +} + +void DefaultTailLoopHandler::update_memory_access_ops(LinearIR::constExprIt begin, LinearIR::constExprIt end, size_t tail_size) const { for (auto expr_it = std::next(begin); expr_it != end; expr_it++) { // Skip inner Loops const auto loop_begin = ov::as_type_ptr(expr_it->get()->get_node());