Skip to content

Commit

Permalink
Fixed Validate handling in pass::Manager
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jun 19, 2024
1 parent d17b405 commit 790d063
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/core/include/openvino/pass/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class OPENVINO_API Manager {
///
/// \return Returns true if the model was changed by transformations,
/// false otherwise.
bool run_passes(std::shared_ptr<Model> model);
bool run_passes(const std::shared_ptr<Model>& model);

void set_pass_visualization(bool new_state) {
m_visualize = new_state;
Expand Down
184 changes: 101 additions & 83 deletions src/core/src/pass/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,18 @@ PerfCounters& perf_counters() {
static PerfCounters counters;
return counters;
}
} // namespace
} // namespace pass
} // namespace ov

#endif // ENABLE_PROFILING_ITT

namespace {
bool getenv_visualize_tracing() {
return ov::util::getenv_bool("OV_ENABLE_VISUALIZE_TRACING");
}
} // namespace

ov::pass::Manager::Manager() : m_pass_config(std::make_shared<PassConfig>()), m_visualize(getenv_visualize_tracing()) {}

ov::pass::Manager::~Manager() = default;

ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config)
: m_pass_config(std::move(pass_config)),
m_visualize(getenv_visualize_tracing()) {}

void ov::pass::Manager::set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}

namespace {
class stopwatch {
public:
void start() {
if (m_active == false) {
if (!m_active) {
m_active = true;
m_start_time = m_clock.now();
}
}

void stop() {
if (m_active == true) {
if (m_active) {
auto end_time = m_clock.now();
m_last_time = end_time - m_start_time;
m_active = false;
Expand All @@ -89,85 +65,127 @@ class stopwatch {
bool m_active = false;
std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero();
};

class Profiler {
public:
Profiler(bool visualize, bool profile_pass_enable)
: m_visualize(visualize),
m_profile_pass_enable(profile_pass_enable) {}

void start_timer(const std::string& name) {
if (m_profile_pass_enable) {
stopwatches[name] = stopwatch();
stopwatches[name].start();
}
}

void stop_timer(const std::string& name, const std::string& msg) {
if (m_profile_pass_enable) {
auto& stopwatch = stopwatches.at(name);
stopwatch.stop();
cout << msg << setw(7) << stopwatch.get_milliseconds() << "ms"
<< "\n";
}
}

void visualize(const shared_ptr<ov::Model>& model, const std::string& pass_name) {
if (m_visualize) {
// visualizations and serializations will be named after the outermost function
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(m_index++);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = model->get_name() + std::string("_") + index_str + std::string("_") + pass_name;

auto file_ext = "svg";
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
vt.run_on_model(model);
}
}

private:
size_t m_index = 0;
std::unordered_map<std::string, stopwatch> stopwatches;

bool m_visualize;
bool m_profile_pass_enable;
};

} // namespace
} // namespace pass
} // namespace ov

bool ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes");
#endif // ENABLE_PROFILING_ITT

namespace {
bool getenv_visualize_tracing() {
return ov::util::getenv_bool("OV_ENABLE_VISUALIZE_TRACING");
}
} // namespace

ov::pass::Manager::Manager() : m_pass_config(std::make_shared<PassConfig>()), m_visualize(getenv_visualize_tracing()) {}

ov::pass::Manager::~Manager() = default;

ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config)
: m_pass_config(std::move(pass_config)),
m_visualize(getenv_visualize_tracing()) {}

void ov::pass::Manager::set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}

static bool profile_enabled = ov::util::getenv_bool("OV_PROFILE_PASS_ENABLE");
bool ov::pass::Manager::run_passes(const shared_ptr<ov::Model>& model) {
OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes");
Profiler profiler(m_visualize, ov::util::getenv_bool("OV_PROFILE_PASS_ENABLE"));

size_t index = 0;
stopwatch pass_timer;
stopwatch overall_timer;
overall_timer.start();
bool pass_applied = false;
bool function_changed = false;
bool model_changed = false;
bool needs_validate = false;
for (auto& pass : m_pass_list) {
const std::string passes_name = "Passes";

profiler.start_timer(passes_name);
for (const auto& pass : m_pass_list) {
const auto& pass_name = pass->get_name();

if (m_pass_config->is_disabled(pass->get_type_info())) {
OPENVINO_DEBUG << "Pass " << pass->get_name() << " is disabled";
OPENVINO_DEBUG << "Pass " << pass_name << " is disabled";
continue;
}

// This checks if we need to skip the graph transformation when the graph pass relies on
// static shape but the model state is dynamic.
if (pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && model->is_dynamic()) {
OPENVINO_DEBUG << "Pass " << pass_name << " requires static shape but the "
<< "model is dynamic. Skipping this transformation";
continue;
}

OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::ov_pass, ov::pass::perf_counters()[pass->get_type_info()]);

pass_timer.start();
profiler.start_timer(pass_name);

if (auto matcher_pass = dynamic_pointer_cast<MatcherPass>(pass)) {
// This checks is to skip the graph transformation when the graph pass relies on
// static shape but the function state is dynamic.
if (matcher_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
OPENVINO_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
<< "model is dynamic. Skipping this transformation";
continue;
}
// GraphRewrite is a temporary container for MatcherPass to make execution
// on on entire ov::Model
pass_applied = GraphRewrite(matcher_pass).run_on_model(func);
} else if (auto function_pass = dynamic_pointer_cast<ModelPass>(pass)) {
// This checks is to skip the graph transformation when the graph pass relies on
// static shape but the function state is dynamic.
if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
OPENVINO_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
<< "model is dynamic. Skipping this transformation";
continue;
}

// GraphRewrite is a temporary container for MatcherPass to make execution on entire ov::Model
pass_applied = GraphRewrite(matcher_pass).run_on_model(model);
} else if (auto model_pass = dynamic_pointer_cast<ModelPass>(pass)) {
if (dynamic_pointer_cast<Validate>(pass)) {
if (needs_validate) {
function_pass->run_on_model(func);
needs_validate = false;
pass_applied = model_pass->run_on_model(model);
}
} else {
pass_applied = function_pass->run_on_model(func);
continue;
}
pass_applied = model_pass->run_on_model(model);
}

if (m_visualize) {
// visualizations and serializations will be named after the outermost function
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name();
profiler.stop_timer(pass_name, std::string((pass_applied ? " + " : " ") + pass->get_name()));
profiler.visualize(model, pass_name);

if (m_visualize) {
auto file_ext = "svg";
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
vt.run_on_model(func);
}
}
index++;
pass_timer.stop();
if (profile_enabled) {
cout << setw(7) << pass_timer.get_milliseconds() << "ms" << (pass_applied ? " + " : " ")
<< pass->get_name() << "\n";
}
function_changed = function_changed || pass_applied;
needs_validate = pass_applied;
}
if (profile_enabled) {
cout << "passes done in " << overall_timer.get_milliseconds() << "ms\n";
model_changed = model_changed || pass_applied;
needs_validate = needs_validate || pass_applied;
}

return function_changed;
profiler.stop_timer(passes_name, "All passes done in ");

return model_changed;
}
Loading

0 comments on commit 790d063

Please sign in to comment.