Skip to content

Commit

Permalink
[NNVM] Bugfix operator fusion for residual block with layout transform (
Browse files Browse the repository at this point in the history
#1760)

* Bugfix operator fusion for residual block with layout transform

* add a test case

* update error message
  • Loading branch information
masahi authored and tqchen committed Sep 25, 2018
1 parent 1022ad7 commit 1de91d0
Showing 3 changed files with 55 additions and 4 deletions.
3 changes: 2 additions & 1 deletion nnvm/src/compiler/graph_compile.cc
Original file line number Diff line number Diff line change
@@ -109,13 +109,14 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) {
inputs.push_back(it->second);
}
// Find master idx in the subgraph.
int sub_master_idx = 0;
int sub_master_idx = -1;
for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
if (subidx[i].source->op() == idx[master].source->op()) {
sub_master_idx = i;
break;
}
}
CHECK_NE(sub_master_idx, -1) << "A master node not found in the subgraph.";
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
for (LoweredFunc f : fe.compiled_func->funcs) {
if (!func_set.count(f.get())) {
17 changes: 14 additions & 3 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
@@ -136,11 +136,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {

// Point to the group root id of each node.
GroupVec group_vec(idx.num_nodes(), -1);
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1;
const auto& inode = idx[nid];
bool is_root = false;
if (group_vec[nid] == -1) {
group_vec[nid] = nid;
node_ids_per_group[nid].push_back(nid);
is_root = true;
}

// Check if injective op and out_ewise_fusable op (e.g. conv2d) are in the same group.
@@ -156,7 +160,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
}
}
// Change the master node from out_ewise_fusable op to itself
if (parent_injective && parent_out_ewise) master_vec[nid] = nid;
if (parent_injective && parent_out_ewise) {
master_vec[nid] = nid;
if (!is_root) {
// Children nodes in the same group might be pointing to a master node in a different group.
for (uint32_t j : node_ids_per_group[group_vec[nid]]) {
master_vec[j] = nid;
}
}
}

// Propagate the group id.
for (const auto& e : inode.inputs) {
@@ -172,6 +184,7 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]);
group_vec[e.node_id] = group_vec[nid];
node_ids_per_group[group_vec[nid]].push_back(e.node_id);
}
}
}
@@ -223,12 +236,10 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
*/
if (opt_level >= 1) {
std::vector<std::vector<uint32_t> > children_group_ids(idx.num_nodes());
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
CHECK_NE(group_vec[nid], -1);
node_ids_per_group[group_vec[nid]].push_back(nid);
if (inode.inputs.size() != 1) continue;
const uint32_t parent_nid = inode.inputs[0].node_id;
// if parent node has more than one child, record each child's group id.
39 changes: 39 additions & 0 deletions nnvm/tests/python/compiler/test_op_fusion.py
Original file line number Diff line number Diff line change
@@ -143,6 +143,44 @@ def test_concatenate_conv2d():
np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5)


def test_residual_block_layout_transform():
ch = 16
size = 32
data = sym.Variable(name="data")
conv1 = sym.conv2d(data=data, kernel_size=(3,3), channels=ch, padding = (1, 1), use_bias=False, name="conv1")
layout_transform1 = sym.__layout_transform__(data=conv1, src_layout="NCHW", dst_layout="NCHW8c")
layout_transform2 = sym.__layout_transform__(data=layout_transform1, src_layout="NCHW8c", dst_layout="NCHW")
conv2 = sym.conv2d(data=conv1, kernel_size=(3,3), channels=ch, padding = (1, 1), use_bias=False, name="conv2")
elemwise_sum = sym.elemwise_add(layout_transform2, conv2)
out = sym.relu(elemwise_sum)

dtype="float32"
dshape = (1, ch, size, size)
kshape = (ch, ch, 3, 3)
oshape = (1, ch, size, size)
shape_dict = {"data": dshape}

target = "llvm" # only test on llvm since it involves NCHW8c layout
ctx = tvm.context(target, 0)
graph, lib, _ = nnvm.compiler.build(out, target, shape_dict)
# data, conv1 weight, conv1, layout transform + elemwise add + relu, conv2 weight, conv2 op
assert graph.index.num_nodes == 6

data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel1 = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
kernel2 = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
m = graph_runtime.create(graph, lib, ctx)
m.run(data=data, conv1_weight=kernel1, conv2_weight=kernel2)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))

conv1 = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel1.asnumpy(), (1,1), 'SAME')
conv2 = topi.testing.conv2d_nchw_python(
conv1, kernel2.asnumpy(), (1,1), 'SAME')
ref = np.maximum(conv1 + conv2, 0)
np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5)


def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
@@ -191,3 +229,4 @@ def get_sym(out_channel):
test_fuse_conv2d_elu()
test_injective_conv2d()
test_concatenate_conv2d()
test_residual_block_layout_transform()

0 comments on commit 1de91d0

Please sign in to comment.