From f8db4af6c434e72a826288be6f0fe148c04b9f8b Mon Sep 17 00:00:00 2001 From: ziheng Date: Fri, 29 Sep 2017 18:20:24 -0700 Subject: [PATCH] [OP] Conv2d and Depthwise Conv2d for Raspberry Pi (#49) * [TUTORIAL] ImageNet Inference on Raspberry Pi * Update tvm --- nnvm/python/nnvm/testing/utils.py | 3 ++- nnvm/python/nnvm/top/nn.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/nnvm/python/nnvm/testing/utils.py b/nnvm/python/nnvm/testing/utils.py index d6c03fc1b745..9b228d595d6a 100644 --- a/nnvm/python/nnvm/testing/utils.py +++ b/nnvm/python/nnvm/testing/utils.py @@ -39,7 +39,8 @@ def create_workload(net, batch_size, image_shape=(3, 224, 224), params : dict of str to NDArray The parameters. """ - image_shape = (3, 224, 224) + if image_shape is None: + image_shape = (3, 224, 224) data_shape = (batch_size,) + image_shape params = {} g = graph.create(net) diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 9fe945d1eb6c..6efde0c573cd 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -108,7 +108,7 @@ def compute_conv2d(attrs, inputs, _): assert layout == "NCHW", "only support nchw for now" assert dilation == (1, 1), "not support dilate now" if groups == 1: - out = topi.nn.conv2d_nchw(inputs[0], inputs[1], strides, padding) + out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding) elif groups == get_const_int(inputs[0].shape[1]) and groups == channels: out = topi.nn.depthwise_conv2d_nchw(inputs[0], inputs[1], strides, padding) else: @@ -128,6 +128,12 @@ def schedule_conv2d(attrs, outs, target): return topi.cuda.schedule_conv2d_nchw(outs) return topi.cuda.schedule_depthwise_conv2d_nchw(outs) # naive schedule + + if tvm.target.current_target() == tvm.target.rasp(): + if groups == 1: + return topi.rasp.schedule_conv2d(outs) + return topi.rasp.schedule_depthwise_conv2d(outs) + return tvm.create_schedule([x.op for x in outs]) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)