From 5c0b6085d478c354a6d23a0e4125534f3262c1d6 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Sun, 8 Dec 2019 22:08:21 -0800 Subject: [PATCH] [VTA] Bringing group convolution support (#4421) * group conv operator support for VTA * autotvm tuning script for group conv2d * lint fix * lint fix * lint fix * addressing comments --- vta/python/vta/top/__init__.py | 1 + vta/python/vta/top/vta_group_conv2d.py | 199 +++++++++++++++ vta/scripts/tune_group_conv2d.py | 155 +++++++++++ .../test_benchmark_topi_group_conv2d.py | 240 ++++++++++++++++++ 4 files changed, 595 insertions(+) create mode 100644 vta/python/vta/top/vta_group_conv2d.py create mode 100644 vta/scripts/tune_group_conv2d.py create mode 100644 vta/tests/python/integration/test_benchmark_topi_group_conv2d.py diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 09d3101bb3a7..7fdf27f8e01a 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -22,5 +22,6 @@ from . import op from . import vta_conv2d from . import vta_conv2d_transpose +from . import vta_group_conv2d from . import vta_dense from . import util diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py new file mode 100644 index 000000000000..e54637f2c204 --- /dev/null +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Group conv2D operator declaration and schedule registration for VTA.""" + +import numpy as np + +import tvm +from tvm import autotvm +import topi + +from ..environment import get_env + +@autotvm.register_topi_compute(topi.nn.group_conv2d_nchw, 'vta', 'direct') +def packed_group_conv2d(cfg, + data, + kernel, + strides, + padding, + dilation, + group, + out_dtype): + """ Packed group conv2d nchw function.""" + assert dilation == (1, 1) + + if padding[0]: + pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") + else: + pad_data = data + assert len(data.shape) == 6 + assert len(kernel.shape) == 6 + assert data.dtype == "int8", data.dtype + assert kernel.dtype == "int8", kernel.dtype + assert out_dtype == "int32", out_dtype + + oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) + owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) + oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4]) + + ishape = topi.util.get_const_tuple(data.shape) + kshape = topi.util.get_const_tuple(kernel.shape) + assert group * kshape[1] == ishape[1] + assert kshape[0] % group == 0 + d_i = tvm.reduce_axis((0, kshape[2]), name='d_i') + d_j = tvm.reduce_axis((0, kshape[3]), name='d_j') + k_o = tvm.reduce_axis((0, kshape[1]), name='k_o') + k_i = tvm.reduce_axis((0, kshape[-1]), name='k_i') + hstride, wstride = strides + out = tvm.compute( + oshape, + lambda b_o, c_o, i, j, b_i, c_i: tvm.sum( + pad_data[b_o, c_o // (kshape[0] // group) * kshape[1] + k_o, i * hstride + d_i, + j * wstride + d_j, b_i, k_i].astype(out_dtype) * + kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype), + axis=[k_o, d_i, d_j, k_i]), + name="res", tag="packed_group_conv2d") + + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + kshape[2] * kshape[3] * ishape[1] * kshape[-1]) + + return out + + +@autotvm.register_topi_schedule(topi.generic.schedule_group_conv2d_nchw, 'vta', 'direct') +def schedule_packed_group_conv2d(cfg, outs): + """ Schedule the packed conv2d. + """ + assert len(outs) == 1 + output = outs[0] + const_ops = [] + ewise_inputs = [] + ewise_ops = [] + conv2d_res = [] + assert output.dtype == "int8" + assert output.op.input_tensors[0].dtype == "int32" + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + if not op.axis: + const_ops.append(op) + else: + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_group_conv2d" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis + c_i, _, _, _ = s[conv2d_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_h', x_i, num_outputs=2) + cfg.define_split('tile_w', x_j, num_outputs=2) + cfg.define_split('tile_ci', c_i, num_outputs=2) + cfg.define_split('tile_co', c_o, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + cfg.define_knob('h_nthread', [1, 2]) + ###### space definition end ###### + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + + env = get_env() + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(env.inp_scope) + else: + cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) + ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) + s[conv2d_stage].set_scope(env.acc_scope) + + # cache read input + cache_read_ewise = [] + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], env.alu) + + for op in const_ops: + s[op].compute_inline() + + # tile + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) + x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) + x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy) + + # virtual threading along output channel axes + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if cfg['h_nthread'].val > 1: + _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) + + k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy) + s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy) + s[conv2d_stage].tensorize(x_bi, env.gemm) + s[output].pragma(x_co1, env.dma_copy) + + return s diff --git a/vta/scripts/tune_group_conv2d.py b/vta/scripts/tune_group_conv2d.py new file mode 100644 index 000000000000..6a542ddd3916 --- /dev/null +++ b/vta/scripts/tune_group_conv2d.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tuning a single group conv2d operator""" + +from collections import namedtuple +import logging +import os + +import tvm +from tvm import autotvm +from tvm.contrib.util import get_lower_ir +import topi +import vta +import vta.testing + +env = vta.get_env() + +Workload = namedtuple("GroupConv2DWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', 'groups', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +# Mobilenet (grouped variant) workloads +mobilenet_wkls = [ + ('mobilenet.D1', Workload(env.BATCH, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D2', Workload(env.BATCH, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D3', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D4', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D5', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D6', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D7', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D8', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D9', Workload(env.BATCH, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)), +] + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def group_conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, group): + + CI_G = CI // groups + data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) + kernel_shape = (CO//env.BLOCK_OUT, CI_G//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + + with tvm.target.vta(): + res = topi.nn.group_conv2d_nchw( + data, + kernel, + strides, + padding, + dilation, + groups, + env.acc_dtype) + res = topi.right_shift(res, env.WGT_WIDTH) + res = topi.add(res, bias) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_group_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + + return s, [data, kernel, bias, res] + +if __name__ == '__main__': + + # Logging config (for printing tuning log to the screen) + logging.basicConfig() + + # Tuning log files + log_file = "%s.group_conv2d.log" % (env.TARGET) + # create tmp log file + tmp_log_file = log_file + ".tmp" + if os.path.exists(log_file): + os.remove(log_file) + + # Get tracker info from env + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + for idx, (wl_name, wl) in enumerate(mobilenet_wkls): + prefix = "[Task %2d/%2d] " % (idx, len(mobilenet_wkls)) + + # Read in workload parameters + N = wl.batch + CI = wl.in_filter + H = wl.height + W = wl.width + CO = wl.out_filter + KH = wl.hkernel + KW = wl.wkernel + strides = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + dilation = (1, 1) + groups = wl.groups + + # Create task + task = autotvm.task.create( + group_conv2d, + args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, groups), + target=tvm.target.vta(), + target_host=env.target_host, + template_key='direct') + print(task.config_space) + + # Tune + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.RPCRunner( + env.TARGET, host=tracker_host, port=int(tracker_port), + number=5, timeout=60, + check_correctness=True)) + + # Run Tuner + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune( + n_trial=len(task.config_space), + early_stopping=None, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(len(task.config_space), prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # Pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_file) + os.remove(tmp_log_file) diff --git a/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py new file mode 100644 index 000000000000..975d5b9aaaf9 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_group_conv2d.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing topi group conv2d operator for VTA""" + +import json +import os + +import numpy as np +from collections import namedtuple + +import tvm +from tvm import autotvm +from tvm.contrib import util +import topi +import topi.testing +import vta +from vta import program_fpga, reconfig_runtime +import vta.testing +from vta.testing import simulator + + +Workload = namedtuple("GroupConv2DWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', 'groups', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +# Get batch info from env +env = vta.get_env() + +# Mobilenet (grouped variant) workloads +mobilenet_wkls = [ + ('mobilenet.D1', Workload(env.BATCH, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D2', Workload(env.BATCH, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D3', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D4', Workload(env.BATCH, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D5', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D6', Workload(env.BATCH, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D7', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D8', Workload(env.BATCH, 14, 14, 512, 512, 32, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D9', Workload(env.BATCH, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)), +] + +# FIXME: we need a custom clip operator to circumvent a pattern detection limitation +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def run_group_conv2d(env, remote, wl, target, + check_correctness=True, print_ir=False, + samples=4): + + # Workload assertions + assert wl.hpad == wl.wpad + + # Perform packing only if we are targeting the accelerator + if "arm_cpu" in target.keys: + data_pack = False + layout = "NCHW" + elif "vta" in target.keys: + data_pack = True + layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN) + + # Derive shapes depending upon packing + CI_G = wl.in_filter // wl.groups + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, CI_G, wl.hkernel, wl.wkernel) + b_shape = (wl.batch, wl.out_filter, 1, 1) + if data_pack: + data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, CI_G//env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + else: + data_shape = a_shape + kernel_shape = w_shape + bias_shape = b_shape + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + # Define base computation schedule + with target: + res = topi.nn.group_conv2d_nchw( + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1), + wl.groups, env.acc_dtype) + res = topi.right_shift(res, 8) + res = topi.add(res, bias) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + # Derive base schedule + s = topi.generic.schedule_group_conv2d_nchw([res]) + if print_ir: + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) + + # Derive number of ops + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * \ + wl.out_filter * wl.in_filter // wl.groups + + def get_ref_data(): + # derive min max for act, wgt, and bias types (max non inclusive) + a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) + w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) + b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2) + a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype) + w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype) + b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype) + r_np = topi.testing.conv2d_nchw_python( + a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), + (wl.hstride, wl.wstride), wl.hpad, wl.groups).astype(env.acc_dtype) + return a_np, w_np, b_np, r_np + + # Data in original format + data_np, kernel_np, bias_np, res_ref = get_ref_data() + if data_pack: + data_np = data_np.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_np = kernel_np.reshape( + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + CI_G//env.BLOCK_IN, env.BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_np = bias_np.reshape( + wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + + # Build + if "vta" in target.keys: + mod = vta.build(s, [data, kernel, bias, res], + target=target, + target_host=env.target_host, + name="conv2d") + else: + mod = tvm.build(s, [data, kernel, bias, res], + target=target, + target_host=env.target_host, + name="conv2d") + temp = util.tempdir() + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + ctx = remote.context(str(target)) + + res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype) + data_arr = tvm.nd.array(data_np, ctx) + kernel_arr = tvm.nd.array(kernel_np, ctx) + bias_arr = tvm.nd.array(bias_np, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=samples) + + # In vta sim mode, collect simulator runtime statistics + stats = {} + cost = None + if env.TARGET in ["sim", "tsim"]: + # Check if we're in local RPC mode (allows us to rebuild the + # runtime on the fly when varying the VTA designs) + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: + if env.TARGET == "sim": + remote.get_function("vta.simulator.profiler_clear")() + else: + remote.get_function("vta.tsim.profiler_clear")() + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + if env.TARGET == "sim": + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + else: + stats = json.loads(remote.get_function("vta.tsim.profiler_status")()) + else: + simulator.clear_stats() + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = simulator.stats() + else: + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + + # Check correctness + correct = False + if check_correctness: + res_orig = res_arr.asnumpy() + if data_pack: + res_orig = res_orig.transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) + bias_np = bias_np.transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, 1, 1) + res_ref = res_ref >> env.WGT_WIDTH + res_ref += bias_np + res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) + res_ref = res_ref.astype(env.out_dtype) + correct = np.allclose(res_orig, res_ref) + + gops = (num_ops / cost.mean) / float(10 ** 9) + status = "PASSED" if correct else "FAILED" + if "arm_cpu" in target.keys: + device = "CPU" + elif "vta" in target.keys: + device = "VTA" + print("%s GROUP CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops)) + + return correct, cost, stats + +def test_conv2d(device="vta"): + def _run(env, remote): + if device == "vta": + target = env.target + if env.TARGET not in ["sim", "tsim"]: + assert tvm.module.enabled("rpc") + program_fpga(remote, bitstream=None) + reconfig_runtime(remote) + elif device == "arm_cpu": + target = env.target_vta_cpu + with autotvm.tophub.context(target): # load pre-tuned schedule parameters + for _, wl in mobilenet_wkls: + print(wl) + run_group_conv2d(env, remote, wl, target) + vta.testing.run(_run) + +if __name__ == "__main__": + test_conv2d(device="arm_cpu") + test_conv2d(device="vta")