Skip to content

Commit

Permalink
Fix NormalizeL2Fusion and allow LpNormalization to be fused to Normal… (
Browse files Browse the repository at this point in the history
#5664)

* Fix NormalizeL2Fusion and allow LpNormalization to be fused to NormalizeL2

* apply code format

* use cast_vector<uint64_t>

* use MKLDNNNormalizeL2Node::isSupportedOperation in normalizeL2FusionCallback
  • Loading branch information
mateusztabaka authored Jul 22, 2021
1 parent aecb2b9 commit 1e1e3bf
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 74 deletions.
9 changes: 9 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/common_optimizations/softmax_fusion.hpp>
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
Expand Down Expand Up @@ -87,6 +88,7 @@

#include "nodes/mkldnn_mvn_node.h"
#include "nodes/mkldnn_fake_quantize_node.h"
#include "nodes/mkldnn_normalize_node.h"
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"

#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
Expand Down Expand Up @@ -277,6 +279,13 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
return node->input_value(0).get_partial_shape().rank().get_length() > 5;
});

auto normalizeL2FusionCallback = [](const_node_ptr &node) -> bool {
std::string errorMsg;
return !MKLDNNNormalizeL2Node::isSupportedOperation(node, errorMsg);
};
pass_config->set_callback<ngraph::pass::NormalizeL2FusionWithAdd>(normalizeL2FusionCallback);
pass_config->set_callback<ngraph::pass::NormalizeL2FusionWithMax>(normalizeL2FusionCallback);

// List of enabled/disabled transformations
pass_config->disable<ngraph::pass::ConvertGELU>();
pass_config->disable<ngraph::pass::ConvertShuffleChannels3>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ MKLDNNNormalizeL2Node::MKLDNNNormalizeL2Node(const std::shared_ptr<ngraph::Node>
}
}

bool MKLDNNNormalizeL2Node::isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept {
bool MKLDNNNormalizeL2Node::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
const auto norm = std::dynamic_pointer_cast<const ngraph::op::v0::NormalizeL2>(op);
if (!norm) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class MKLDNNNormalizeL2Node : public MKLDNNNode {
return false;
}

static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
bool canFuse(const MKLDNNNodePtr& node) const override;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
Expand All @@ -52,12 +52,14 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
const auto eps_attr_value = eps_attr->cast_vector<float>()[0];

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::MAX);
if (transformation_callback(normalize_l2))
return false;

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_max_eps).get_node_shared_ptr(),
pattern_to_output.at(max).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
Expand All @@ -79,10 +81,10 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
Expand All @@ -106,12 +108,14 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
const auto eps_attr_value = op::util::has_constant_value<float>(exp_input, 2.0f);

auto normalize_l2 = std::make_shared<ngraph::opset4::NormalizeL2>(data_input, axes_input, eps_attr_value, op::EpsMode::ADD);
if (transformation_callback(normalize_l2))
return false;

normalize_l2->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(pow).get_node_shared_ptr(),
pattern_to_output.at(reduce_sum).get_node_shared_ptr(),
pattern_to_output.at(sqrt).get_node_shared_ptr(),
pattern_to_output.at(sqrt_add_eps).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(divide).get_node_shared_ptr()
},
normalize_l2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMax) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand Down Expand Up @@ -62,10 +62,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand All @@ -81,10 +81,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
Expand All @@ -101,10 +101,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand All @@ -120,10 +120,10 @@ TEST(TransformationTests, NormalizeL2FusionWithMaxIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
auto max = std::make_shared<ngraph::opset4::Maximum>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(max);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
Expand All @@ -141,10 +141,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAdd) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand Down Expand Up @@ -176,10 +176,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand All @@ -196,10 +196,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectExp) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {eps_value});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
Expand All @@ -216,10 +216,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});

Expand All @@ -235,10 +235,10 @@ TEST(TransformationTests, NormalizeL2FusionWithAddIncorrectEpsValueShape) {
auto pow = std::make_shared<ngraph::opset4::Power>(input, exp);
auto axes_const = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto reduce_sum = std::make_shared<ngraph::opset4::ReduceSum>(pow, axes_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
auto eps_const = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{2}, {1, 2});
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
auto add = std::make_shared<ngraph::opset4::Add>(reduce_sum, eps_const);
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(add);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{input});
}
Expand Down
12 changes: 9 additions & 3 deletions ngraph/core/builder/include/ngraph/builder/norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ namespace ngraph
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-0 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l0_norm(const Output<Node>& value,
const Output<Node>& reduction_axes);
const Output<Node>& reduction_axes,
bool keep_dims = false);

/// \brief Calculates L-1 norm of a value.
///
Expand All @@ -45,12 +47,14 @@ namespace ngraph
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-1 norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> l1_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
float bias = 0.f);
float bias = 0.f,
bool keep_dims = false);

/// \brief Calculates L-2 norm of input tensor.
///
Expand All @@ -77,13 +81,15 @@ namespace ngraph
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] keep_dims The flag indicates if axes will be removed or kept.
///
/// \return L-p norm of value. The output sub-graph is composed of v1 ops.
///
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const Output<Node>& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f);
float bias = 0.f,
bool keep_dims = false);
} // namespace opset1
} // namespace builder
} // namespace ngraph
Loading

0 comments on commit 1e1e3bf

Please sign in to comment.