diff --git a/inference-engine/src/transformations/include/transformations/mish_fusion.hpp b/inference-engine/src/transformations/include/transformations/mish_fusion.hpp index 24c46e57661e2f..8c8f9344dddc4c 100644 --- a/inference-engine/src/transformations/include/transformations/mish_fusion.hpp +++ b/inference-engine/src/transformations/include/transformations/mish_fusion.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2020 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -9,7 +9,7 @@ #include -#include +#include "ngraph/pattern/matcher.hpp" namespace ngraph { namespace pass { @@ -19,12 +19,7 @@ class TRANSFORMATIONS_API MishFusion; } // namespace pass } // namespace ngraph -class ngraph::pass::MishFusion: public ngraph::pass::GraphRewrite { +class ngraph::pass::MishFusion: public ngraph::pass::MatcherPass { public: - MishFusion() : GraphRewrite() { - mish_fusion(); - } - -private: - void mish_fusion(); + MishFusion(); }; diff --git a/inference-engine/src/transformations/src/transformations/mish_fusion.cpp b/inference-engine/src/transformations/src/transformations/mish_fusion.cpp index 5fc95fbaa27e1d..4a389c859eb38a 100644 --- a/inference-engine/src/transformations/src/transformations/mish_fusion.cpp +++ b/inference-engine/src/transformations/src/transformations/mish_fusion.cpp @@ -9,51 +9,32 @@ #include #include - -void ngraph::pass::MishFusion::mish_fusion() { - auto input0 = std::make_shared(element::f64, Shape{1, 1, 1, 1}); - auto exp = std::make_shared(input0); - auto input_const = op::Constant::create(element::f64, Shape{1}, {-1}); - auto add = std::make_shared(exp, input_const); - auto log = std::make_shared(add); - auto tanh = std::make_shared(log); - auto mul = std::make_shared(input0, tanh); - - ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { - - auto mul = std::dynamic_pointer_cast (m.get_match_root()); - if (!mul) { - return false; - } - - auto tanh = std::dynamic_pointer_cast (mul->input_value(1).get_node_shared_ptr()); - if (!tanh) { - return false; - } - - auto log = std::dynamic_pointer_cast (tanh->input_value(0).get_node_shared_ptr()); - if (!log) { - return false; - } - - auto add = std::dynamic_pointer_cast (log->input_value(0).get_node_shared_ptr()); - if (!add) { - return false; - } - - auto exp = std::dynamic_pointer_cast (add->input_value(0).get_node_shared_ptr()); - if (!exp) { - return false; - } - - auto mish = std::make_shared(exp->input(0).get_source_output()); - - mish->set_friendly_name(exp->get_friendly_name()); - ngraph::copy_runtime_info({mul, tanh, log, add, exp}, mish); - ngraph::replace_node(mul, mish); - return true; - }; - - auto m = std::make_shared(mul, "MishFusion"); - this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); +#include + +ngraph::pass::MishFusion::MishFusion() { + auto input = ngraph::pattern::any_input(); + auto exp = std::make_shared(input); + auto add = std::make_shared(exp, ngraph::pattern::wrap_type()); + auto log = std::make_shared(add); + auto tanh = std::make_shared(log); + auto mul = std::make_shared(input, tanh); + + ngraph::graph_rewrite_callback callback = [=](ngraph::pattern::Matcher& m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto exp_input = pattern_to_output.at(input); + + auto mish = std::make_shared(exp_input); + + mish->set_friendly_name(m.get_match_root()->get_friendly_name()); + ngraph::copy_runtime_info({pattern_to_output.at(mul).get_node_shared_ptr(), + pattern_to_output.at(tanh).get_node_shared_ptr(), + pattern_to_output.at(log).get_node_shared_ptr(), + pattern_to_output.at(add).get_node_shared_ptr(), + pattern_to_output.at(exp).get_node_shared_ptr()}, mish); + ngraph::replace_node(m.get_match_root(), mish); + return true; + }; + + auto m = std::make_shared(mul, "MishFusion"); + register_matcher(m, callback); }