diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp index 9a52445bf76f02..21211a7be462cb 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -212,8 +212,9 @@ ngraph::pass::TransposeFuse::TransposeFuse() { auto new_order = ngraph::opset7::Constant::create(element::i64, {order2.size()}, order2); auto new_transpose = register_new_node(input, new_order); + new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name()); ngraph::copy_runtime_info({ transpose1, transpose2 }, new_transpose); - ngraph::replace_node(transpose2, new_transpose); + ngraph::replace_node(m.get_match_root(), new_transpose); } return true; diff --git a/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp index 9d0c6fdf715dea..64537727295ae1 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp @@ -239,10 +239,13 @@ TEST(TransformationTests, TransposeFuses) { auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 }); auto tr1_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 1, 2, 3, 4 }); auto transpose1 = std::make_shared(input, tr1_order); + transpose1->set_friendly_name("transpose1"); auto tr2_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 1, 3, 4, 2, 5 }); auto transpose2 = std::make_shared(transpose1, tr2_order); + transpose2->set_friendly_name("transpose2"); auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 }); auto add = std::make_shared(transpose2, add_const); + add->set_friendly_name("add"); f = std::make_shared(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input }); @@ -257,12 +260,15 @@ TEST(TransformationTests, TransposeFuses) { auto input = std::make_shared(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 }); auto tr_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 2, 3, 1, 4 }); auto transpose = std::make_shared(input, tr_order); + transpose->set_friendly_name("transpose2"); auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 }); auto add = std::make_shared(transpose, add_const); + add->set_friendly_name("add"); f_ref = std::make_shared(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input }); } - auto res = compare_functions(f, f_ref); - ASSERT_TRUE(res.first) << res.second; + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); + const FunctionsComparator::Result res = func_comparator(f, f_ref); + ASSERT_TRUE(res.valid) << res.message; }