diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/leaky_relu_fusion.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/leaky_relu_fusion.hpp new file mode 100644 index 00000000000000..79e203485fa383 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/leaky_relu_fusion.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API LeakyReluFusion; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief LeakyReluFusion transformation replaces following graph: + * Multiply->Maximum to LeakyRelu + */ + +class ngraph::pass::LeakyReluFusion: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + LeakyReluFusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 415ecb11610901..44b2f5d7f40be7 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -21,6 +21,7 @@ #include "transformations/common_optimizations/swish_fusion.hpp" #include "transformations/common_optimizations/normalize_l2_fusion.hpp" #include "transformations/common_optimizations/pull_transpose_through_fq.hpp" +#include "transformations/common_optimizations/leaky_relu_fusion.hpp" #include "transformations/common_optimizations/lin_op_sequence_fusion.hpp" #include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp" #include "transformations/common_optimizations/hsigmoid_fusion.hpp" @@ -133,6 +134,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); + common_fusions->add_matcher(); common_fusions->set_name("ngraph::pass::CommonFusions"); manager.register_pass(); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/leaky_relu_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/leaky_relu_fusion.cpp new file mode 100644 index 00000000000000..388d2f171041f3 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/leaky_relu_fusion.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/leaky_relu_fusion.hpp" +#include "transformations/utils/utils.hpp" + +#include +#include + +#include +#include +#include +#include "itt.hpp" + + +NGRAPH_RTTI_DEFINITION(ngraph::pass::LeakyReluFusion, "LeakyReluFusion", 0); + +ngraph::pass::LeakyReluFusion::LeakyReluFusion() { + MATCHER_SCOPE(LeakyReluFusion); + auto data_pattern = ngraph::pattern::any_input(); + auto alpha_pattern = ngraph::pattern::any_input(pattern::has_static_shape()); + auto multiply_pattern = ngraph::pattern::wrap_type({data_pattern, alpha_pattern}, pattern::consumers_count(1)); + auto max_pattern = ngraph::pattern::wrap_type({data_pattern, multiply_pattern}); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto pattern_map = m.get_pattern_value_map(); + auto data = pattern_map.at(data_pattern); + const auto & original_alpha_pattern = pattern_map.at(alpha_pattern); + + if (shape_size(original_alpha_pattern.get_shape()) != 1) + return false; + + auto leaky_relu = register_new_node(data, original_alpha_pattern); + auto maximum = pattern_map.at(max_pattern); + leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name()); + + copy_runtime_info({ + pattern_map.at(multiply_pattern).get_node_shared_ptr(), + maximum.get_node_shared_ptr() + }, + leaky_relu); + replace_node(maximum.get_node_shared_ptr(), leaky_relu); + + return true; + }; + + auto m = std::make_shared(max_pattern, matcher_name); + this->register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/leaky_relu_fusion.cpp b/inference-engine/tests/functional/inference_engine/transformations/leaky_relu_fusion.cpp new file mode 100644 index 00000000000000..dec4de41c13c5f --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/leaky_relu_fusion.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + +TEST(TransformationTests, LeakyReluFusionConstant) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1}); + auto multiply = std::make_shared(data, alpha); + auto max = std::make_shared(data, multiply); + f = std::make_shared(NodeVector{max}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1}); + auto leaky_relu = std::make_shared(data, alpha); + f_ref = std::make_shared(NodeVector{leaky_relu}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, LeakyReluFusionScalar) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1}); + auto multiply = std::make_shared(data, alpha); + auto max = std::make_shared(data, multiply); + f = std::make_shared(NodeVector{max}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1}); + auto leaky_relu = std::make_shared(data, alpha); + f_ref = std::make_shared(NodeVector{leaky_relu}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, LeakyReluFusionParameter) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = std::make_shared(element::f32, Shape{}); + auto multiply = std::make_shared(data, alpha); + auto max = std::make_shared(data, multiply); + f = std::make_shared(NodeVector{max}, ParameterVector{data, alpha}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto alpha = std::make_shared(element::f32, Shape{}); + auto leaky_relu = std::make_shared(data, alpha); + f_ref = std::make_shared(NodeVector{leaky_relu}, ParameterVector{data, alpha}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +}