Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NormalizeL2 transformation #1892

Merged
merged 14 commits into from
Aug 28, 2020
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <utility>
#include <memory>

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API NormalizeL2Fusion;
class TRANSFORMATIONS_API NormalizeL2FusionWithMax;
class TRANSFORMATIONS_API NormalizeL2FusionWithAdd;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2Fusion transformation replaces various sub-graphs with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2Fusion: public ngraph::pass::GraphRewrite {
public:
NormalizeL2Fusion() {
add_matcher<ngraph::pass::NormalizeL2FusionWithMax>();
add_matcher<ngraph::pass::NormalizeL2FusionWithAdd>();
}
};

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2FusionWithMax transformation replaces a sub-graph
* x/(max(sqrt(sum(x[j0, ..., jN]**2), eps)) with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2FusionWithMax: public ngraph::pass::MatcherPass {
public:
NormalizeL2FusionWithMax();
};

/**
* @ingroup ie_transformation_common_api
* @brief NormalizeL2FusionWithAdd transformation replaces a sub-graph
* x/(add(sqrt(sum(x[j0, ..., jN]**2), eps)) with a NormalizeL2 op.
*/
class ngraph::pass::NormalizeL2FusionWithAdd: public ngraph::pass::MatcherPass {
public:
NormalizeL2FusionWithAdd();
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ngraph/op/util/op_annotations.hpp>
#include <ngraph/op/constant.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>

namespace ngraph {
namespace op {
Expand Down Expand Up @@ -53,6 +54,38 @@ inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& out
return out_name;
}

template <typename T>
bool has_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
const T value,
T epsilon = std::numeric_limits<T>::epsilon()) {
if (!constant) {
return false;
}

const bool is_scalar_or_single_elem = is_scalar(constant->get_shape()) ||
shape_size(constant->get_shape()) == 1;
if (!is_scalar_or_single_elem) {
return false;
}

if (constant->get_element_type() == ngraph::element::f16 ||
constant->get_element_type() == ngraph::element::f32 ||
constant->get_element_type() == ngraph::element::f64 ||
constant->get_element_type() == ngraph::element::bf16) {
const auto data = constant->cast_vector<T>();
if (std::fabs(data[0] - value) > epsilon) {
return false;
}
} else {
const auto data = constant->cast_vector<T>();
if (data[0] != value) {
return false;
}
}

return true;
}

TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<op::Constant> & const_node, float & value);

TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> normalize_constant(const std::shared_ptr<op::Constant> & constant,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "transformations/softplus_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include "transformations/hswish_fusion.hpp"
#include "transformations/normalize_l2_fusion.hpp"
#include "transformations/convert_quantize_dequantize.hpp"

#include <ngraph/pass/manager.hpp>
Expand Down Expand Up @@ -47,6 +48,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();

manager.set_callback(m_transformation_callback);
manager.run_passes(f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,14 @@
//

#include "transformations/hswish_fusion.hpp"
#include "transformations/utils/utils.hpp"

#include <memory>

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
const float value,
float epsilon = std::numeric_limits<float>::epsilon()) {
if (!constant) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || std::fabs(data[0] - value) > epsilon) {
return false;
}
} else {
return false;
}
return true;
}

ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
auto input = ngraph::pattern::any_input();
Expand All @@ -47,9 +31,9 @@ ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());

bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(div_const_value, 6.0);
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0)
&& op::util::has_constant_value<float>(min_const_value, 6.0)
&& op::util::has_constant_value<float>(div_const_value, 6.0);

if (!valid_constant_values) {
return false;
Expand Down Expand Up @@ -96,9 +80,9 @@ ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());

bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
&& op::util::has_constant_value<float>(mul_const_value, (1.0f/6.0f), 0.0001f);

if (!valid_constant_values) {
return false;
Expand Down Expand Up @@ -148,10 +132,10 @@ ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());

bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(max_const_value, 0.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(div_const_value, 6.0);
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
&& op::util::has_constant_value<float>(max_const_value, 0.0f)
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
&& op::util::has_constant_value<float>(div_const_value, 6.0f);

if (!valid_constant_values) {
return false;
Expand Down Expand Up @@ -196,8 +180,8 @@ ngraph::pass::HSwishFusionWithClamp::HSwishFusionWithClamp() {
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());

bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0)
&& op::util::has_constant_value(mul_const_value, (1.0/6.0), 0.0001);

if (!valid_constant_values) {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/normalize_l2_fusion.hpp"
#include "transformations/utils/utils.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
auto input = ngraph::pattern::any_input();

auto exp = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();

const auto data_input = pattern_to_output.at(input);
const auto exp_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(exp).get_node_shared_ptr());
const auto axes_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(axes).get_node_shared_ptr());
const auto eps_attr = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(eps_const).get_node_shared_ptr());

if (!exp_input || !axes_input || !eps_attr) {
return false;
}

const bool is_square_pow = op::util::has_constant_value<float>(exp_input, 2.0f);
if (!is_square_pow) {
return false;
}
if (shape_size(eps_attr->get_shape()) > 1) {
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
const auto eps_attr_value = eps_attr->cast_vector<float>()[0];

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::MAX);

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_max_eps).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
ngraph::replace_node(m.get_match_root(), normalize_l2);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
register_matcher(m, matcher_pass_callback);
}

ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
auto input = ngraph::pattern::any_input();

auto exp = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can join these two passes if check sqrt_max_eps or sqrt_add_eps types in the callback.
@GlebKazantaev what do you think about it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();

const auto data_input = pattern_to_output.at(input);
const auto exp_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(exp).get_node_shared_ptr());
const auto axes_input = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(axes).get_node_shared_ptr());
const auto eps_attr = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(eps_const).get_node_shared_ptr());

if (!exp_input || !axes_input || !eps_attr) {
return false;
}

const bool is_square_pow = shape_size(exp_input->get_shape()) <= 1 && exp_input->cast_vector<int64_t>()[0] == 2;
if (!is_square_pow) {
return false;
}
if (shape_size(eps_attr->get_shape()) > 1) {
return false;
}
const auto eps_attr_value = op::util::has_constant_value<float>(exp_input, 2.0f);

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::ADD);

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_add_eps).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
ngraph::replace_node(m.get_match_root(), normalize_l2);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,7 @@
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
if (!constant) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || data[0] != 1.0) {
return false;
}
} else {
return false;
}
return true;
}
#include "transformations/utils/utils.hpp"

bool check_beta_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
// check that the constant for beta contains only one distinct element
Expand Down Expand Up @@ -124,7 +110,7 @@ ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() {
auto exp_input = pattern_to_output.at(input);

auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!check_constant_value(constant)) {
if (!op::util::has_constant_value<float>(constant, 1.0f)) {
return false;
}

Expand Down Expand Up @@ -161,7 +147,7 @@ ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() {
auto exp_input = pattern_to_output.at(input);

auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!check_constant_value(constant)) {
if (!op::util::has_constant_value<float>(constant, 1.0f)) {
return false;
}

Expand Down
Loading