From fce49e6d80fbdccd1bc04bc5acc0f07370b9ce72 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <alexandra.sidorova@intel.com>
Date: Wed, 9 Feb 2022 21:11:49 +0300
Subject: [PATCH] [Transformations] Added interchangeable reshape elimination
 (#9691)

* [Transformations] Added interchangeable reshape elimination

* Applied comments #2

* returned Reshape in condition

* applied comments #3

* applied comments #4

* added comment in plugin with reason about transformation
---
 .../reshape_sequence_fusion.hpp               |  4 +-
 .../moc_transformations.cpp                   |  2 +-
 .../common_optimizations/nop_elimination.cpp  | 27 ++++---
 .../reshape_sequence_fusion.cpp               | 20 ++++-
 .../convert_to_cpu_specific_opset.hpp         |  3 +
 .../common_optimizations/nop_elimination.cpp  | 74 +++++++++++++++++--
 .../reshape_sequence_fusion.cpp               | 18 +++++
 7 files changed, 124 insertions(+), 24 deletions(-)

diff --git a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp
index 273e134c86ae6f..4d54950fed49b2 100644
--- a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp
+++ b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp
@@ -18,11 +18,11 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion;
 
 /**
  * @ingroup ie_transformation_common_api
- * @brief ReshpaeSequenceFusion fuses sequence of Reshape operation into single Reshape
+ * @brief ReshapeSequenceFusion fuses sequence of Reshape operation into single Reshape or eliminates full redundant sequence
  */
 
 class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass {
 public:
     NGRAPH_RTTI_DECLARATION;
-    ReshapeSequenceFusion();
+    ReshapeSequenceFusion(bool use_shape_for_elimination = true);
 };
diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp
index d4a760f845069c..0b8258c3255a46 100644
--- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp
+++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp
@@ -153,7 +153,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
     common_fusions->add_matcher<ngraph::pass::DivideFusion>();
     common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
     common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
-    common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>();
+    common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
     common_fusions->set_name("ngraph::pass::CommonFusions");
 
     manager.register_pass<ngraph::pass::BinarizeWeights>();
diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp
index 3d6d8c52f03385..22fb21b074a8e0 100644
--- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp
+++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp
@@ -102,16 +102,25 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
     if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
         ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
         ov::as_type_ptr<opset3::Reshape>(input_node)) {
+        if (input_node->get_output_target_inputs(0).size() != 1)
+            return false;
+
         auto shape = node->get_output_shape(0);
-        std::vector<int64_t> vi;
-        vi.assign(shape.begin(), shape.end());
-        auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
-        auto new_reshape =
-            make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
-        new_reshape->set_friendly_name(node->get_friendly_name());
-        copy_runtime_info({input_node, node}, new_reshape);
-        replace_node(node, new_reshape);
-        return true;
+
+        // remove interchangeable nodes
+        if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
+            return replace_output_update_name(node->output(0), input_node->input_value(0));
+        } else {
+            std::vector<int64_t> vi;
+            vi.assign(shape.begin(), shape.end());
+            auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
+            auto new_reshape =
+                    make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
+            new_reshape->set_friendly_name(node->get_friendly_name());
+            copy_runtime_info({input_node, node}, new_reshape);
+            replace_node(node, new_reshape);
+            return true;
+        }
     }
 
     return false;
diff --git a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp
index 6516e14eca6016..f95adb43f1011d 100644
--- a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp
+++ b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp
@@ -55,7 +55,7 @@ bool has_valid_pattern(const ov::Output<ov::Node>& node_out) {
 }
 } // namespace
 
-ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
+ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) {
     MATCHER_SCOPE(ReshapeSequenceFusion);
     auto reshape_input = pattern::any_input();
     auto reshape_a_pattern = pattern::wrap_type<opset8::Constant>();
@@ -87,9 +87,21 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
             input = node->input_value(0);
         }
 
-        reshape->input(0).replace_source_output(input);
-        copy_runtime_info(nodes, reshape);
-        return false;
+        // remove redundant reshapes
+        bool replaced = false;
+        if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() &&
+            input.get_shape() == reshape->get_output_shape(0)) {
+            // in case if elimination is not allowed we still can eliminate all transposes except last one
+            replaced = replace_output_update_name(reshape->output(0), input);
+        }
+
+        if (!replaced) {
+            reshape->input(0).replace_source_output(input);
+            copy_runtime_info(nodes, reshape);
+            return false; // because root node wasn't replaced
+        }
+
+        return true;
     };
 
     auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_b, matcher_name);
diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp b/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp
index e3da16039f88aa..75a7809321b824 100644
--- a/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp
+++ b/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp
@@ -18,6 +18,7 @@
 #include "transformations/convert_precision.hpp"
 #include "transformations/utils/utils.hpp"
 #include "rnn_sequences_optimization.hpp"
+#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
 
 namespace MKLDNNPlugin {
 
@@ -34,6 +35,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
     if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc)) {
         manager.register_pass<ReshapeFullyConnectedFusion>();
     }
+    // after transformation "MoveEltwiseUpThroughDataMov" there can be Reshape sequences that should be eliminated or fused
+    manager.register_pass<ngraph::pass::ReshapeSequenceFusion>();
     manager.register_pass<ngraph::pass::ConstantFolding>();
     manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});
 
diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp
index d8244a4239b8e7..b38ab845172c8d 100644
--- a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp
+++ b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp
@@ -140,17 +140,48 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
     pass_manager.register_pass<pass::NopElimination>();
     pass_manager.run_passes(f);
 
-    bool reshape_is_missing = true;
+    bool movement_are_missing = true;
     for (auto node : f->get_ops()) {
-        if (node->get_friendly_name() == "reshape") {
-            reshape_is_missing = false;
-            ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
-            auto original_names = ngraph::getFusedNamesVector(node);
-            sort(original_names.begin(), original_names.end());
-            ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
+        if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
+            movement_are_missing = false;
         }
     }
-    ASSERT_FALSE(reshape_is_missing);
+    ASSERT_TRUE(movement_are_missing);
+}
+
+TEST(nop_elimination, squeeze_unsqueeze_elimination) {
+    std::shared_ptr<Function> f;
+    {
+        auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
+
+        auto relu = std::make_shared<opset4::Relu>(arg);
+        relu->set_friendly_name("relu");
+
+        auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
+        auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
+        squeeze->set_friendly_name("squeeze");
+
+        auto unsqueeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
+        auto unsqueeze = std::make_shared<opset4::Unsqueeze>(squeeze, unsqueeze_axes);
+        unsqueeze->set_friendly_name("unsqueeze");
+
+        auto abs = std::make_shared<opset4::Abs>(unsqueeze);
+
+        f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
+    }
+
+    pass::Manager pass_manager;
+    pass_manager.register_pass<pass::InitNodeInfo>();
+    pass_manager.register_pass<pass::NopElimination>();
+    pass_manager.run_passes(f);
+
+    bool movement_are_missing = true;
+    for (auto node : f->get_ops()) {
+        if (node->get_friendly_name() == "squeeze" || node->get_friendly_name() == "unsqueeze") {
+            movement_are_missing = false;
+        }
+    }
+    ASSERT_TRUE(movement_are_missing);
 }
 
 TEST(nop_elimination, reshape_elimination_v1_dynamic) {
@@ -165,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) {
     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
 }
 
+TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) {
+    std::shared_ptr<Function> f;
+    {
+        auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
+
+        auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3});
+        auto reshape_1 = std::make_shared<opset4::Reshape>(arg, reshape_1_shape, false);
+        reshape_1->set_friendly_name("reshape_1");
+
+        auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
+        auto reshape_2 = std::make_shared<opset4::Reshape>(reshape_1, reshape_2_shape, false);
+        reshape_2->set_friendly_name("reshape_2");
+
+        auto relu = std::make_shared<opset4::Relu>(reshape_1);
+        relu->set_friendly_name("relu");
+
+        f = std::make_shared<Function>(NodeVector{reshape_2, relu}, ParameterVector{arg});
+    }
+
+    pass::Manager pass_manager;
+    pass_manager.register_pass<pass::InitNodeInfo>();
+    pass_manager.register_pass<pass::NopElimination>();
+    pass_manager.run_passes(f);
+
+    ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 2);
+}
+
 TEST(nop_elimination, concat_elimination_single_node) {
     int64_t a = 0;
     auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp
index 01ced7b3ac47ba..9c2c271b9bf03d 100644
--- a/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp
+++ b/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp
@@ -305,3 +305,21 @@ TEST_F(TransformationTestsF, ReshapeSequenceFusionNeg5_special_zero_false) {
         manager.register_pass<pass::ReshapeSequenceFusion>();
     }
 }
+
+TEST_F(TransformationTestsF, ReshapeSequenceFusionEliminate) {
+    {
+        auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
+        auto relu = std::make_shared<opset6::Relu>(data);
+        auto a = reshape(relu, {2, 3});
+        auto b = reshape(a, {1, 2, 3});
+        function = std::make_shared<Function>(OutputVector{b}, ParameterVector{data});
+
+        manager.register_pass<pass::ReshapeSequenceFusion>();
+    }
+
+    {
+        auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
+        auto relu = std::make_shared<opset6::Relu>(data);
+        function_ref = std::make_shared<Function>(OutputVector{relu}, ParameterVector{data});
+    }
+}