Skip to content

Commit

Permalink
Refactoring mish transformation according to review
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jul 30, 2020
1 parent 0bdb0e2 commit a011a73
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2020 Intel Corporation
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -9,7 +9,7 @@

#include <transformations_visibility.hpp>

#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {
Expand All @@ -19,12 +19,7 @@ class TRANSFORMATIONS_API MishFusion;
} // namespace pass
} // namespace ngraph

class ngraph::pass::MishFusion: public ngraph::pass::GraphRewrite {
class ngraph::pass::MishFusion: public ngraph::pass::MatcherPass {
public:
MishFusion() : GraphRewrite() {
mish_fusion();
}

private:
void mish_fusion();
MishFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,32 @@

#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>

void ngraph::pass::MishFusion::mish_fusion() {
auto input0 = std::make_shared<pattern::op::Label>(element::f64, Shape{1, 1, 1, 1});
auto exp = std::make_shared<ngraph::opset4::Exp>(input0);
auto input_const = op::Constant::create(element::f64, Shape{1}, {-1});
auto add = std::make_shared<ngraph::opset4::Add>(exp, input_const);
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);

ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {

auto mul = std::dynamic_pointer_cast<ngraph::opset4::Multiply> (m.get_match_root());
if (!mul) {
return false;
}

auto tanh = std::dynamic_pointer_cast<ngraph::opset4::Tanh> (mul->input_value(1).get_node_shared_ptr());
if (!tanh) {
return false;
}

auto log = std::dynamic_pointer_cast<ngraph::opset4::Log> (tanh->input_value(0).get_node_shared_ptr());
if (!log) {
return false;
}

auto add = std::dynamic_pointer_cast<ngraph::opset4::Add> (log->input_value(0).get_node_shared_ptr());
if (!add) {
return false;
}

auto exp = std::dynamic_pointer_cast<ngraph::opset4::Exp> (add->input_value(0).get_node_shared_ptr());
if (!exp) {
return false;
}

auto mish = std::make_shared<ngraph::opset4::Mish>(exp->input(0).get_source_output());

mish->set_friendly_name(exp->get_friendly_name());
ngraph::copy_runtime_info({mul, tanh, log, add, exp}, mish);
ngraph::replace_node(mul, mish);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MishFusion");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
#include <ngraph/pattern/op/wrap_type.hpp>

ngraph::pass::MishFusion::MishFusion() {
auto input = ngraph::pattern::any_input();
auto exp = std::make_shared<ngraph::opset4::Exp>(input);
auto add = std::make_shared<ngraph::opset4::Add>(exp, ngraph::pattern::wrap_type<ngraph::opset4::Constant>());
auto log = std::make_shared<ngraph::opset4::Log>(add);
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);

ngraph::graph_rewrite_callback callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);

auto mish = std::make_shared<ngraph::opset4::Mish>(exp_input);

mish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(tanh).get_node_shared_ptr(),
pattern_to_output.at(log).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(exp).get_node_shared_ptr()}, mish);
ngraph::replace_node(m.get_match_root(), mish);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MishFusion");
register_matcher(m, callback);
}

0 comments on commit a011a73

Please sign in to comment.