From 15ecf4971d3d5222067d24b430705c5f833c5f9c Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 2 Jul 2020 09:18:38 +0200 Subject: [PATCH] [IE CLDNN] Don't force expected reorder layout & improve i64->i32 fallback (#1088) --- .../graph_optimizer/add_required_reorders.cpp | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/inference-engine/thirdparty/clDNN/src/graph_optimizer/add_required_reorders.cpp b/inference-engine/thirdparty/clDNN/src/graph_optimizer/add_required_reorders.cpp index 7a6b50df19605a..f9e2e47d12a47b 100644 --- a/inference-engine/thirdparty/clDNN/src/graph_optimizer/add_required_reorders.cpp +++ b/inference-engine/thirdparty/clDNN/src/graph_optimizer/add_required_reorders.cpp @@ -44,9 +44,6 @@ void add_required_reorders::add_reorder(program_impl& p, program_node* node, pro auto new_reorder = std::make_shared(node->id() + "_reorder_" + usr->id(), node->id(), reorder_layout); auto& new_reorder_node = p.get_or_create(new_reorder); - // make sure that new_reorder_node has correct layout - new_reorder_node.set_output_layout(reorder_layout, false); - // ToDo: add a method to program_impl class which adds an intermediate node given a node and its user auto it = std::find(usr->get_dependencies().begin(), usr->get_dependencies().end(), node); if (it == usr->get_dependencies().end()) { @@ -98,7 +95,6 @@ void add_required_reorders::run(program_impl& p) { usr->set_output_layout(current_layout, false); if (usr->type()->does_possible_implementation_exist(p.get_engine(), *usr)) { correct_layout_selected = true; - break; } else { current_layout = original_layout; current_layout.data_type = data_types::i32; @@ -106,9 +102,27 @@ void add_required_reorders::run(program_impl& p) { usr->set_output_layout(current_layout, false); if (usr->type()->does_possible_implementation_exist(p.get_engine(), *usr)) { correct_layout_selected = true; - break; } } + + if (correct_layout_selected) { + // change output_data_type field in usr to i32 + if ((static_cast(usr->get_primitive()->output_data_type) == true) && + (*(usr->get_primitive()->output_data_type) == data_types::i64)) { + std::const_pointer_cast(usr->get_primitive())->output_data_type = data_types::i32; + } + // add reorders between usr int32 output and inputs of its users + auto next_usr_itr = usr->get_users().begin(); + while (next_usr_itr != usr->get_users().end()) { + auto next_usr = *next_usr_itr++; + if (!next_usr->is_type()) { + if ((next_usr->get_output_layout() != usr->get_output_layout())) { + add_reorder(p, usr, next_usr); + } + } + } + break; + } } } @@ -185,7 +199,13 @@ void add_required_reorders::run(program_impl& p) { " kernel which satisfies output format dependecies."); } - // add reorders between usr int32 outputs and inputs of its users + // change output_data_type field in usr to i32 + if ((static_cast(usr->get_primitive()->output_data_type) == true) && + (*(usr->get_primitive()->output_data_type) == data_types::i64)) { + std::const_pointer_cast(usr->get_primitive())->output_data_type = data_types::i32; + } + + // add reorders between usr int32 output and inputs of its users auto next_usr_itr = usr->get_users().begin(); while (next_usr_itr != usr->get_users().end()) { auto next_usr = *next_usr_itr++;