From 0af71310bbadf73e52a1c0130b7fabb9c3ebe2a0 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 5 Jul 2018 09:12:46 -0700 Subject: [PATCH] [TOPI] Fix the CPU op perf (#56) --- vta/Makefile | 1 - vta/docs/.gitignore | 1 + .../resnet18/pynq/imagenet_predict.py | 190 ------------ vta/python/vta/top/arm_conv2d.py | 290 +----------------- vta/python/vta/top/vta_conv2d.py | 2 +- .../integration/test_benchmark_topi_conv2d.py | 112 ++++++- vta/tutorials/resnet.py | 6 +- 7 files changed, 121 insertions(+), 481 deletions(-) delete mode 100644 vta/examples/resnet18/pynq/imagenet_predict.py diff --git a/vta/Makefile b/vta/Makefile index 60e93c0bd347..c43c49bddaca 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -63,7 +63,6 @@ doc: clean: $(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o - $(RM) -rf cat.jpg quantize_graph.json quantize_params.pkl synset.txt -include build/*.d diff --git a/vta/docs/.gitignore b/vta/docs/.gitignore index 845005a76aa8..eee93ee4003a 100644 --- a/vta/docs/.gitignore +++ b/vta/docs/.gitignore @@ -2,3 +2,4 @@ doxygen modules tutorials _build +gen_modules \ No newline at end of file diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py deleted file mode 100644 index db31616003cc..000000000000 --- a/vta/examples/resnet18/pynq/imagenet_predict.py +++ /dev/null @@ -1,190 +0,0 @@ -# some standard imports -import nnvm -import tvm -import vta -import vta.testing -import os -import numpy as np -import pickle -import json -import logging - -from PIL import Image -from nnvm.compiler import graph_attr -from tvm.contrib import graph_runtime, rpc, util -from tvm.contrib.download import download - -bfactor = 1 -cfactor = 16 -verbose = False -# only run fpga component, mark non-conv ops as nop -debug_fpga_only = False - -# Obtain model files (they're too large to check-in) -# Download them into _data dir -data_dir = "_data/" -url = "https://homes.cs.washington.edu/~moreau/media/vta/" -TEST_FILE = 'cat.jpg' -CATEG_FILE = 'synset.txt' -RESNET_GRAPH_FILE = 'resnet18_qt8.json' -RESNET_PARAMS_FILE = 'resnet18_qt8.params' -# Create data dir -if not os.path.exists(data_dir): - os.makedirs(data_dir) -# Download files -for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]: - if not os.path.isfile(file): - download(os.path.join(url, file), os.path.join(data_dir, file)) - -if verbose: - logging.basicConfig(level=logging.DEBUG) - -# Change to -device=vtacpu to run cpu only inference. -target = tvm.target.create("llvm -device=vta") -target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" - -if vta.get_env().TARGET == "sim": - target_host = "llvm" - -synset = eval(open(os.path.join(data_dir, CATEG_FILE)).read()) -image = Image.open(os.path.join(data_dir, TEST_FILE)).resize((224, 224)) - -def transform_image(image): - image = np.array(image) - np.array([123., 117., 104.]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :] - return image - -def mark_nop(graph, conv_layer=-1, skip_conv_layer=()): - """Helper function to mark certain op as nop - - Useful to debug performance issues. - """ - jgraph = json.loads(graph.json()) - counter = 0 - for nid, node in enumerate(jgraph["nodes"]): - op_name = node["op"] - if op_name != "tvm_op": - continue - attrs = node["attrs"] - node_name = node["name"] - func_name = attrs["func_name"] - if func_name.find("conv2d") != -1: - if conv_layer >= 0: - if counter != conv_layer: - attrs["func_name"] = "__nop" - if counter in skip_conv_layer: - attrs["func_name"] = "__nop" - counter += 1 - else: - if conv_layer >= 0: - attrs["func_name"] = "__nop" - attrs["func_name"] = "__nop" - if attrs["func_name"] != "__nop": - print("Run function %s"% func_name) - graph = nnvm.graph.load_json(json.dumps(jgraph)) - return graph - -x = transform_image(image) -print('x', x.shape) - -###################################################################### -# now compile the graph -import nnvm.compiler -np.random.seed(0) -sym = nnvm.graph.load_json( - open(os.path.join(data_dir, RESNET_GRAPH_FILE)).read()) -params = nnvm.compiler.load_param_dict( - open(os.path.join(data_dir, RESNET_PARAMS_FILE), 'rb').read()) - -shape_dict = {"data": x.shape} -dtype_dict = {"data": 'float32'} -shape_dict.update({k: v.shape for k, v in params.items()}) -dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - -graph = nnvm.graph.create(sym) -graph_attr.set_shape_inputs(sym, shape_dict) -graph_attr.set_dtype_inputs(sym, dtype_dict) -graph = graph.apply("InferShape").apply("InferType") - -dtype = "float32" -sym = vta.graph.clean_cast(sym) -sym = vta.graph.clean_conv_fuse(sym) - -if target.device_name == "vta": - sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor) - -with nnvm.compiler.build_config(opt_level=3): - if target.device_name != "vta": - graph, lib, params = nnvm.compiler.build( - sym, target_host, shape_dict, dtype_dict, - params=params) - else: - with vta.build_config(): - graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, - params=params, target_host=target_host) - - -assert tvm.module.enabled("rpc") -temp = util.tempdir() -lib.save(temp.relpath("graphlib.o")) - -if vta.get_env().TARGET == "sim": - remote = rpc.LocalSession() - print("local session") -else: - host = os.environ.get("VTA_PYNQ_RPC_HOST", None) - assert host - port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") - port = int(port) - remote = rpc.connect(host, port) - # Program FPGA, and build runtime if necessary - # Overwrite bitstream with a path to your own if you built it yourself - vta.reconfig_runtime(remote) - vta.program_fpga(remote, bitstream=None) - -remote.upload(temp.relpath("graphlib.o")) -lib = remote.load_module("graphlib.o") -ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) - -print("Build complete...") - -def run_e2e(graph): - """Running end to end example - """ - if debug_fpga_only: - graph = mark_nop(graph, skip_conv_layer=(0,)) - m = graph_runtime.create(graph, lib, ctx) - # set inputs - m.set_input('data', tvm.nd.array(x.astype("float32"))) - m.set_input(**params) - # execute - timer = m.module.time_evaluator("run", ctx, number=10) - tcost = timer() - # get outputs - tvm_output = m.get_output( - 0,tvm.nd.empty((1000,), dtype, remote.cpu(0))) - - top = list(reversed(np.argsort(tvm_output.asnumpy()))) - for i in range(5): - print('TVM prediction top-%d: %s' % (i, synset[top[i]])) - print("t-cost=%g" % tcost.mean) - - -def run_layer(old_graph, layer_begin, layer_end): - """Run a certain layer.""" - for layer_id in range(layer_begin, layer_end): - print("run resnet[%d]..."% (layer_id)) - graph = mark_nop(old_graph, layer_id) - m = graph_runtime.create(graph, lib, ctx) - # set inputs - m.set_input('data', tvm.nd.array(x.astype("float32"))) - m.set_input(**params) - # execute - timer = m.module.time_evaluator("run", ctx, number=1) - tcost = timer() - print("resnet[%d]: %g\n"% (layer_id, tcost.mean)) - -run_e2e(graph) diff --git a/vta/python/vta/top/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py index c959f1ee9967..d8f749b3b908 100644 --- a/vta/python/vta/top/arm_conv2d.py +++ b/vta/python/vta/top/arm_conv2d.py @@ -5,13 +5,9 @@ """ from __future__ import absolute_import as _abs -import tvm -from topi import tag from topi.nn.conv2d import conv2d, _get_schedule from topi.nn.conv2d import SpatialPack, Im2ColPack, Workload -from topi.nn.conv2d import _SCH_TO_DECL_FUNC -from topi.nn.conv2d import _get_workload -from topi.nn.util import infer_pad, infer_stride +from topi.rasp import conv2d as _rasp_conv2d from topi import generic _WORKLOADS = [ @@ -52,284 +48,8 @@ def _schedule_conv2d(wkl): sch = _SCHEDULES[idx] return sch +conv2d.register(["vtacpu", "vta"], _rasp_conv2d._declaration_conv2d) -@conv2d.register(["vtacpu", "vta"]) -def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): - assert layout == 'NCHW', "only support NCHW convolution on vtacpu" - assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu" - wkl = _get_workload(data, kernel, stride, padding, out_dtype) - sch = _get_schedule(wkl) - return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype) - - -def _schedule_spatial_conv2d(s, data, data_pad, data_vec, - kernel, kernel_vec, - conv_out, output, last): - # no stride and padding info here - padding = infer_pad(data, data_pad) - if data_pad is None: - stride = infer_stride(data, kernel, output) - else: - stride = infer_stride(data_pad, kernel, output) - wkl = _get_workload(data, kernel, stride, padding, output.dtype) - sch = _get_schedule(wkl) - - H, W = wkl.height, wkl.width - CI, CO = wkl.in_filter, wkl.out_filter - HK, WK = wkl.hkernel, wkl.wkernel - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - HCAT, WCAT = HK-1, WK-1 - DOPAD = (HPAD != 0 and WPAD != 0) - - VH = sch.vh - VW = sch.vw - VC = sch.vc - UNROLL = sch.unroll - - A, B, C = data, kernel, last - A0, A1 = data_pad, data_vec - B0 = kernel_vec - C0, C1 = conv_out, output - - CC = s.cache_write(C0, "global") - - _, co, oh, ow, vh, vw, vc = s[C0].op.axis - if UNROLL: - s[C0].unroll(vw) - s[C0].vectorize(vc) - - s[CC].compute_at(s[C0], ow) - _, co, oh, ow, vh, vw, vc = s[CC].op.axis - ci, dh, dw = s[CC].op.reduce_axis - s[CC].reorder(ci, dh, vh, dw, vw, vc) - - if UNROLL: - s[CC].unroll(vw) - s[CC].vectorize(vc) - - ##### Schedule A - if DOPAD: - s[A0].compute_inline() - - _, h, _, _, _, _ = s[A1].op.axis - if sch.ba == 1: - oaxis = h - paxis = h - else: - oh, ih = s[A1].split(h, sch.ba) - oaxis = oh - paxis = ih - - s[A1].parallel(paxis) - s[A1].pragma(oaxis, "parallel_launch_point") - s[A1].pragma(paxis, "parallel_stride_pattern") - s[A1].pragma(oaxis, "parallel_barrier_when_finish") - - - ##### Schedule B - co, _, _, _, _ = s[B0].op.axis - if sch.bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[B0].split(co, sch.bc) - oaxis = oco - paxis = ico - - s[B0].parallel(paxis) - s[B0].pragma(oaxis, "parallel_launch_point") - s[B0].pragma(paxis, "parallel_stride_pattern") - s[B0].pragma(oaxis, "parallel_barrier_when_finish") - - - ##### Schedule C - n, co, h, w = s[C].op.axis - co, vc = s[C].split(co, VC) - oh, ow, vh, vw = s[C].tile(h, w, VH, VW) - s[C].reorder(n, co, oh, ow, vh, vw, vc) - if C != C1: - s[C1].compute_inline() - s[C0].compute_at(s[C], ow) - - if sch.bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[C].split(co, sch.bc) - oaxis = oco - paxis = ico - - s[C].parallel(paxis) - s[C].pragma(oaxis, "parallel_launch_point") - s[C].pragma(paxis, "parallel_stride_pattern") - s[C].pragma(oaxis, "parallel_barrier_when_finish") - - return s - -def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, - kernel, kernel_vec, - conv_out, output, last): - # no stride and padding info here - padding = infer_pad(data, data_pad) - if data_pad is None: - stride = infer_stride(data, kernel, output) - else: - stride = infer_stride(data_pad, kernel, output) - wkl = _get_workload(data, kernel, stride, padding, output.dtype) - - sch = _get_schedule(wkl) - - H, W = wkl.height, wkl.width - CI = wkl.in_filter - CO = wkl.out_filter - HK, WK = wkl.hkernel, wkl.wkernel - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - HCAT, WCAT = HK-1, WK-1 - DOPAD = (HPAD != 0 and WPAD != 0) - - P = sch.vp - Q = sch.vq - UNROLL = sch.unroll - - A, B, C = data, kernel, last - A0, A1, A2 = data_pad, data_col, data_vec - B0 = kernel_vec - C0, C1 = conv_out, output - - CC = s.cache_write(C0, "global") - AA = s.cache_read(A2, "global", [CC]) - BB = s.cache_read(B0, "global", [CC]) - - - ##### Schedule CC - _, co, im, vim, vco = s[C0].op.axis - s[C0].unroll(vim) - s[C0].vectorize(vco) - - s[CC].compute_at(s[C0], im) - _, co, im, vim, vco = s[CC].op.axis - ci, hk, wk = s[CC].op.reduce_axis - s[CC].reorder(ci, hk, wk, vim, vco) - s[CC].unroll(vim) - s[CC].vectorize(vco) - # s[CC].unroll(ccr) - - ### Schedule C - _, co, h, w = s[C].op.axis - im = s[C].fuse(h, w) - im, vim = s[C].split(im, P) - co, vco = s[C].split(co, Q) - s[C].reorder(co, im, vim, vco) - - if sch.bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[C].split(co, sch.bc) - oaxis = oco - paxis = ico - - s[C].parallel(paxis) - s[C].pragma(oaxis, "parallel_launch_point") - s[C].pragma(paxis, "parallel_stride_pattern") - s[C].pragma(oaxis, "parallel_barrier_when_finish") - if C1 != C: - s[C1].compute_inline() - - s[C0].compute_at(s[C], paxis) - - ##### Schedule A - if DOPAD: - s[A0].compute_inline() - s[A1].compute_inline() - s[AA].compute_at(s[CC], wk) - s[AA].unroll(AA.op.axis[4]) - - _, im, _, _, _, _ = s[A2].op.axis - if sch.ba == 1: - oaxis = im - paxis = im - else: - oim, iim = s[A2].split(im, sch.ba) - oaxis = oim - paxis = iim - - s[A2].parallel(paxis) - s[A2].pragma(oaxis, "parallel_launch_point") - s[A2].pragma(paxis, "parallel_stride_pattern") - s[A2].pragma(oaxis, "parallel_barrier_when_finish") - - - ##### Schedule B - s[BB].compute_at(s[CC], wk) - s[BB].vectorize(BB.op.axis[4]) - - co, _, _, _, _ = s[B0].op.axis - if sch.bc == 1: - oaxis = co - paxis = co - else: - oco, ico = s[B0].split(co, sch.bc) - oaxis = oco - paxis = ico - - s[B0].parallel(paxis) - s[B0].pragma(oaxis, "parallel_launch_point") - s[B0].pragma(paxis, "parallel_stride_pattern") - s[B0].pragma(oaxis, "parallel_barrier_when_finish") - - return s - -@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"]) -def schedule_conv2d(outs): - """Create schedule for tensors""" - s = tvm.create_schedule([x.op for x in outs]) - - def traverse(op): - """Traverse operators from computation graph""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - for tensor in op.input_tensors: - if tensor.op.input_tensors: - traverse(tensor.op) - - if 'spatial_conv_output' in op.tag: - output = op.output(0) - conv_out = op.input_tensors[0] - kernel_vec = conv_out.op.input_tensors[1] - kernel = kernel_vec.op.input_tensors[0] - data_vec = conv_out.op.input_tensors[0] - data = data_vec.op.input_tensors[0] - data_pad = None - if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: - data_pad = data - data = data_pad.op.input_tensors[0] - - _schedule_spatial_conv2d(s, data, data_pad, data_vec, - kernel, kernel_vec, - conv_out, output, outs[0]) - - if 'im2col_conv_output' in op.tag: - output = op.output(0) - conv_out = op.input_tensors[0] - kernel_vec = conv_out.op.input_tensors[1] - kernel = kernel_vec.op.input_tensors[0] - data_vec = conv_out.op.input_tensors[0] - data_col = data_vec.op.input_tensors[0] - data = data_col.op.input_tensors[0] - data_pad = None - if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: - data_pad = data - data = data_pad.op.input_tensors[0] - _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, - kernel, kernel_vec, - conv_out, output, outs[0]) - - traverse(outs[0].op) - return s +generic.schedule_conv2d_nchw.register( + ["vtacpu", "vta"], + _rasp_conv2d.schedule_conv2d_nchw) diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 258d2b20bb85..ab3b4290bb2a 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -192,7 +192,7 @@ def _build(funcs, target, target_host): tvm_t = tvm.target.create(target) if tvm_t.device_name == "vta": return tvm.build(funcs, target="ext_dev", target_host=target_host) - elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vta-cpu": + elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu": return tvm.build(funcs, target=target_host) return tvm.build(funcs, target=target) diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index 24d9968dc31a..1ee5cdb01167 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -20,6 +20,114 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x +def test_cpu_conv2d(): + def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True): + data_shape = (batch_size, wl.in_filter, wl.height, wl.width) + kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + res_conv = topi.nn.conv2d( + data, kernel, padding=(wl.hpad, wl.wpad), + strides=(wl.hstride, wl.wstride), + out_dtype="int32") + res = topi.right_shift(res_conv, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + # To compute number of ops, use a x2 factor for FMA + num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + a_shape = (batch_size, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + stride = (wl.hstride, wl.wstride) + data_dtype = data.dtype + kernel_dtype = kernel.dtype + acc_dtype = env.acc_dtype + assert wl.hpad == wl.wpad + padding = wl.hpad + + @memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = topi.testing.conv2d_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) + return a_np, w_np, b_np + + + def verify(s, check_correctness): + mod = tvm.build(s, [data, kernel, res], + "llvm -device=vtacpu", + env.target_host, + name="conv2d") + temp = util.tempdir() + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + # verify + ctx = remote.cpu(0) + # Data in original format + data_orig, kernel_orig, res_ref = get_ref_data() + res_shape = topi.util.get_const_tuple(res.shape) + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_orig, ctx) + kernel_arr = tvm.nd.array(kernel_orig, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=5) + cost = time_f(data_arr, kernel_arr, res_arr) + res_unpack = res_arr.asnumpy() + if check_correctness: + assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) + padding = wl.hpad + res_ref = res_ref >> 8 + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def conv_normal(print_ir): + print("----- CONV2D CPU End-to-End Test-------") + s = topi.generic.schedule_conv2d_nchw([res]) + if print_ir: + print(tvm.lower(s, [data, kernel, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + + conv_normal(False) + + def _run(env, remote): + # ResNet18 workloads + resnet = { + # Workloads of resnet18 on imagenet + 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + } + batch_size = 1 + for i in range(1, len(resnet)): + wl = resnet[i] + key = "resnet-cfg[%d]" % i + print("key=%s" % key) + print(wl) + with tvm.target.create("llvm -device=vtacpu"): + run_cpu_conv2d(env, remote, key, batch_size, wl) + vta.testing.run(_run) + def test_vta_conv2d(): def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): @@ -54,7 +162,7 @@ def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): assert wl.hpad == wl.wpad padding = wl.hpad - @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc") + @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") def get_ref_data(): a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) @@ -153,4 +261,6 @@ def _run(env, remote): if __name__ == "__main__": + test_cpu_conv2d() + exit(0) test_vta_conv2d() diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py index 72ed5d7d2184..ebe580ce184a 100644 --- a/vta/tutorials/resnet.py +++ b/vta/tutorials/resnet.py @@ -116,8 +116,8 @@ def generate_graph(graph_fn, params_fn, device="vta"): with nnvm.compiler.build_config(opt_level=3): if target.device_name != "vta": graph, lib, params = nnvm.compiler.build( - sym, target_host, shape_dict, dtype_dict, - params=params) + sym, target, shape_dict, dtype_dict, + params=params, target_host=target_host) else: with vta.build_config(): graph, lib, params = nnvm.compiler.build( @@ -323,4 +323,4 @@ def thumbnailify(image, pad=15): # When everything done, release the capture cap.release() - cv2.destroyAllWindows() \ No newline at end of file + cv2.destroyAllWindows()