From 5c920079c4f2fed066435f769a2653d948a307f7 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 28 Jan 2022 18:56:49 +0300 Subject: [PATCH] applied comments #3 --- .../reshape_sequence_fusion.hpp | 2 +- .../common_optimizations/nop_elimination.cpp | 6 ++--- .../reshape_sequence_fusion.cpp | 14 ++++++---- .../common_optimizations/nop_elimination.cpp | 27 +++++++++++++++++++ 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp index 34a0b6315a857b..4d54950fed49b2 100644 --- a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp @@ -24,5 +24,5 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion; class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; - ReshapeSequenceFusion(); + ReshapeSequenceFusion(bool use_shape_for_elimination = true); }; diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index a9277430a8b852..16c58752f2e846 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -100,12 +100,12 @@ static bool eliminate_reshape_v1(const std::shared_ptr& node) { // eliminate redundant reshape, squeeze, or unsqueeze auto input_node = input.get_node_shared_ptr(); - if (input_node->get_output_size() != 1) - return false; - if (ov::as_type_ptr(input_node) || ov::as_type_ptr(input_node) || ov::as_type_ptr(input_node)) { + if (input_node->get_output_target_inputs(0).size() != 1) + return false; + auto shape = node->get_output_shape(0); // remove interchangeable nodes diff --git a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp index 545224b9dd5647..45a1a6f9c20568 100644 --- a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp @@ -26,7 +26,7 @@ bool has_valid_pattern(const std::shared_ptr & node) { } } -ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() { +ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) { MATCHER_SCOPE(ReshapeSequenceFusion); auto reshape_input = pattern::any_input(); auto reshape_a_pattern = pattern::wrap_type(); @@ -59,10 +59,14 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() { } // remove redundant reshapes - if (input.get_node_shared_ptr()->get_output_partial_shape(0).is_static() && reshape->get_output_partial_shape(0).is_static() && - input.get_node_shared_ptr()->get_output_shape(0) == reshape->get_output_shape(0)) { - return replace_output_update_name(reshape->output(0), input); - } else { + bool replaced = false; + if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() && + input.get_shape() == reshape->get_output_shape(0)) { + // in case if elimination is not allowed we still can eliminate all transposes except last one + replaced = replace_output_update_name(reshape->output(0), input); + } + + if (!replaced) { reshape->input(0).replace_source_output(input); copy_runtime_info(nodes, reshape); } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp index a981e779e7a08e..b38ab845172c8d 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp @@ -196,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) { ASSERT_TRUE(count_ops_of_type(f) == 1); } +TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) { + std::shared_ptr f; + { + auto arg = std::make_shared(element::f32, PartialShape{8, 16, 1, 3}); + + auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3}); + auto reshape_1 = std::make_shared(arg, reshape_1_shape, false); + reshape_1->set_friendly_name("reshape_1"); + + auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3}); + auto reshape_2 = std::make_shared(reshape_1, reshape_2_shape, false); + reshape_2->set_friendly_name("reshape_2"); + + auto relu = std::make_shared(reshape_1); + relu->set_friendly_name("relu"); + + f = std::make_shared(NodeVector{reshape_2, relu}, ParameterVector{arg}); + } + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.run_passes(f); + + ASSERT_TRUE(count_ops_of_type(f) == 2); +} + TEST(nop_elimination, concat_elimination_single_node) { int64_t a = 0; auto A = make_shared(element::f32, Shape{2, 3});