Skip to content

Commit

Permalink
WrapType Improvements (#4040)
Browse files Browse the repository at this point in the history
* Extended WrapType to consume multiple types; Added variadic wrap_type support

* Updated transformations to use wrap_type

* Fix BatchNormDecomposition

* Added tests
  • Loading branch information
Gleb Kazantaev authored Feb 2, 2021
1 parent 3a86b3a commit cca0d56
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API BatchNormDecomposition;
class TRANSFORMATIONS_API BatchNormV5Decomposition;

} // namespace pass
} // namespace ngraph
Expand All @@ -29,9 +28,3 @@ class ngraph::pass::BatchNormDecomposition: public ngraph::pass::MatcherPass {
NGRAPH_RTTI_DECLARATION;
BatchNormDecomposition();
};

class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
BatchNormV5Decomposition();
};
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
common_fusions->add_matcher<ngraph::pass::DepthToSpaceFusion>();
//common_fusions->add_matcher<ngraph::pass::MishFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
Expand Down Expand Up @@ -115,7 +114,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::BatchNormV5Decomposition>();
decomp->set_name("ngraph::pass::CommonDecompositions");

// CF is required after all decompositions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvAddFusion, "ConvAddFusion", 0);
ngraph::pass::ConvAddFusion::ConvAddFusion() {
MATCHER_SCOPE(ConvAddFusion);
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});

matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
return conv_callback<op::ConvolutionIE>(m);
Expand All @@ -179,7 +179,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvMultiplyFusion, "ConvMultiplyFusion", 0
ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() {
MATCHER_SCOPE(ConvMultiplyFusion);
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, pattern::any_input()});

matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
return conv_callback<op::ConvolutionIE>(m);
Expand All @@ -194,7 +194,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::DeconvAddFusion, "DeconvAddFusion", 0);
ngraph::pass::DeconvAddFusion::DeconvAddFusion() {
MATCHER_SCOPE(DeconvAddFusion);
auto conv = ngraph::pattern::wrap_type<op::DeconvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});

matcher_pass_callback callback = [](ngraph::pattern::Matcher &m){
return conv_callback<op::DeconvolutionIE>(m);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposi

ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
MATCHER_SCOPE(BatchNormDecomposition);
auto bn = pattern::wrap_type<opset1::BatchNormInference>({
auto bn = pattern::wrap_type<opset1::BatchNormInference, opset5::BatchNormInference>({
pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
Expand All @@ -28,20 +28,30 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
});

ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
if (!m_bn) {
auto m_bn = m.get_match_root();
Output<Node> m_input, m_gamma, m_beta, m_mean, m_var;
double eps;
if (auto m_bn_v1 = dynamic_pointer_cast<opset1::BatchNormInference>(m_bn)) {
m_gamma = m_bn_v1->input_value(0);
m_beta = m_bn_v1->input_value(1);
m_input = m_bn_v1->input_value(2);
m_mean = m_bn_v1->input_value(3);
m_var = m_bn_v1->input_value(4);
eps = m_bn_v1->get_eps_value();
} else if (auto m_bn_v5 = dynamic_pointer_cast<opset5::BatchNormInference>(m_bn)) {
m_input = m_bn_v5->input_value(0);
m_gamma = m_bn_v5->input_value(1);
m_beta = m_bn_v5->input_value(2);
m_mean = m_bn_v5->input_value(3);
m_var = m_bn_v5->input_value(4);
eps = m_bn_v5->get_eps_value();
} else {
return false;
}

auto m_gamma = m_bn->input_value(0);
auto m_beta = m_bn->input_value(1);
auto m_input = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);

const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {eps}));
// scale = sqrt(variance + eps)
auto scale = make_shared<opset5::Sqrt>(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
Expand Down Expand Up @@ -79,67 +89,3 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);

// TODO: this pass will be unified with BatchNormDecomposition pass
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
MATCHER_SCOPE(BatchNormV5Decomposition);
auto bn = pattern::wrap_type<opset5::BatchNormInference>({
pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())
});

ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
if (!m_bn) {
return false;
}

auto m_input = m_bn->input_value(0);
auto m_gamma = m_bn->input_value(1);
auto m_beta = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);

const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
// scale = sqrt(variance + eps)
auto scale = make_shared<opset5::Sqrt>(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);

int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;

// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
Shape input_aligned_shape = m_gamma.get_shape();
for (int64_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1);
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);

auto gamma_div_scale_aligned = make_shared<opset5::Reshape>(gamma_div_scale, new_shape, true);
auto beta_aligned = make_shared<opset5::Reshape>(m_beta, new_shape, true);
auto mean_aligned = make_shared<opset5::Reshape>(m_mean, new_shape, true);

// input_sub_mean = input - mean
auto input_sub_mean = register_new_node<opset5::Subtract>(m_input, mean_aligned);
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
auto mul = std::make_shared<opset5::Multiply>(input_sub_mean, gamma_div_scale_aligned);
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
auto add = std::make_shared<opset5::Add>(mul, beta_aligned);

add->set_friendly_name(m_bn->get_friendly_name());

copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned,
beta_aligned, input_sub_mean, mul, add});

replace_node(m_bn, add);

return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

#include <ngraph/ngraph.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0);

ngraph::pass::ConvertGELU::ConvertGELU() {
MATCHER_SCOPE(ConvertGELU);
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
auto gelu = pattern::wrap_type<ngraph::opset2::Gelu>();

ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

using namespace ngraph;

NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleChannels3", 0);

ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() {
MATCHER_SCOPE(ConvertShuffleChannels3);
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input);
auto shuffle_channels = pattern::wrap_type<opset3::ShuffleChannels>();

ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "Conver

ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToLSTMSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();

ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))
Expand Down Expand Up @@ -201,8 +201,8 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe

ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToRNNSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();

ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))
Expand Down Expand Up @@ -357,8 +357,8 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ

ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToGRUSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();

ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::LSTMCellDecomposition, "LSTMCellDecompositi

ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() {
MATCHER_SCOPE(LSTMCellDecomposition);
auto is_supported_lstm_cell = [](const std::shared_ptr<Node>& n) {
return pattern::has_class<ngraph::opset1::LSTMCell>()(n) || pattern::has_class<ngraph::opset4::LSTMCell>()(n);
};
auto any_lstm = std::make_shared<pattern::op::Label>(element::f32, Shape{}, is_supported_lstm_cell);
auto any_lstm = pattern::wrap_type<opset1::LSTMCell, opset4::LSTMCell>();

ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(m.get_match_root());
if (!lstm_cell || transformation_callback(lstm_cell)) {
Expand Down
33 changes: 23 additions & 10 deletions ngraph/core/include/ngraph/pattern/op/wrap_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ namespace ngraph
[](const Output<Node>& output) { return true; },
const OutputVector& input_values = {})
: Pattern(input_values, pred)
, m_wrapped_type(wrapped_type)
, m_wrapped_types({wrapped_type})
{
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}

explicit WrapType(std::vector<NodeTypeInfo> wrapped_types,
const ValuePredicate& pred =
[](const Output<Node>& output) { return true; },
const OutputVector& input_values = {})
: Pattern(input_values, pred)
, m_wrapped_types(std::move(wrapped_types))
{
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
Expand All @@ -45,30 +55,33 @@ namespace ngraph
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;

NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; }
NodeTypeInfo get_wrapped_type() const;

const std::vector<NodeTypeInfo>& get_wrapped_types() const;

private:
NodeTypeInfo m_wrapped_type;
std::vector<NodeTypeInfo> m_wrapped_types;
};
}

template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
const pattern::op::ValuePredicate& pred)
{
static_assert(std::is_base_of<Node, T>::value, "Unexpected template type");
return std::make_shared<op::WrapType>(T::type_info, pred, inputs);
std::vector<DiscreteTypeInfo> info{Args::type_info...};
return std::make_shared<op::WrapType>(info, pred, inputs);
}

template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
{
return wrap_type<T>(inputs, [](const Output<Node>& output) { return true; });
return wrap_type<Args...>(inputs, [](const Output<Node>& output) { return true; });
}

template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
{
return wrap_type<T>({}, pred);
return wrap_type<Args...>({}, pred);
}
}
}
11 changes: 8 additions & 3 deletions ngraph/core/src/pass/graph_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,25 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
// it's type
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
// and default algorithm is used.
NodeTypeInfo root_type_info = root->get_type_info();
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root))
{
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p))
{
root_type_info = any_type->get_wrapped_type();
for (const auto& root_type_info : any_type->get_wrapped_types())
{
type_to_matcher[root_type_info].push_back(matcher_index);
}
}
else
{
all_roots_has_type = false;
break;
}
}
type_to_matcher[root_type_info].push_back(matcher_index);
else
{
type_to_matcher[root->get_type_info()].push_back(matcher_index);
}

// TODO: traverse parents for root_type_info in order to register complete list of matchers
// including ones triggered by parent type info.
Expand Down
21 changes: 20 additions & 1 deletion ngraph/core/src/pattern/op/wrap_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) &&
if (std::any_of(m_wrapped_types.begin(),
m_wrapped_types.end(),
[&](const NodeTypeInfo& type_info) {
return graph_value.get_node_shared_ptr()->get_type_info().is_castable(
type_info);
}) &&
m_predicate(graph_value))
{
auto& pattern_map = matcher->get_pattern_value_map();
Expand All @@ -44,3 +49,17 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
}
return false;
}

NodeTypeInfo pattern::op::WrapType::get_wrapped_type() const
{
if (m_wrapped_types.size() > 1)
{
throw ngraph::ngraph_error("get_wrapped_type() called on WrapType with more than one type");
}
return m_wrapped_types.at(0);
}

const std::vector<NodeTypeInfo>& pattern::op::WrapType::get_wrapped_types() const
{
return m_wrapped_types;
}
Loading

0 comments on commit cca0d56

Please sign in to comment.