Skip to content

Commit

Permalink
[NNVM][CONVOLUTION] deptiwise convolution support for tensorflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Jun 14, 2018
1 parent 300414a commit 8dcbf3c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
13 changes: 12 additions & 1 deletion nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1:
out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
# NCHW
out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding)
elif groups == get_const_int(inputs[0].shape[3]) and groups == channels:
# NHWC
out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding)
else:
raise ValueError("not support arbitrary group number for now")

if attrs.get_bool("use_bias"):
bias = inputs[2]
expand_axis = 1 if layout == "NCHW" else 0
Expand All @@ -112,13 +117,19 @@ def compute_conv2d(attrs, inputs, _):
def schedule_conv2d(attrs, outs, target):
"""Schedule definition of conv2d"""
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs["layout"]
with tvm.target.create(target):
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif groups == channels and layout == "NCHW":
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif groups == channels and layout == "NHWC":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
else:
raise ValueError("No compatible schedule")

@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
Expand Down
6 changes: 3 additions & 3 deletions nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,

wshape = ConvertLayout(wshape, kOIHW, kernel_layout);

// Depthwise
// NCHW : Expects weights in CNHW - Conversion is handled in frontend.
// NHWC : Original format (HWCN)
// Depthwise Expects weights in
// NCHW : CNHW(IOHW) - Conversion is handled in frontend.
// NHWC : HWCN(HWIO) - Original format
if (param.layout == "NHWC") {
wshape[kernel_layout.indexof('I')] *= param.groups;
} else {
Expand Down
27 changes: 25 additions & 2 deletions nnvm/tests/python/compiler/test_top_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_dilated_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def test_grouped_conv2d():
def test_grouped_conv2d_nchw():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
name="y", padding=(1,1))
Expand All @@ -80,6 +80,28 @@ def test_grouped_conv2d():
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)

def test_grouped_conv2d_nhwc():
x = sym.Variable("x")
y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32,
name="y", padding=(1,1), layout="NHWC", kernel_layout ='HWIO')
dtype = "float32"
dshape = (1, 18, 18, 32)
kshape = (3, 3, 32, 1)
oshape = (1, 18, 18, 32)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[2]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.depthwise_conv2d_python_nhwc(
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
c_np = c_np + bias.asnumpy().reshape(1, 1, kshape[2])
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)


def test_conv2d_transpose():
x = sym.Variable("x")
Expand Down Expand Up @@ -232,7 +254,8 @@ def test_upsampling():
if __name__ == "__main__":
test_conv2d()
test_dilated_conv2d()
test_grouped_conv2d()
test_grouped_conv2d_nchw()
test_grouped_conv2d_nhwc()
test_conv2d_transpose()
test_max_pool2d()
test_avg_pool2d()
Expand Down

0 comments on commit 8dcbf3c

Please sign in to comment.