Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move pass pattern to ov #7255

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4388883
Moved ngraph::Node to ov namespace
ilyachur Aug 25, 2021
06303f4
Fixed code style
ilyachur Aug 25, 2021
fc5c23f
Fixed VPU
ilyachur Aug 25, 2021
bd87e41
Fixed GNA
ilyachur Aug 25, 2021
ec2c2e3
Fixed tests
ilyachur Aug 25, 2021
5766e70
Merge remote-tracking branch 'upstream/master' into move_node_to_ov
ilyachur Aug 25, 2021
16076af
Merge remote-tracking branch 'upstream/master' into move_node_to_ov
ilyachur Aug 26, 2021
9f00f52
Added aliases for backward compatibility
ilyachur Aug 26, 2021
250368f
Fix clDNN
ilyachur Aug 26, 2021
85ab414
Try to fix build
ilyachur Aug 26, 2021
6363551
Fixed comment
ilyachur Aug 26, 2021
5f54eb8
Renamed RTTI macros
ilyachur Aug 26, 2021
b7c7803
Merge remote-tracking branch 'upstream/master' into move_node_to_ov
ilyachur Aug 27, 2021
20e8fc9
Merge remote-tracking branch 'upstream/master' into move_node_to_ov
ilyachur Aug 29, 2021
5b43470
Add new headers
ilyachur Aug 26, 2021
bb48740
Fixed ngraph build
ilyachur Aug 26, 2021
53361eb
Fixed unit tests
ilyachur Aug 27, 2021
70ecf71
Try to fix Serialize
ilyachur Aug 27, 2021
858fad3
Merge remote-tracking branch 'upstream/master' into move_pass_pattern…
ilyachur Aug 30, 2021
0faed49
Merge remote-tracking branch 'upstream/master' into move_pass_pattern…
ilyachur Aug 30, 2021
d190601
Merge remote-tracking branch 'upstream/master' into move_pass_pattern…
ilyachur Aug 31, 2021
0d8bae5
Merge remote-tracking branch 'upstream/master' into move_pass_pattern…
ilyachur Aug 31, 2021
c2f7e97
Merge remote-tracking branch 'upstream/master' into move_pass_pattern…
ilyachur Sep 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ endif()
# resolving dependencies for the project
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR})
message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
f.validate_nodes_and_infer_types();
}
}

std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");

const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}

std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}

} // namespace

namespace ngraph {

// ! [function_pass:serialize_cpp]
// serialize.cpp
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
Expand Down Expand Up @@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
return false;
}

namespace {

std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");

const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}

std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}

} // namespace

pass::Serialize::Serialize(std::ostream& xmlFile,
std::ostream& binFile,
pass::Serialize::Version version,
Expand All @@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
{
}
// ! [function_pass:serialize_cpp]
} // namespace ngraph
18 changes: 2 additions & 16 deletions ngraph/core/include/ngraph/pass/constant_folding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,10 @@
#pragma once

#include "ngraph/pass/pass.hpp"
#include "openvino/pass/constant_folding.hpp"

namespace ngraph {
namespace pass {
/**
* @brief Constant folding iterates over the function and tries to evaluate nodes
* with constant inputs. Such nodes are then replaced with new Constants containing
* the result of a folded operation.
*/
class NGRAPH_API ConstantFolding : public FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
/// \brief Folds pre-calculated output tensor values to constants in case lower and
/// upper estimations are equal. Traverses graph backwards starting from the results.
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
};
using ov::pass::ConstantFolding;
} // namespace pass
} // namespace ngraph
9 changes: 3 additions & 6 deletions ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pass/graph_rewrite.hpp"
#include "openvino/pass/convert_fp32_to_fp16.hpp"

namespace ngraph {
namespace pass {
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
using ov::pass::ConvertFP32ToFP16;
} // namespace pass
} // namespace ngraph
241 changes: 9 additions & 232 deletions ngraph/core/include/ngraph/pass/graph_rewrite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,240 +10,17 @@

#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "openvino/pass/graph_rewrite.hpp"

namespace ngraph {
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
using ov::graph_rewrite_callback;
using ov::handler_callback;
using ov::matcher_pass_callback;
using ov::recurrent_graph_rewrite_callback;
namespace pass {
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
/// pattern and
/// action that is applied if pattern is matched.
///
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
/// and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
/// within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
/// execution
/// queue. That means that operations that were created inside transformation callback can
/// be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite
/// automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register
/// make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.

class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
public:
NGRAPH_RTTI_DECLARATION;

MatcherPass() = default;

MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;

explicit MatcherPass(const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase(),
m_handler(handler),
m_matcher(m) {
set_name(name);
set_property(property, true);
}

bool apply(std::shared_ptr<ngraph::Node> node);

template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}

template <typename T>
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
m_new_nodes.push_back(node);
return node;
}

const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
return m_new_nodes;
}
void clear_new_nodes() {
m_new_nodes.clear();
}
std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher;
}

protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);

private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};

/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
/// in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
/// class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and
/// applies
/// registered matcher passes for each node. But if all registered matcher passes have type
/// based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or
/// pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.

class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;

GraphRewrite() = default;

explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
m_matchers.push_back(pass);
}

/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMatchers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T,
bool Enabled = true,
class... Args,
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
std::shared_ptr<T> add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>()) {
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}

/// \brief Register passes from GraphRewrite class that contains sequence of matcher
/// passes registered in its ctor.
/// For example:
///
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
/// public:
/// NGRAPH_RTTI_DECLARATION;
/// Fusions() {
/// add_matcher<ngraph::pass::AddFusion>();
/// add_matcher<ngraph::pass::MulFusion>();
/// }
/// };
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<LinFusions>();
/// anchor->add_matcher<OtherFusions>();
/// anchor->set_name("CommonFusions");
/// manager.run_passes(f);
///
/// In this case all matcher passes from LinFusions pass will be united with other
/// registered matchers.
template <typename T,
class... Args,
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
void add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();

for (auto& matcher : pass->m_matchers) {
pass->set_pass_config(pass_config);
m_matchers.push_back(matcher);
}
}

NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);

NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;

protected:
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);

bool m_enable_shape_inference = false;

std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};

class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;

BackwardGraphRewrite() = default;

explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};

class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
public:
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}

void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);

// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);

bool run_on_function(std::shared_ptr<ngraph::Function> f) override;

private:
size_t m_num_iters;

std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
using ov::pass::BackwardGraphRewrite;
using ov::pass::GraphRewrite;
using ov::pass::MatcherPass;
using ov::pass::RecurrentGraphRewrite;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it only in legacy API.

} // namespace pass
} // namespace ngraph
Loading