diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index dfebc368eaf3f..1689a82428ef1 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -98,9 +98,14 @@ def compute_conv2d(attrs, inputs, _): if groups == 1: out = topi.nn.conv2d(inputs[0], kernel, strides, padding, layout) elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: + # NCHW out = topi.nn.depthwise_conv2d_nchw(inputs[0], kernel, strides, padding) + elif groups == get_const_int(inputs[0].shape[3]) and groups == channels: + # NHWC + out = topi.nn.depthwise_conv2d_nhwc(inputs[0], kernel, strides, padding) else: raise ValueError("not support arbitrary group number for now") + if attrs.get_bool("use_bias"): bias = inputs[2] expand_axis = 1 if layout == "NCHW" else 0 @@ -112,13 +117,19 @@ def compute_conv2d(attrs, inputs, _): def schedule_conv2d(attrs, outs, target): """Schedule definition of conv2d""" groups = attrs.get_int("groups") + channels = attrs.get_int("channels") layout = attrs["layout"] with tvm.target.create(target): if groups == 1 and layout == "NCHW": return topi.generic.schedule_conv2d_nchw(outs) elif groups == 1 and layout == "NHWC": return topi.generic.schedule_conv2d_nhwc(outs) - return topi.generic.schedule_depthwise_conv2d_nchw(outs) + elif groups == channels and layout == "NCHW": + return topi.generic.schedule_depthwise_conv2d_nchw(outs) + elif groups == channels and layout == "NHWC": + return topi.generic.schedule_depthwise_conv2d_nhwc(outs) + else: + raise ValueError("No compatible schedule") @reg.register_alter_op_layout("conv2d") def alter_conv2d_layout(attrs, inputs, tinfos): diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index d6f855fb98780..e6576baf893b1 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -80,9 +80,9 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, wshape = ConvertLayout(wshape, kOIHW, kernel_layout); - // Depthwise - // NCHW : Expects weights in CNHW - Conversion is handled in frontend. - // NHWC : Original format (HWCN) + // Depthwise Expects weights in + // NCHW : CNHW(IOHW) - Conversion is handled in frontend. + // NHWC : HWCN(HWIO) - Original format if (param.layout == "NHWC") { wshape[kernel_layout.indexof('I')] *= param.groups; } else { diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index f752b7d442718..76cf96b3eb836 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -58,7 +58,7 @@ def test_dilated_conv2d(): np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) -def test_grouped_conv2d(): +def test_grouped_conv2d_nchw(): x = sym.Variable("x") y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32, name="y", padding=(1,1)) @@ -80,6 +80,28 @@ def test_grouped_conv2d(): c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) +def test_grouped_conv2d_nhwc(): + x = sym.Variable("x") + y = sym.conv2d(x, channels=32, kernel_size=(3,3), groups=32, + name="y", padding=(1,1), layout="NHWC", kernel_layout ='HWIO') + dtype = "float32" + dshape = (1, 18, 18, 32) + kshape = (3, 3, 32, 1) + oshape = (1, 18, 18, 32) + shape_dict = {"x": dshape} + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = graph_runtime.create(graph, lib, ctx) + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + bias = tvm.nd.array(np.random.uniform(size=kshape[2]).astype(dtype)) + m.run(x=data, y_weight=kernel, y_bias=bias) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + c_np = topi.testing.depthwise_conv2d_python_nhwc( + data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') + c_np = c_np + bias.asnumpy().reshape(1, 1, kshape[2]) + np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) + def test_conv2d_transpose(): x = sym.Variable("x") @@ -232,7 +254,8 @@ def test_upsampling(): if __name__ == "__main__": test_conv2d() test_dilated_conv2d() - test_grouped_conv2d() + test_grouped_conv2d_nchw() + test_grouped_conv2d_nhwc() test_conv2d_transpose() test_max_pool2d() test_avg_pool2d()