From 67a46b3bb90361f6b73f02d32ec9408bcd84524d Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 22 Jan 2024 17:39:10 +0100 Subject: [PATCH] Introduced a virtual method 'can_be_merged' for lowered passes --- .../include/snippets/lowered/loop_manager.hpp | 10 ++++-- .../snippets/lowered/pass/iter_handler.hpp | 3 ++ .../include/snippets/lowered/pass/pass.hpp | 12 +++++++ .../snippets/lowered/pass/pass_config.hpp | 3 ++ .../lowered/pass/propagate_subtensors.hpp | 1 + .../snippets/src/lowered/loop_manager.cpp | 31 +++++-------------- .../src/lowered/pass/iter_handler.cpp | 15 +++++++++ src/common/snippets/src/lowered/pass/pass.cpp | 20 ++++++++++++ .../snippets/src/lowered/pass/pass_config.cpp | 8 +++++ .../src/lowered/pass/propagate_subtensors.cpp | 5 +++ .../lowered/pass/softmax_decomposition.cpp | 4 +-- .../snippets/tests/src/lowered/pass/loop.cpp | 4 +-- .../x64/pass/lowered/brgemm_blocking.cpp | 2 +- .../x64/pass/lowered/cpu_iter_handlers.cpp | 5 +++ .../x64/pass/lowered/cpu_iter_handlers.hpp | 1 + 15 files changed, 92 insertions(+), 32 deletions(-) diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index 5cdb7fd2f989d1..7b90a5e0d68f74 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -54,6 +54,7 @@ class LinearIR::LoopManager { const lowered::pass::PassPipeline& get_first_iter_handelrs() const; const lowered::pass::PassPipeline& get_main_iter_handelrs() const; const lowered::pass::PassPipeline& get_last_iter_handelrs() const; + SpecificIterationHandlers merge(const SpecificIterationHandlers& other) const; template void register_handler(Args&&... args) { @@ -72,8 +73,6 @@ class LinearIR::LoopManager { } } - static SpecificIterationHandlers merge_loop_handlers(const SpecificIterationHandlers& lhs, const SpecificIterationHandlers& rhs); - private: lowered::pass::PassPipeline m_first_iter_handlers; lowered::pass::PassPipeline m_main_body_handlers; @@ -98,7 +97,7 @@ class LinearIR::LoopManager { size_t get_increment() const; const std::vector& get_entry_points() const; const std::vector& get_exit_points() const; - SpecificIterationHandlers& get_handlers(); + const SpecificIterationHandlers& get_handlers() const; // Sets dim_idx to all entry and exit points void set_dim_idx(size_t dim_idx); @@ -108,6 +107,11 @@ class LinearIR::LoopManager { void set_exit_points(std::vector exit_points); void set_handlers(SpecificIterationHandlers handlers); + template + void register_handler(Args&&... args) { + m_handlers.register_handler(args...); + } + private: size_t m_work_amount = 0; size_t m_increment = 0; diff --git a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp index df79ffb7a5e874..a3255d61edbf40 100644 --- a/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp @@ -24,6 +24,7 @@ class UpdateMemoryAccessCounts : public pass::RangedPass { UpdateMemoryAccessCounts(size_t count); OPENVINO_RTTI("UpdateMemoryAccessCounts", "RangedPass") bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override; + bool can_be_merged(const std::shared_ptr& other) override; private: size_t m_count; @@ -40,6 +41,7 @@ class SetFillOffset : public pass::RangedPass { SetFillOffset(size_t offset); OPENVINO_RTTI("SetFillOffset", "RangedPass") bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override; + bool can_be_merged(const std::shared_ptr& other) override; private: size_t m_offset; @@ -56,6 +58,7 @@ class TransformInnerSplitLoop : public pass::RangedPass { TransformInnerSplitLoop(size_t tail_size); OPENVINO_RTTI("TransformInnerSplitLoop", "RangedPass") bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override; + bool can_be_merged(const std::shared_ptr& other) override; private: size_t m_tail_size; diff --git a/src/common/snippets/include/snippets/lowered/pass/pass.hpp b/src/common/snippets/include/snippets/lowered/pass/pass.hpp index ce17fda4b199ee..9c52e818c883ed 100644 --- a/src/common/snippets/include/snippets/lowered/pass/pass.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/pass.hpp @@ -39,6 +39,15 @@ class PassBase { const char* get_type_name() const { return get_type_info().name; } + + /** + * @brief Checks if the current pass can be merged with another one (e.g. during 2 pass pipelines fusion) + * @param other Pointer on the another pass. + * @return bool value indicating whether the passes can be merged or not + */ + virtual bool can_be_merged(const std::shared_ptr& other) { + return false; + } }; /** @@ -81,6 +90,7 @@ class PassPipeline { PassPipeline(const std::shared_ptr& pass_config); const std::vector>& get_passes() const { return m_passes; } + const std::shared_ptr& get_pass_config() const { return m_pass_config; } bool empty() const { return m_passes.empty(); } void register_pass(const snippets::pass::PassPosition& position, const std::shared_ptr& pass); @@ -104,6 +114,8 @@ class PassPipeline { void run(lowered::LinearIR& linear_ir) const; void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const; + static PassPipeline merge_pipelines(const PassPipeline& lhs, const PassPipeline& rhs); + private: std::shared_ptr m_pass_config; std::vector> m_passes; diff --git a/src/common/snippets/include/snippets/lowered/pass/pass_config.hpp b/src/common/snippets/include/snippets/lowered/pass/pass_config.hpp index 03fe2b3dd6d65d..90a45cc0eba708 100644 --- a/src/common/snippets/include/snippets/lowered/pass/pass_config.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/pass_config.hpp @@ -48,6 +48,9 @@ class PassConfig { return is_enabled(T::get_type_info_static()); } + friend bool operator==(const PassConfig& lhs, const PassConfig& rhs); + friend bool operator!=(const PassConfig& lhs, const PassConfig& rhs); + private: std::unordered_set m_disabled; std::unordered_set m_enabled; diff --git a/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp b/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp index 596036af2fb5bd..bc01d640f9b824 100644 --- a/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/propagate_subtensors.hpp @@ -24,6 +24,7 @@ class UpdateSubtensors : public pass::RangedPass { UpdateSubtensors(size_t tail_size); OPENVINO_RTTI("UpdateSubtensors", "RangedPass") bool run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override; + bool can_be_merged(const std::shared_ptr& other) override; private: size_t m_tail_size; diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index f13999a384f66a..e94858187c2fd6 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -66,28 +66,11 @@ const lowered::pass::PassPipeline& LoopInfo::SpecificIterationHandlers::get_last return m_last_iter_handlers; } -LoopInfo::SpecificIterationHandlers LoopInfo::SpecificIterationHandlers::merge_loop_handlers( - const LoopInfo::SpecificIterationHandlers& lhs, - const LoopInfo::SpecificIterationHandlers& rhs) { - auto merge_handlers_pipelines = [](const lowered::pass::PassPipeline& lhs, const lowered::pass::PassPipeline& rhs) { - auto merged_pipeline = lhs; - const auto& merged_passes = merged_pipeline.get_passes(); - for (const auto& pass : rhs.get_passes()) { - auto pred = [&pass](const std::shared_ptr& p) { - return p->get_type_info() == pass->get_type_info(); - }; - if (std::find_if(merged_passes.begin(), merged_passes.end(), pred) == merged_passes.end()) { - merged_pipeline.register_pass(pass); - } - } - return merged_pipeline; - }; - - LoopInfo::SpecificIterationHandlers merged_handlers( - merge_handlers_pipelines(lhs.get_first_iter_handelrs(), rhs.get_first_iter_handelrs()), - merge_handlers_pipelines(lhs.get_main_iter_handelrs(), rhs.get_main_iter_handelrs()), - merge_handlers_pipelines(lhs.get_last_iter_handelrs(), rhs.get_last_iter_handelrs())); - return merged_handlers; +LoopInfo::SpecificIterationHandlers LoopInfo::SpecificIterationHandlers::merge(const LoopInfo::SpecificIterationHandlers& other) const { + return LoopInfo::SpecificIterationHandlers( + lowered::pass::PassPipeline::merge_pipelines(m_first_iter_handlers, other.get_first_iter_handelrs()), + lowered::pass::PassPipeline::merge_pipelines(m_main_body_handlers, other.get_main_iter_handelrs()), + lowered::pass::PassPipeline::merge_pipelines(m_last_iter_handlers, other.get_last_iter_handelrs())); } LoopInfo::LoopInfo(size_t work_amount, @@ -151,7 +134,7 @@ const std::vector& LoopInfo::get_exit_points() const { return m_exit_points; } -LoopInfo::SpecificIterationHandlers& LoopInfo::get_handlers() { +const LoopInfo::SpecificIterationHandlers& LoopInfo::get_handlers() const { return m_handlers; } @@ -482,7 +465,7 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, loop_info->set_entry_points(new_entries); loop_info->set_exit_points(new_exits); - loop_info->set_handlers(LoopInfo::SpecificIterationHandlers::merge_loop_handlers(loop_info_upper->get_handlers(), loop_info_lower->get_handlers())); + loop_info->set_handlers(loop_info_upper->get_handlers().merge(loop_info_lower->get_handlers())); // Since fusion can be called for broadcastable loops (one of the loops has work_amount = increment = 1), // maximum value is set to the fused loop loop_info->set_work_amount(std::max(loop_info_upper->get_work_amount(), loop_info_lower->get_work_amount())); diff --git a/src/common/snippets/src/lowered/pass/iter_handler.cpp b/src/common/snippets/src/lowered/pass/iter_handler.cpp index 18c2ed745d72fc..a29a39db87d8ad 100644 --- a/src/common/snippets/src/lowered/pass/iter_handler.cpp +++ b/src/common/snippets/src/lowered/pass/iter_handler.cpp @@ -45,6 +45,11 @@ bool UpdateMemoryAccessCounts::run(LinearIR& linear_ir, LinearIR::constExprIt be return true; } +bool UpdateMemoryAccessCounts::can_be_merged(const std::shared_ptr& other) { + const auto casted_pass = ov::as_type_ptr(other); + return casted_pass && m_count == casted_pass->m_count; +} + SetFillOffset::SetFillOffset(size_t offset) : RangedPass(), m_offset(offset) {} bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { @@ -57,6 +62,11 @@ bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear return true; } +bool SetFillOffset::can_be_merged(const std::shared_ptr& other) { + const auto casted_pass = ov::as_type_ptr(other); + return casted_pass && m_offset == casted_pass->m_offset; +} + TransformInnerSplitLoop::TransformInnerSplitLoop(size_t tail_size) : RangedPass(), m_tail_size(tail_size) {} bool TransformInnerSplitLoop::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { @@ -101,6 +111,11 @@ bool TransformInnerSplitLoop::run(LinearIR& linear_ir, LinearIR::constExprIt beg return modified; } +bool TransformInnerSplitLoop::can_be_merged(const std::shared_ptr& other) { + const auto casted_pass = ov::as_type_ptr(other); + return casted_pass && m_tail_size == casted_pass->m_tail_size; +} + } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/lowered/pass/pass.cpp b/src/common/snippets/src/lowered/pass/pass.cpp index 7bc69732d34f56..e215e14b9014ee 100644 --- a/src/common/snippets/src/lowered/pass/pass.cpp +++ b/src/common/snippets/src/lowered/pass/pass.cpp @@ -51,6 +51,26 @@ void PassPipeline::register_positioned_passes(const std::vector> passes_map; + for (const auto& pass : lhs_passes) { + passes_map[pass->get_type_info()] = pass; + } + + auto merged_pipeline = lhs; + for (const auto& pass : rhs.get_passes()) { + auto lhs_pass_it = passes_map.find(pass->get_type_info()); + if (lhs_pass_it == passes_map.end()) { + merged_pipeline.register_pass(pass); + } else { + OPENVINO_ASSERT(lhs_pass_it->second->can_be_merged(pass), "2 passes with type info ", pass->get_type_info(), " can't be merged."); + } + } + return merged_pipeline; +} + } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/lowered/pass/pass_config.cpp b/src/common/snippets/src/lowered/pass/pass_config.cpp index ae73f88c55805a..6d4888e81c7420 100644 --- a/src/common/snippets/src/lowered/pass/pass_config.cpp +++ b/src/common/snippets/src/lowered/pass/pass_config.cpp @@ -28,6 +28,14 @@ bool PassConfig::is_enabled(const DiscreteTypeInfo& type_info) const { return m_enabled.count(type_info); } +bool operator==(const PassConfig& lhs, const PassConfig& rhs) { + return lhs.m_disabled == rhs.m_disabled && lhs.m_enabled == rhs.m_enabled; +} + +bool operator!=(const PassConfig& lhs, const PassConfig& rhs) { + return !(lhs == rhs); +} + } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp b/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp index 1b576c6c24a47d..7c27c20312d1c6 100644 --- a/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp +++ b/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp @@ -147,6 +147,11 @@ bool UpdateSubtensors::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Lin return true; } +bool UpdateSubtensors::can_be_merged(const std::shared_ptr& other) { + const auto casted_pass = ov::as_type_ptr(other); + return casted_pass && m_tail_size == casted_pass->m_tail_size; +} + } // namespace pass } // namespace lowered } // namespace snippets diff --git a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp index e89e619ab55480..7fb9d0784ace71 100644 --- a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp +++ b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp @@ -77,7 +77,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr const auto& reduce_max_loop_info = loop_manager->get_loop_info(reduce_max_loop_id); const auto tail_size = inner_work_amount % inner_increment; if (tail_size != 0) { - reduce_max_loop_info->get_handlers().register_handler(tail_size); + reduce_max_loop_info->register_handler(tail_size); } const auto broadcast_horizon_max = push_node(std::make_shared(horizon_max.second, broadcasted_dim)); const auto vector_buffer_sum = push_node(std::make_shared()); @@ -101,7 +101,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir, lowered::LinearIR::constExpr (*sum.first)->get_output_port(0)}); const auto& reduce_sum_loop_info = loop_manager->get_loop_info(reduce_sum_loop_id); if (tail_size != 0) { - reduce_sum_loop_info->get_handlers().register_handler(tail_size); + reduce_sum_loop_info->register_handler(tail_size); } // Divide is expensive operation, so we decompose it into 1 / x * y, where 1 / x is executed outside loop diff --git a/src/common/snippets/tests/src/lowered/pass/loop.cpp b/src/common/snippets/tests/src/lowered/pass/loop.cpp index ccb320247cadbd..209ecb4592368a 100644 --- a/src/common/snippets/tests/src/lowered/pass/loop.cpp +++ b/src/common/snippets/tests/src/lowered/pass/loop.cpp @@ -50,8 +50,8 @@ static void init_linear_ir(const std::vector& in_shapes, Linea const auto& outer_loop_info = loop_manager->get_loop_info(loop_id); const auto outer_tail_size = outer_wa % outer_inc; if (outer_tail_size != 0) { - outer_loop_info->get_handlers().register_handler(outer_tail_size); + outer_loop_info->register_handler(outer_tail_size); } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index 6db8a5888f2d3a..8936fc9bb8ce75 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -152,7 +152,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linea std::vector exits{LoopPort(brgemm_expr->get_output_port(0), false)}; const auto id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits); const auto loop_info = loop_manager->get_loop_info(id); - loop_info->get_handlers().register_handler(0.f); + loop_info->register_handler(0.f); }; apply_k_blocking(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.cpp index 41eda2273157d4..8bdbaf2380b625 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.cpp @@ -24,6 +24,11 @@ bool SetBrgemmBeta::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear } return true; } + +bool SetBrgemmBeta::can_be_merged(const std::shared_ptr& other) { + const auto casted_pass = ov::as_type_ptr(other); + return casted_pass && m_beta == casted_pass->m_beta; +} } // namespace pass } // namespace intel_cpu } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp index f647430791a0d8..abe8f0b4e86adf 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/cpu_iter_handlers.hpp @@ -22,6 +22,7 @@ class SetBrgemmBeta : public snippets::lowered::pass::RangedPass { bool run(snippets::lowered::LinearIR& linear_ir, snippets::lowered::LinearIR::constExprIt begin, snippets::lowered::LinearIR::constExprIt end) override; + bool can_be_merged(const std::shared_ptr& other) override; private: float m_beta;