Skip to content

Commit

Permalink
[NNVM] Bug fix Prevent fusing convolution with injective op (#1608)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and tqchen committed Aug 17, 2018
1 parent acc2151 commit 6cd5a8f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
31 changes: 30 additions & 1 deletion nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
// Check if we can fuse to the master.
int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1;
bool mark_as_injective = false;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) {
TOpPattern ipt = pattern_vec[e.node_id];
if (ipt != kElemWise) ewise = false;
if (ipt <= kInjective) {
if (ipt <= kBroadcast) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else if (ipt == kInjective) {
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
mark_as_injective = true;
} else if (ipt == kOutEWiseFusable &&
chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
Expand All @@ -87,6 +91,8 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
master_vec[nid] = chosen_master;
if (chosen_master != -1) {
pt = kOutEWiseFusable;
} else if (mark_as_injective) {
pt = kInjective;
} else {
pt = ewise ? kElemWise : kBroadcast;
}
Expand Down Expand Up @@ -135,8 +141,31 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
if (group_vec[nid] == -1) {
group_vec[nid] = nid;
}

// Check if injective op and out_ewise_fusable op (e.g. conv2d) are in the same group.
bool parent_out_ewise = false;
bool parent_injective = false;
for (const auto& e : inode.inputs) {
TOpPattern pt = pattern_vec[e.node_id];
if (pt == kOutEWiseFusable) {
parent_out_ewise = true;
} else if (pt == kInjective) {
parent_injective = true;
}
}
// Change the master node from out_ewise_fusable op to itself
if (parent_injective && parent_out_ewise) master_vec[nid] = nid;

// Propagate the group id.
for (const auto& e : inode.inputs) {
TOpPattern pt = pattern_vec[e.node_id];
if (parent_out_ewise && parent_injective) {
if (pt == kOutEWiseFusable) {
continue; // Do not fuse out_ewise_fusable op
} else if (pt == kInjective) {
master_vec[e.node_id] = nid;
}
}
if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]);
Expand Down
34 changes: 34 additions & 0 deletions nnvm/tests/python/compiler/test_op_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,39 @@ def test_injective_reduce_injective():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def test_injective_conv2d():
channels = 16
data = sym.Variable(name="data")
pool = sym.global_avg_pool2d(data=data)
weight = sym.reshape(pool, shape=[1, channels, 1, 1])
residual = sym.conv2d(data=data, kernel_size=(3,3), channels=channels, padding=(1, 1),
layout="NCHW", kernel_layout="OIHW", use_bias=False, name="conv")
net = weight * data + residual
size = 56
dtype="float32"
dshape = (1, channels, size, size)
kshape = (channels, channels, 3, 3)
oshape = dshape
shape_dict = {"data": dshape}

for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, shape_dict)
# data, global_avg_pool, conv weight, conv op, fused elemwise add
assert graph.index.num_nodes == 5

data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
m = graph_runtime.create(graph, lib, ctx)
m.run(data=data, conv_weight=kernel)
# get output
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
residual = topi.testing.conv2d_nchw_python(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
weight = np.mean(data.asnumpy(), axis=(2, 3))
c_np = weight[:, :, np.newaxis, np.newaxis] * data.asnumpy() + residual
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
Expand Down Expand Up @@ -123,3 +156,4 @@ def get_sym(out_channel):
test_ewise_injective()
test_conv_ewise_injective()
test_fuse_conv2d_elu()
test_injective_conv2d()

0 comments on commit 6cd5a8f

Please sign in to comment.