Skip to content

Commit

Permalink
PassManager refactoring in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jul 16, 2024
1 parent d3f2f43 commit 140b2fd
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(CommonOptimizations);
ov::pass::Manager manager(get_pass_config());
ov::pass::Manager manager(get_pass_config(), "CommonOptimizations");
manager.set_per_pass_validation(false);

using namespace ov::pass;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
f->validate_nodes_and_infer_types();
}

ov::pass::Manager manager(get_pass_config());
ov::pass::Manager manager(get_pass_config(), "MOC");
manager.set_per_pass_validation(false);
using namespace ov::pass;
REGISTER_PASS(manager, InitNodeInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ bool pass::SimplifyShapeOfSubGraph::run_on_model(const std::shared_ptr<Model>& f

REGISTER_PASS(manager, PrepareShapeOpsForEliminationAroundBE)
REGISTER_PASS(manager, AbsSinking)
// FIXME: manager runs Validate based on the last pass, when fixed the following line must be deleted
REGISTER_PASS(manager, Validate)
REGISTER_PASS(manager, SharedOpOptimization)
REGISTER_PASS(manager, EliminateGatherUnsqueeze) // should run after SharedOpOptimization
REGISTER_PASS(manager, NopElimination, m_use_shapes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ ov::pass::LabelResolvingThroughSelect::LabelResolvingThroughSelect() {
}

ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
m_manager = std::make_shared<pass::Manager>();
m_manager = std::make_shared<pass::Manager>("Symbolic");
m_manager->set_per_pass_validation(false);

#define REGISTER_SYMBOLIC(region, ...) m_manager->register_pass<region>(__VA_ARGS__);
Expand Down
10 changes: 9 additions & 1 deletion src/core/include/openvino/pass/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ class OPENVINO_API Manager {
Manager();
virtual ~Manager();

//// \brief Construct Manager with a provided name.
explicit Manager(std::string name);

//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
explicit Manager(std::shared_ptr<PassConfig> pass_config, std::string name = "PassManager");

/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
Expand Down Expand Up @@ -99,6 +102,11 @@ class OPENVINO_API Manager {
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
std::string m_name = "PassManager";

private:
bool run_pass(const std::shared_ptr<PassBase>& pass, const std::shared_ptr<Model>& model,
bool needs_validate);
};
} // namespace pass
} // namespace ov
114 changes: 63 additions & 51 deletions src/core/src/pass/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>

#include "itt.hpp"
#include "openvino/pass/graph_rewrite.hpp"
Expand All @@ -29,7 +30,6 @@ PerfCounters& perf_counters() {
static PerfCounters counters;
return counters;
}

} // namespace
} // namespace pass
} // namespace ov
Expand All @@ -55,7 +55,7 @@ class stopwatch {
}

void stop() {
if (m_active) {
if (m_active == true) {
auto end_time = m_clock.now();
m_last_time = end_time - m_start_time;
m_active = false;
Expand All @@ -81,25 +81,37 @@ class stopwatch {
std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero();
};

int nesting_lvl = 0;

class Profiler {
public:
Profiler(bool visualize, bool profile_pass_enable)
Profiler(bool visualize, bool profile_pass_enable, std::string manager_name)
: m_visualize(visualize),
m_profile_pass_enable(profile_pass_enable) {}
m_profile_pass_enable(profile_pass_enable),
m_manager_name(std::move(manager_name))
{}

void start_timer(const std::string& name) {
if (m_profile_pass_enable) {
stopwatches[name] = stopwatch();
stopwatches[name].start();
nesting_lvl++;
}
}

void stop_timer(const std::string& name, const std::string& msg) {
void stop_timer(const std::string& name) {
if (m_profile_pass_enable) {
auto& stopwatch = stopwatches.at(name);
stopwatch.stop();
cout << msg << setw(7) << stopwatch.get_milliseconds() << "ms"
<< "\n";

nesting_lvl--;
if (name == m_manager_name) {
std::cout << std::string(nesting_lvl, '\t') << std::string(60, '_') << std::endl;
}
cout << (name == m_manager_name ? right : left) << std::string(nesting_lvl, '\t') << setw(60) << name << right << setw(8) << stopwatch.get_milliseconds() << "ms" << "\n";
if (nesting_lvl == 0) {
std::cout << "\n\n";
}
}
}

Expand All @@ -123,6 +135,7 @@ class Profiler {

bool m_visualize;
bool m_profile_pass_enable;
std::string m_manager_name;
};

} // namespace
Expand All @@ -131,67 +144,66 @@ ov::pass::Manager::Manager() : m_pass_config(std::make_shared<PassConfig>()), m_

ov::pass::Manager::~Manager() = default;

ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config)
: m_pass_config(std::move(pass_config)),
m_visualize(getenv_visualize_tracing()) {}
ov::pass::Manager::Manager(std::string name) :
m_pass_config(std::make_shared<PassConfig>()),
m_visualize(getenv_visualize_tracing()),
m_name(std::move(name)) {
}

ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config, std::string name) :
m_pass_config(std::move(pass_config)),
m_visualize(getenv_visualize_tracing()),
m_name(std::move(name)) {
}

void ov::pass::Manager::set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}

bool ov::pass::Manager::run_passes(const shared_ptr<ov::Model>& model) {
OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes");
Profiler profiler(m_visualize, getenv_enable_profiling());
Profiler profiler(m_visualize, true, m_name);

bool pass_applied = false;
bool model_changed = false;
bool needs_validate = false;

const std::string passes_name = "Passes";
profiler.start_timer(passes_name);
bool pass_changed_model = false;

profiler.start_timer(m_name);
for (const auto& pass : m_pass_list) {
const auto& pass_name = pass->get_name();

if (m_pass_config->is_disabled(pass->get_type_info())) {
OPENVINO_DEBUG("Pass ", pass_name, " is disabled");
continue;
}

// This checks if we need to skip the graph transformation when the graph pass relies on
// static shape but the model state is dynamic.
if (pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && model->is_dynamic()) {
OPENVINO_DEBUG("Pass ", pass_name, " requires static shape but the "
,"model is dynamic. Skipping this transformation");
continue;
}
profiler.start_timer(pass->get_name());

OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::ov_pass, ov::pass::perf_counters()[pass->get_type_info()]);
pass_changed_model = run_pass(pass, model, pass_changed_model);

profiler.start_timer(pass_name);
profiler.stop_timer(pass->get_name());
model_changed = model_changed || pass_changed_model;
profiler.visualize(model, pass->get_name());
}
profiler.stop_timer(m_name);

if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass)) {
// GraphRewrite is a temporary container for MatcherPass to make execution on entire ov::Model
pass_applied = GraphRewrite(matcher_pass).run_on_model(model);
} else if (auto model_pass = dynamic_pointer_cast<ModelPass>(pass)) {
if (dynamic_pointer_cast<Validate>(pass)) {
if (needs_validate) {
needs_validate = false;
pass_applied = model_pass->run_on_model(model);
}
continue;
}
pass_applied = model_pass->run_on_model(model);
}
return model_changed;
}

profiler.stop_timer(pass_name, std::string((pass_applied ? " + " : " ") + pass->get_name()));
profiler.visualize(model, pass_name);
bool ov::pass::Manager::run_pass(const std::shared_ptr<PassBase>& pass, const std::shared_ptr<Model>& model,
bool needs_validate) {
if (m_pass_config->is_disabled(pass->get_type_info())) {
OPENVINO_DEBUG("Pass ", pass->get_name(), " is disabled.");
return false;
}

model_changed = model_changed || pass_applied;
needs_validate = needs_validate || pass_applied;
// This checks if we need to skip the graph transformation when the graph pass relies on
// static shape but the model state is dynamic.
if (pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && model->is_dynamic()) {
OPENVINO_DEBUG("Pass ", pass->get_name(), " requires static shape but the "
,"model is dynamic. Skipping this transformation.");
return false;
}

profiler.stop_timer(passes_name, "All passes done in ");
OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::ov_pass, ov::pass::perf_counters()[pass->get_type_info()]);

return model_changed;
if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass)) {
// GraphRewrite is a temporary container for MatcherPass to make execution on entire ov::Model
return GraphRewrite(matcher_pass).run_on_model(model);
} else if (auto model_pass = dynamic_pointer_cast<ModelPass>(pass)) {
return model_pass->run_on_model(model);
}
return false;
}
Loading

0 comments on commit 140b2fd

Please sign in to comment.