From 1de91d07539e2c62d00a405ceccd8cd620872316 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 25 Sep 2018 12:10:11 +0900 Subject: [PATCH] [NNVM] Bugfix operator fusion for residual block with layout transform (#1760) * Bugfix operator fusion for residual block with layout transform * add a test case * update error message --- nnvm/src/compiler/graph_compile.cc | 3 +- nnvm/src/compiler/graph_fuse.cc | 17 +++++++-- nnvm/tests/python/compiler/test_op_fusion.py | 39 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/nnvm/src/compiler/graph_compile.cc b/nnvm/src/compiler/graph_compile.cc index e51730c09d66..3316f3932e27 100644 --- a/nnvm/src/compiler/graph_compile.cc +++ b/nnvm/src/compiler/graph_compile.cc @@ -109,13 +109,14 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) { inputs.push_back(it->second); } // Find master idx in the subgraph. - int sub_master_idx = 0; + int sub_master_idx = -1; for (uint32_t i = 0; i < subidx.num_nodes(); i++) { if (subidx[i].source->op() == idx[master].source->op()) { sub_master_idx = i; break; } } + CHECK_NE(sub_master_idx, -1) << "A master node not found in the subgraph."; fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx); for (LoweredFunc f : fe.compiled_func->funcs) { if (!func_set.count(f.get())) { diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index c9ea58affb2c..4d724ae66c35 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -136,11 +136,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) { // Point to the group root id of each node. GroupVec group_vec(idx.num_nodes(), -1); + std::vector > node_ids_per_group(idx.num_nodes()); for (uint32_t i = idx.num_nodes(); i != 0; --i) { uint32_t nid = i - 1; const auto& inode = idx[nid]; + bool is_root = false; if (group_vec[nid] == -1) { group_vec[nid] = nid; + node_ids_per_group[nid].push_back(nid); + is_root = true; } // Check if injective op and out_ewise_fusable op (e.g. conv2d) are in the same group. @@ -156,7 +160,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) { } } // Change the master node from out_ewise_fusable op to itself - if (parent_injective && parent_out_ewise) master_vec[nid] = nid; + if (parent_injective && parent_out_ewise) { + master_vec[nid] = nid; + if (!is_root) { + // Children nodes in the same group might be pointing to a master node in a different group. + for (uint32_t j : node_ids_per_group[group_vec[nid]]) { + master_vec[j] = nid; + } + } + } // Propagate the group id. for (const auto& e : inode.inputs) { @@ -172,6 +184,7 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) { CHECK(group_vec[e.node_id] == -1|| group_vec[e.node_id] == group_vec[nid]); group_vec[e.node_id] = group_vec[nid]; + node_ids_per_group[group_vec[nid]].push_back(e.node_id); } } } @@ -223,12 +236,10 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) { */ if (opt_level >= 1) { std::vector > children_group_ids(idx.num_nodes()); - std::vector > node_ids_per_group(idx.num_nodes()); for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; CHECK_NE(group_vec[nid], -1); - node_ids_per_group[group_vec[nid]].push_back(nid); if (inode.inputs.size() != 1) continue; const uint32_t parent_nid = inode.inputs[0].node_id; // if parent node has more than one child, record each child's group id. diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 0c81ac890d55..288f112f1063 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -143,6 +143,44 @@ def test_concatenate_conv2d(): np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5) +def test_residual_block_layout_transform(): + ch = 16 + size = 32 + data = sym.Variable(name="data") + conv1 = sym.conv2d(data=data, kernel_size=(3,3), channels=ch, padding = (1, 1), use_bias=False, name="conv1") + layout_transform1 = sym.__layout_transform__(data=conv1, src_layout="NCHW", dst_layout="NCHW8c") + layout_transform2 = sym.__layout_transform__(data=layout_transform1, src_layout="NCHW8c", dst_layout="NCHW") + conv2 = sym.conv2d(data=conv1, kernel_size=(3,3), channels=ch, padding = (1, 1), use_bias=False, name="conv2") + elemwise_sum = sym.elemwise_add(layout_transform2, conv2) + out = sym.relu(elemwise_sum) + + dtype="float32" + dshape = (1, ch, size, size) + kshape = (ch, ch, 3, 3) + oshape = (1, ch, size, size) + shape_dict = {"data": dshape} + + target = "llvm" # only test on llvm since it involves NCHW8c layout + ctx = tvm.context(target, 0) + graph, lib, _ = nnvm.compiler.build(out, target, shape_dict) + # data, conv1 weight, conv1, layout transform + elemwise add + relu, conv2 weight, conv2 op + assert graph.index.num_nodes == 6 + + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel1 = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + kernel2 = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + m = graph_runtime.create(graph, lib, ctx) + m.run(data=data, conv1_weight=kernel1, conv2_weight=kernel2) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + + conv1 = topi.testing.conv2d_nchw_python( + data.asnumpy(), kernel1.asnumpy(), (1,1), 'SAME') + conv2 = topi.testing.conv2d_nchw_python( + conv1, kernel2.asnumpy(), (1,1), 'SAME') + ref = np.maximum(conv1 + conv2, 0) + np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5) + + def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2): with nnvm.compiler.build_config(opt_level=opt_level): graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params) @@ -191,3 +229,4 @@ def get_sym(out_channel): test_fuse_conv2d_elu() test_injective_conv2d() test_concatenate_conv2d() + test_residual_block_layout_transform()