Skip to content

Commit

Permalink
Fix node name issue introduced by openvinotoolkit#5854 (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#6709)

* Fix node name issue introduced by openvinotoolkit#5854

* Compare names in TransposeFuse tests
  • Loading branch information
mvafin authored and rnugmanx committed Aug 26, 2021
1 parent e45194f commit b3d14fd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ ngraph::pass::TransposeFuse::TransposeFuse() {
auto new_order = ngraph::opset7::Constant::create(element::i64, {order2.size()}, order2);
auto new_transpose = register_new_node<ngraph::opset7::Transpose>(input, new_order);

new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ transpose1, transpose2 }, new_transpose);
ngraph::replace_node(transpose2, new_transpose);
ngraph::replace_node(m.get_match_root(), new_transpose);
}

return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,13 @@ TEST(TransformationTests, TransposeFuses) {
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 });
auto tr1_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 1, 2, 3, 4 });
auto transpose1 = std::make_shared<ngraph::opset6::Transpose>(input, tr1_order);
transpose1->set_friendly_name("transpose1");
auto tr2_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 1, 3, 4, 2, 5 });
auto transpose2 = std::make_shared<ngraph::opset6::Transpose>(transpose1, tr2_order);
transpose2->set_friendly_name("transpose2");
auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 });
auto add = std::make_shared<ngraph::opset6::Add>(transpose2, add_const);
add->set_friendly_name("add");

f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });

Expand All @@ -257,12 +260,15 @@ TEST(TransformationTests, TransposeFuses) {
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 });
auto tr_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 2, 3, 1, 4 });
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, tr_order);
transpose->set_friendly_name("transpose2");
auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 });
auto add = std::make_shared<ngraph::opset6::Add>(transpose, add_const);
add->set_friendly_name("add");

f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
const FunctionsComparator::Result res = func_comparator(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

0 comments on commit b3d14fd

Please sign in to comment.