diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp index 7845d835cd54b9..531bacc8ded626 100644 --- a/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp +++ b/inference-engine/src/transformations/include/transformations/op_conversions/batch_norm_decomposition.hpp @@ -19,7 +19,6 @@ namespace ngraph { namespace pass { class TRANSFORMATIONS_API BatchNormDecomposition; -class TRANSFORMATIONS_API BatchNormV5Decomposition; } // namespace pass } // namespace ngraph @@ -29,9 +28,3 @@ class ngraph::pass::BatchNormDecomposition: public ngraph::pass::MatcherPass { NGRAPH_RTTI_DECLARATION; BatchNormDecomposition(); }; - -class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass { -public: - NGRAPH_RTTI_DECLARATION; - BatchNormV5Decomposition(); -}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index d1a25a7f00fb27..26c084f8c16002 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -83,7 +83,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); common_fusions->add_matcher(); common_fusions->add_matcher(); - //common_fusions->add_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); @@ -115,7 +114,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); decomp->add_matcher(); decomp->add_matcher(); - decomp->add_matcher(); decomp->set_name("ngraph::pass::CommonDecompositions"); // CF is required after all decompositions diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/conv_bias_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/conv_bias_fusion.cpp index e4225bdcb15367..f6f05dc08b1bca 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/conv_bias_fusion.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/conv_bias_fusion.cpp @@ -164,7 +164,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvAddFusion, "ConvAddFusion", 0); ngraph::pass::ConvAddFusion::ConvAddFusion() { MATCHER_SCOPE(ConvAddFusion); auto conv = ngraph::pattern::wrap_type(pattern::consumers_count(1)); - auto add = ngraph::pattern::wrap_type({conv, std::make_shared()}); + auto add = ngraph::pattern::wrap_type({conv, pattern::any_input()}); matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) { return conv_callback(m); @@ -179,7 +179,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvMultiplyFusion, "ConvMultiplyFusion", 0 ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() { MATCHER_SCOPE(ConvMultiplyFusion); auto conv = ngraph::pattern::wrap_type(pattern::consumers_count(1)); - auto add = ngraph::pattern::wrap_type({conv, std::make_shared()}); + auto add = ngraph::pattern::wrap_type({conv, pattern::any_input()}); matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) { return conv_callback(m); @@ -194,7 +194,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::DeconvAddFusion, "DeconvAddFusion", 0); ngraph::pass::DeconvAddFusion::DeconvAddFusion() { MATCHER_SCOPE(DeconvAddFusion); auto conv = ngraph::pattern::wrap_type(pattern::consumers_count(1)); - auto add = ngraph::pattern::wrap_type({conv, std::make_shared()}); + auto add = ngraph::pattern::wrap_type({conv, pattern::any_input()}); matcher_pass_callback callback = [](ngraph::pattern::Matcher &m){ return conv_callback(m); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp index 8907487b8b66a9..a7efe7960b948b 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/batch_norm_decomposition.cpp @@ -19,7 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposi ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { MATCHER_SCOPE(BatchNormDecomposition); - auto bn = pattern::wrap_type({ + auto bn = pattern::wrap_type({ pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()), @@ -28,20 +28,30 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { }); ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) { - auto m_bn = dynamic_pointer_cast(m.get_match_root()); - if (!m_bn) { + auto m_bn = m.get_match_root(); + Output m_input, m_gamma, m_beta, m_mean, m_var; + double eps; + if (auto m_bn_v1 = dynamic_pointer_cast(m_bn)) { + m_gamma = m_bn_v1->input_value(0); + m_beta = m_bn_v1->input_value(1); + m_input = m_bn_v1->input_value(2); + m_mean = m_bn_v1->input_value(3); + m_var = m_bn_v1->input_value(4); + eps = m_bn_v1->get_eps_value(); + } else if (auto m_bn_v5 = dynamic_pointer_cast(m_bn)) { + m_input = m_bn_v5->input_value(0); + m_gamma = m_bn_v5->input_value(1); + m_beta = m_bn_v5->input_value(2); + m_mean = m_bn_v5->input_value(3); + m_var = m_bn_v5->input_value(4); + eps = m_bn_v5->get_eps_value(); + } else { return false; } - auto m_gamma = m_bn->input_value(0); - auto m_beta = m_bn->input_value(1); - auto m_input = m_bn->input_value(2); - auto m_mean = m_bn->input_value(3); - auto m_var = m_bn->input_value(4); - const auto& input_type = m_input.get_element_type(); // scale_add = variance + eps - auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); + auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {eps})); // scale = sqrt(variance + eps) auto scale = make_shared(scale_add); // Divide `gamma` by `sqrt(variance + eps)` @@ -79,67 +89,3 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { this->register_matcher(m, callback); } -NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5); - -// TODO: this pass will be unified with BatchNormDecomposition pass -ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() { - MATCHER_SCOPE(BatchNormV5Decomposition); - auto bn = pattern::wrap_type({ - pattern::any_input(pattern::has_static_rank()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()), - pattern::any_input(pattern::has_static_shape()) - }); - - ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) { - auto m_bn = dynamic_pointer_cast(m.get_match_root()); - if (!m_bn) { - return false; - } - - auto m_input = m_bn->input_value(0); - auto m_gamma = m_bn->input_value(1); - auto m_beta = m_bn->input_value(2); - auto m_mean = m_bn->input_value(3); - auto m_var = m_bn->input_value(4); - - const auto& input_type = m_input.get_element_type(); - // scale_add = variance + eps - auto scale_add = make_shared(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); - // scale = sqrt(variance + eps) - auto scale = make_shared(scale_add); - // Divide `gamma` by `sqrt(variance + eps)` - auto gamma_div_scale = std::make_shared(m_gamma, scale); - - int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2; - - // TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf - Shape input_aligned_shape = m_gamma.get_shape(); - for (int64_t i = 0; i < dims_to_add; ++i) - input_aligned_shape.push_back(1); - auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); - - auto gamma_div_scale_aligned = make_shared(gamma_div_scale, new_shape, true); - auto beta_aligned = make_shared(m_beta, new_shape, true); - auto mean_aligned = make_shared(m_mean, new_shape, true); - - // input_sub_mean = input - mean - auto input_sub_mean = register_new_node(m_input, mean_aligned); - // Multiply `input - mean` and `gamma / sqrt(variance + eps)` - auto mul = std::make_shared(input_sub_mean, gamma_div_scale_aligned); - // Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta` - auto add = std::make_shared(mul, beta_aligned); - - add->set_friendly_name(m_bn->get_friendly_name()); - - copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, - beta_aligned, input_sub_mean, mul, add}); - - replace_node(m_bn, add); - - return true; - }; - auto m = std::make_shared(bn, matcher_name); - this->register_matcher(m, callback); -} diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_gelu.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_gelu.cpp index 2ee04f4f654c36..a3c683d1537a9d 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/convert_gelu.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_gelu.cpp @@ -10,13 +10,13 @@ #include #include +#include NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0); ngraph::pass::ConvertGELU::ConvertGELU() { MATCHER_SCOPE(ConvertGELU); - auto input = std::make_shared(element::f32, Shape{}); - auto gelu = std::make_shared(input); + auto gelu = pattern::wrap_type(); ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { auto gelu = std::dynamic_pointer_cast(m.get_match_root()); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_shuffle_channels3.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_shuffle_channels3.cpp index 4b4d9f901fd661..ec80a640a5dae1 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/convert_shuffle_channels3.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_shuffle_channels3.cpp @@ -11,6 +11,7 @@ #include #include #include +#include using namespace ngraph; @@ -18,8 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleCha ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() { MATCHER_SCOPE(ConvertShuffleChannels3); - auto input = std::make_shared(element::f32, Shape{1, 1, 1, 1}); - auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input); + auto shuffle_channels = pattern::wrap_type(); ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) { auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root()); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp index 4a1d8b8b28e4bb..2a9c72b59cb990 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp @@ -24,8 +24,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "Conver ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() { MATCHER_SCOPE(ConvertTensorIteratorToLSTMSequence); - auto tensor_iterator = std::make_shared(ngraph::element::f32, - ngraph::Shape{}, ngraph::pattern::has_class()); + auto tensor_iterator = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) { auto ti = std::dynamic_pointer_cast(m.get_match_root()); if (!ti || transformation_callback(ti)) @@ -201,8 +201,8 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() { MATCHER_SCOPE(ConvertTensorIteratorToRNNSequence); - auto tensor_iterator = std::make_shared(ngraph::element::f32, - ngraph::Shape{}, ngraph::pattern::has_class()); + auto tensor_iterator = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) { auto ti = std::dynamic_pointer_cast(m.get_match_root()); if (!ti || transformation_callback(ti)) @@ -357,8 +357,8 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() { MATCHER_SCOPE(ConvertTensorIteratorToGRUSequence); - auto tensor_iterator = std::make_shared(ngraph::element::f32, - ngraph::Shape{}, ngraph::pattern::has_class()); + auto tensor_iterator = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) { auto ti = std::dynamic_pointer_cast(m.get_match_root()); if (!ti || transformation_callback(ti)) diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/lstm_cell_decomposition.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/lstm_cell_decomposition.cpp index 1f30d6662cf3a9..28e7d2c429d333 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/lstm_cell_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/lstm_cell_decomposition.cpp @@ -18,10 +18,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::LSTMCellDecomposition, "LSTMCellDecompositi ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() { MATCHER_SCOPE(LSTMCellDecomposition); - auto is_supported_lstm_cell = [](const std::shared_ptr& n) { - return pattern::has_class()(n) || pattern::has_class()(n); - }; - auto any_lstm = std::make_shared(element::f32, Shape{}, is_supported_lstm_cell); + auto any_lstm = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { auto lstm_cell = std::dynamic_pointer_cast(m.get_match_root()); if (!lstm_cell || transformation_callback(lstm_cell)) { diff --git a/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp b/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp index 3c95c7e4300ce3..77307b7ee670c4 100644 --- a/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp +++ b/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp @@ -36,7 +36,17 @@ namespace ngraph [](const Output& output) { return true; }, const OutputVector& input_values = {}) : Pattern(input_values, pred) - , m_wrapped_type(wrapped_type) + , m_wrapped_types({wrapped_type}) + { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + } + + explicit WrapType(std::vector wrapped_types, + const ValuePredicate& pred = + [](const Output& output) { return true; }, + const OutputVector& input_values = {}) + : Pattern(input_values, pred) + , m_wrapped_types(std::move(wrapped_types)) { set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); } @@ -45,30 +55,33 @@ namespace ngraph const Output& pattern_value, const Output& graph_value) override; - NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; } + NodeTypeInfo get_wrapped_type() const; + + const std::vector& get_wrapped_types() const; + private: - NodeTypeInfo m_wrapped_type; + std::vector m_wrapped_types; }; } - template + template std::shared_ptr wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) { - static_assert(std::is_base_of::value, "Unexpected template type"); - return std::make_shared(T::type_info, pred, inputs); + std::vector info{Args::type_info...}; + return std::make_shared(info, pred, inputs); } - template + template std::shared_ptr wrap_type(const OutputVector& inputs = {}) { - return wrap_type(inputs, [](const Output& output) { return true; }); + return wrap_type(inputs, [](const Output& output) { return true; }); } - template + template std::shared_ptr wrap_type(const pattern::op::ValuePredicate& pred) { - return wrap_type({}, pred); + return wrap_type({}, pred); } } } diff --git a/ngraph/core/src/pass/graph_rewrite.cpp b/ngraph/core/src/pass/graph_rewrite.cpp index d1980024e47060..85a9189241b33f 100644 --- a/ngraph/core/src/pass/graph_rewrite.cpp +++ b/ngraph/core/src/pass/graph_rewrite.cpp @@ -109,12 +109,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) // it's type // and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown // and default algorithm is used. - NodeTypeInfo root_type_info = root->get_type_info(); if (auto p = dynamic_pointer_cast(root)) { if (auto any_type = dynamic_pointer_cast(p)) { - root_type_info = any_type->get_wrapped_type(); + for (const auto& root_type_info : any_type->get_wrapped_types()) + { + type_to_matcher[root_type_info].push_back(matcher_index); + } } else { @@ -122,7 +124,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) break; } } - type_to_matcher[root_type_info].push_back(matcher_index); + else + { + type_to_matcher[root->get_type_info()].push_back(matcher_index); + } // TODO: traverse parents for root_type_info in order to register complete list of matchers // including ones triggered by parent type info. diff --git a/ngraph/core/src/pattern/op/wrap_type.cpp b/ngraph/core/src/pattern/op/wrap_type.cpp index 74ca1b61bdc5e1..b76403c032950f 100644 --- a/ngraph/core/src/pattern/op/wrap_type.cpp +++ b/ngraph/core/src/pattern/op/wrap_type.cpp @@ -31,7 +31,12 @@ bool pattern::op::WrapType::match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) { - if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) && + if (std::any_of(m_wrapped_types.begin(), + m_wrapped_types.end(), + [&](const NodeTypeInfo& type_info) { + return graph_value.get_node_shared_ptr()->get_type_info().is_castable( + type_info); + }) && m_predicate(graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); @@ -44,3 +49,17 @@ bool pattern::op::WrapType::match_value(Matcher* matcher, } return false; } + +NodeTypeInfo pattern::op::WrapType::get_wrapped_type() const +{ + if (m_wrapped_types.size() > 1) + { + throw ngraph::ngraph_error("get_wrapped_type() called on WrapType with more than one type"); + } + return m_wrapped_types.at(0); +} + +const std::vector& pattern::op::WrapType::get_wrapped_types() const +{ + return m_wrapped_types; +} \ No newline at end of file diff --git a/ngraph/test/pattern.cpp b/ngraph/test/pattern.cpp index 5d3772069a7ea3..733078815274a8 100644 --- a/ngraph/test/pattern.cpp +++ b/ngraph/test/pattern.cpp @@ -810,7 +810,7 @@ TEST(pattern, is_contained_match) ASSERT_FALSE(n.is_contained_match()); } -TEST(pattern, wrap_type) +TEST(pattern, wrap_type_single_op) { auto a = make_shared(element::f32, Shape{1, 3, 64, 64}); auto b = make_shared(a); @@ -852,3 +852,47 @@ TEST(pattern, wrap_type) ASSERT_TRUE(matcher->match(static_pointer_cast(mul2))); } } + +TEST(pattern, wrap_type_multi_op) +{ + auto a = make_shared(element::f32, Shape{1, 3, 64, 64}); + auto b = make_shared(a); + auto c = make_shared(a); + auto mul = make_shared(a, op::Constant::create(element::f32, Shape{}, {1})); + auto add = make_shared(op::Constant::create(element::f32, Shape{}, {1}), a); + + { + auto m = pattern::wrap_type(); + auto matcher = std::make_shared(m, "MulAddMatcher"); + ASSERT_TRUE(matcher->match(mul->output(0))); + ASSERT_EQ(matcher->get_matched_nodes().size(), 1); + ASSERT_EQ(matcher->get_matched_nodes()[0], mul); + ASSERT_EQ(matcher->get_pattern_map().count(m), 1); + + ASSERT_TRUE(matcher->match(add->output(0))); + ASSERT_EQ(matcher->get_matched_nodes().size(), 1); + ASSERT_EQ(matcher->get_matched_nodes()[0], add); + ASSERT_EQ(matcher->get_pattern_map().count(m), 1); + + ASSERT_FALSE(matcher->match(static_pointer_cast(a))); + ASSERT_FALSE(matcher->match(static_pointer_cast(b))); + ASSERT_FALSE(matcher->match(static_pointer_cast(c))); + } + { + auto m = pattern::wrap_type(); + auto matcher = std::make_shared(m, "ElementwiseMatcher"); + ASSERT_TRUE(matcher->match(mul->output(0))); + ASSERT_EQ(matcher->get_matched_nodes().size(), 1); + ASSERT_EQ(matcher->get_matched_nodes()[0], mul); + ASSERT_EQ(matcher->get_pattern_map().count(m), 1); + + ASSERT_TRUE(matcher->match(add->output(0))); + ASSERT_EQ(matcher->get_matched_nodes().size(), 1); + ASSERT_EQ(matcher->get_matched_nodes()[0], add); + ASSERT_EQ(matcher->get_pattern_map().count(m), 1); + + ASSERT_FALSE(matcher->match(static_pointer_cast(a))); + ASSERT_FALSE(matcher->match(static_pointer_cast(b))); + ASSERT_FALSE(matcher->match(static_pointer_cast(c))); + } +}