Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NormalizeL2 transformation #1892

Merged
merged 14 commits into from
Aug 28, 2020
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <utility>
#include <memory>

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API NormalizeL2Fusion;
class TRANSFORMATIONS_API NormalizeL2FusionWithMax;
class TRANSFORMATIONS_API NormalizeL2FusionWithAdd;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
public:
NormalizeL2Fusion() {
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
}
};

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2FusionWithMax transformation replaces a sub-graph
* x/(max(sqrt(sum(x[j0, ..., jN]**2), eps)) with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2FusionWithMax: public ngraph::pass::MatcherPass {
public:
NormalizeL2FusionWithMax();
};

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2FusionWithAdd transformation replaces a sub-graph
* x/(add(sqrt(sum(x[j0, ..., jN]**2), eps)) with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2FusionWithAdd: public ngraph::pass::MatcherPass {
public:
NormalizeL2FusionWithAdd();
};
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "transformations/mish_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include "transformations/hswish_fusion.hpp"
#include "transformations/normalize_l2_fusion.hpp"

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
Expand All @@ -40,6 +41,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();

manager.set_callback(m_transformation_callback);
manager.run_passes(f);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/normalize_l2_fusion.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
auto input = ngraph::pattern::any_input();

auto exp = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();

const auto data_input = pattern_to_output.at(input);
const auto exp_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(exp).get_node_shared_ptr());
const bool is_square_pow = shape_size(exp_input->get_shape()) <= 1 && exp_input->cast_vector<int64_t>()[0] == 2;
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
if (!is_square_pow) {
return false;
}
const auto axes_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(axes).get_node_shared_ptr());
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
const auto eps_attr = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(eps_const).get_node_shared_ptr());
if (shape_size(eps_attr->get_shape()) > 1) {
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
const auto eps_attr_value = eps_attr->cast_vector<float>()[0];

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::MAX);

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_max_eps).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
ngraph::replace_node(m.get_match_root(), normalize_l2);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
register_matcher(m, matcher_pass_callback);
}

ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
auto input = ngraph::pattern::any_input();

auto exp = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can join these two passes if check sqrt_max_eps or sqrt_add_eps types in the callback.
@GlebKazantaev what do you think about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();

const auto data_input = pattern_to_output.at(input);
const auto exp_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(exp).get_node_shared_ptr());
const bool is_square_pow = shape_size(exp_input->get_shape()) <= 1 && exp_input->cast_vector<int64_t>()[0] == 2;
if (!is_square_pow) {
return false;
}
const auto axes_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(axes).get_node_shared_ptr());
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
const auto eps_attr = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(eps_const).get_node_shared_ptr());
if (shape_size(eps_attr->get_shape()) > 1) {
return false;
}
const auto eps_attr_value = eps_attr->cast_vector<float>()[0];

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::ADD);

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_add_eps).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
ngraph::replace_node(m.get_match_root(), normalize_l2);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
register_matcher(m, matcher_pass_callback);
}
Loading