Skip to content

Commit

Permalink
Fix mish transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Jul 29, 2020
1 parent fbb96c9 commit 0b3e23c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void ngraph::pass::MishFusion::mish_fusion() {
auto mul = std::make_shared<ngraph::opset4::Multiply>(input0, tanh);

ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {

auto mul = std::dynamic_pointer_cast<ngraph::opset4::Multiply> (m.get_match_root());
if (!mul) {
return false;
Expand All @@ -40,7 +41,7 @@ void ngraph::pass::MishFusion::mish_fusion() {
return false;
}

auto exp = std::dynamic_pointer_cast<ngraph::opset4::Add> (add->input_value(0).get_node_shared_ptr());
auto exp = std::dynamic_pointer_cast<ngraph::opset4::Exp> (add->input_value(0).get_node_shared_ptr());
if (!exp) {
return false;
}
Expand All @@ -49,7 +50,7 @@ void ngraph::pass::MishFusion::mish_fusion() {

mish->set_friendly_name(exp->get_friendly_name());
ngraph::copy_runtime_info({mul, tanh, log, add, exp}, mish);
ngraph::replace_node(exp, mish);
ngraph::replace_node(mul, mish);
return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
Expand All @@ -31,8 +33,12 @@ TEST(TransformationTests, MishFusing) {

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input0});

ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::MishFusion().run_on_function(f);
ngraph::pass::Manager manager;
//manager.register_pass<ngraph::pass::VisualizeTree>("/home/imironov/projects/dpd_vcp_dl/svg_debug/before.svg");
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::MishFusion>();
//manager.register_pass<ngraph::pass::VisualizeTree>("/home/imironov/projects/dpd_vcp_dl/svg_debug/after.svg");
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

Expand Down

0 comments on commit 0b3e23c

Please sign in to comment.