diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 38038f2a4618..72dd2da90bf3 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -537,33 +537,40 @@ void FindOutputEntries(nnvm::Graph* g, */ void CutGraphInputs(const std::vector &input_entries, std::vector *orig_entries, + std::vector *unique_orig_entries, + std::vector *unique_input_entries, const bool skip_var = false) { orig_entries->resize(input_entries.size()); // map for creating unique var nodes for deduplicating entries from the same node - std::unordered_map name_count_map; + std::unordered_map name_count_map; for (size_t i = 0; i < input_entries.size(); ++i) { nnvm::NodeEntry *e = input_entries[i]; // If the node is a variable itself, we may want to skip the node. if (e->node->is_variable() && skip_var) { continue; } - + // save all original entries orig_entries->at(i) = *e; + // get unique name for this entry nnvm::Symbol sym; sym.outputs.push_back(*e); const auto output_names = sym.ListOutputNames(); CHECK_EQ(output_names.size(), 1U); const std::string& var_name = output_names[0]; + // check if this entry is a duplicate auto it = name_count_map.find(var_name); if (name_count_map.end() == it) { - name_count_map.emplace(var_name, 0); + // first use of this node as input to subgraph + unique_orig_entries->push_back(*e); + unique_input_entries->push_back(e); + nnvm::ObjectPtr n = nnvm::CreateVariableNode(var_name + std::to_string(0)); + *e = nnvm::NodeEntry{n, 0, 0}; + // store node for re-use + name_count_map.emplace(var_name, *e); } else { - ++(it->second); + // other use of same node as input to subgraph + *e = it->second; } - nnvm::ObjectPtr n = nnvm::CreateVariableNode( - var_name + std::to_string(name_count_map[var_name])); - - *e = nnvm::NodeEntry{n, 0, 0}; } } @@ -593,10 +600,13 @@ void CreateSubgraphNode(nnvm::Graph* g, #if DEBUG_SUBGRAPH LOG(INFO) << "Searching for input entries..."; #endif - std::vector input_entries; + std::vector input_entries; // nodes that produce inputs to subgraph nodes FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries); - std::vector orig_input_entries; - CutGraphInputs(input_entries, &orig_input_entries, false); + std::vector orig_input_entries; // original input entries (dupes) + std::vector unique_orig_entries; // unique original input entries + std::vector unique_input_entries; // unique modified subgraph inputs + CutGraphInputs(input_entries, &orig_input_entries, &unique_orig_entries, + &unique_input_entries, false); #if DEBUG_SUBGRAPH PrintNodeEntries(input_entries); LOG(INFO) << "Searching for output entries..."; @@ -605,20 +615,31 @@ void CreateSubgraphNode(nnvm::Graph* g, FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries); // Create a subgraph for the subgraph node + // entries are in topological order, with duplicates being neighbors nnvm::Symbol sym; + size_t idx = 0; + nnvm::NodeEntryEqual node_equal; sym.outputs.resize(output_entries.size()); for (size_t i = 0; i < output_entries.size(); ++i) { - sym.outputs[i] = *output_entries[i]; + if (i == 0) { // add first entry + sym.outputs[idx] = *output_entries[i]; + } else if (!node_equal(sym.outputs[idx], *output_entries[i])) { // compare to see if diff + // add new entries + idx++; + sym.outputs[idx] = *output_entries[i]; + } // else skip over dupe entries } + sym.outputs.resize(idx+1); + const SubgraphPropertyPtr& subg_prop = g->GetAttr("subgraph_property"); - subg_prop->InitSubgraphInputs(&input_entries, &orig_input_entries); + subg_prop->InitSubgraphInputs(&unique_input_entries, &unique_orig_entries); nnvm::ObjectPtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id); // CreateSubgraphNode returns NULL if subgraph property determines that subgraph is sub-optimal // In that case, subgraph node is not created and graph is not modified if (n) { // Connect the external nodes to the subgraph node. subg_prop->ConnectSubgraphOutputs(n, &output_entries); - subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries); + subg_prop->ConnectSubgraphInputs(n, &unique_input_entries, &unique_orig_entries); const auto& indexed_graph = g->indexed_graph(); for (size_t i = 0; i < n->inputs.size(); ++i) { diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index 7fadfca2ea97..ae3075c2d080 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -342,8 +342,18 @@ class SubgraphProperty { */ virtual void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node, std::vector* output_entries) const { + // Collapse output_entries pointing to same NodeEntry + // Outputs are ordered, duplicates are neighbors + nnvm::NodeEntryEqual node_equal; + nnvm::NodeEntry prevNodeEntry; + uint32_t idx = 0; for (size_t i = 0; i < output_entries->size(); ++i) { - *output_entries->at(i) = nnvm::NodeEntry{subgraph_node, static_cast(i), 0}; + // increment the output idx for each unique output of the subgraph + if (i != 0 && !node_equal(prevNodeEntry, *output_entries->at(i))) + idx++; + prevNodeEntry = *output_entries->at(i); // make a copy so we can compare before modifying + // change output entry to point to subgraph instead of original node + *output_entries->at(i) = nnvm::NodeEntry{subgraph_node, idx, 0}; } } /*! diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index 9771a18618d8..2974838f3838 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -87,6 +87,18 @@ def network_structure_7(): ret = ret1 + ret2 return (ret, ['data'], [(1,)]) +def network_structure_8(): + # in this graph, two nodes in the subgraph consume the same input, and + # and two nodes outside the subgraph consume a single output from the subgraph + data = mx.sym.Variable('data', shape=(1,)) + sin1 = mx.sym.sin(data) + sin2 = mx.sym.sin(data) + plus = sin1 + sin2 + ret1 = mx.sym.cos(plus) + ret2 = mx.sym.cos(plus) + ret = ret1 - ret2 + return (ret, ['data'], [(1,)]) + def get_graphs(): return [ (network_structure_1(), ['Convolution']), @@ -104,7 +116,8 @@ def get_graphs(): (network_structure_6(), [mx.sym.sin.__name__]), (network_structure_6(), [mx.sym.Convolution.__name__]), (network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]), - (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus']) + (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus']), + (network_structure_8(), ['sin', 'elemwise_add']) ] @pytest.mark.parametrize('subgraph_backend', ['default', 'default_v2']) @@ -158,7 +171,6 @@ def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): exe.forward() return exe sym, _, _ = sym - original_exec = get_executor(sym) with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), @@ -407,7 +419,7 @@ def test_subgraph_backend_gluon(sym, subgraph_backend, op_names, tmpdir): # Test Gluon HybridBlocks for graph partitioning a network created by HybridSequential. @pytest.mark.serial def test_subgraph_backend_gluon_ext1(tmpdir): - def get_net(): + def get_net(): net = nn.HybridSequential() # Here we use the class HybridSequential. net.add(nn.Dense(256, activation='relu'), nn.Dense(128, activation='relu'), @@ -476,3 +488,23 @@ def hybrid_forward(self, F, x): for i in range(len(outputs1)): assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) + +if __name__ == "__main__": + import datetime + tmpdir = datetime.datetime.now().strftime('mylogfile_%H_%M_%S_%f_%d_%m_%Y.log') + os.mkdir(tmpdir) + subgraph_backends = ['default', 'default_v2'] + graphs = get_graphs() + for subgraph_backend in subgraph_backends: + for sym,op_names in graphs: + test_subgraph_exe1(sym, subgraph_backend, op_names) + test_subgraph_exe2(sym, subgraph_backend, op_names) + test_subgraph_exe3(sym, subgraph_backend, op_names) + test_subgraph_exe4(sym, subgraph_backend, op_names) + test_subgraph_exe5(sym, subgraph_backend, op_names) + test_subgraph_exe6(sym, subgraph_backend, op_names) + test_subgraph_exe7(sym, subgraph_backend, op_names) + test_subgraph_exe8(sym, subgraph_backend, op_names) + test_subgraph_backend_gluon(sym, subgraph_backend, op_names, tmpdir) + test_subgraph_backend_gluon_ext1(tmpdir) + test_subgraph_backend_gluon_ext2(tmpdir)