Skip to content

Commit

Permalink
Add MVN fusion transformation (#4262)
Browse files Browse the repository at this point in the history
* Add MVN fusion transformation

* Apply suggestions from code review

Co-authored-by: Gleb Kazantaev <[email protected]>

* Apply review feedback

* Fix build

* Apply review feedback

Co-authored-by: Gleb Kazantaev <[email protected]>
  • Loading branch information
mvafin and GlebKazantaev authored Feb 18, 2021
1 parent ec9b589 commit 1bfe79c
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ bool MKLDNNMVNNode::checkAxesSuitability(const std::shared_ptr<const ngraph::Nod
if (mvn != nullptr && node->get_input_size() == 2) {
if (auto axesNode = dynamic_cast<ngraph::op::v0::Constant*>(mvn->get_input_node_ptr(1))) {
auto axesVal = axesNode->cast_vector<int>();
std::sort(axesVal.begin(), axesVal.end());
auto& mvnShape = mvn->get_output_shape(0);
if (mvnShape.size() == 1) {
if (axesVal.size() == 1 && axesVal[0] == 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API MVNFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief MVNFusion transformation replaces group of
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
*/
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MVNFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/common_optimizations/pad_fusion.hpp"
#include "transformations/common_optimizations/eliminate_unsqueeze_gather.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
Expand Down Expand Up @@ -96,6 +97,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::NormalizeL2Fusion>();
common_fusions->add_matcher<ngraph::pass::ClampFusion>();
common_fusions->add_matcher<ngraph::pass::PadFusion>();
common_fusions->add_matcher<ngraph::pass::MVNFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "itt.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/utils/utils.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset6.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusion, "MVNFusion", 0);

ngraph::pass::MVNFusion::MVNFusion() {
MATCHER_SCOPE(MVNFusion);
auto single_consumer = pattern::consumers_count(1);

// Detect MVN decomposition pattern:
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
auto x = pattern::any_input();

// (x - ReduceMean(x, axes))
// `------mean1-------'
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes }, single_consumer);

// (x - ReduceMean(x, axes))
// `-sub1------------------'
auto sub1 = pattern::wrap_type<opset6::Subtract>({ x, mean1 }, single_consumer);

// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---mean2----------'
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ x, mean2_axes }, single_consumer);

// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `-sub2------------------'
auto sub2 = pattern::wrap_type<opset6::Subtract>({ x, mean2 }, single_consumer);

// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---------------------power--'
auto const_2 = pattern::wrap_type<opset6::Constant>();
auto power = pattern::wrap_type<opset6::Power>({ sub2, const_2 }, single_consumer);

// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `---mean3--------------------------------'
auto mean3_axes = pattern::wrap_type<opset6::Constant>();
auto mean3 = pattern::wrap_type<opset6::ReduceMean>({ power, mean3_axes }, single_consumer);

// Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2))
// `--Power--------------------------------------'
auto const_0_5 = pattern::wrap_type<ngraph::opset6::Constant>();
auto power_sqrt = pattern::wrap_type<opset6::Power>({ mean3, const_0_5 }, single_consumer);
// TODO: use Or to accept opset6::Sqrt operation also.

// (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
// `-----------------------------------------------Add---'
auto eps = pattern::wrap_type<opset6::Constant>();
auto add_eps = pattern::wrap_type<opset6::Add>({ power_sqrt, eps }, single_consumer);

// Final Divide
auto const_neg_1 = pattern::wrap_type<opset6::Constant>();
auto power_div = pattern::wrap_type<opset6::Power>({ add_eps, const_neg_1 }, single_consumer);
auto div = pattern::wrap_type<opset6::Multiply>({ sub1, power_div });

// TODO: use Or to accept opset6::Divide operation. Also as root operation has multiple types
// we must handle it in GraphRewrite engine to perform efficient matching.
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(x);

auto const_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_2).get_node_shared_ptr());
auto const_0_5_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_0_5).get_node_shared_ptr());
auto const_neg_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_neg_1).get_node_shared_ptr());

if (!const_2_node || !const_0_5_node || !const_neg_1_node) {
return false;
}

auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
float eps_value;

bool valid_constant_values = op::util::has_constant_value<float>(const_2_node, 2.0)
&& op::util::has_constant_value<float>(const_0_5_node, 0.5)
&& op::util::has_constant_value<float>(const_neg_1_node, -1.0)
&& op::util::get_single_value(const_eps_node, eps_value);

if (!valid_constant_values) {
return false;
}

auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
auto axes_3_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean3_axes).get_node_shared_ptr());

if (!axes_1_node || !axes_2_node || !axes_3_node) {
return false;
}

auto axes_1_value = axes_1_node->cast_vector<int64_t>();
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
auto axes_3_value = axes_3_node->cast_vector<int64_t>();

if (!(axes_1_value == axes_2_value && axes_2_value == axes_3_value)) {
return false;
}

auto mvn = std::make_shared<ngraph::opset6::MVN>(exp_input, axes_1_node, true, eps_value, op::MVNEpsMode::OUTSIDE_SQRT);

mvn->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
pattern_to_output.at(sub1).get_node_shared_ptr(),
pattern_to_output.at(mean2).get_node_shared_ptr(),
pattern_to_output.at(sub2).get_node_shared_ptr(),
pattern_to_output.at(power).get_node_shared_ptr(),
pattern_to_output.at(mean3).get_node_shared_ptr(),
pattern_to_output.at(power_sqrt).get_node_shared_ptr(),
pattern_to_output.at(add_eps).get_node_shared_ptr(),
pattern_to_output.at(power_div).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr() }, mvn);
ngraph::replace_node(m.get_match_root(), mvn);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <string>
#include <memory>

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/mvn_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"

using namespace testing;

TEST(TransformationTests, MVNFusionTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes);
auto sub1 = std::make_shared<ngraph::opset6::Subtract>(input, mean1);
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean2_axes);
auto sub2 = std::make_shared<ngraph::opset6::Subtract>(input, mean2);
auto const_2 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 2 });
auto power_sqr = std::make_shared<ngraph::opset6::Power>(sub2, const_2);
auto mean3_axes = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mean3 = std::make_shared<ngraph::opset6::ReduceMean>(power_sqr, mean3_axes);
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 0.5 });
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(mean3, const_0_5);
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
auto add_eps = std::make_shared<ngraph::opset6::Add>(power_sqrt, eps);
auto const_neg_1 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
auto power_div = std::make_shared<ngraph::opset6::Power>(add_eps, const_neg_1);
auto div = std::make_shared<ngraph::opset6::Multiply>(sub1, power_div);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ div }, ngraph::ParameterVector{ input });

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MVNFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224, 224 });
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 1, 2, 3 });
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::OUTSIDE_SQRT);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mvn }, ngraph::ParameterVector{ input });
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

0 comments on commit 1bfe79c

Please sign in to comment.