diff --git a/src/core/include/openvino/pass/manager.hpp b/src/core/include/openvino/pass/manager.hpp index 8ca9ce354eeb5c..f84ece9d442fba 100644 --- a/src/core/include/openvino/pass/manager.hpp +++ b/src/core/include/openvino/pass/manager.hpp @@ -66,7 +66,7 @@ class OPENVINO_API Manager { /// /// \return Returns true if the model was changed by transformations, /// false otherwise. - bool run_passes(std::shared_ptr model); + bool run_passes(const std::shared_ptr& model); void set_pass_visualization(bool new_state) { m_visualize = new_state; diff --git a/src/core/src/pass/manager.cpp b/src/core/src/pass/manager.cpp index a6bee008ca99d4..246616bcea2d74 100644 --- a/src/core/src/pass/manager.cpp +++ b/src/core/src/pass/manager.cpp @@ -29,42 +29,18 @@ PerfCounters& perf_counters() { static PerfCounters counters; return counters; } -} // namespace -} // namespace pass -} // namespace ov - -#endif // ENABLE_PROFILING_ITT - -namespace { -bool getenv_visualize_tracing() { - return ov::util::getenv_bool("OV_ENABLE_VISUALIZE_TRACING"); -} -} // namespace - -ov::pass::Manager::Manager() : m_pass_config(std::make_shared()), m_visualize(getenv_visualize_tracing()) {} - -ov::pass::Manager::~Manager() = default; - -ov::pass::Manager::Manager(std::shared_ptr pass_config) - : m_pass_config(std::move(pass_config)), - m_visualize(getenv_visualize_tracing()) {} - -void ov::pass::Manager::set_per_pass_validation(bool new_state) { - m_per_pass_validation = new_state; -} -namespace { class stopwatch { public: void start() { - if (m_active == false) { + if (!m_active) { m_active = true; m_start_time = m_clock.now(); } } void stop() { - if (m_active == true) { + if (m_active) { auto end_time = m_clock.now(); m_last_time = end_time - m_start_time; m_active = false; @@ -89,85 +65,127 @@ class stopwatch { bool m_active = false; std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero(); }; + +class Profiler { +public: + Profiler(bool visualize, bool profile_pass_enable) + : m_visualize(visualize), + m_profile_pass_enable(profile_pass_enable) {} + + void start_timer(const std::string& name) { + if (m_profile_pass_enable) { + stopwatches[name] = stopwatch(); + stopwatches[name].start(); + } + } + + void stop_timer(const std::string& name, const std::string& msg) { + if (m_profile_pass_enable) { + auto& stopwatch = stopwatches.at(name); + stopwatch.stop(); + cout << msg << setw(7) << stopwatch.get_milliseconds() << "ms" + << "\n"; + } + } + + void visualize(const shared_ptr& model, const std::string& pass_name) { + if (m_visualize) { + // visualizations and serializations will be named after the outermost function + const size_t num_digits_in_pass_index = 3; + std::string index_str = std::to_string(m_index++); + index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str; + auto base_filename = model->get_name() + std::string("_") + index_str + std::string("_") + pass_name; + + auto file_ext = "svg"; + pass::VisualizeTree vt(base_filename + std::string(".") + file_ext); + vt.run_on_model(model); + } + } + +private: + size_t m_index = 0; + std::unordered_map stopwatches; + + bool m_visualize; + bool m_profile_pass_enable; +}; + } // namespace +} // namespace pass +} // namespace ov -bool ov::pass::Manager::run_passes(shared_ptr func) { - OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes"); +#endif // ENABLE_PROFILING_ITT + +namespace { +bool getenv_visualize_tracing() { + return ov::util::getenv_bool("OV_ENABLE_VISUALIZE_TRACING"); +} +} // namespace + +ov::pass::Manager::Manager() : m_pass_config(std::make_shared()), m_visualize(getenv_visualize_tracing()) {} + +ov::pass::Manager::~Manager() = default; + +ov::pass::Manager::Manager(std::shared_ptr pass_config) + : m_pass_config(std::move(pass_config)), + m_visualize(getenv_visualize_tracing()) {} + +void ov::pass::Manager::set_per_pass_validation(bool new_state) { + m_per_pass_validation = new_state; +} - static bool profile_enabled = ov::util::getenv_bool("OV_PROFILE_PASS_ENABLE"); +bool ov::pass::Manager::run_passes(const shared_ptr& model) { + OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes"); + Profiler profiler(m_visualize, ov::util::getenv_bool("OV_PROFILE_PASS_ENABLE")); - size_t index = 0; - stopwatch pass_timer; - stopwatch overall_timer; - overall_timer.start(); bool pass_applied = false; - bool function_changed = false; + bool model_changed = false; bool needs_validate = false; - for (auto& pass : m_pass_list) { + const std::string passes_name = "Passes"; + + profiler.start_timer(passes_name); + for (const auto& pass : m_pass_list) { + const auto& pass_name = pass->get_name(); + if (m_pass_config->is_disabled(pass->get_type_info())) { - OPENVINO_DEBUG << "Pass " << pass->get_name() << " is disabled"; + OPENVINO_DEBUG << "Pass " << pass_name << " is disabled"; + continue; + } + + // This checks if we need to skip the graph transformation when the graph pass relies on + // static shape but the model state is dynamic. + if (pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && model->is_dynamic()) { + OPENVINO_DEBUG << "Pass " << pass_name << " requires static shape but the " + << "model is dynamic. Skipping this transformation"; continue; } OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::ov_pass, ov::pass::perf_counters()[pass->get_type_info()]); - pass_timer.start(); + profiler.start_timer(pass_name); if (auto matcher_pass = dynamic_pointer_cast(pass)) { - // This checks is to skip the graph transformation when the graph pass relies on - // static shape but the function state is dynamic. - if (matcher_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) { - OPENVINO_DEBUG << "Pass " << pass->get_name() << " requires static shape but the " - << "model is dynamic. Skipping this transformation"; - continue; - } - // GraphRewrite is a temporary container for MatcherPass to make execution - // on on entire ov::Model - pass_applied = GraphRewrite(matcher_pass).run_on_model(func); - } else if (auto function_pass = dynamic_pointer_cast(pass)) { - // This checks is to skip the graph transformation when the graph pass relies on - // static shape but the function state is dynamic. - if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) { - OPENVINO_DEBUG << "Pass " << pass->get_name() << " requires static shape but the " - << "model is dynamic. Skipping this transformation"; - continue; - } - + // GraphRewrite is a temporary container for MatcherPass to make execution on entire ov::Model + pass_applied = GraphRewrite(matcher_pass).run_on_model(model); + } else if (auto model_pass = dynamic_pointer_cast(pass)) { if (dynamic_pointer_cast(pass)) { if (needs_validate) { - function_pass->run_on_model(func); needs_validate = false; + pass_applied = model_pass->run_on_model(model); } - } else { - pass_applied = function_pass->run_on_model(func); + continue; } + pass_applied = model_pass->run_on_model(model); } - if (m_visualize) { - // visualizations and serializations will be named after the outermost function - const size_t num_digits_in_pass_index = 3; - std::string index_str = std::to_string(index); - index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str; - auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name(); + profiler.stop_timer(pass_name, std::string((pass_applied ? " + " : " ") + pass->get_name())); + profiler.visualize(model, pass_name); - if (m_visualize) { - auto file_ext = "svg"; - pass::VisualizeTree vt(base_filename + std::string(".") + file_ext); - vt.run_on_model(func); - } - } - index++; - pass_timer.stop(); - if (profile_enabled) { - cout << setw(7) << pass_timer.get_milliseconds() << "ms" << (pass_applied ? " + " : " ") - << pass->get_name() << "\n"; - } - function_changed = function_changed || pass_applied; - needs_validate = pass_applied; - } - if (profile_enabled) { - cout << "passes done in " << overall_timer.get_milliseconds() << "ms\n"; + model_changed = model_changed || pass_applied; + needs_validate = needs_validate || pass_applied; } - return function_changed; + profiler.stop_timer(passes_name, "All passes done in "); + + return model_changed; } diff --git a/src/core/tests/pass_manager.cpp b/src/core/tests/pass_manager.cpp index fe3a0e1232c814..29bece4c72116e 100644 --- a/src/core/tests/pass_manager.cpp +++ b/src/core/tests/pass_manager.cpp @@ -16,8 +16,10 @@ #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/parameter.hpp" +#include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/manager.hpp" #include "openvino/pass/pass.hpp" +#include "openvino/pass/validate.hpp" using namespace ov; using namespace std; @@ -75,6 +77,98 @@ bool validate_list(const std::vector>& nodes) { return rc; } +class TestMatcherPassTrue : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TestMatcherPassTrue"); + TestMatcherPassTrue() : MatcherPass() { + auto any_input = ov::pass::pattern::any_input(); + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + return true; + }; + + auto m = std::make_shared(any_input, "TestMatcherPassTrue"); + this->register_matcher(m, callback); + } +}; + +class TestMatcherPassFalse : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TestMatcherPassFalse"); + TestMatcherPassFalse() : MatcherPass() { + auto any_input = ov::pass::pattern::any_input(); + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + return false; + }; + + auto m = std::make_shared(any_input, "TestMatcherPassFalse"); + this->register_matcher(m, callback); + } +}; + +class TestModelPassTrue : public pass::ModelPass { +public: + OPENVINO_RTTI("TestModelPassTrue"); + + bool run_on_model(const std::shared_ptr& f) override { + return true; + } +}; + +class TestModelPassFalse : public pass::ModelPass { +public: + OPENVINO_RTTI("TestModelPassFalse"); + + bool run_on_model(const std::shared_ptr& f) override { + return false; + } +}; + +class TestValidate : public pass::Validate { +public: + OPENVINO_RTTI("TestValidate"); + + bool run_on_model(const std::shared_ptr& f) override { + m_applied = true; + return pass::Validate::run_on_model(f); + } + + bool is_applied() const { + return m_applied; + } + +private: + bool m_applied = false; +}; + +class TestValidate2 : public TestValidate {}; + +class TestManager : public pass::Manager { +public: + bool is_validation_applied() { + bool applied = false; + bool is_init = true; + for (const auto& pass : m_pass_list) { + auto validate_2 = std::dynamic_pointer_cast(pass); + auto validate = std::dynamic_pointer_cast(pass); + if (validate && !validate_2) { + if (is_init) { + is_init = false; + applied = validate->is_applied(); + } + applied = applied && validate->is_applied(); + } + } + return applied; + } + + bool is_2nd_validation_applied() { + for (const auto& pass : m_pass_list) { + if (auto validate = std::dynamic_pointer_cast(pass)) { + return validate->is_applied(); + } + } + } +}; } // namespace TEST(pass_manager, add) { @@ -90,3 +184,82 @@ TEST(pass_manager, add) { EXPECT_EQ(node_count, sorted.size()); EXPECT_TRUE(validate_list(sorted)); } + +TEST(pass_manager, passes_not_applied) { + TestManager pass_manager; + pass_manager.set_per_pass_validation(false); + + auto graph = make_test_graph(); + + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + const auto res = pass_manager.run_passes(graph); + + EXPECT_FALSE(res); + EXPECT_FALSE(pass_manager.is_validation_applied()); +} + +TEST(pass_manager, model_pass_applied) { + TestManager pass_manager; + pass_manager.set_per_pass_validation(false); + + auto graph = make_test_graph(); + + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + const auto res = pass_manager.run_passes(graph); + + EXPECT_TRUE(res); + EXPECT_TRUE(pass_manager.is_validation_applied()); +} + +TEST(pass_manager, matcher_pass_applied) { + TestManager pass_manager; + pass_manager.set_per_pass_validation(false); + + auto graph = make_test_graph(); + + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + const auto res = pass_manager.run_passes(graph); + + EXPECT_TRUE(res); + EXPECT_TRUE(pass_manager.is_validation_applied()); +} + +TEST(pass_manager, two_validations) { + TestManager pass_manager; + pass_manager.set_per_pass_validation(false); + + auto graph = make_test_graph(); + + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + const auto res = pass_manager.run_passes(graph); + + EXPECT_TRUE(res); + EXPECT_TRUE(pass_manager.is_validation_applied()); + EXPECT_FALSE(pass_manager.is_2nd_validation_applied()); +} \ No newline at end of file