diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp index 5200b1386e5839..04d7cfaefebf44 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp @@ -1352,6 +1352,7 @@ bool MKLDNNMVNNode::checkAxesSuitability(const std::shared_ptrget_input_size() == 2) { if (auto axesNode = dynamic_cast(mvn->get_input_node_ptr(1))) { auto axesVal = axesNode->cast_vector(); + std::sort(axesVal.begin(), axesVal.end()); auto& mvnShape = mvn->get_output_shape(0); if (mvnShape.size() == 1) { if (axesVal.size() == 1 && axesVal[0] == 0) diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/mvn_fusion.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/mvn_fusion.hpp new file mode 100644 index 00000000000000..3de28fe67f9c52 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/mvn_fusion.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace pass { + + class TRANSFORMATIONS_API MVNFusion; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief MVNFusion transformation replaces group of + * operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op. + */ +class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + MVNFusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index cba231b2aa9c56..d3a6ed7b9acb61 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -30,6 +30,7 @@ #include "transformations/common_optimizations/clamp_fusion.hpp" #include "transformations/common_optimizations/pad_fusion.hpp" #include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp" +#include "transformations/common_optimizations/mvn_fusion.hpp" #include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp" #include "transformations/op_conversions/convert_pad_to_group_conv.hpp" #include "transformations/op_conversions/convert_divide.hpp" @@ -96,6 +97,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); + common_fusions->add_matcher(); common_fusions->set_name("ngraph::pass::CommonFusions"); manager.register_pass(); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/mvn_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/mvn_fusion.cpp new file mode 100644 index 00000000000000..9990232bfdce03 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/mvn_fusion.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "itt.hpp" +#include "transformations/common_optimizations/mvn_fusion.hpp" +#include "transformations/utils/utils.hpp" + +#include +#include + +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusion, "MVNFusion", 0); + +ngraph::pass::MVNFusion::MVNFusion() { + MATCHER_SCOPE(MVNFusion); + auto single_consumer = pattern::consumers_count(1); + + // Detect MVN decomposition pattern: + // (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) + auto x = pattern::any_input(); + + // (x - ReduceMean(x, axes)) + // `------mean1-------' + auto mean1_axes = pattern::wrap_type(); + auto mean1 = pattern::wrap_type({ x, mean1_axes }, single_consumer); + + // (x - ReduceMean(x, axes)) + // `-sub1------------------' + auto sub1 = pattern::wrap_type({ x, mean1 }, single_consumer); + + // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + // `---mean2----------' + auto mean2_axes = pattern::wrap_type(); + auto mean2 = pattern::wrap_type({ x, mean2_axes }, single_consumer); + + // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + // `-sub2------------------' + auto sub2 = pattern::wrap_type({ x, mean2 }, single_consumer); + + // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + // `---------------------power--' + auto const_2 = pattern::wrap_type(); + auto power = pattern::wrap_type({ sub2, const_2 }, single_consumer); + + // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + // `---mean3--------------------------------' + auto mean3_axes = pattern::wrap_type(); + auto mean3 = pattern::wrap_type({ power, mean3_axes }, single_consumer); + + // Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + // `--Power--------------------------------------' + auto const_0_5 = pattern::wrap_type(); + auto power_sqrt = pattern::wrap_type({ mean3, const_0_5 }, single_consumer); + // TODO: use Or to accept opset6::Sqrt operation also. + + // (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) + // `-----------------------------------------------Add---' + auto eps = pattern::wrap_type(); + auto add_eps = pattern::wrap_type({ power_sqrt, eps }, single_consumer); + + // Final Divide + auto const_neg_1 = pattern::wrap_type(); + auto power_div = pattern::wrap_type({ add_eps, const_neg_1 }, single_consumer); + auto div = pattern::wrap_type({ sub1, power_div }); + + // TODO: use Or to accept opset6::Divide operation. Also as root operation has multiple types + // we must handle it in GraphRewrite engine to perform efficient matching. + ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(x); + + auto const_2_node = std::dynamic_pointer_cast(pattern_to_output.at(const_2).get_node_shared_ptr()); + auto const_0_5_node = std::dynamic_pointer_cast(pattern_to_output.at(const_0_5).get_node_shared_ptr()); + auto const_neg_1_node = std::dynamic_pointer_cast(pattern_to_output.at(const_neg_1).get_node_shared_ptr()); + + if (!const_2_node || !const_0_5_node || !const_neg_1_node) { + return false; + } + + auto const_eps_node = std::dynamic_pointer_cast(pattern_to_output.at(eps).get_node_shared_ptr()); + float eps_value; + + bool valid_constant_values = op::util::has_constant_value(const_2_node, 2.0) + && op::util::has_constant_value(const_0_5_node, 0.5) + && op::util::has_constant_value(const_neg_1_node, -1.0) + && op::util::get_single_value(const_eps_node, eps_value); + + if (!valid_constant_values) { + return false; + } + + auto axes_1_node = std::dynamic_pointer_cast(pattern_to_output.at(mean1_axes).get_node_shared_ptr()); + auto axes_2_node = std::dynamic_pointer_cast(pattern_to_output.at(mean2_axes).get_node_shared_ptr()); + auto axes_3_node = std::dynamic_pointer_cast(pattern_to_output.at(mean3_axes).get_node_shared_ptr()); + + if (!axes_1_node || !axes_2_node || !axes_3_node) { + return false; + } + + auto axes_1_value = axes_1_node->cast_vector(); + auto axes_2_value = axes_2_node->cast_vector(); + auto axes_3_value = axes_3_node->cast_vector(); + + if (!(axes_1_value == axes_2_value && axes_2_value == axes_3_value)) { + return false; + } + + auto mvn = std::make_shared(exp_input, axes_1_node, true, eps_value, op::MVNEpsMode::OUTSIDE_SQRT); + + mvn->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({ pattern_to_output.at(mean1).get_node_shared_ptr(), + pattern_to_output.at(sub1).get_node_shared_ptr(), + pattern_to_output.at(mean2).get_node_shared_ptr(), + pattern_to_output.at(sub2).get_node_shared_ptr(), + pattern_to_output.at(power).get_node_shared_ptr(), + pattern_to_output.at(mean3).get_node_shared_ptr(), + pattern_to_output.at(power_sqrt).get_node_shared_ptr(), + pattern_to_output.at(add_eps).get_node_shared_ptr(), + pattern_to_output.at(power_div).get_node_shared_ptr(), + pattern_to_output.at(div).get_node_shared_ptr() }, mvn); + ngraph::replace_node(m.get_match_root(), mvn); + return true; + }; + + auto m = std::make_shared(div, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/mvn_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/mvn_fusion_test.cpp new file mode 100644 index 00000000000000..8129c1c0aa440e --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/mvn_fusion_test.cpp @@ -0,0 +1,62 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, MVNFusionTest) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 }); + auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 }); + auto mean1 = std::make_shared(input, mean1_axes); + auto sub1 = std::make_shared(input, mean1); + auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 }); + auto mean2 = std::make_shared(input, mean2_axes); + auto sub2 = std::make_shared(input, mean2); + auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 }); + auto power_sqr = std::make_shared(sub2, const_2); + auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 }); + auto mean3 = std::make_shared(power_sqr, mean3_axes); + auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 }); + auto power_sqrt = std::make_shared(mean3, const_0_5); + auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 }); + auto add_eps = std::make_shared(power_sqrt, eps); + auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 }); + auto power_div = std::make_shared(add_eps, const_neg_1); + auto div = std::make_shared(sub1, power_div); + + f = std::make_shared(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input }); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 }); + auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 }); + auto mvn = std::make_shared(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT); + + f_ref = std::make_shared(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input }); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +}