From 7e856c37001a39fac4090e15fc3f6b890a4685d4 Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Thu, 6 Aug 2020 15:55:12 +0300 Subject: [PATCH] Add mish fusion transformation (#1399) * Add mish fusion transformation * Add mish op to python api --- .../include/transformations/mish_fusion.hpp | 32 ++++++++++++ .../common_optimizations.cpp | 2 + .../src/transformations/mish_fusion.cpp | 40 ++++++++++++++ .../transformations/mish_fusion_test.cpp | 52 +++++++++++++++++++ .../extensions/front/tf/activation_ext.py | 12 ++++- .../extensions/ops/activation_ops.py | 10 ++++ ngraph/python/src/ngraph/__init__.py | 1 + ngraph/python/src/ngraph/opset4/__init__.py | 1 + ngraph/python/src/ngraph/opset4/ops.py | 10 ++++ ngraph/test/op_eval/mish.cpp | 2 +- 10 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 inference-engine/src/transformations/include/transformations/mish_fusion.hpp create mode 100644 inference-engine/src/transformations/src/transformations/mish_fusion.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp diff --git a/inference-engine/src/transformations/include/transformations/mish_fusion.hpp b/inference-engine/src/transformations/include/transformations/mish_fusion.hpp new file mode 100644 index 00000000000000..0bcfe2e68c86fe --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/mish_fusion.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API MishFusion; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief MishFusion transformation replaces group of + * operations: x * tanh(log(exp(x) + 1)) to Mish op. + */ +class ngraph::pass::MishFusion: public ngraph::pass::MatcherPass { +public: + MishFusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index ebba2881f9126d..78cee297537a69 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -11,6 +11,7 @@ #include "transformations/remove_filtering_boxes_by_size.hpp" #include "transformations/init_node_info.hpp" #include "transformations/itt.hpp" +#include "transformations/mish_fusion.hpp" #include #include @@ -32,6 +33,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); // partially depends on CF manager.register_pass(); + manager.register_pass(); manager.set_callback(m_transformation_callback); manager.run_passes(f); diff --git a/inference-engine/src/transformations/src/transformations/mish_fusion.cpp b/inference-engine/src/transformations/src/transformations/mish_fusion.cpp new file mode 100644 index 00000000000000..62a1e76b8f2c18 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/mish_fusion.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/mish_fusion.hpp" + +#include +#include + +#include +#include +#include + +ngraph::pass::MishFusion::MishFusion() { + auto input = ngraph::pattern::any_input(); + auto exp = std::make_shared(input); + auto add = std::make_shared(exp, ngraph::pattern::wrap_type()); + auto log = std::make_shared(add); + auto tanh = std::make_shared(log); + auto mul = std::make_shared(input, tanh); + + ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto mish = std::make_shared(exp_input); + + mish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(tanh).get_node_shared_ptr(), + pattern_to_output.at(log).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(exp).get_node_shared_ptr()}, mish); + ngraph::replace_node(m.get_match_root(), mish); + return true; + }; + + auto m = std::make_shared(mul, "MishFusion"); + register_matcher(m, matcher_pass_callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp new file mode 100644 index 00000000000000..a540cec12d26d3 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/mish_fusion_test.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, MishFusing) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f64, ngraph::Shape{3, 1, 2}); + auto exp = std::make_shared(input0); + auto input_const = ngraph::opset4::Constant::create(ngraph::element::f64, ngraph::Shape{1}, {-1}); + auto add = std::make_shared(exp, input_const); + auto log = std::make_shared(add); + auto tanh = std::make_shared(log); + auto mul = std::make_shared(input0, tanh); + + f = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto mish = std::make_shared(data); + + f_ref = std::make_shared(ngraph::NodeVector{mish}, ngraph::ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/model-optimizer/extensions/front/tf/activation_ext.py b/model-optimizer/extensions/front/tf/activation_ext.py index 4da9556a880a7d..83492f8754edad 100644 --- a/model-optimizer/extensions/front/tf/activation_ext.py +++ b/model-optimizer/extensions/front/tf/activation_ext.py @@ -14,7 +14,7 @@ limitations under the License. """ from extensions.ops.activation_ops import Abs, Elu, Erf, Exp, ReLU, LeakyReLU, LogicalNot, ReLU6, Sigmoid, \ - Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh + Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh, Mish from mo.front.extractor import FrontExtractorOp @@ -210,3 +210,13 @@ class CeilExtractor(FrontExtractorOp): def extract(cls, node): Ceiling.update_node_stat(node) return cls.enabled + + +class MishExtractor(FrontExtractorOp): + op = 'Mish' + enabled = True + + @classmethod + def extract(cls, node): + Mish.update_node_stat(node) + return cls.enabled diff --git a/model-optimizer/extensions/ops/activation_ops.py b/model-optimizer/extensions/ops/activation_ops.py index 9127903541d2c8..145dfdf595a139 100644 --- a/model-optimizer/extensions/ops/activation_ops.py +++ b/model-optimizer/extensions/ops/activation_ops.py @@ -240,3 +240,13 @@ def __init__(self, graph: Graph, attrs: dict): 'infer': None } super().__init__(graph, mandatory_props, attrs) + + +class Mish(Activation): + op = 'Mish' + operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0))) + + def __init__(self, graph: Graph, attrs: dict): + sp_attrs = {'version': 'opset4'} + sp_attrs.update(attrs) + super().__init__(graph, sp_attrs) diff --git a/ngraph/python/src/ngraph/__init__.py b/ngraph/python/src/ngraph/__init__.py index 0945ff9a31f37f..08cf718072be8e 100644 --- a/ngraph/python/src/ngraph/__init__.py +++ b/ngraph/python/src/ngraph/__init__.py @@ -95,6 +95,7 @@ from ngraph.opset4 import max_pool from ngraph.opset4 import maximum from ngraph.opset4 import minimum +from ngraph.opset4 import mish from ngraph.opset4 import mod from ngraph.opset4 import multiply from ngraph.opset4 import mvn diff --git a/ngraph/python/src/ngraph/opset4/__init__.py b/ngraph/python/src/ngraph/opset4/__init__.py index c7f5d2b010455f..eac33dd75751d1 100644 --- a/ngraph/python/src/ngraph/opset4/__init__.py +++ b/ngraph/python/src/ngraph/opset4/__init__.py @@ -84,6 +84,7 @@ from ngraph.opset1.ops import max_pool from ngraph.opset1.ops import maximum from ngraph.opset1.ops import minimum +from ngraph.opset4.ops import mish from ngraph.opset1.ops import mod from ngraph.opset1.ops import multiply from ngraph.opset2.ops import mvn diff --git a/ngraph/python/src/ngraph/opset4/ops.py b/ngraph/python/src/ngraph/opset4/ops.py index f9888229da880d..b91f4e7c632871 100644 --- a/ngraph/python/src/ngraph/opset4/ops.py +++ b/ngraph/python/src/ngraph/opset4/ops.py @@ -137,3 +137,13 @@ def non_max_suppression( } return _get_node_factory_opset4().create("NonMaxSuppression", inputs, attributes) + + +@nameable_op +def mish(data: NodeInput, name: Optional[str] = None,) -> Node: + """Return a node which performs Mish. + + :param data: Tensor with input data floating point type. + :return: The new node which performs Mish + """ + return _get_node_factory_opset4().create("Mish", as_nodes(data), {}) diff --git a/ngraph/test/op_eval/mish.cpp b/ngraph/test/op_eval/mish.cpp index e04d8fac7bbf9a..acc81f0e95f17d 100644 --- a/ngraph/test/op_eval/mish.cpp +++ b/ngraph/test/op_eval/mish.cpp @@ -46,6 +46,6 @@ TEST(op_eval, mish_0D) EXPECT_EQ(result->get_element_type(), element::f32); EXPECT_EQ(result->get_shape(), (Shape{})); auto result_data = read_vector(result); - EXPECT_NEAR(result_data[0], expected_result[i][0], 0.3); + EXPECT_NEAR(result_data[0], expected_result[i][0], 0.000001); } }