Skip to content

Commit

Permalink
Enable TransposeSyncinc and BroadcastElementwiseFusion in MOC Backend (
Browse files Browse the repository at this point in the history
  • Loading branch information
Gleb Kazantaev authored and andrei-cv committed Aug 30, 2021
1 parent ba5d231 commit 84860c7
Showing 1 changed file with 11 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <transformations/common_optimizations/binarize_weights.hpp>
#include <transformations/common_optimizations/conv_to_binary_conv.hpp>
#include <transformations/common_optimizations/eliminate_unsqueeze_gather.hpp>
#include <transformations/common_optimizations/split_squeeze_concat_fusion.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
#include <transformations/common_optimizations/broadcast_elementwise_fusion.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);

Expand All @@ -45,12 +48,20 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();

auto transpose_sinking = manager.register_pass<ngraph::pass::GraphRewrite>();
transpose_sinking->add_matcher<ngraph::pass::TransposeSinking>();
// SplitSqueezeConcatFusion should work in same GraphRewrite as TransposesSinking,
// because it replaces pattern that may contain Transposes which must be optimized before
// the transformation and it also inserts Transpose that can be optimized by TransposeSinking
transpose_sinking->add_matcher<ngraph::pass::SplitSqueezeConcatFusion>();

auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();
eliminations->set_name("ngraph::pass::CommonEliminations");

auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
common_fusions->add_matcher<ngraph::pass::BroadcastElementwiseFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
Expand Down

0 comments on commit 84860c7

Please sign in to comment.