forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LeakyReluFusion transformation (openvinotoolkit#6816)
- Loading branch information
1 parent
6a3f114
commit 96dff5c
Showing
4 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
32 changes: 32 additions & 0 deletions
32
...ne/src/transformations/include/transformations/common_optimizations/leaky_relu_fusion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
#include <memory> | ||
|
||
#include <transformations_visibility.hpp> | ||
|
||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API LeakyReluFusion; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief LeakyReluFusion transformation replaces following graph: | ||
* Multiply->Maximum to LeakyRelu | ||
*/ | ||
|
||
class ngraph::pass::LeakyReluFusion: public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
LeakyReluFusion(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
...engine/src/transformations/src/transformations/common_optimizations/leaky_relu_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/leaky_relu_fusion.hpp" | ||
#include "transformations/utils/utils.hpp" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include <ngraph/opsets/opset8.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include "itt.hpp" | ||
|
||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::LeakyReluFusion, "LeakyReluFusion", 0); | ||
|
||
ngraph::pass::LeakyReluFusion::LeakyReluFusion() { | ||
MATCHER_SCOPE(LeakyReluFusion); | ||
auto data_pattern = ngraph::pattern::any_input(); | ||
auto alpha_pattern = ngraph::pattern::any_input(pattern::has_static_shape()); | ||
auto multiply_pattern = ngraph::pattern::wrap_type<opset8::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1)); | ||
auto max_pattern = ngraph::pattern::wrap_type<opset8::Maximum>({data_pattern, multiply_pattern}); | ||
|
||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { | ||
auto pattern_map = m.get_pattern_value_map(); | ||
auto data = pattern_map.at(data_pattern); | ||
const auto & original_alpha_pattern = pattern_map.at(alpha_pattern); | ||
|
||
if (shape_size(original_alpha_pattern.get_shape()) != 1) | ||
return false; | ||
|
||
auto leaky_relu = register_new_node<ngraph::opset8::PRelu>(data, original_alpha_pattern); | ||
auto maximum = pattern_map.at(max_pattern); | ||
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name()); | ||
|
||
copy_runtime_info({ | ||
pattern_map.at(multiply_pattern).get_node_shared_ptr(), | ||
maximum.get_node_shared_ptr() | ||
}, | ||
leaky_relu); | ||
replace_node(maximum.get_node_shared_ptr(), leaky_relu); | ||
|
||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(max_pattern, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
104 changes: 104 additions & 0 deletions
104
inference-engine/tests/functional/inference_engine/transformations/leaky_relu_fusion.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <string> | ||
#include <memory> | ||
#include <queue> | ||
|
||
#include <ngraph/function.hpp> | ||
#include <ngraph/opsets/opset8.hpp> | ||
#include <transformations/common_optimizations/leaky_relu_fusion.hpp> | ||
#include <transformations/init_node_info.hpp> | ||
#include <transformations/utils/utils.hpp> | ||
#include <ngraph/pass/manager.hpp> | ||
#include <ngraph/pass/constant_folding.hpp> | ||
|
||
#include "common_test_utils/ngraph_test_utils.hpp" | ||
|
||
|
||
using namespace testing; | ||
using namespace ngraph; | ||
|
||
TEST(TransformationTests, LeakyReluFusionConstant) { | ||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr); | ||
{ | ||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1}); | ||
auto multiply = std::make_shared<opset8::Multiply>(data, alpha); | ||
auto max = std::make_shared<opset8::Maximum>(data, multiply); | ||
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data}); | ||
|
||
pass::Manager m; | ||
m.register_pass<pass::InitNodeInfo>(); | ||
m.register_pass<pass::LeakyReluFusion>(); | ||
m.run_passes(f); | ||
ASSERT_NO_THROW(check_rt_info(f)); | ||
} | ||
|
||
{ | ||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1}); | ||
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha); | ||
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data}); | ||
} | ||
|
||
auto res = compare_functions(f, f_ref); | ||
ASSERT_TRUE(res.first) << res.second; | ||
} | ||
|
||
TEST(TransformationTests, LeakyReluFusionScalar) { | ||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr); | ||
{ | ||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1}); | ||
auto multiply = std::make_shared<opset8::Multiply>(data, alpha); | ||
auto max = std::make_shared<opset8::Maximum>(data, multiply); | ||
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data}); | ||
|
||
pass::Manager m; | ||
m.register_pass<pass::InitNodeInfo>(); | ||
m.register_pass<pass::LeakyReluFusion>(); | ||
m.run_passes(f); | ||
ASSERT_NO_THROW(check_rt_info(f)); | ||
} | ||
|
||
{ | ||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1}); | ||
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha); | ||
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data}); | ||
} | ||
|
||
auto res = compare_functions(f, f_ref); | ||
ASSERT_TRUE(res.first) << res.second; | ||
} | ||
|
||
TEST(TransformationTests, LeakyReluFusionParameter) { | ||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr); | ||
{ | ||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{}); | ||
auto multiply = std::make_shared<opset8::Multiply>(data, alpha); | ||
auto max = std::make_shared<opset8::Maximum>(data, multiply); | ||
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data, alpha}); | ||
|
||
pass::Manager m; | ||
m.register_pass<pass::InitNodeInfo>(); | ||
m.register_pass<pass::LeakyReluFusion>(); | ||
m.run_passes(f); | ||
ASSERT_NO_THROW(check_rt_info(f)); | ||
} | ||
|
||
{ | ||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2}); | ||
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{}); | ||
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha); | ||
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data, alpha}); | ||
} | ||
|
||
auto res = compare_functions(f, f_ref); | ||
ASSERT_TRUE(res.first) << res.second; | ||
} |