Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mish fusion transformation #1399

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API MishFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief MishFusion transformation replaces group of
* operations: x * tanh(log(exp(x) + 1)) to Mish op.
*/
class ngraph::pass::MishFusion: public ngraph::pass::MatcherPass {
public:
MishFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "transformations/remove_filtering_boxes_by_size.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/nop_elimination.hpp>
Expand All @@ -32,6 +33,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();

manager.set_callback(m_transformation_callback);
manager.run_passes(f);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/mish_fusion.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

ngraph::pass::MishFusion::MishFusion() {
iimironov marked this conversation as resolved.
Show resolved Hide resolved
auto input = ngraph::pattern::any_input();
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
auto add = std::make_shared<ngraph::opset4::Add>(exp, ngraph::pattern::wrap_type<ngraph::opset4::Constant>());
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);

ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);

auto mish = std::make_shared<ngraph::opset4::Mish>(exp_input);

mish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(log).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(exp).get_node_shared_ptr()}, mish);
ngraph::replace_node(m.get_match_root(), mish);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MishFusion");
register_matcher(m, matcher_pass_callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <string>
#include <memory>
#include <queue>

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"

using namespace testing;

TEST(TransformationTests, MishFusing) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f64, ngraph::Shape{3, 1, 2});
auto exp = std::make_shared<ngraph::opset4::Exp>(input0);
auto input_const = ngraph::opset4::Constant::create(ngraph::element::f64, ngraph::Shape{1}, {-1});
auto add = std::make_shared<ngraph::opset4::Add>(exp, input_const);
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0});

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mish = std::make_shared<ngraph::opset4::Mish>(data);

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mish}, ngraph::ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
12 changes: 11 additions & 1 deletion model-optimizer/extensions/front/tf/activation_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.activation_ops import Abs, Elu, Erf, Exp, ReLU, LeakyReLU, LogicalNot, ReLU6, Sigmoid, \
Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh
Sin, Sinh, Cos, Cosh, Tan, Tanh, Ceiling, Atanh, Acosh, Asinh, Mish
from mo.front.extractor import FrontExtractorOp


Expand Down Expand Up @@ -210,3 +210,13 @@ class CeilExtractor(FrontExtractorOp):
def extract(cls, node):
Ceiling.update_node_stat(node)
return cls.enabled


class MishExtractor(FrontExtractorOp):
op = 'Mish'
enabled = True

@classmethod
def extract(cls, node):
Mish.update_node_stat(node)
return cls.enabled
10 changes: 10 additions & 0 deletions model-optimizer/extensions/ops/activation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,13 @@ def __init__(self, graph: Graph, attrs: dict):
'infer': None
}
super().__init__(graph, mandatory_props, attrs)


class Mish(Activation):
op = 'Mish'
operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0)))

def __init__(self, graph: Graph, attrs: dict):
sp_attrs = {'version': 'opset4'}
sp_attrs.update(attrs)
super().__init__(graph, sp_attrs)
1 change: 1 addition & 0 deletions ngraph/python/src/ngraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from ngraph.opset4 import max_pool
from ngraph.opset4 import maximum
from ngraph.opset4 import minimum
from ngraph.opset4 import mish
from ngraph.opset4 import mod
from ngraph.opset4 import multiply
from ngraph.opset4 import mvn
Expand Down
1 change: 1 addition & 0 deletions ngraph/python/src/ngraph/opset4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from ngraph.opset1.ops import max_pool
from ngraph.opset1.ops import maximum
from ngraph.opset1.ops import minimum
from ngraph.opset4.ops import mish
from ngraph.opset1.ops import mod
from ngraph.opset1.ops import multiply
from ngraph.opset2.ops import mvn
Expand Down
10 changes: 10 additions & 0 deletions ngraph/python/src/ngraph/opset4/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,13 @@ def non_max_suppression(
}

return _get_node_factory_opset4().create("NonMaxSuppression", inputs, attributes)


@nameable_op
def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
"""Return a node which performs Mish.

:param data: Tensor with input data floating point type.
:return: The new node which performs Mish
"""
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
2 changes: 1 addition & 1 deletion ngraph/test/op_eval/mish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ TEST(op_eval, mish_0D)
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), (Shape{}));
auto result_data = read_vector<float>(result);
EXPECT_NEAR(result_data[0], expected_result[i][0], 0.3);
EXPECT_NEAR(result_data[0], expected_result[i][0], 0.000001);
}
}