diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index f65312be1a29..4999d93d1861 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -146,6 +146,7 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) { bool parent_out_ewise = false; bool parent_injective = false; for (const auto& e : inode.inputs) { + if (fuse_vec[e.node_id] != FuseRule::kFuseToMaster) continue; TOpPattern pt = pattern_vec[e.node_id]; if (pt == kOutEWiseFusable) { parent_out_ewise = true; diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py index 5f4da3865a45..0c81ac890d55 100644 --- a/nnvm/tests/python/compiler/test_op_fusion.py +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -110,6 +110,39 @@ def test_injective_conv2d(): np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) +def test_concatenate_conv2d(): + ch = 3 + size = 8 + data = sym.Variable(name="data") + concat = sym.concatenate(data, data, axis=1) + conv = sym.conv2d(data=concat, kernel_size=(1,1), channels=ch*2, use_bias=False, name="conv") + net = sym.elemwise_add(concat, conv) + + dtype="float32" + dshape = (1, ch, size, size) + kshape = (ch*2, ch*2, 1, 1) + oshape = (1, ch*2, size, size) + shape_dict = {"data": dshape} + + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(net, target, shape_dict) + # data, conv weight, conv op, concat + assert graph.index.num_nodes == 4 + + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + m = graph_runtime.create(graph, lib, ctx) + m.run(data=data, conv_weight=kernel) + # get output + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + + concat = np.concatenate((data.asnumpy(), data.asnumpy()), axis=1) + conv = topi.testing.conv2d_nchw_python( + concat, kernel.asnumpy(), (1,1), 'SAME') + ref = concat + conv + np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5) + + def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2): with nnvm.compiler.build_config(opt_level=opt_level): graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params) @@ -157,3 +190,4 @@ def get_sym(out_channel): test_conv_ewise_injective() test_fuse_conv2d_elu() test_injective_conv2d() + test_concatenate_conv2d()