Skip to content

Commit

Permalink
Add LeakyReluFusion transformation (#6816)
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztabaka authored Jul 30, 2021
1 parent 518ec79 commit c0c2f2d
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API LeakyReluFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief LeakyReluFusion transformation replaces following graph:
* Multiply->Maximum to LeakyRelu
*/

class ngraph::pass::LeakyReluFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
LeakyReluFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "transformations/common_optimizations/swish_fusion.hpp"
#include "transformations/common_optimizations/normalize_l2_fusion.hpp"
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
Expand Down Expand Up @@ -133,6 +134,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
#include "transformations/utils/utils.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"


NGRAPH_RTTI_DEFINITION(ngraph::pass::LeakyReluFusion, "LeakyReluFusion", 0);

ngraph::pass::LeakyReluFusion::LeakyReluFusion() {
MATCHER_SCOPE(LeakyReluFusion);
auto data_pattern = ngraph::pattern::any_input();
auto alpha_pattern = ngraph::pattern::any_input(pattern::has_static_shape());
auto multiply_pattern = ngraph::pattern::wrap_type<opset8::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1));
auto max_pattern = ngraph::pattern::wrap_type<opset8::Maximum>({data_pattern, multiply_pattern});

ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map.at(data_pattern);
const auto & original_alpha_pattern = pattern_map.at(alpha_pattern);

if (shape_size(original_alpha_pattern.get_shape()) != 1)
return false;

auto leaky_relu = register_new_node<ngraph::opset8::PRelu>(data, original_alpha_pattern);
auto maximum = pattern_map.at(max_pattern);
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name());

copy_runtime_info({
pattern_map.at(multiply_pattern).get_node_shared_ptr(),
maximum.get_node_shared_ptr()
},
leaky_relu);
replace_node(maximum.get_node_shared_ptr(), leaky_relu);

return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(max_pattern, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <string>
#include <memory>
#include <queue>

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <transformations/common_optimizations/leaky_relu_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"


using namespace testing;
using namespace ngraph;

TEST(TransformationTests, LeakyReluFusionConstant) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, LeakyReluFusionScalar) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, LeakyReluFusionParameter) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data, alpha});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data, alpha});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

0 comments on commit c0c2f2d

Please sign in to comment.