Skip to content

Commit

Permalink
Move pass pattern to ov (#7255)
Browse files Browse the repository at this point in the history
* Moved ngraph::Node to ov namespace

* Fixed code style

* Fixed VPU

* Fixed GNA

* Fixed tests

* Added aliases for backward compatibility

* Fix clDNN

* Try to fix build

* Fixed comment

* Renamed RTTI macros

* Add new headers

* Fixed ngraph build

* Fixed unit tests

* Try to fix Serialize
  • Loading branch information
ilyachur authored Sep 2, 2021
1 parent 07f7061 commit 9eca6ba
Show file tree
Hide file tree
Showing 62 changed files with 2,028 additions and 1,556 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ endif()
# resolving dependencies for the project
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR})
message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
f.validate_nodes_and_infer_types();
}
}

std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");

const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}

std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}

} // namespace

namespace ngraph {

// ! [function_pass:serialize_cpp]
// serialize.cpp
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
Expand Down Expand Up @@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
return false;
}

namespace {

std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");

const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}

std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}

} // namespace

pass::Serialize::Serialize(std::ostream& xmlFile,
std::ostream& binFile,
pass::Serialize::Version version,
Expand All @@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
{
}
// ! [function_pass:serialize_cpp]
} // namespace ngraph
18 changes: 2 additions & 16 deletions ngraph/core/include/ngraph/pass/constant_folding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,10 @@
#pragma once

#include "ngraph/pass/pass.hpp"
#include "openvino/pass/constant_folding.hpp"

namespace ngraph {
namespace pass {
/**
* @brief Constant folding iterates over the function and tries to evaluate nodes
* with constant inputs. Such nodes are then replaced with new Constants containing
* the result of a folded operation.
*/
class NGRAPH_API ConstantFolding : public FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
/// \brief Folds pre-calculated output tensor values to constants in case lower and
/// upper estimations are equal. Traverses graph backwards starting from the results.
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
};
using ov::pass::ConstantFolding;
} // namespace pass
} // namespace ngraph
9 changes: 3 additions & 6 deletions ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pass/graph_rewrite.hpp"
#include "openvino/pass/convert_fp32_to_fp16.hpp"

namespace ngraph {
namespace pass {
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
using ov::pass::ConvertFP32ToFP16;
} // namespace pass
} // namespace ngraph
241 changes: 9 additions & 232 deletions ngraph/core/include/ngraph/pass/graph_rewrite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,240 +10,17 @@

#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "openvino/pass/graph_rewrite.hpp"

namespace ngraph {
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
using ov::graph_rewrite_callback;
using ov::handler_callback;
using ov::matcher_pass_callback;
using ov::recurrent_graph_rewrite_callback;
namespace pass {
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
/// pattern and
/// action that is applied if pattern is matched.
///
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
/// and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
/// within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
/// execution
/// queue. That means that operations that were created inside transformation callback can
/// be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite
/// automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register
/// make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.

class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
public:
NGRAPH_RTTI_DECLARATION;

MatcherPass() = default;

MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;

explicit MatcherPass(const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase(),
m_handler(handler),
m_matcher(m) {
set_name(name);
set_property(property, true);
}

bool apply(std::shared_ptr<ngraph::Node> node);

template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}

template <typename T>
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
m_new_nodes.push_back(node);
return node;
}

const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
return m_new_nodes;
}
void clear_new_nodes() {
m_new_nodes.clear();
}
std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher;
}

protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);

private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};

/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
/// in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
/// class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and
/// applies
/// registered matcher passes for each node. But if all registered matcher passes have type
/// based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or
/// pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.

class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;

GraphRewrite() = default;

explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
m_matchers.push_back(pass);
}

/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMatchers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T,
bool Enabled = true,
class... Args,
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
std::shared_ptr<T> add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>()) {
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}

/// \brief Register passes from GraphRewrite class that contains sequence of matcher
/// passes registered in its ctor.
/// For example:
///
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
/// public:
/// NGRAPH_RTTI_DECLARATION;
/// Fusions() {
/// add_matcher<ngraph::pass::AddFusion>();
/// add_matcher<ngraph::pass::MulFusion>();
/// }
/// };
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<LinFusions>();
/// anchor->add_matcher<OtherFusions>();
/// anchor->set_name("CommonFusions");
/// manager.run_passes(f);
///
/// In this case all matcher passes from LinFusions pass will be united with other
/// registered matchers.
template <typename T,
class... Args,
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
void add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();

for (auto& matcher : pass->m_matchers) {
pass->set_pass_config(pass_config);
m_matchers.push_back(matcher);
}
}

NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);

NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;

protected:
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);

bool m_enable_shape_inference = false;

std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};

class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;

BackwardGraphRewrite() = default;

explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};

class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
public:
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}

void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);

// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

private:
size_t m_num_iters;

std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
using ov::pass::BackwardGraphRewrite;
using ov::pass::GraphRewrite;
using ov::pass::MatcherPass;
using ov::pass::RecurrentGraphRewrite;
} // namespace pass
} // namespace ngraph
Loading

0 comments on commit 9eca6ba

Please sign in to comment.