diff --git a/nnvm/python/nnvm/testing/utils.py b/nnvm/python/nnvm/testing/utils.py index d6c03fc1b745..9b228d595d6a 100644 --- a/nnvm/python/nnvm/testing/utils.py +++ b/nnvm/python/nnvm/testing/utils.py @@ -39,7 +39,8 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), params : dict of str to NDArray The parameters. """ - image_shape = (3, 224, 224) + if image_shape is None: + image_shape = (3, 224, 224) data_shape = (batch_size,) + image_shape params = {} g = graph.create(net) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 9fe945d1eb6c..6efde0c573cd 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -108,7 +108,7 @@ def compute_conv2d(attrs, inputs, _): assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" if groups == 1: - out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding) + out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding) elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding) else: @@ -128,6 +128,12 @@ def schedule_conv2d(attrs, outs, target): return topi.cuda.schedule_conv2d_nchw(outs) return topi.cuda.schedule_depthwise_conv2d_nchw(outs) # naive schedule + + if tvm.target.current_target() == tvm.target.rasp(): + if groups == 1: + return topi.rasp.schedule_conv2d(outs) + return topi.rasp.schedule_depthwise_conv2d(outs) + return tvm.create_schedule([x.op for x in outs]) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)