diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 68c40d33dfd7..077f5dd21850 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -554,6 +554,7 @@ void FindOutputEntries(nnvm::Graph* g, */ void CutGraphInputs(const std::vector &input_entries, std::vector *orig_entries, + std::vector *all_entries, const bool skip_var = false) { // map for creating unique var nodes for deduplicating entries from the same node std::unordered_map name_count_map; @@ -584,7 +585,9 @@ void CutGraphInputs(const std::vector &input_entries, // store the node in the map name_count_map.emplace(var_name, e_); - } + } + all_entries->push_back(*e); + // lookup the name of the node and set it as the input dependency *e = name_count_map[var_name]; } @@ -618,8 +621,11 @@ void CreateSubgraphNode(nnvm::Graph* g, #endif std::vector input_entries; FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries); + // deduplicated array of inputs to connect to subgraph std::vector orig_input_entries; - CutGraphInputs(input_entries, &orig_input_entries, false); + // all original input connections, used to reattach subgraph inputs + std::vector all_input_entries; + CutGraphInputs(input_entries, &orig_input_entries, &all_input_entries, false); #if DEBUG_SUBGRAPH PrintNodeEntries(input_entries); LOG(INFO) << "Searching for output entries..."; @@ -661,7 +667,7 @@ void CreateSubgraphNode(nnvm::Graph* g, } } } else { - ReattachGraphInputs(input_entries, &orig_input_entries); + ReattachGraphInputs(input_entries, &all_input_entries); } #if DEBUG_SUBGRAPH if (n)