Skip to content

Commit

Permalink
Fixed unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyachur committed Aug 27, 2021
1 parent e6fe3f2 commit 755212d
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 48 deletions.
2 changes: 2 additions & 0 deletions ngraph/core/include/ngraph/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using ov::pass::Manager;
using ov::pass::PassBase;
using ov::pass::PassProperty;
using ov::pass::PassPropertyMask;
NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.")
const PassPropertyMask all_pass_property_off;

class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
public:
Expand Down
1 change: 0 additions & 1 deletion ngraph/core/include/openvino/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ enum class PassProperty : uint32_t {
};

using PassPropertyMask = ngraph::EnumMask<PassProperty>;
const PassPropertyMask all_pass_property_off;

class OPENVINO_API PassBase {
friend class Manager;
Expand Down
47 changes: 23 additions & 24 deletions ngraph/core/src/pass/graph_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
#include "ngraph/op/util/sub_graph_base.hpp"
#include "perf_counters.hpp"

using namespace std;
using namespace ngraph;

/* GraphRewrite algorithm:
* GraphRewrite processes an input graph in an topological order(i.e. args before users)
* Given the following graph: Abs2
Expand Down Expand Up @@ -70,25 +67,26 @@ PerfCounters& perf_counters_graph_rewrite() {
} // namespace pass
} // namespace ov

bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
bool ov::pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
// Initialize execution queue with nodes in topological order
deque<std::weak_ptr<Node>> nodes_to_run;
std::deque<std::weak_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops()) {
nodes_to_run.emplace_front(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}

bool pass::GraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
bool ov::pass::GraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
// Initialize execution queue with nodes in topological order
deque<std::weak_ptr<Node>> nodes_to_run;
std::deque<std::weak_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops()) {
nodes_to_run.emplace_back(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}

bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std::weak_ptr<Node>> nodes_to_run) {
bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
std::deque<std::weak_ptr<Node>> nodes_to_run) {
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function");

bool rewritten = false;
Expand All @@ -111,16 +109,16 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
auto root = matcher->get_pattern_value().get_node_shared_ptr();
// pattern::op::AnyOutput operation automatically appends for multi output operations inside
// Matcher and to gen actual root node we need to take it's parent.
if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
if (auto any_type = std::dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
root = any_type->input_value(0).get_node_shared_ptr();
}

// if root is an operation from opset or has pattern::op::WrapType type then we can extract
// it's type
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
// and default algorithm is used.
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root)) {
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p)) {
if (auto p = std::dynamic_pointer_cast<pattern::op::Pattern>(root)) {
if (auto any_type = std::dynamic_pointer_cast<pattern::op::WrapType>(p)) {
for (const auto& root_type_info : any_type->get_wrapped_types()) {
type_to_matcher[root_type_info].push_back(matcher_index);
}
Expand Down Expand Up @@ -180,7 +178,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
continue;

// Recursive apply Matchers for sub-graph based nodes
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
run_on_function(sub_graph);
}
Expand Down Expand Up @@ -236,9 +234,9 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
return rewritten;
}

void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property) {
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property) {
m_matchers.push_back(std::make_shared<MatcherPass>(
m->get_name(),
m,
Expand All @@ -258,15 +256,16 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
property));
}

void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, const graph_rewrite_callback& callback) {
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback) {
NGRAPH_SUPPRESS_DEPRECATED_START
// TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
NGRAPH_SUPPRESS_DEPRECATED_END
}

void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
auto pass_config = get_pass_config();
// We have to preserve disabled passes because in case when we register matchers inside
// GraphRewrite c-tor we work with local PassConfig instance.
Expand All @@ -293,9 +292,9 @@ void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
}
}

void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property) {
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property) {
m_matchers.push_back(std::make_shared<MatcherPass>(
"Recurrent matcher",
nullptr,
Expand All @@ -310,20 +309,20 @@ void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::Rec
property));
}

void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback) {
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback) {
// TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
}

bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
bool ov::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<Function> f) {
bool changed = false;
size_t i = 0;

// This check is very expensive and is only needed for experimental features, so we will hide
// it behind an environment variable for now. TODO: Find a less expensive way to handle this.
static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
static bool s_rerun_dynamic_check = ngraph::getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");

auto run_matchers = [&]() -> bool {
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
Expand Down
2 changes: 1 addition & 1 deletion ngraph/core/src/pass/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using namespace std;

OPENVINO_RTTI_DEFINITION(ov::pass::FunctionPass, "ov::pass::FunctionPass", 0);

ov::pass::PassBase::PassBase() : m_property{all_pass_property_off}, m_pass_config(std::make_shared<PassConfig>()) {}
ov::pass::PassBase::PassBase() : m_property(), m_pass_config(std::make_shared<PassConfig>()) {}

bool ov::pass::PassBase::get_property(const PassPropertyMask& prop) const {
return m_property.is_set(prop);
Expand Down
36 changes: 14 additions & 22 deletions ngraph/test/runtime/pass/dyn_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "dyn_elimination.hpp"

#include <numeric>

#include "dyn_elimination.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/range.hpp"
Expand All @@ -19,38 +20,30 @@ NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;

pass::DynElimination::DynElimination()
: GraphRewrite()
{
pass::DynElimination::DynElimination() : GraphRewrite() {
construct_range();
}

template <typename T>
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg)
{
const std::shared_ptr<op::Constant>& step_arg) {
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();

NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);

runtime::reference::range<T>(
start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());

return make_shared<op::Constant>(et, shape, elements);
}

void pass::DynElimination::construct_range()
{
auto start_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto stop_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto step_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
void pass::DynElimination::construct_range() {
auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());

auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);

Expand All @@ -70,12 +63,11 @@ void pass::DynElimination::construct_range()
std::shared_ptr<op::Constant> replacement;

#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
# pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (et)
{
switch (et) {
case element::Type_t::bf16:
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
break;
Expand Down Expand Up @@ -122,7 +114,7 @@ void pass::DynElimination::construct_range()
break;
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
# pragma GCC diagnostic pop
#endif

replace_node(range_node, replacement);
Expand Down

0 comments on commit 755212d

Please sign in to comment.