Skip to content

Commit

Permalink
Add mish fusion transformation (#1399)
Browse files Browse the repository at this point in the history
* Add mish fusion transformation

* Add mish op to python api
  • Loading branch information
iimironov authored Aug 6, 2020
1 parent ab869da commit 7e856c3
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API MishFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief MishFusion transformation replaces group of
* operations: x * tanh(log(exp(x) + 1)) to Mish op.
*/
class ngraph::pass::MishFusion: public ngraph::pass::MatcherPass {
public:
MishFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "transformations/remove_filtering_boxes_by_size.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/nop_elimination.hpp>
Expand All @@ -32,6 +33,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();

manager.set_callback(m_transformation_callback);
manager.run_passes(f);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/mish_fusion.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

ngraph::pass::MishFusion::MishFusion() {
auto input = ngraph::pattern::any_input();
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
auto add = std::make_shared<ngraph::opset4::Add>(exp, ngraph::pattern::wrap_type<ngraph::opset4::Constant>());
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);

auto mish = std::make_shared<ngraph::opset4::Mish>(exp_input);

mish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(log).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(exp).get_node_shared_ptr()}, mish);
ngraph::replace_node(m.get_match_root(), mish);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MishFusion");
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <string>
#include <memory>
#include <queue>

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"

using namespace testing;

TEST(TransformationTests, MishFusing) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f64, ngraph::Shape{3, 1, 2});
auto exp = std::make_shared<ngraph::opset4::Exp>(input0);
auto input_const = ngraph::opset4::Constant::create(ngraph::element::f64, ngraph::Shape{1}, {-1});
auto add = std::make_shared<ngraph::opset4::Add>(exp, input_const);
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0});

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mish = std::make_shared<ngraph::opset4::Mish>(data);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mish}, ngraph::ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
12 changes: 11 additions & 1 deletion model-optimizer/extensions/front/tf/activation_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.activation_ops import Abs, Elu, Erf, Exp, ReLU, LeakyReLU, LogicalNot, ReLU6, Sigmoid, \
Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh
Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh, Mish
from mo.front.extractor import FrontExtractorOp


Expand Down Expand Up @@ -210,3 +210,13 @@ class CeilExtractor(FrontExtractorOp):
def extract(cls, node):
Ceiling.update_node_stat(node)
return cls.enabled


class MishExtractor(FrontExtractorOp):
op = 'Mish'
enabled = True

@classmethod
def extract(cls, node):
Mish.update_node_stat(node)
return cls.enabled
10 changes: 10 additions & 0 deletions model-optimizer/extensions/ops/activation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,13 @@ def __init__(self, graph: Graph, attrs: dict):
'infer': None
}
super().__init__(graph, mandatory_props, attrs)


class Mish(Activation):
op = 'Mish'
operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))

def __init__(self, graph: Graph, attrs: dict):
sp_attrs = {'version': 'opset4'}
sp_attrs.update(attrs)
super().__init__(graph, sp_attrs)
1 change: 1 addition & 0 deletions ngraph/python/src/ngraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from ngraph.opset4 import max_pool
from ngraph.opset4 import maximum
from ngraph.opset4 import minimum
from ngraph.opset4 import mish
from ngraph.opset4 import mod
from ngraph.opset4 import multiply
from ngraph.opset4 import mvn
Expand Down
1 change: 1 addition & 0 deletions ngraph/python/src/ngraph/opset4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from ngraph.opset1.ops import max_pool
from ngraph.opset1.ops import maximum
from ngraph.opset1.ops import minimum
from ngraph.opset4.ops import mish
from ngraph.opset1.ops import mod
from ngraph.opset1.ops import multiply
from ngraph.opset2.ops import mvn
Expand Down
10 changes: 10 additions & 0 deletions ngraph/python/src/ngraph/opset4/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,13 @@ def non_max_suppression(
}

return _get_node_factory_opset4().create("NonMaxSuppression", inputs, attributes)


@nameable_op
def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
"""Return a node which performs Mish.
:param data: Tensor with input data floating point type.
:return: The new node which performs Mish
"""
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
2 changes: 1 addition & 1 deletion ngraph/test/op_eval/mish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ TEST(op_eval, mish_0D)
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), (Shape{}));
auto result_data = read_vector<float>(result);
EXPECT_NEAR(result_data[0], expected_result[i][0], 0.3);
EXPECT_NEAR(result_data[0], expected_result[i][0], 0.000001);
}
}

0 comments on commit 7e856c3

Please sign in to comment.