From dacd7b68d1019413ccb8121aebd68f5c25a2de94 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 28 Jun 2019 09:28:20 -0700 Subject: [PATCH] [VTA][Relay] Relay Compilation + AutoTVM compatible operator libraries for VTA (#3135) --- include/vta/driver.h | 2 +- python/vta/build_module.py | 1 + python/vta/environment.py | 7 + python/vta/exec/rpc_server.py | 3 + python/vta/pkg_config.py | 2 - python/vta/testing/simulator.py | 13 + python/vta/testing/util.py | 28 +- python/vta/top/__init__.py | 11 +- python/vta/top/arm_conv2d.py | 37 -- python/vta/top/bitpack.py | 90 ++++ python/vta/top/graphpack.py | 308 ++++++++++++ python/vta/top/nnvm_bitpack.py | 86 ++++ python/vta/top/nnvm_graphpack.py | 223 +++++++++ python/vta/top/nnvm_op.py | 130 +++++ python/vta/top/op.py | 144 ++++++ python/vta/top/vta_conv2d.py | 461 +++-------------- python/vta/top/vta_dense.py | 170 +++++++ scripts/tune_conv2d.py | 130 +++++ scripts/tune_dense.py | 105 ++++ scripts/tune_resnet.py | 310 ++++++++++++ scripts/tune_resnet_nnvm.py | 256 ++++++++++ src/runtime.cc | 2 - src/sim/sim_driver.cc | 42 +- .../integration/test_benchmark_topi_conv2d.py | 433 +++++++--------- .../integration/test_benchmark_topi_dense.py | 185 +++++++ tutorials/README.txt | 1 + tutorials/autotvm/README.txt | 3 + tutorials/autotvm/tune_relay_vta.py | 468 ++++++++++++++++++ tutorials/frontend/README.txt | 4 + tutorials/frontend/deploy_resnet_on_vta.py | 262 ++++++++++ tutorials/optimize/README.txt | 2 + tutorials/{ => optimize}/convolution_opt.py | 0 .../{ => optimize}/matrix_multiply_opt.py | 0 tutorials/resnet.py | 330 ------------ 34 files changed, 3228 insertions(+), 1021 deletions(-) delete mode 100644 python/vta/top/arm_conv2d.py create mode 100644 python/vta/top/bitpack.py create mode 100644 python/vta/top/graphpack.py create mode 100644 python/vta/top/nnvm_bitpack.py create mode 100644 python/vta/top/nnvm_graphpack.py create mode 100644 python/vta/top/nnvm_op.py create mode 100644 python/vta/top/op.py create mode 100644 python/vta/top/vta_dense.py create mode 100644 scripts/tune_conv2d.py create mode 100644 scripts/tune_dense.py create mode 100644 scripts/tune_resnet.py create mode 100644 scripts/tune_resnet_nnvm.py create mode 100644 tests/python/integration/test_benchmark_topi_dense.py create mode 100644 tutorials/autotvm/README.txt create mode 100644 tutorials/autotvm/tune_relay_vta.py create mode 100644 tutorials/frontend/README.txt create mode 100644 tutorials/frontend/deploy_resnet_on_vta.py create mode 100644 tutorials/optimize/README.txt rename tutorials/{ => optimize}/convolution_opt.py (100%) rename tutorials/{ => optimize}/matrix_multiply_opt.py (100%) delete mode 100644 tutorials/resnet.py diff --git a/include/vta/driver.h b/include/vta/driver.h index eca9e4da9799..2d8e9c2c3d84 100644 --- a/include/vta/driver.h +++ b/include/vta/driver.h @@ -42,7 +42,7 @@ extern "C" { /*! \brief Physically contiguous buffer size limit */ #ifndef VTA_MAX_XFER -#define VTA_MAX_XFER (1<<22) +#define VTA_MAX_XFER (1<<25) #endif /*! PAGE SIZE */ diff --git a/python/vta/build_module.py b/python/vta/build_module.py index 471dc90746de..dbd2e4b45fd6 100644 --- a/python/vta/build_module.py +++ b/python/vta/build_module.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """VTA specific buildin for runtime.""" from __future__ import absolute_import as _abs diff --git a/python/vta/environment.py b/python/vta/environment.py index 4c2200d04727..093b0ec5c386 100644 --- a/python/vta/environment.py +++ b/python/vta/environment.py @@ -234,6 +234,10 @@ def gemm(self): """GEMM intrinsic""" return self.dev.gemm + @property + def target(self): + return tvm.target.vta(model=self.TARGET) + @property def target_host(self): """The target host""" @@ -243,6 +247,9 @@ def target_host(self): return "llvm" raise ValueError("Unknown target %s" % self.TARGET) + @property + def target_vta_cpu(self): + return tvm.target.arm_cpu(model=self.TARGET) def get_env(): """Get the current VTA Environment. diff --git a/python/vta/exec/rpc_server.py b/python/vta/exec/rpc_server.py index 8caa48a56104..0ac97a2ab07e 100644 --- a/python/vta/exec/rpc_server.py +++ b/python/vta/exec/rpc_server.py @@ -66,6 +66,9 @@ def ext_dev_callback(): @tvm.register_func("tvm.contrib.vta.init", override=True) def program_fpga(file_name): + from pynq import xlnk + # Reset xilinx driver + xlnk.Xlnk().xlnk_reset() path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name) env = get_env() program_bitstream.bitstream_program(env.TARGET, path) diff --git a/python/vta/pkg_config.py b/python/vta/pkg_config.py index 2c30414ace1a..3977d5aa2e43 100644 --- a/python/vta/pkg_config.py +++ b/python/vta/pkg_config.py @@ -77,8 +77,6 @@ def __init__(self, cfg, proj_root): if self.target == "pynq": self.ldflags = [ "-L/usr/lib", - "-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/", - "-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/", "-l:libcma.so"] else: self.ldflags = [] diff --git a/python/vta/testing/simulator.py b/python/vta/testing/simulator.py index dbeba84f6d4a..2d6cfe305756 100644 --- a/python/vta/testing/simulator.py +++ b/python/vta/testing/simulator.py @@ -84,4 +84,17 @@ def tsim_cycles(): """ return tvm.get_global_func("tvm.vta.tsim.cycles")() +# debug flag to skip execution. +DEBUG_SKIP_EXEC = 1 + +def debug_mode(flag): + """Set debug mode + Paramaters + ---------- + flag : int + The debug flag, 0 means clear all flags. + """ + tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag) + + LIBS = _load_lib() diff --git a/python/vta/testing/util.py b/python/vta/testing/util.py index 06c700cd7119..30760409733c 100644 --- a/python/vta/testing/util.py +++ b/python/vta/testing/util.py @@ -18,7 +18,7 @@ from __future__ import absolute_import as _abs import os -from tvm import rpc +from tvm import rpc, autotvm from ..environment import get_env from . import simulator @@ -42,7 +42,7 @@ def run(run_func): # the port it's listening to, e.g. 9090 local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) if local_rpc: - remote = rpc.connect("localhost", local_rpc) + remote = rpc.connect("127.0.0.1", local_rpc) run_func(env, remote) else: # Make sure simulation library exists @@ -54,12 +54,22 @@ def run(run_func): elif env.TARGET == "pynq": - # Run on PYNQ if env variable exists - host = os.environ.get("VTA_PYNQ_RPC_HOST", None) - port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None)) - if host and port: - remote = rpc.connect(host, port) + tracket_host = os.environ.get("TVM_TRACKER_HOST", None) + tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None) + pynq_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None)) + # Run device from fleet node if env variables are defined + if tracket_host and tracket_port: + remote = autotvm.measure.request_remote(env.TARGET, + tracket_host, + tracket_port, + timeout=10000) run_func(env, remote) else: - raise RuntimeError( - "Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables") + # Next, run on PYNQ if env variables are defined + if pynq_host and pynq_port: + remote = rpc.connect(pynq_host, pynq_port) + run_func(env, remote) + else: + raise RuntimeError( + "Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables") diff --git a/python/vta/top/__init__.py b/python/vta/top/__init__.py index 614ed2347181..3b5132ebf0ef 100644 --- a/python/vta/top/__init__.py +++ b/python/vta/top/__init__.py @@ -1,5 +1,12 @@ """TVM TOPI connector, eventually most of these should go to TVM repo""" -from .vta_conv2d import packed_conv2d, schedule_packed_conv2d +from . import bitpack +from .graphpack import graph_pack +from . import op from . import vta_conv2d -from . import arm_conv2d +from . import vta_dense + +# NNVM is deprecated for VTA +# from . import nnvm_bitpack +# from .nnvm_graphpack import nnvm_graph_pack +# from . import nnvm_op diff --git a/python/vta/top/arm_conv2d.py b/python/vta/top/arm_conv2d.py deleted file mode 100644 index 6e34917c0b71..000000000000 --- a/python/vta/top/arm_conv2d.py +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Reuse conv2d schedule from ARM CPU""" - -import tvm - -from topi.nn import conv2d, conv2d_alter_layout -from topi import generic - -@conv2d.register(["vtacpu", "vta"]) -def compute(*args, **kwargs): - with tvm.target.arm_cpu("vtacpu"): - return conv2d(*args, **kwargs) - -@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"]) -def schedule(*args, **kwargs): - with tvm.target.arm_cpu("vtacpu"): - return generic.schedule_conv2d_nchw(*args, **kwargs) - -@conv2d_alter_layout.register(["vtacpu", "vta"]) -def alter(*args, **kwargs): - with tvm.target.arm_cpu("vtacpu"): - return conv2d_alter_layout(*args, **kwargs) diff --git a/python/vta/top/bitpack.py b/python/vta/top/bitpack.py new file mode 100644 index 000000000000..d4748faad6a7 --- /dev/null +++ b/python/vta/top/bitpack.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=ungrouped-imports + +"""Bit packing operators""" +from __future__ import absolute_import as _abs + +import tvm +from topi import util + +from tvm.relay.op.op import register_compute, register_schedule +from tvm.relay.op.op import register_pattern, OpPattern +from tvm.relay.op.op import schedule_injective + +def bitpack(data, bits, pack_type="int8", name="bitpack"): + """Packs lowest dimension into format needed by VTA + + Parameters + ---------- + pack_axis : int + index of the axis to pack in data + bit_axis : int + index of axis to place bit axis in resulting packed data + + Returns + ------- + packed : Tensor + The packed tensor. + """ + shape_vec = list(data.shape) + if pack_type == 'int8': + data_width = 8 + elif pack_type == 'int16': + data_width = 16 + elif pack_type == 'int32': + data_width = 32 + else: + raise RuntimeError("Unknown pack type %s" % pack_type) + assert data_width % bits == 0 + lanes = data_width // bits + + # Data must be in multiples of the data_width + assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size" + shape_vec[-1] = shape_vec[-1] // lanes + oshape = tuple(shape_vec) + + def _bitpack(*indices): + ret = None + mask = tvm.const((1 << bits) - 1, pack_type) + for k in range(lanes): + idx = list(indices) + idx[-1] = idx[-1] * lanes + k + elem = data(*idx).astype(pack_type) + if k == 0: + ret = elem & mask + else: + val = (elem & mask) << tvm.const(k * bits, pack_type) + ret = ret | val + return ret + + return tvm.compute( + oshape, _bitpack, name=name, tag='bitpack') + + +@register_compute("bitpack", level=15) +def compute_bitpack(attrs, inputs): + lanes = attrs.lanes + dtype = inputs[0].dtype + assert dtype == "int8" + width = 8 + assert width % lanes == 0 + bits = 8 // lanes + return bitpack(inputs[0], bits, dtype) + +register_schedule("bitpack", schedule_injective) +register_pattern("bitpack", OpPattern.INJECTIVE) diff --git a/python/vta/top/graphpack.py b/python/vta/top/graphpack.py new file mode 100644 index 000000000000..6f901833ea15 --- /dev/null +++ b/python/vta/top/graphpack.py @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""A Relay implementation of graph packing.""" + +from tvm import relay +from tvm.relay import op +from tvm.relay import ExprMutator + +def _to_shape(shape): + return tuple(int(sh) for sh in shape) + +def _pack_batch_channel(data, dshape, bfactor, cfactor): + """Pack the data channel dimension. + """ + assert int(dshape[0]) % bfactor == 0 + assert int(dshape[1]) % cfactor == 0 + data = op.reshape(data, + newshape=(int(dshape[0]) // bfactor, bfactor, + int(dshape[1]) // cfactor, cfactor, + int(dshape[2]), int(dshape[3]))) + data = op.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _unpack_batch_channel(data, old_shape): + """Unpack the data channel dimension. + """ + data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) + data = op.reshape(data, newshape=old_shape) + return data + + +def _pack_weight(data, dshape, cfactor): + """Pack the weight into packed format. + """ + assert len(dshape) == 4 + assert int(dshape[0]) % cfactor == 0 + assert int(dshape[1]) % cfactor == 0 + data = op.reshape(data, + newshape=(int(dshape[0]) // cfactor, cfactor, + int(dshape[1]) // cfactor, cfactor, + int(dshape[2]), int(dshape[3]))) + data = op.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _pack_weight_conv2d_transpose(data, dshape, cfactor): + """Pack the weight into packed format. + """ + dshape = _to_shape(dshape) + assert len(dshape) == 4 + assert dshape[0] % cfactor == 0 + assert dshape[1] % cfactor == 0 + data = op.reshape(data, + newshape=(dshape[0] // cfactor, cfactor, + dshape[1] // cfactor, cfactor, + dshape[2], dshape[3])) + data = op.transpose( + data, axes=(2, 0, 4, 5, 3, 1)) + return data + + +def _pack_bias(data, dshape, dtype, bfactor, cfactor): + """Pack the bias parameter. + """ + dshape = _to_shape(dshape) + assert len(dshape) == 3 + assert dshape[0] % cfactor == 0 + data = op.reshape(data, + newshape=(dshape[0] // cfactor, + cfactor, dshape[1], + dshape[2], 1)) + data = op.transpose( + data, axes=(0, 2, 3, 4, 1)) + + # broadcast batch dimension to bfactor + data = op.broadcast_to( + data, + shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)) + return data + + +def _get_shape(node): + """Get the shape of a node. + """ + return _to_shape(node.checked_type.shape) + +class ExprPack(ExprMutator): + """Visitor to perform graph packing on an AST. + """ + def __init__(self, bfactor, cfactor, weight_bits): + self.bfactor = bfactor + self.cfactor = cfactor + self.weight_bits = weight_bits + self.start_pack = False + # Cache Operator the algorithm matches against. + self.bitpack_start = op.op.get('annotation.bitpack_start') + self.bitpack_end = op.op.get('annotation.bitpack_end') + self.conv2d = op.op.get("nn.conv2d") + self.conv2d_transpose = op.op.get("nn.conv2d_transpose") + self.add = op.op.get("add") + self.bias_add = op.op.get("nn.bias_add") + self.number_of_conv2d = 0 + super().__init__() + + def visit_call(self, call): + # First visit the children. + oshape = _get_shape(call) + odtype = call.checked_type.dtype + input_types = [arg.checked_type for arg in call.args] + args = [self.visit(arg) for arg in call.args] + + # Start and stop cases. + if call.op == self.bitpack_start: + assert not self.start_pack + self.start_pack = True + return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor) + elif call.op == self.bitpack_end: + if self.start_pack: + self.start_pack = False + data = args[0] + data_shape = _get_shape(call.args[0]) + return _unpack_batch_channel(data, data_shape) + else: + pass + if self.start_pack: + # Operator cases + if call.op == self.conv2d and odtype == 'int32': + self.number_of_conv2d += 1 + assert 8 % self.weight_bits == 0 + w_lanes = 8 // self.weight_bits + data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) + kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor) + data, weight = args + data_shape = _to_shape(input_types[0].shape) + kernel_shape = _to_shape(input_types[1].shape) + kernel = _pack_weight(weight, kernel_shape, self.cfactor) + # insert bit packing when necessary + if w_lanes != 1: + assert 8 % w_lanes == 0 + kernel = op.bitpack(kernel, lanes=w_lanes) + conv2d = op.nn.conv2d( + data, + kernel, + strides=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + groups=call.attrs.groups, + channels=call.attrs.channels, + kernel_size=call.attrs.kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + out_dtype=call.attrs.out_dtype) + return conv2d + elif call.op == self.conv2d_transpose and odtype == 'int32': + self.number_of_conv2d += 1 + assert 8 % self.weight_bits == 0 + w_lanes = 8 // self.weight_bits + if self.start_pack: + data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor) + kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor) + data, weight = args + data_shape = _to_shape(input_types[0].shape) + kernel_shape = _to_shape(input_types[1].shape) + kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor) + conv2d = op.nn.conv2d_transpose( + data, + kernel, + strides=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + groups=call.attrs.groups, + channels=call.attrs.channels, + kernel_size=call.attrs.kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + output_padding=call.attrs.output_padding, + out_dtype=call.attrs.out_dtype) + return conv2d + elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): + pass + elif call.op == self.add and len(input_types[1].shape) == 3: + data, bias = args + bias = _pack_bias(bias, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) + return relay.Call(self.add, [data, bias]) + elif self.start_pack and call.op == self.bias_add: + data, bias = args + bias = _pack_bias(bias, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) + return relay.Call(self.add, [data, bias]) + elif self.start_pack and call.op == op.op.get('cast') and \ + input_types[0].dtype == 'int32': + cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs) + return relay.Call(op.op.get('copy'), [cast]) + + return relay.Call( + self.visit(call.op), + args, + call.attrs) + +class BT(Exception): + pass +def get_subgraph(expr, start_name, stop_name): + """ We assume stop_name only appears once for simplicity. + This constraint will be lifted in the future. + bitpack_start and bitpack_end are both inclusive. + """ + bitpack_start = op.op.get('annotation.bitpack_start') + bitpack_end = op.op.get('annotation.bitpack_end') + anf = relay.ir_pass.to_a_normal_form(expr) + def _recursion(anf, start_found, stop_found): + """ Helper to obtain the subgraph. + """ + if isinstance(anf, relay.expr.Function): + return relay.expr.Function(anf.params, + _recursion(anf.body, start_found, stop_found), + anf.ret_type, anf.type_params, anf.attrs) + elif isinstance(anf, relay.expr.Let): + value = anf.value + if isinstance(value, relay.expr.Call): + if isinstance(value.op, relay.op.Op): + if value.op.name == start_name and not start_found: + value = relay.expr.Call(bitpack_start, [value]) + start_found = True + elif value.op.name == stop_name: + raise BT() + try: + return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found)) + except BT: + assert start_found + assert not stop_found + stop_found = True + value = relay.expr.Call(bitpack_end, [value]) + # todo: check anf.body has no more stop_name beside that one + return relay.expr.Let(anf.var, value, anf.body) + else: + assert start_found + assert stop_found + return anf + annotated = _recursion(anf, False, False) + return relay.ir_pass.infer_type(relay.ir_pass.to_graph_normal_form(annotated)) + +def graph_pack(expr, + bfactor, + cfactor, + weight_bits, + start_name="nn.max_pool2d", + stop_name="nn.global_avg_pool2d"): + """Pack the graph into batch&channel packed format. + + Parameters + ---------- + expr : relay.Expr + The input program. + + bfactor : int + The packing factor in batch + + cfactor : int + The packing factor in channel + + weight_bits: int + The bit-width of the weights. + + start_name: str, optional + Start packing from certain known node. + + stop_name: str, optional + Stop packing from certain known node. + + Returns + ------- + expr : Expr + The transformed expression. + """ + assert isinstance(expr, relay.Function) + expr = get_subgraph(expr, start_name, stop_name) + expr = relay.ir_pass.infer_type(expr) + packer = ExprPack( + bfactor, cfactor, + weight_bits) + expr = packer.visit(expr) + assert not packer.start_pack + return relay.ir_pass.infer_type(expr) diff --git a/python/vta/top/nnvm_bitpack.py b/python/vta/top/nnvm_bitpack.py new file mode 100644 index 000000000000..0dc241330339 --- /dev/null +++ b/python/vta/top/nnvm_bitpack.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Bit packing operators""" +from __future__ import absolute_import as _abs + +import tvm +from topi import util + +from nnvm.top import registry as reg, OpPattern +from nnvm.top.tensor import _fschedule_broadcast + +def bitpack(data, bits, pack_type="int8", name="bitpack"): + """Packs lowest dimension into format needed by VTA + Parameters + ---------- + pack_axis : int + index of the axis to pack in data + bit_axis : int + index of axis to place bit axis in resulting packed data + Returns + ------- + packed : Tensor + The packed tensor. + """ + shape_vec = list(data.shape) + if pack_type == 'int8': + data_width = 8 + elif pack_type == 'int16': + data_width = 16 + elif pack_type == 'int32': + data_width = 32 + else: + raise RuntimeError("Unknown pack type %s" % pack_type) + assert data_width % bits == 0 + lanes = data_width // bits + + # Data must be in multiples of the data_width + assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size" + shape_vec[-1] = shape_vec[-1] // lanes + oshape = tuple(shape_vec) + + def _bitpack(*indices): + ret = None + mask = tvm.const((1 << bits) - 1, pack_type) + for k in range(lanes): + idx = list(indices) + idx[-1] = idx[-1] * lanes + k + elem = data(*idx).astype(pack_type) + if k == 0: + ret = elem & mask + else: + val = (elem & mask) << tvm.const(k * bits, pack_type) + ret = ret | val + return ret + + return tvm.compute( + oshape, _bitpack, name=name, tag='bitpack') + + +@reg.register_compute("bitpack", level=15) +def compute_bitpack(attrs, inputs, out): + lanes = attrs.get_int("lanes") + dtype = inputs[0].dtype + assert dtype == "int8" + width = 8 + assert width % lanes == 0 + bits = 8 // lanes + return bitpack(inputs[0], bits, dtype) + +reg.register_schedule("bitpack", _fschedule_broadcast) +reg.register_pattern("bitpack", OpPattern.INJECTIVE) diff --git a/python/vta/top/nnvm_graphpack.py b/python/vta/top/nnvm_graphpack.py new file mode 100644 index 000000000000..427001ffa5ed --- /dev/null +++ b/python/vta/top/nnvm_graphpack.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""An NNVM implementation of graph packing.""" + +import nnvm +from nnvm.compiler import graph_attr, graph_util + +def _pack_batch_channel(data, dshape, bfactor, cfactor): + """Pack the data channel dimension. + """ + assert dshape[0] % bfactor == 0 + assert dshape[1] % cfactor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // bfactor, bfactor, + dshape[1] // cfactor, cfactor, + dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _unpack_batch_channel(data, old_shape): + """Unpack the data channel dimension. + """ + data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3)) + data = nnvm.sym.reshape(data, shape=old_shape) + return data + + +def _pack_weight(data, dshape, cfactor): + """Pack the weight into packed format. + """ + assert len(dshape) == 4 + assert dshape[0] % cfactor == 0 + assert dshape[1] % cfactor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // cfactor, cfactor, + dshape[1] // cfactor, cfactor, + dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(0, 2, 4, 5, 1, 3)) + return data + + +def _pack_weight_conv2d_transpose(data, dshape, cfactor): + """Pack the weight into packed format. + """ + assert len(dshape) == 4 + assert dshape[0] % cfactor == 0 + assert dshape[1] % cfactor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // cfactor, cfactor, + dshape[1] // cfactor, cfactor, + dshape[2], dshape[3])) + data = nnvm.sym.transpose( + data, axes=(2, 0, 4, 5, 3, 1)) + return data + + +def _pack_bias(data, dshape, bfactor, cfactor): + """Pack the bias parameter. + """ + assert len(dshape) == 3 + assert dshape[0] % cfactor == 0 + data = nnvm.sym.reshape(data, + shape=(dshape[0] // cfactor, + cfactor, dshape[1], + dshape[2], 1)) + data = nnvm.sym.transpose( + data, axes=(0, 2, 3, 4, 1)) + # broadcast batch dimension to bfactor + data = nnvm.sym.broadcast_to( + data, + shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)) + return data + + +def _get_shape(sym, shape_dict): + """Get the shape of a node. + """ + return graph_util.infer_shape( + nnvm.graph.create(sym), **shape_dict)[1][0] + + +def nnvm_graph_pack(graph, + shape_dict, + bfactor, + cfactor, + weight_bits, + start_name="max_pool2d0", + stop_name="global_avg_pool2d0"): + """Pack the graph into batch&channel packed format. + + Parameters + ---------- + graph : Graph + The input graph. + + shape_dict : dict of str to shape + The input shape. + + bfactor : int + The packing factor in batch + + cfactor : int + The packing factor in channel + + start_name: str, optional + Start packing from certain known node. + + start_name: str, optional + Stop packing from certain known node. + + Returns + ------- + graph : Graph + The transformed graph. + """ + graph = graph_attr.set_shape_inputs(graph, shape_dict) + graph = graph.apply("InferShape") + shape = graph.json_attr("shape") + gidx = graph.index + node_map = {} + dset = set() + start_pack = False + + for nid, node in enumerate(gidx.nodes): + children = [node_map[e[0]] for e in node["inputs"]] + ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]] + oshape = shape[gidx.entry_id(nid, 0)] + attrs = node.get("attrs", {}) + node_name = node["name"] + op_name = node["op"] + get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)( + *c, name=n_n, **a) + if op_name == "null": + new_node = nnvm.symbol.Variable(node_name) + if start_name and node_name == start_name: + start_pack = True + new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor) + if start_pack and "_begin_state_" in node_name: # RNN -> CNN, pack + new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor) + elif node_name == start_name: + assert not start_pack + start_pack = True + new_node = get_clone(children, op_name, node_name, attrs) + new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor) + elif node_name == stop_name: + if start_pack: + start_pack = False + children[0] = _unpack_batch_channel(children[0], ishape[0]) + new_node = getattr(nnvm.symbol, op_name)( + *children, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name == "conv2d" and attrs.get("out_dtype", None) == "int32": + assert 8 % weight_bits == 0 + w_lanes = 8 // weight_bits + if start_pack: + attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor) + attrs["kernel_layout"] = "OIHW%do%di%dp" % (cfactor, cfactor, w_lanes) + data, weight = children + weight = _pack_weight(weight, ishape[1], cfactor) + # insert bit packing when necessary + if w_lanes != 1: + assert 8 % w_lanes == 0 + weight = nnvm.sym.bitpack(weight, lanes=w_lanes) + new_node = nnvm.sym.conv2d( + data, weight, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name == "conv2d_transpose" and attrs.get("out_dtype", None) == "int32": + assert 8 % weight_bits == 0 + w_lanes = 8 // weight_bits + if start_pack: + attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor) + attrs["kernel_layout"] = "IOHW%di%do%dp" % (cfactor, cfactor, w_lanes) + data, weight = children + weight = _pack_weight_conv2d_transpose(weight, ishape[1], cfactor) + new_node = nnvm.sym.conv2d_transpose( + data, weight, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name.startswith("broadcast_") and tuple(ishape[0]) == tuple(ishape[1]): + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name.startswith("broadcast") and len(ishape[1]) == 3: + if start_pack: + children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor) + new_node = getattr(nnvm.symbol, op_name)( + *children, name=node_name, **attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + elif op_name.startswith("elementwise_add"): + new_node = get_clone(children, op_name, node_name, attrs) + else: + new_node = get_clone(children, op_name, node_name, attrs) + dset.add(op_name) + node_map[nid] = new_node + + assert len(graph.index.output_entries) == 1 + ret = node_map[graph.index.output_entries[0][0]] + if start_pack: + oshape = shape[graph.index.output_entries[0][0]] + ret = _unpack_batch_channel(ret, oshape) + graph = nnvm.graph.create(ret) + graph = graph_attr.set_shape_inputs(graph, shape_dict) + graph = graph.apply("InferShape") + return graph diff --git a/python/vta/top/nnvm_op.py b/python/vta/top/nnvm_op.py new file mode 100644 index 000000000000..a38b2172671b --- /dev/null +++ b/python/vta/top/nnvm_op.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" +from __future__ import absolute_import as _abs + +import logging + +import tvm +import topi + +from nnvm.top import registry as reg, OpPattern +from nnvm.top import nn as _nn + +from .vta_conv2d import is_packed_layout +from ..environment import get_env + +@tvm.register_func("nnvm.compiler.build_target", override=True) +def _build(funcs, target, target_host): + tvm_t = tvm.target.create(target) + if tvm_t.device_name == "vta": + return tvm.build(funcs, target="ext_dev", target_host=target_host) + if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu": + return tvm.build(funcs, target=target_host) + return tvm.build(funcs, target=target) + +@tvm.register_func("nnvm.compiler.lower", override=True) +def _lower(sch, inputs, func_name, graph): + import traceback + # pylint: disable=broad-except + try: + f = tvm.lower(sch, inputs, name=func_name) + if "quantized_conv2d" in func_name: + logging.info(graph.ir(join_entry_attrs=["shape"])) + except Exception: + msg = traceback.format_exc() + msg += "Error during compile graph\n" + msg += "--------------------------\n" + msg += graph.ir(join_entry_attrs=["shape"]) + raise RuntimeError(msg) + return f if isinstance( + f, (tvm.container.Array, tuple, list)) else [f] + +# override to force partition at copy +reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) + +@reg.register_compute("clip", level=15) +def compute_clip(attrs, inputs, _): + """ Clip operator. """ + x = inputs[0] + a_min = attrs.get_float("a_min") + a_max = attrs.get_float("a_max") + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + with tvm.tag_scope(topi.tag.ELEMWISE): + x = tvm.compute( + x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute( + x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +@reg.register_compute("conv2d", level=15) +def compute_conv2d(attrs, inputs, out): + """ Compute definition of conv2d """ + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs["layout"] + out_dtype = attrs['out_dtype'] + + assert dilation == (1, 1), "not support dilate now" + if is_packed_layout(layout): + if groups == 1: + assert groups == 1 + env = get_env() + assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" + inputs = list(inputs) + assert inputs[1].dtype == "int8" + return topi.nn.conv2d(inputs[0], inputs[1], strides, + padding, dilation, layout, out_dtype) + return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, + padding, dilation, groups, out_dtype) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_conv2d(attrs, inputs, out) + +@reg.register_schedule("conv2d", level=15) +def schedule_conv2d(attrs, outs, target): + """ Schedule definition of conv2d """ + layout = attrs["layout"] + groups = attrs.get_int('groups') + + if is_packed_layout(layout): + target = tvm.target.create(target) + if target.device_name == "vta": + if groups == 1: + return topi.generic.schedule_conv2d_nchw(outs) + return topi.generic.schedule_group_conv2d_nchw(outs) + elif str(target).startswith("llvm"): + return tvm.create_schedule([x.op for x in outs]) + else: + raise RuntimeError("not support target %s" % target) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target()) + +@reg.register_alter_op_layout("conv2d", level=15) +def alter_conv2d_layout(attrs, inputs, out): + layout = attrs['layout'] + if is_packed_layout(layout): + return None + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.alter_conv2d_layout(attrs, inputs, out) diff --git a/python/vta/top/op.py b/python/vta/top/op.py new file mode 100644 index 000000000000..96eaa8fb9905 --- /dev/null +++ b/python/vta/top/op.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, ungrouped-imports +"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" +from __future__ import absolute_import as _abs + +import tvm +import topi + +from tvm.relay.op import op as reg +from tvm.relay.op.op import OpPattern +from tvm.relay.op.nn import _nn + +from .vta_conv2d import is_packed_layout +from ..environment import get_env + +# override to force partition at copy +reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) + +@reg.register_compute("clip", level=15) +def compute_clip(attrs, inputs, output_type, target): + """ Clip operator. """ + x = inputs[0] + a_min = attrs.a_min + a_max = attrs.a_max + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + with tvm.tag_scope(topi.tag.ELEMWISE): + x = tvm.compute( + x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute( + x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return [x] + + +@reg.register_compute("nn.conv2d", level=15) +def compute_conv2d(attrs, inputs, output_type, target): + """ Compute definition of conv2d """ + padding = topi.util.get_const_tuple(attrs.padding) + strides = topi.util.get_const_tuple(attrs.strides) + dilation = tuple([int(d) for d in attrs.dilation]) + groups = attrs.groups + layout = attrs.data_layout + out_dtype = attrs.out_dtype + + if target.device_name == "vta": + assert dilation == (1, 1), "support for dilation limited to (1, 1)" + if is_packed_layout(layout): + if groups == 1: + assert groups == 1 + env = get_env() + assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now" + inputs = list(inputs) + assert inputs[1].dtype == "int8" + return [topi.nn.conv2d(inputs[0], + inputs[1], + strides, + padding, + dilation, + layout, + out_dtype)] + return [topi.nn.group_conv2d_nchw(inputs[0], + inputs[1], + strides, + padding, + dilation, + groups, + out_dtype)] + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_conv2d(attrs, inputs, output_type, target) + + # If VTA is not the target, default to _nn def + return _nn.compute_conv2d(attrs, inputs, output_type, target) + + +@reg.register_schedule("nn.conv2d", level=15) +def schedule_conv2d(attrs, outs, target): + """ Schedule definition of conv2d """ + groups = attrs.groups + layout = attrs.data_layout + + if target.device_name == "vta": + if is_packed_layout(layout): + target = tvm.target.create(target) + assert target.device_name == "vta" + if groups == 1: + return topi.generic.schedule_conv2d_nchw(outs) + return topi.generic.schedule_group_conv2d_nchw(outs) + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target()) + + # If VTA is not the target, default to _nn def + return _nn.schedule_conv2d(attrs, outs, target) + + +@reg.register_compute("nn.dense", level=15) +def compute_dense(attrs, inputs, out_type, target): + """Compute definition of dense""" + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + + if target.device_name == "vta": + if inputs[0].shape == 4: # this implies the layout is packed + target = tvm.target.create(target) + return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)] + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_dense(attrs, inputs, out_type, target) + + # If VTA is not the target, default to _nn def + return _nn.compute_dense(attrs, inputs, out_type, target) + + +@reg.register_schedule("nn.dense", level=15) +def schedule_dense(attrs, outs, target): + """Schedule definition of dense""" + if target.device_name == "vta": + if outs[0].shape == 4: # this implies the layout is packed + target = tvm.target.create(target) + assert target.device_name == "vta" + return topi.generic.schedule_dense(outs) + # If it's not packed, run on ARM CPU + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_dense(attrs, outs, tvm.target.current_target()) + + # If VTA is not the target, default to _nn def + return _nn.schedule_dense(attrs, outs, target) diff --git a/python/vta/top/vta_conv2d.py b/python/vta/top/vta_conv2d.py index ef4f2017381a..c455f535d93c 100644 --- a/python/vta/top/vta_conv2d.py +++ b/python/vta/top/vta_conv2d.py @@ -14,182 +14,49 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" -from __future__ import absolute_import as _abs +"""Conv2D operator declaration and schedule registration for VTA.""" -from collections import namedtuple - -import logging +import numpy as np import tvm +from tvm import autotvm import topi -from nnvm.top import registry as reg, OpPattern -from nnvm.top import nn as _nn from ..environment import get_env +def is_packed_layout(layout): + """Check if layout is packed layout""" + if layout == "NCHW": + return False + if "n" in layout and "c" in layout: + return True + return False -Workload = namedtuple("Conv2DWorkload", - ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) - -def find_schedules(layer, vt_only=False, best_only=False): - """ Returns a schedule for a given a layer. - - Parameters - ---------- - layer : Workload - Convolutional layer description. - vt_only : Boolean - Produce a schedule plan with virtual threading. - best_only : Boolean - Return the "best" schedule plan. - - Returns - ------- - fil_sched : list - List of valid schedules. - - """ - # pylint: disable=too-many-nested-blocks - env = get_env() - - # Helper function to get factors - def _find_factors(n): - factors = [] - for f in range(1, n + 1): - if n % f == 0: - factors.append(f) - return factors - - def _get_data_movement_byte(schedule, layer): - """ Estimate data movement in bytes for the schedule plan - """ - env = get_env() - b_f = schedule.b_factor - h_f = schedule.h_factor - w_f = schedule.w_factor - ci_f = schedule.ic_factor - co_f = schedule.oc_factor - # Derive data movement - inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH - wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH - out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH - input_tile_elems = b_f * \ - ((h_f - 1) * layer.hstride + layer.hkernel) * \ - ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f - weight_tile_elems = layer.hkernel * layer.wkernel * ci_f - output_tile_elems = b_f * h_f * w_f * co_f - # Derive tiling factors - b_factor = layer.batch // (b_f * env.BATCH) - h_factor = (layer.height // layer.hstride) // h_f - w_factor = (layer.width // layer.wstride) // w_f - ci_factor = layer.in_filter // (ci_f * env.BLOCK_IN) - co_factor = layer.out_filter // (co_f * env.BLOCK_OUT) - # Compute input transaction count - input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor - weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor - output_xfers = b_factor * h_factor * w_factor * co_factor - # Compute total transfer sizes - input_xfer_byte = input_tile_elems * input_xfers * inp_elem_sizeb // 8 - weight_xfer_byte = weight_tile_elems * weight_xfers * wgt_elem_sizeb // 8 - output_xfer_byte = output_tile_elems * output_xfers * out_elem_sizeb // 8 - total_xfer_byte = input_xfer_byte + weight_xfer_byte + output_xfer_byte - return total_xfer_byte - - # Scheduling exploration - batch_factors = _find_factors(layer.batch // env.BATCH) - height_factors = _find_factors(layer.height // layer.hstride) - width_factors = _find_factors(layer.width // layer.wstride) - cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN) - cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT) - ht_factors = [1, 2] - cot_factors = [1, 2] - - # Explore schedules - schedules = [] - for b_f in batch_factors: - for h_f in height_factors: - for w_f in width_factors: - for ci_f in cin_factors: - for co_f in cout_factors: - # FIXME: 2D load pattern matching imposes restrictions on schedule - valid = (w_f == layer.width // layer.wstride) or \ - (w_f != layer.width // layer.wstride and co_f == 1) and \ - ci_f == 1 - if valid: - schedules.append([b_f, h_f, w_f, ci_f, co_f]) - - # Filter the schedules that wouldn't work in the available BRAM sizes - inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH - wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH - out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH - inp_brams_sizeb = env.INP_BUFF_SIZE * 8 - wgt_brams_sizeb = env.WGT_BUFF_SIZE * 8 - out_brams_sizeb = env.OUT_BUFF_SIZE * 8 - fil_sched = [] - xfer_size = [] - for sched in schedules: - b_f, h_f, w_f, ci_f, co_f = sched - for h_t in ht_factors: - for co_t in cot_factors: - # Make sure to filter cases where we apply threading on two axes - # or cases where the threading factors for h and co are not - # factors of h and co - if (h_t == 2 and co_t == 2) or (h_f % h_t != 0) or (co_f % co_t != 0): - continue - # Adjust tile sizes if threading is applied - h_f //= h_t - co_f //= co_t - # Derive tile sizes - input_tile_elems = b_f * \ - ((h_f - 1) * layer.hstride + layer.hkernel) * \ - ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f - weight_tile_elems = layer.hkernel * layer.wkernel * ci_f * co_f - output_tile_elems = b_f * h_f * w_f * co_f - - # Derive valid schedule filter - valid = True - # If in vitrual-threaded mode, only allow for threaded plans - valid &= (vt_only and (h_t == 2 or co_t == 2)) or not vt_only - # Check that we don't exceed input/weight/output capacity - valid &= input_tile_elems * inp_elem_sizeb <= inp_brams_sizeb // (co_t * h_t) - valid &= weight_tile_elems * wgt_elem_sizeb <= wgt_brams_sizeb - valid &= output_tile_elems * out_elem_sizeb <= out_brams_sizeb // (co_t * h_t) - # Make sure that we don't write to the same acc location within 2 consecutive cycles - valid &= h_f > 2 and w_f > 2 - # TODO: check that we don't exceed instruction or micro-op count - - if valid: - schedule = Schedule(b_factor=b_f, oc_factor=co_f, ic_factor=ci_f, h_factor=h_f, - w_factor=w_f, oc_nthread=co_t, h_nthread=h_t) - fil_sched.append(schedule) - xfer_size.append(_get_data_movement_byte(schedule, layer)) - - if best_only and xfer_size: - return [fil_sched[xfer_size.index(min(xfer_size))]] - return fil_sched +@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct') +def _declaration_conv2d(cfg, + data, + kernel, + strides, + padding, + dilation, + layout, + out_dtype): + """ Packed conv2d function.""" + if not is_packed_layout(layout): + raise topi.InvalidShapeError() + assert dilation == (1, 1) -def packed_conv2d(data, - kernel, - padding, - strides, - out_dtype="int32"): - """ Packed conv2d function. - """ if padding[0]: pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") else: pad_data = data assert len(data.shape) == 6 assert len(kernel.shape) == 6 - oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) - owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) + oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) + owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4]) ishape = topi.util.get_const_tuple(data.shape) kshape = topi.util.get_const_tuple(kernel.shape) - assert data.dtype == "int8", data.dtype - assert kernel.dtype == "int8", kernel.dtype d_i = tvm.reduce_axis((0, kshape[2]), name='d_i') d_j = tvm.reduce_axis((0, kshape[3]), name='d_j') k_o = tvm.reduce_axis((0, ishape[1]), name='k_o') @@ -201,167 +68,55 @@ def packed_conv2d(data, pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) * kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype), axis=[k_o, d_i, d_j, k_i]), - name="res", tag="packed_conv2d") - return res - -@tvm.register_func("nnvm.compiler.build_target", override=True) -def _build(funcs, target, target_host): - tvm_t = tvm.target.create(target) - if tvm_t.device_name == "vta": - return tvm.build(funcs, target="ext_dev", target_host=target_host) - if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu": - return tvm.build(funcs, target=target_host) - return tvm.build(funcs, target=target) - - -@tvm.register_func("nnvm.compiler.lower", override=True) -def _lower(sch, inputs, func_name, graph): - import traceback - # pylint: disable=broad-except - try: - f = tvm.lower(sch, inputs, name=func_name) - if "quantized_conv2d" in func_name: - logging.info(graph.ir(join_entry_attrs=["shape"])) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile graph\n" - msg += "--------------------------\n" - msg += graph.ir(join_entry_attrs=["shape"]) - raise RuntimeError(msg) - return f if isinstance( - f, (tvm.container.Array, tuple, list)) else [f] - - -@reg.register_compute("clip", level=15) -def compute_clip(attrs, inputs, _): - """ Clip operator. - """ - x = inputs[0] - a_min = attrs.get_float("a_min") - a_max = attrs.get_float("a_max") - const_min = tvm.const(a_min, x.dtype) - const_max = tvm.const(a_max, x.dtype) - with tvm.tag_scope(topi.tag.ELEMWISE): - x = tvm.compute( - x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") - x = tvm.compute( - x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") - return x - -# override to force partition at copy -reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) - -def is_packed_layout(layout): - """Check if layout is packed layout""" - if layout == "NCHW": - return False - if "n" in layout and "c" in layout: - return True - return False - -@reg.register_alter_op_layout("conv2d", level=15) -def alter_conv2d_layout(attrs, inputs, out): - layout = attrs['layout'] - if is_packed_layout(layout): - return None - return _nn.alter_conv2d_layout(attrs, inputs, out) - - -@reg.register_compute("conv2d", level=15) -def compute_conv2d(attrs, inputs, out): - """ 2D convolution algorithm. - """ - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - dilation = attrs.get_int_tuple("dilation") - groups = attrs.get_int("groups") - layout = attrs["layout"] - out_dtype = attrs['out_dtype'] - assert dilation == (1, 1), "not support dilate now" - if is_packed_layout(layout): - assert groups == 1 - return packed_conv2d(inputs[0], inputs[1], - padding, strides, out_dtype=out_dtype) - return _nn.compute_conv2d(attrs, inputs, out) - - -@reg.register_schedule("conv2d", level=15) -def schedule_conv2d(attrs, outs, target): - """ 2D convolution schedule. - """ - layout = attrs["layout"] - - if is_packed_layout(layout): - target = tvm.target.create(target) - if target.device_name == "vta": - return schedule_packed_conv2d(outs) - if str(target).startswith("llvm"): - return tvm.create_schedule([x.op for x in outs]) - raise RuntimeError("not support target %s" % target) - return _nn.schedule_conv2d(attrs, outs, target) + name="res", tag="conv2d_dense") + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + kshape[2] * kshape[3] * ishape[1] * ishape[-1]) -def _get_workload(data, pad_data, kernel, output): - """ Get the workload structure. - """ - o_shape = topi.util.get_const_tuple(output.shape) - d_shape = topi.util.get_const_tuple(data.shape) - k_shape = topi.util.get_const_tuple(kernel.shape) - o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape - i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape - k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape - # For now we need to assume that input channel blocking is the same - # as the output channel blocking - assert o_blk == i_blk - assert ob_blk == ib_blk - # Make sure that dimensions match - assert o_b == i_b - assert o_blk == ko_blk - assert i_blk == ki_blk - assert k_o == o_c - assert k_i == i_c - # Scale the channel size - i_c *= i_blk - o_c *= o_blk - if pad_data is not None: - p_shape = topi.util.get_const_tuple(pad_data.shape) - h_pad = (p_shape[2] - d_shape[2]) // 2 - w_pad = (p_shape[3] - d_shape[3]) // 2 - else: - h_pad, w_pad = 0, 0 - h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) - w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) - return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) - -_WL2PLAN = {} + return res -def schedule_packed_conv2d(outs): - """ Schedule the packed conv2d. - """ +@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct') +def _schedule_conv2d(cfg, outs): assert len(outs) == 1 output = outs[0] + const_ops = [] ewise_inputs = [] ewise_ops = [] conv2d_res = [] - assert output.dtype == "int8" - assert output.op.input_tensors[0].dtype == "int32" + assert "int" in output.op.input_tensors[0].dtype def _traverse(op): if topi.tag.is_broadcast(op.tag): if not op.same_as(output.op): - ewise_ops.append(op) + if not op.axis: + const_ops.append(op) + else: + ewise_ops.append(op) for tensor in op.input_tensors: if isinstance(tensor.op, tvm.tensor.PlaceholderOp): ewise_inputs.append((op, tensor)) else: _traverse(tensor.op) else: - assert op.tag == "packed_conv2d" + assert op.tag == "conv2d_dense" conv2d_res.append(op) _traverse(output.op) assert len(conv2d_res) == 1 conv2d_stage = conv2d_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis + c_i, _, _, _ = s[conv2d_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_h', x_i, num_outputs=2) + cfg.define_split('tile_w', x_j, num_outputs=2) + cfg.define_split('tile_ci', c_i, num_outputs=2) + cfg.define_split('tile_co', c_o, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + cfg.define_knob('h_nthread', [1, 2]) + ###### space definition end ###### data, kernel = conv2d_stage.op.input_tensors if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: @@ -370,21 +125,8 @@ def _traverse(op): data = temp else: pad_data = None - wrkld = _get_workload(data, pad_data, kernel, output) - if wrkld in _WL2PLAN: - plan = _WL2PLAN[wrkld] - else: - plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] - logging.info("Trying to find plan for %s", wrkld) - env = get_env() - - load_inp = load_wgt = load_out = store_out = env.dma_copy - alu = env.alu - gemm = env.gemm - # schedule1 - oshape = topi.util.get_const_tuple(output.shape) - s = tvm.create_schedule(output.op) + env = get_env() # setup pad if pad_data is not None: @@ -394,27 +136,26 @@ def _traverse(op): cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) s[conv2d_stage].set_scope(env.acc_scope) + # cache read input cache_read_ewise = [] - for consumer, tensor in ewise_inputs: cache_read_ewise.append( s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope for op in ewise_ops: s[op].set_scope(env.acc_scope) - s[op].pragma(s[op].op.axis[0], alu) + s[op].pragma(s[op].op.axis[0], env.alu) - # tile - oc_factor = (plan.oc_factor if plan.oc_factor - else plan.out_filter // env.BLOCK_OUT) - h_factor = (plan.h_factor if plan.h_factor else oshape[2]) - w_factor = (plan.w_factor if plan.w_factor else oshape[3]) + for op in const_ops: + s[op].compute_inline() + # tile x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis - x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) - x_i0, x_i1 = s[output].split(x_i, factor=h_factor) - x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) + x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) + x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) store_pt = x_j0 @@ -425,17 +166,17 @@ def _traverse(op): for tensor in cache_read_ewise: s[tensor].compute_at(s[output], store_pt) - s[tensor].pragma(s[tensor].op.axis[0], load_out) + s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy) # virtual threading along output channel axes - if plan.oc_nthread > 1: - _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) # virtual threading along spatial rows - if plan.h_nthread > 1: - _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + if cfg['h_nthread'].val > 1: + _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) @@ -443,68 +184,14 @@ def _traverse(op): k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) - if plan.ic_factor: - k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) - s[cdata].compute_at(s[conv2d_stage], k_o) - s[ckernel].compute_at(s[conv2d_stage], k_o) + k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) # Use VTA instructions - s[cdata].pragma(s[cdata].op.axis[0], load_inp) - s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) - s[conv2d_stage].tensorize(x_bi, gemm) - s[output].pragma(x_co1, store_out) - return s - -class Conv2DSchedule(object): - """ 2D convolution schedule object. - """ - def __init__(self, - b_factor=1, - oc_factor=1, - ic_factor=1, - h_factor=1, - w_factor=0, - oc_nthread=0, - h_nthread=0, - debug_sync=False): - self.b_factor = b_factor - self.oc_factor = oc_factor - self.ic_factor = ic_factor - self.h_factor = h_factor - self.w_factor = w_factor - self.oc_nthread = oc_nthread - self.h_nthread = h_nthread - self.debug_sync = debug_sync - - def __str__(self): - return "{}.{}.{}.{}.{}.{}.{}".format( - self.b_factor, self.oc_factor, self.ic_factor, - self.h_factor, self.w_factor, - self.oc_nthread, self.h_nthread) + s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy) + s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy) + s[conv2d_stage].tensorize(x_bi, env.gemm) + s[output].pragma(x_co1, env.dma_copy) -Schedule = Conv2DSchedule - -# Layer description of the ResNet18 -RESNET = { - 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), -} - -for idx in RESNET: - f_schedules = find_schedules(RESNET[idx], vt_only=True, best_only=True) - if f_schedules: - scheds = f_schedules[0] - _WL2PLAN[RESNET[idx]] = scheds - else: - logging.warning("No valid schedule was found for the workload on current vta configuration") - break + return s diff --git a/python/vta/top/vta_dense.py b/python/vta/top/vta_dense.py new file mode 100644 index 000000000000..9d6c19c5af20 --- /dev/null +++ b/python/vta/top/vta_dense.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Dense operator declaration and schedule registration for VTA.""" + +import numpy as np +import tvm +from tvm import autotvm +import topi + +from ..environment import get_env + +def is_packed_layout(layout): + """Check if layout is packed layout""" + if layout == "NCHW": + return False + if "n" in layout and "c" in layout: + return True + return False + +@autotvm.register_topi_compute(topi.nn.dense, 'vta', 'direct') +def _declaration_dense(cfg, + data, + weight, + bias=None, + out_dtype=None): + """Dense function declaration.""" + + # Make sure that the dense operator is packed + if len(data.shape) != 4 or len(weight.shape) != 4: + raise topi.InvalidShapeError() + + # Derive shapes + ishape = topi.util.get_const_tuple(data.shape) + wshape = topi.util.get_const_tuple(weight.shape) + oshape = (data.shape[0], weight.shape[0], data.shape[2], weight.shape[2]) + + # Reduction axes (input channel) + assert ishape[1] == wshape[1] + assert ishape[3] == wshape[3] + k_o = tvm.reduce_axis((0, ishape[1]), name='k_o') + k_i = tvm.reduce_axis((0, ishape[3]), name='k_i') + res = tvm.compute( + oshape, + lambda b_o, c_o, b_i, c_i: tvm.sum( + data[b_o, k_o, b_i, k_i].astype(out_dtype) * + weight[c_o, k_o, c_i, k_i].astype(out_dtype), + axis=[k_o, k_i]), + name="res", tag="dense_pack") + + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + ishape[1] * ishape[3]) + + return res + +@autotvm.register_topi_schedule(topi.generic.schedule_dense, 'vta', 'direct') +def _schedule_dense(cfg, outs): + """Packed dense schedule.""" + + assert len(outs) == 1 + output = outs[0] + const_ops = [] + ewise_inputs = [] + ewise_ops = [] + dense_res = [] + assert "int" in output.op.input_tensors[0].dtype + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + if not op.axis: + const_ops.append(op) + else: + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "dense_pack" + dense_res.append(op) + + _traverse(output.op) + assert len(dense_res) == 1 + dense_stage = dense_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, c_o, _, _ = s[dense_stage].op.axis + c_i, _ = s[dense_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_ci', c_i, num_outputs=2) + cfg.define_split('tile_co', c_o, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + ###### space definition end ###### + + data, weight = dense_stage.op.input_tensors + + env = get_env() + + cdata = s.cache_read(data, env.inp_scope, [dense_stage]) + cweight = s.cache_read(weight, env.wgt_scope, [dense_stage]) + s[dense_stage].set_scope(env.acc_scope) + + # cache read input + cache_read_ewise = [] + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], env.alu) + + for op in const_ops: + s[op].compute_inline() + + # apply tiling for SRAM reuse + x_b, x_c, _, _ = s[output].op.axis + x_bo, x_bi = cfg['tile_b'].apply(s, output, x_b) + x_co, x_ci = cfg['tile_co'].apply(s, output, x_c) + s[output].reorder(x_bo, x_co, x_bi, x_ci) + store_pt = x_co + + # set all compute scopes + s[dense_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy) + + # virtual threading along output channel axes + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co, factor=cfg['oc_nthread'].val) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_bi, _ = s[dense_stage].op.axis + k_o, _ = s[dense_stage].op.reduce_axis + s[dense_stage].reorder(x_bo, k_o, x_co) + + k_o, _ = cfg['tile_ci'].apply(s, dense_stage, k_o) + s[cdata].compute_at(s[dense_stage], k_o) + s[cweight].compute_at(s[dense_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy) + s[cweight].pragma(s[cweight].op.axis[0], env.dma_copy) + s[dense_stage].tensorize(x_bi, env.gemm) + s[output].pragma(x_ci, env.dma_copy) + + return s diff --git a/scripts/tune_conv2d.py b/scripts/tune_conv2d.py new file mode 100644 index 000000000000..f55c7e985716 --- /dev/null +++ b/scripts/tune_conv2d.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tuning a single conv2d operator""" + +from collections import namedtuple +import logging +import os + +import tvm +from tvm import autotvm +from tvm.contrib.util import get_lower_ir +import topi +import vta +import vta.testing + +env = vta.get_env() + +Workload = namedtuple("Conv2DWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +resnet_wkls = [ + # Workloads of resnet18 on imagenet + # ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)), + ('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)), + # ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet + ('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)), +] + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype): + data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) + kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + with tvm.target.vta(): + res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation, + layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32') + res = topi.add(res, bias) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + + return s, [data, kernel, bias, res] + +if __name__ == '__main__': + + # Logging config (for printing tuning log to the screen) + logging.basicConfig() + logging.getLogger('autotvm').setLevel(logging.DEBUG) + + # Get tracker info from env + tracket_host = os.environ.get("TVM_TRACKER_HOST", None) + tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + if not tracket_host or not tracket_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + for wl_name, wl in resnet_wkls: + + # Workload parameters + N = wl.batch + CI = wl.in_filter + H = wl.height + W = wl.width + CO = wl.out_filter + KH = wl.hkernel + KW = wl.wkernel + strides = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + dilation = (1, 1) + in_dtype = 'int8' + out_dtype = 'int32' + + task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype), + target=tvm.target.vta(), target_host=env.target_host, template_key='direct') + print(task.config_space) + + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000, + check_correctness=True)) + + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune(n_trial=len(task.config_space), + measure_option=measure_option, + callbacks=[autotvm.callback.log_to_file('conv2d.log')]) + + print("\nBest tuner config:") + print(tuner.best_config) diff --git a/scripts/tune_dense.py b/scripts/tune_dense.py new file mode 100644 index 000000000000..237ca2754512 --- /dev/null +++ b/scripts/tune_dense.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tuning a single dense operator""" + +from collections import namedtuple +import logging +import os + +import tvm +from tvm import autotvm +from tvm.contrib.util import get_lower_ir +import topi +import vta +import vta.testing + +env = vta.get_env() + +Workload = namedtuple("DenseWorkload", + ['batch', 'in_filter', 'out_filter']) + +resnet_wkls = [ + # Workloads of resnet18 on imagenet + ('resnet-18.dense', Workload(16, 512, 1024)), +] + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def dense(N, CI, CO): + data_shape = (N//env.BATCH, CI//env.BLOCK_IN, env.BATCH, env.BLOCK_IN) + kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, env.BLOCK_OUT, env.BLOCK_IN) + + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + with tvm.target.vta(): + res = topi.nn.dense(data, kernel, None, 'int32') + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_dense([res]) + else: + s = tvm.create_schedule([res.op]) + + return s, [data, kernel, res] + +if __name__ == '__main__': + + # Logging config (for printing tuning log to the screen) + logging.basicConfig() + logging.getLogger('autotvm').setLevel(logging.DEBUG) + + # Get tracker info from env + tracket_host = os.environ.get("TVM_TRACKER_HOST", None) + tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + if not tracket_host or not tracket_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + for wl_name, wl in resnet_wkls: + + # Workload parameters + N = wl.batch + CI = wl.in_filter + CO = wl.out_filter + + task = autotvm.task.create(dense, args=(N, CI, CO), + target=tvm.target.vta(), target_host=env.target_host, template_key='direct') + print(task.config_space) + + measure_option = autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, number=4, repeat=3, timeout=10000, + check_correctness=True)) + + tuner = autotvm.tuner.RandomTuner(task) + tuner.tune(n_trial=len(task.config_space), + measure_option=measure_option, + callbacks=[autotvm.callback.log_to_file('dense.log')]) + + print("\nBest tuner config:") + print(tuner.best_config) diff --git a/scripts/tune_resnet.py b/scripts/tune_resnet.py new file mode 100644 index 000000000000..21aa96cd350f --- /dev/null +++ b/scripts/tune_resnet.py @@ -0,0 +1,310 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Perform ResNet autoTVM tuning on VTA using Relay.""" + +import argparse, os, time +from mxnet.gluon.model_zoo import vision +import numpy as np +from PIL import Image + +import topi +import tvm +from tvm import rpc, autotvm, relay +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_runtime, util, download +from tvm.contrib.debugger import debug_runtime +import vta +from vta.testing import simulator +from vta.top import graph_pack +from tvm.autotvm.task import extract_from_program + +def parse_arguments(): + + parser = argparse.ArgumentParser(description='Train a model for image classification.') + parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'], + help='Input model name.') + parser.add_argument('--start-name', type=str, default='nn.max_pool2d', + help='The name of the node where packing starts') + parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d', + help='The name of the node where packing stops') + parser.add_argument('--debug-profile', action='store_true', + help='Show layer-wise time cost profiling results') + parser.add_argument('--device', default='vta', choices=['vta', 'arm_cpu'], + help='Select device target') + parser.add_argument('--measurements', type=int, default=1, + help='Number of measurements during AutoTVM search') + parser.add_argument('--tuner', type=str, default="random", + help='AutoTVM search strategy') + parser.add_argument('--log-filename', type=str, default="resnet-18.log", + help='AutoTVM log file name') + + return parser.parse_args() + + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + return s, [A, W, res] + + @autotvm.task.register("topi_nn_dense", override=True) + def _topi_nn_dense(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.dense(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_dense([res]) + else: + s = tvm.create_schedule([res.op]) + + return s, [A, W, res] + + +def compile_network(opt, env, target): + + # Populate the shape and data type dictionary + dtype_dict = {"data": 'float32'} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(opt.model, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=opt.start_name, + stop_name=opt.stop_name) + relay_prog = relay.ir_pass.fold_constant(relay_prog) + + return relay_prog, params + + +def tune_tasks(tasks, + measure_option, + tuner='xgb', + n_trial=1000, + early_stopping=None, + log_filename='tuning.log', + use_transfer_learning=True, + try_winograd=True): + + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(tsk, loss_type='rank') + elif tuner == 'ga': + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(tsk) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune(n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + +if __name__ == '__main__': + + opt = parse_arguments() + + # Make sure that TVM was compiled with RPC=1 + assert tvm.module.enabled("rpc") + + # Read in VTA environment + env = vta.get_env() + + # Get remote from fleet node + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + # Get remote + if env.TARGET != "sim": + + # Measure build start time + reconfig_start = time.time() + + # Get remote from fleet node + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + + # Report on reconfiguration time + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + + # In simulation mode, host the RPC server locally. + else: + remote = rpc.LocalSession() + + # VTA target and execution context + target = env.target if opt.device == "vta" else env.target_vta_cpu + ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0) + + # Compile Relay program + print("Initial compile...") + relay_prog, params = compile_network(opt, env, target) + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Perform task extraction on Relay program + print("Extracting tasks...") + tasks = extract_from_program(func=relay_prog, + params=params, + ops=(tvm.relay.op.nn.conv2d,), + target=target, + target_host=env.target_host) + + # Perform Autotuning + print("Tuning...") + tuning_opt = { + 'log_filename': opt.log_filename, + 'tuner': opt.tuner, + 'n_trial': 1e9, + 'early_stopping': None, + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port, + number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60, + check_correctness=True)) + } + tune_tasks(tasks, **tuning_opt) + + # Compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[opt.log_filename]): + + # Compile network + print("Compiling network with best tuning parameters...") + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + if target.device_name != "vta": + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + + # Export library + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # If detailed runtime info is needed build with debug runtime + if opt.debug_profile: + m = debug_runtime.create(graph, lib, ctx) + else: + m = graph_runtime.create(graph, lib, ctx) + + # Set the network parameters and synthetic input + image = tvm.nd.array( + (np.random.uniform(size=(1, 3, 224, 224))).astype('float32')) + m.set_input(**params) + m.set_input('data', image) + + # Perform inference + timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements) + tcost = timer() + prof_res = np.array(tcost.results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + + # Display profile information + if opt.debug_profile: + m.run() diff --git a/scripts/tune_resnet_nnvm.py b/scripts/tune_resnet_nnvm.py new file mode 100644 index 000000000000..22a4dd5dfc78 --- /dev/null +++ b/scripts/tune_resnet_nnvm.py @@ -0,0 +1,256 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Perform ResNet autoTVM tuning on VTA using NNVM.""" + +import argparse +import os +import time +import numpy as np + +import tvm +from tvm import rpc, autotvm +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_runtime, util +from tvm.contrib.download import download + +import topi +import nnvm.compiler +import vta +import vta.testing + +env = vta.get_env() + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + return s, [A, W, res] + + + +def generate_graph(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + # Compile NNVM graph + with nnvm.compiler.build_config(opt_level=3): + with vta.build_config(): + graph, lib, params = nnvm.compiler.build( + sym, target, shape_dict, dtype_dict, + params=params, target_host=target_host) + + return graph, lib, params + + +def extract_tasks(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + with vta.build_config(): + tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target, + params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host) + return tasks + + +def download_model(): + url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" + categ_fn = 'synset.txt' + graph_fn = 'resnet18_qt8.json' + params_fn = 'resnet18_qt8.params' + data_dir = '_data' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + for file in [categ_fn, graph_fn, params_fn]: + if not os.path.isfile(file): + download(os.path.join(url, file), os.path.join(data_dir, file)) + + sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read()) + params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read()) + + return sym, params + + +def tune_tasks(tasks, + measure_option, + tuner='xgb', + n_trial=1000, + early_stopping=None, + log_filename='tuning.log', + use_transfer_learning=True, + try_winograd=True): + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(tsk, loss_type='rank') + elif tuner == 'ga': + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(tsk) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune(n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + +if __name__ == '__main__': + + # Get tracker info from env + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + # Download model + sym, params = download_model() + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Extract tasks + print("Extracting tasks...") + target = tvm.target.vta() + target_host = env.target_host + tasks = extract_tasks(sym, params, target, target_host) + + # Perform Autotuning + print("Tuning...") + tuning_opt = { + 'log_filename': 'resnet-18.log', + + 'tuner': 'random', + 'n_trial': 1e9, + 'early_stopping': None, + + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port, + number=4, repeat=3, timeout=60, + check_correctness=True)) + } + tune_tasks(tasks, **tuning_opt) + + # compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]): + + # ResNet parameters + input_shape = (1, 3, 224, 224) + dtype = 'float32'\ + + # Compile network + print("Compiling network with best tuning parameters...") + graph, lib, params = generate_graph(sym, params, target, target_host) + input_shape = (1, 3, 224, 224) + dtype = 'float32' + + # Export library + tmp = util.tempdir() + filename = "net.tar" + lib.export_library(tmp.relpath(filename)) + + # Upload module to device + print("Upload...") + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + remote.upload(tmp.relpath(filename)) + rlib = remote.load_module(filename) + + # Upload parameters to device + ctx = remote.context(str(target), 0) + rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module = graph_runtime.create(graph, rlib, ctx) + module.set_input('data', data_tvm) + module.set_input(**rparams) + + # Evaluate + print("Evaluate inference time cost...") + ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + diff --git a/src/runtime.cc b/src/runtime.cc index 06b34743955f..f44e3cab8a82 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -908,12 +908,10 @@ class CommandQueue { insn_queue_.InitSpace(); device_ = VTADeviceAlloc(); CHECK(device_ != nullptr); - printf("Initialize VTACommandHandle...\n"); } ~CommandQueue() { VTADeviceFree(device_); - printf("Close VTACommandhandle...\n"); } uint32_t GetElemBytes(uint32_t memory_id) { diff --git a/src/sim/sim_driver.cc b/src/sim/sim_driver.cc index 5f9f6b637599..0691195f140e 100644 --- a/src/sim/sim_driver.cc +++ b/src/sim/sim_driver.cc @@ -35,6 +35,11 @@ namespace vta { namespace sim { +/*! \brief debug flag for skipping computation */ +enum DebugFlagMask { + kSkipExec = 1 +}; + /*! * \brief Helper class to pack and unpack bits * Applies truncation when pack to low level bits. @@ -253,8 +258,12 @@ class SRAM { return &(data_[index]); } // Execute the load instruction on this SRAM - void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) { + void Load(const VTAMemInsn* op, + DRAM* dram, + uint64_t* load_counter, + bool skip_exec) { load_counter[0] += (op->x_size * op->y_size) * kElemBytes; + if (skip_exec) return; DType* sram_ptr = data_ + op->sram_base; uint8_t* dram_ptr = static_cast(dram->GetAddr( op->dram_base * kElemBytes)); @@ -325,6 +334,8 @@ class Profiler { uint64_t gemm_counter{0}; /*! \brief instr counter for ALU ops */ uint64_t alu_counter{0}; + /*! \brief set debug mode */ + int64_t debug_flag{0}; /*! \brief clear the profiler */ void Clear() { inp_load_nbytes = 0; @@ -335,6 +346,10 @@ class Profiler { gemm_counter = 0; alu_counter = 0; } + /*! \return Whether we should skip execution. */ + bool SkipExec() const { + return (debug_flag & DebugFlagMask::kSkipExec) != 0; + } std::string AsJSON() { std::ostringstream os; @@ -398,13 +413,15 @@ class Device { void RunLoad(const VTAMemInsn* op) { if (op->x_size == 0) return; if (op->memory_type == VTA_MEM_ID_INP) { - inp_.Load(op, dram_, &(prof_->inp_load_nbytes)); + inp_.Load(op, dram_, &(prof_->inp_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_WGT) { - wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes)); + wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_ACC) { - acc_.Load(op, dram_, &(prof_->acc_load_nbytes)); + acc_.Load(op, dram_, &(prof_->acc_load_nbytes), prof_->SkipExec()); } else if (op->memory_type == VTA_MEM_ID_UOP) { - uop_.Load(op, dram_, &(prof_->uop_load_nbytes)); + // always load in uop, since uop is stateful + // subsequent non-debug mode exec can depend on it. + uop_.Load(op, dram_, &(prof_->uop_load_nbytes), false); } else { LOG(FATAL) << "Unknown memory_type=" << op->memory_type; } @@ -416,7 +433,9 @@ class Device { op->memory_type == VTA_MEM_ID_UOP) { prof_->out_store_nbytes += ( op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8); - acc_.TruncStore(op, dram_); + if (!prof_->SkipExec()) { + acc_.TruncStore(op, dram_); + } } else { LOG(FATAL) << "Store do not support memory_type=" << op->memory_type; @@ -425,7 +444,8 @@ class Device { void RunGEMM(const VTAGemInsn* op) { if (!op->reset_reg) { - prof_->gemm_counter += op->iter_out * op->iter_in; + prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn); + if (prof_->SkipExec()) return; for (uint32_t y = 0; y < op->iter_out; ++y) { for (uint32_t x = 0; x < op->iter_in; ++x) { for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) { @@ -459,6 +479,7 @@ class Device { } } } else { + if (prof_->SkipExec()) return; // reset for (uint32_t y = 0; y < op->iter_out; ++y) { for (uint32_t x = 0; x < op->iter_in; ++x) { @@ -477,7 +498,6 @@ class Device { } void RunALU(const VTAAluInsn* op) { - prof_->alu_counter += op->iter_out * op->iter_in; if (op->use_imm) { RunALU_(op); } else { @@ -520,6 +540,8 @@ class Device { template void RunALULoop(const VTAAluInsn* op, F func) { + prof_->alu_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn); + if (prof_->SkipExec()) return; for (int y = 0; y < op->iter_out; ++y) { for (int x = 0; x < op->iter_in; ++x) { for (int k = op->uop_bgn; k < op->uop_end; ++k) { @@ -566,6 +588,10 @@ TVM_REGISTER_GLOBAL("vta.simulator.profiler_status") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = Profiler::ThreadLocal()->AsJSON(); }); +TVM_REGISTER_GLOBAL("vta.simulator.profiler_debug_mode") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Profiler::ThreadLocal()->debug_flag = args[0]; + }); } // namespace sim } // namespace vta diff --git a/tests/python/integration/test_benchmark_topi_conv2d.py b/tests/python/integration/test_benchmark_topi_conv2d.py index 8a03cb020260..2aec47118e44 100644 --- a/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/tests/python/integration/test_benchmark_topi_conv2d.py @@ -14,7 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Testing if we can generate code in topi style""" + +"""Testing topi conv2d operator for VTA""" + +import os +import json +from collections import namedtuple + +import numpy as np import tvm from tvm import autotvm @@ -23,11 +30,32 @@ import topi import topi.testing import vta +from vta import program_fpga, reconfig_runtime import vta.testing -import numpy as np - -Workload = vta.top.vta_conv2d.Workload - +from vta.testing import simulator + +Workload = namedtuple("Conv2DWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +# ResNet18 workloads +resnet_wkls = [ + # Workloads of resnet18 on imagenet + # ('resnet-18.C1', Workload(1, 224, 224, 3, 64, 7, 7, 3, 3, 2, 2)), + ('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)), + # ('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this layer does not appear in ResNet + ('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)), +] + +# FIXME: we need a custom clip operator to circumvent a pattern detection limitation @tvm.tag_scope(tag=topi.tag.ELEMWISE) def my_clip(x, a_min, a_max): """Unlike topi's current clip, put min and max into two stages.""" @@ -37,249 +65,168 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x -def test_cpu_conv2d(): - def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True): - data_shape = (batch_size, wl.in_filter, wl.height, wl.width) - kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) - res_conv = topi.nn.conv2d( - data, kernel, padding=(wl.hpad, wl.wpad), - strides=(wl.hstride, wl.wstride), - dilation=(1, 1), - out_dtype="int32") - res = topi.right_shift(res_conv, 8) - res = my_clip(res, 0, 127) - res = topi.cast(res, "int8") - - # To compute number of ops, use a x2 factor for FMA - num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter - - a_shape = (batch_size, wl.in_filter, wl.height, wl.width) - w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - stride = (wl.hstride, wl.wstride) - data_dtype = data.dtype - kernel_dtype = kernel.dtype - acc_dtype = env.acc_dtype - assert wl.hpad == wl.wpad - padding = wl.hpad - - @memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc") - def get_ref_data(): - a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) - w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) - a_np = np.abs(a_np) - w_np = np.abs(w_np) - b_np = topi.testing.conv2d_nchw_python( - a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) - return a_np, w_np, b_np - - - def verify(s, check_correctness): - mod = tvm.build(s, [data, kernel, res], - target_host=env.target_host, - name="conv2d") - temp = util.tempdir() - mod.save(temp.relpath("conv2d.o")) - remote.upload(temp.relpath("conv2d.o")) - f = remote.load_module("conv2d.o") - # verify - ctx = remote.cpu(0) - # Data in original format - data_orig, kernel_orig, res_ref = get_ref_data() - res_shape = topi.util.get_const_tuple(res.shape) - res_np = np.zeros(res_shape).astype(res.dtype) - data_arr = tvm.nd.array(data_orig, ctx) - kernel_arr = tvm.nd.array(kernel_orig, ctx) - res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d", ctx, number=5) - cost = time_f(data_arr, kernel_arr, res_arr) - res_unpack = res_arr.asnumpy() - if check_correctness: - assert wl.hpad == wl.wpad - stride = (wl.hstride, wl.wstride) - padding = wl.hpad - res_ref = res_ref >> 8 - res_ref = np.clip(res_ref, 0, 127).astype("int8") - tvm.testing.assert_allclose(res_unpack, res_ref) - return cost - - def conv_normal(print_ir): - print("----- CONV2D CPU End-to-End Test-------") - s = topi.generic.schedule_conv2d_nchw([res]) - if print_ir: - print(tvm.lower(s, [data, kernel, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) - - conv_normal(False) - - def _run(env, remote): - # ResNet18 workloads - resnet = { - # Workloads of resnet18 on imagenet - 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - } - batch_size = 1 - for i in range(1, len(resnet)): - wl = resnet[i] - key = "resnet-cfg[%d]" % i - print("key=%s" % key) - print(wl) - with tvm.target.create("llvm -device=vtacpu"): - run_cpu_conv2d(env, remote, key, batch_size, wl) - - # load pre-tuned operator parameters for ARM CPU - autotvm.tophub.check_backend('vta') - with autotvm.tophub.context('llvm -device=vtacpu'): - vta.testing.run(_run) - - -def test_vta_conv2d(): - def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): - data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN, - wl.height, wl.width, env.BATCH, env.BLOCK_IN) +def run_conv2d(env, remote, wl, target, + check_correctness=True, print_ir=False, + samples=4): + + # Workload assertions + assert wl.hpad == wl.wpad + + # Perform packing only if we are targeting the accelerator + if "arm_cpu" in target.keys: + data_pack = False + layout = "NCHW" + elif "vta" in target.keys: + data_pack = True + layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN) + + # Derive shapes depending upon packing + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + b_shape = (wl.batch, wl.out_filter, 1, 1) + if data_pack: + data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) - bias_shape = (1, wl.out_filter//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) - - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) - bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) - - res_conv = vta.top.packed_conv2d( - data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) - res = topi.right_shift(res_conv, 8) + bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + else: + data_shape = a_shape + kernel_shape = w_shape + bias_shape = b_shape + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + + # Define base computation schedule + with target: + res = topi.nn.conv2d( + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1), + layout, env.acc_dtype) + res = topi.right_shift(res, 8) res = topi.add(res, bias) - res = my_clip(res, 0, 127) - res = topi.cast(res, "int8") - - # To compute number of ops, use a x2 factor for FMA - num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter - - a_shape = (batch_size, wl.in_filter, wl.height, wl.width) - w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - stride = (wl.hstride, wl.wstride) - data_dtype = data.dtype - kernel_dtype = kernel.dtype - acc_dtype = env.acc_dtype - assert wl.hpad == wl.wpad - padding = wl.hpad - - @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") - def get_ref_data(): - a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) - w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) - a_np = np.abs(a_np) - w_np = np.abs(w_np) - b_np = topi.testing.conv2d_nchw_python( - a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) - return a_np, w_np, b_np - - def verify(s, check_correctness): - mod = vta.build(s, [data, kernel, bias, res], "ext_dev", - env.target_host, name="conv2d") - temp = util.tempdir() - - mod.save(temp.relpath("conv2d.o")) - remote.upload(temp.relpath("conv2d.o")) - f = remote.load_module("conv2d.o") - # verify - ctx = remote.ext_dev(0) - # Data in original format - data_orig, kernel_orig, res_ref = get_ref_data() - bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") - bias_orig = np.abs(bias_orig) - - data_packed = data_orig.reshape( - batch_size//env.BATCH, env.BATCH, - wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, - wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) - kernel_packed = kernel_orig.reshape( - wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, - wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, - wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) - bias_packed = bias_orig.reshape( - 1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) - res_shape = topi.util.get_const_tuple(res.shape) - - res_np = np.zeros(res_shape).astype(res.dtype) - data_arr = tvm.nd.array(data_packed, ctx) - kernel_arr = tvm.nd.array(kernel_packed, ctx) - bias_arr = tvm.nd.array(bias_packed, ctx) - res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d", ctx, number=5) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + # Derive base schedule + s = topi.generic.schedule_conv2d_nchw([res]) + if print_ir: + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) + + # Derive number of ops + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") + def get_ref_data(): + # derive min max for act, wgt, and bias types (max non inclusive) + a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) + w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) + b_min, b_max = 0 - 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2), 1 << (env.INP_WIDTH + env.WGT_WIDTH - 2) + a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype) + w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype) + b_np = np.random.randint(b_min, b_max, size=b_shape).astype(env.acc_dtype) + r_np = topi.testing.conv2d_nchw_python( + a_np.astype(env.acc_dtype), w_np.astype(env.acc_dtype), (wl.hstride, wl.wstride), wl.hpad).astype(env.acc_dtype) + return a_np, w_np, b_np, r_np + + # Data in original format + data_np, kernel_np, bias_np, res_ref = get_ref_data() + if data_pack: + data_np = data_np.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_np = kernel_np.reshape( + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_np = bias_np.reshape( + wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + + # Build + if "vta" in target.keys: + mod = vta.build(s, [data, kernel, bias, res], + target=target, + target_host=env.target_host, + name="conv2d") + else: + mod = tvm.build(s, [data, kernel, bias, res], + target=target, + target_host=env.target_host, + name="conv2d") + temp = util.tempdir() + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + ctx = remote.context(str(target)) + + res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype) + data_arr = tvm.nd.array(data_np, ctx) + kernel_arr = tvm.nd.array(kernel_np, ctx) + bias_arr = tvm.nd.array(bias_np, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=samples) + + # In vta sim mode, collect simulator runtime statistics + stats = {} + cost = None + if env.TARGET == "sim": + # Check if we're in local RPC mode (allows us to rebuild the + # runtime on the fly when varying the VTA designs) + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: + remote.get_function("vta.simulator.profiler_clear")() cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) - res_unpack = res_arr.asnumpy().transpose( - (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) - if check_correctness: - assert wl.hpad == wl.wpad - stride = (wl.hstride, wl.wstride) - padding = wl.hpad - res_ref = res_ref >> 8 - res_ref += bias_orig.reshape(wl.out_filter, 1, 1) - res_ref = np.clip(res_ref, 0, 127).astype("int8") - tvm.testing.assert_allclose(res_unpack, res_ref) - return cost - - def conv_normal(print_ir): - print("----- CONV2D End-to-End Test-------") - with vta.build_config(): - s = vta.top.schedule_packed_conv2d([res]) - if print_ir: - print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) - - conv_normal(False) - + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + else: + simulator.clear_stats() + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = simulator.stats() + else: + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + + # Check correctness + correct = False + if check_correctness: + res_orig = res_arr.asnumpy() + if data_pack: + res_orig = res_orig.transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) + res_ref = res_ref >> 8 + res_ref += bias_np.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) + res_ref = res_ref.astype(env.out_dtype) + correct = np.allclose(res_orig, res_ref) + + gops = (num_ops / cost.mean) / float(10 ** 9) + status = "PASSED" if correct else "FAILED" + if "arm_cpu" in target.keys: + device = "CPU" + elif "vta" in target.keys: + device = "VTA" + print("%s CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops)) + + return correct, cost, stats + +def test_conv2d(device="vta"): def _run(env, remote): - # ResNet18 workloads - resnet = { - # Workloads of resnet18 on imagenet - 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - } - - batch_size = 1 - for i in range(0, len(resnet)): - wl = resnet[i] - key = "resnet-cfg[%d]" % i - print("key=%s" % key) - print(wl) - run_vta_conv2d(env, remote, key, batch_size, wl) - + if device == "vta": + target = env.target + if env.TARGET != "sim": + assert tvm.module.enabled("rpc") + program_fpga(remote, bitstream=None) + reconfig_runtime(remote) + elif device == "arm_cpu": + target = env.target_vta_cpu + with autotvm.tophub.context(target): # load pre-tuned schedule parameters + for _, wl in resnet_wkls: + print(wl) + run_conv2d(env, remote, wl, target) vta.testing.run(_run) - if __name__ == "__main__": - test_cpu_conv2d() - test_vta_conv2d() + test_conv2d(device="arm_cpu") + test_conv2d(device="vta") diff --git a/tests/python/integration/test_benchmark_topi_dense.py b/tests/python/integration/test_benchmark_topi_dense.py new file mode 100644 index 000000000000..12fbc45c1c4b --- /dev/null +++ b/tests/python/integration/test_benchmark_topi_dense.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing topi gemm operator for VTA""" + +import os +import json +from collections import namedtuple + +import numpy as np + +import tvm +from tvm import autotvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +from vta import program_fpga, reconfig_runtime +import vta.testing +from vta.testing import simulator + +# FIXME: we need a custom clip operator to circumvent a pattern detection limitation +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +def run_gemm(env, remote, target, + batch_size, in_feat, out_feat, + check_correctness=True, print_ir=True, + samples=4): + + # Perform packing only if we are targeting the accelerator + if "arm_cpu" in target.keys: + data_pack = False + elif "vta" in target.keys: + data_pack = True + + # Derive shapes depending upon packing + a_shape = (batch_size, in_feat) + w_shape = (out_feat, in_feat) + if data_pack: + data_shape = (batch_size//env.BATCH, in_feat//env.BLOCK_IN, + env.BATCH, env.BLOCK_IN) + kernel_shape = (out_feat//env.BLOCK_OUT, in_feat//env.BLOCK_IN, + env.BLOCK_OUT, env.BLOCK_IN) + else: + data_shape = a_shape + kernel_shape = w_shape + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + # Define base computation schedule + with target: + res = topi.nn.dense( + data, kernel, out_dtype=env.acc_dtype) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) + res = topi.cast(res, env.out_dtype) + # Derive base schedule + s = topi.generic.schedule_dense([res]) + if print_ir: + print(vta.lower(s, [data, kernel, res], simple_mode=True)) + + # Derive number of ops + num_ops = 2 * batch_size * in_feat * out_feat + + # @memoize("vta.tests.test_benchmark_topi.dense.verify") + def get_ref_data(): + # derive min max for act, wgt types (max non inclusive) + a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) + w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) + a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype) + w_np = np.random.randint(w_min, w_max, size=w_shape).astype(kernel.dtype) + + r_np = np.dot(a_np.astype(env.acc_dtype), w_np.T.astype(env.acc_dtype)).astype(env.acc_dtype) + return a_np, w_np, r_np + + # Data in original format + data_np, kernel_np, res_ref = get_ref_data() + if data_pack: + data_np = data_np.reshape( + batch_size//env.BATCH, env.BATCH, + in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) + kernel_np = kernel_np.reshape( + out_feat//env.BLOCK_OUT, env.BLOCK_OUT, + in_feat//env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3)) + + # Build + if "vta" in target.keys: + mod = vta.build(s, [data, kernel, res], + target=target, + target_host=env.target_host, + name="dense") + else: + mod = tvm.build(s, [data, kernel, res], + target=target, + target_host=env.target_host, + name="dense") + temp = util.tempdir() + mod.save(temp.relpath("dense.o")) + remote.upload(temp.relpath("dense.o")) + f = remote.load_module("dense.o") + ctx = remote.context(str(target)) + + res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype) + data_arr = tvm.nd.array(data_np, ctx) + kernel_arr = tvm.nd.array(kernel_np, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("dense", ctx, number=samples) + + # In vta sim mode, collect simulator runtime statistics + stats = {} + cost = None + if env.TARGET == "sim": + # Check if we're in local RPC mode (allows us to rebuild the + # runtime on the fly when varying the VTA designs) + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: + remote.get_function("vta.simulator.profiler_clear")() + cost = time_f(data_arr, kernel_arr, res_arr) + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + else: + simulator.clear_stats() + cost = time_f(data_arr, kernel_arr, res_arr) + stats = simulator.stats() + else: + cost = time_f(data_arr, kernel_arr, res_arr) + + # Check correctness + correct = False + if check_correctness: + res_orig = res_arr.asnumpy() + if data_pack: + res_orig = res_orig.reshape(batch_size, out_feat) + res_ref = res_ref >> 8 + res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1) + res_ref = res_ref.astype(env.out_dtype) + correct = np.allclose(res_orig, res_ref) + + gops = (num_ops / cost.mean) / float(10 ** 9) + status = "PASSED" if correct else "FAILED" + if "arm_cpu" in target.keys: + device = "CPU" + elif "vta" in target.keys: + device = "VTA" + print("%s DENSE TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops)) + + return correct, cost, stats + +def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128): + def _run(env, remote): + if device == "vta": + target = env.target + if env.TARGET != "sim": + assert tvm.module.enabled("rpc") + program_fpga(remote, bitstream=None) + reconfig_runtime(remote) + elif device == "arm_cpu": + target = env.target_vta_cpu + with autotvm.tophub.context(target): # load pre-tuned schedule parameters + run_gemm(env, remote, target, batch, in_feat, out_feat) + vta.testing.run(_run) + +if __name__ == "__main__": + test_gemm("vta", 16, 512, 1008) diff --git a/tutorials/README.txt b/tutorials/README.txt index 1ba48b0b1fad..3d3858b111ba 100644 --- a/tutorials/README.txt +++ b/tutorials/README.txt @@ -1,2 +1,3 @@ VTA Tutorials ============= +This page contains tutorials about VTA and how to use TVM/Relay to target VTA. diff --git a/tutorials/autotvm/README.txt b/tutorials/autotvm/README.txt new file mode 100644 index 000000000000..c511381dd57d --- /dev/null +++ b/tutorials/autotvm/README.txt @@ -0,0 +1,3 @@ +Auto tuning +------------- + diff --git a/tutorials/autotvm/tune_relay_vta.py b/tutorials/autotvm/tune_relay_vta.py new file mode 100644 index 000000000000..bdeb6c5d03e2 --- /dev/null +++ b/tutorials/autotvm/tune_relay_vta.py @@ -0,0 +1,468 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-tuning a convolutional network on VTA +========================================== +**Author**: `Lianmin Zheng `_, `Thierry Moreau `_ + +Auto-tuning for a specific accelerator design is critical for getting the best +performance for any given operator. This is a tutorial showcases how to tune a +whole convolutional network on VTA. + +The operator implementation for VTA in TVM is written in template form. +The template has many tunable knobs (tile factor, virtual threads, etc). +We will tune all convolution operators in the neural network. After tuning, +we produce a log file which stores the best schedule parameters for all tuned +operators. When the TVM compiler compiles these operators, it will query this +log file to get the best knob parameters. + +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use the autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user psutil xgboost tornado mxnet requests pillow +# +# To make TVM run faster during tuning, it is recommended to use cython +# as FFI of TVM. In the root directory of TVM, execute +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user cython +# sudo make cython3 +# +# Now return to python code. Import packages. + +import os +from mxnet.gluon.model_zoo import vision +import numpy as np +from PIL import Image + +import topi +import tvm +from tvm import rpc, autotvm, relay +from tvm.contrib import graph_runtime, util, download +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner + +import vta +from vta.testing import simulator +from vta.top import graph_pack + +################################################################# +# Compile network +# --------------- +# Perform vta-specific compilation with Relay from a Gluon model + +def compile_network(env, target, model, start_pack, stop_pack): + + # Populate the shape and data type dictionary + dtype_dict = {"data": 'float32'} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=start_pack, + stop_name=stop_pack) + relay_prog = relay.ir_pass.fold_constant(relay_prog) + + return relay_prog, params + + +################################################################# +# Start RPC Tracker +# ----------------- +# TVM uses an RPC session to communicate with Pynq boards. +# During tuning, the tuner will send the generated code to the board and +# measure the speed of code on the board. +# +# To scale up tuning, TVM uses an RPC Tracker to manage multiple devices. +# The RPC Tracker is a centralized master node. We can register all devices to +# the tracker. For example, if we have 10 Pynq boards, we can register all of them +# to the tracker, and run 10 measurements in parallel, accelerating the tuning process. +# +# To start an RPC tracker, run this command on the host machine. The tracker is +# required during the whole tuning process, so we need to open a new terminal for +# this command: +# +# .. code-block:: bash +# +# python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190 +# +# The expected output is: +# +# .. code-block:: bash +# +# INFO:RPCTracker:bind to 0.0.0.0:9190 + +################################################################# +# Register devices to RPC Tracker +# ----------------------------------- +# Now we can register our devices to the tracker. The first step is to +# build the TVM runtime for the Pynq devices. +# +# Follow `this section `_ +# to build the TVM runtime on the device. Then register the device to the tracker with: +# +# .. code-block:: bash +# +# python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=pynq +# +# (replace :code:`[HOST_IP]` with the IP address of your host machine) +# +# After registering devices, we can confirm it by querying the rpc_tracker: +# +# .. code-block:: bash +# +# python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190 +# +# For example, if we have 6 Pynq boards and 11 Raspberry Pi 3B, +# the output can be +# +# .. code-block:: bash +# +# Queue Status +# ---------------------------------- +# key total free pending +# ---------------------------------- +# pynq 6 6 0 +# rpi3b 11 11 0 +# ---------------------------------- +# +# You can register multiple devices to the tracker to accelerate tuning. + +########################################### +# Set Tuning Options +# ------------------ +# Before tuning, we should apply some configurations. +# Here we use an Pynq-Z1 board as an example. + +# Tracker host and port can be set by your environment +tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0') +tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) + +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() + +# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu + +# Name of Gluon model to compile +# The ``start_pack`` and ``stop_pack`` labels indicate where +# to start and end the graph packing relay pass: in other words +# where to start and finish offloading to VTA. +network = "resnet18_v1" +start_pack="nn.max_pool2d" +stop_pack="nn.global_avg_pool2d" + +# Tuning option +log_file = "%s.%s.log" % (device, network) +tuning_option = { + 'log_filename': log_file, + + 'tuner': 'random', + 'n_trial': 1000, + 'early_stopping': None, + + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(), + runner=autotvm.RPCRunner( + env.TARGET, host=tracker_host, port=tracker_port, + number=5, + timeout=60, + check_correctness=True + ), + ), +} + +#################################################################### +# +# .. note:: How to set tuning options +# +# In general, the default values provided here work well. +# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` +# to larger values, makes the tuning run for longer. +# If your device is under-powered or your conv2d operators are large, consider +# setting a longer timeout. +# + +################################################################### +# Begin Tuning +# ------------ +# Now we can extract tuning tasks from the network and begin tuning. +# Here, we provide a simple utility function to tune a list of tasks. +# This function is just an initial implementation which tunes them in sequential order. +# We will introduce a more sophisticated tuning scheduler in the future. +# +# Given that the tuning will be done on Pynq FPGA boards, make sure that +# the ```TARGET`` entry in the ``vta_config.json`` file is set to ``pynq``. + +# You can skip the implementation of this function for this tutorial. +def tune_tasks(tasks, + measure_option, + tuner='xgb', + n_trial=1000, + early_stopping=None, + log_filename='tuning.log', + use_transfer_learning=True): + + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(tsk, loss_type='rank') + elif tuner == 'xgb_knob': + tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob') + elif tuner == 'ga': + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(tsk) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)), + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + + +######################################################################## +# Register VTA-specific tuning tasks + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + return s, [A, W, res] + + +######################################################################## +# Finally, we launch tuning jobs and evaluate the end-to-end performance. + +def tune_and_evaluate(tuning_opt): + + if env.TARGET != "sim": + # Get remote from fleet node + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + # Reconfigure the JIT runtime and FPGA. + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + else: + # In simulation mode, host the RPC server locally. + remote = rpc.LocalSession() + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Perform task extraction on Relay program + print("Extract tasks...") + relay_prog, params = compile_network(env, target, network, start_pack, stop_pack) + tasks = autotvm.task.extract_from_program(func=relay_prog, + params=params, + ops=(tvm.relay.op.nn.conv2d,), + target=target, + target_host=env.target_host) + + # We should have extracted 10 convolution tasks + assert len(tasks) == 10 + print("Extracted {} conv2d tasks:".format(len(tasks))) + for tsk in tasks: + print("\t{}".format(tsk)) + + # We do not run the tuning in our webpage server since it takes too long. + # Comment the following line to run it by yourself. + return + + # run tuning tasks + print("Tuning...") + tune_tasks(tasks, **tuning_opt) + + # compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[log_file]): + # Compile network + print("Compile...") + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + if target.device_name != "vta": + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + + # Export library + print("Upload...") + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # Generate the graph runtime + ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) + m = graph_runtime.create(graph, lib, ctx) + + # upload parameters to device + image = tvm.nd.array( + (np.random.uniform(size=(1, 3, 224, 224))).astype('float32')) + m.set_input(**params) + m.set_input('data', image) + + # evaluate + print("Evaluate inference time cost...") + timer = m.module.time_evaluator("run", ctx, number=1, repeat=10) + tcost = timer() + prof_res = np.array(tcost.results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + +# Run the tuning and evaluate the results +tune_and_evaluate(tuning_option) + +###################################################################### +# Sample Output +# ------------- +# The tuning needs to compile many programs and extract feature from them. +# So a high performance CPU is recommended. +# One sample output is listed below. +# It takes about 2 hours on a 16T CPU, and 6 Pynq boards. +# +# .. code-block:: bash +# +# Extract tasks... +# [Warning] Invalid shape during AutoTVM task creation +# Extracted 10 conv2d tasks: +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (32, 16, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (32, 16, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (16, 8, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (16, 8, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (8, 4, 1, 1, 16, 16), 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (8, 4, 1, 1, 16, 16, 'int8'), (2, 2), (0, 0), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (4, 4, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (4, 4, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (8, 8, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (8, 8, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 4, 56, 56, 1, 16), 'int8'), ('TENSOR', (8, 4, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 4, 56, 56, 1, 16, 'int8'), (8, 4, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (16, 16, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (16, 16, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 8, 28, 28, 1, 16), 'int8'), ('TENSOR', (16, 8, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 8, 28, 28, 1, 16, 'int8'), (16, 8, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 32, 7, 7, 1, 16), 'int8'), ('TENSOR', (32, 32, 3, 3, 16, 16), 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 32, 7, 7, 1, 16, 'int8'), (32, 32, 3, 3, 16, 16, 'int8'), (1, 1), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Task(func_name=topi_nn_conv2d, args=(('TENSOR', (1, 16, 14, 14, 1, 16), 'int8'), ('TENSOR', (32, 16, 3, 3, 16, 16), 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32'), kwargs={}, workload=('conv2d', (1, 16, 14, 14, 1, 16, 'int8'), (32, 16, 3, 3, 16, 16, 'int8'), (2, 2), (1, 1), (1, 1), 'NCHW1n16c', 'int32')) +# Tuning... +# [Task 1/10] Current/Best: 0.72/ 23.24 GFLOPS | Progress: (480/1000) | 640.31 s Done. +# [Task 2/10] Current/Best: 0.00/ 27.69 GFLOPS | Progress: (576/1000) | 810.09 s Done. +# [Task 3/10] Current/Best: 0.00/ 22.97 GFLOPS | Progress: (1000/1000) | 1125.37 s Done. +# [Task 4/10] Current/Best: 0.00/ 31.26 GFLOPS | Progress: (1000/1000) | 1025.52 s Done. +# [Task 5/10] Current/Best: 0.00/ 15.15 GFLOPS | Progress: (1000/1000) | 1236.58 s Done. +# [Task 6/10] Current/Best: 0.00/ 22.74 GFLOPS | Progress: (1000/1000) | 906.60 s Done. +# [Task 7/10] Current/Best: 0.00/ 15.27 GFLOPS | Progress: (1000/1000) | 1056.25 s Done. +# [Task 8/10] Current/Best: 0.00/ 2.18 GFLOPS | Progress: (1000/1000) | 2275.29 s Done. +# [Task 9/10] Current/Best: 2.23/ 3.99 GFLOPS | Progress: (1000/1000) | 2527.25 s Done. +# [Task 10/10] Current/Best: 1.56/ 6.32 GFLOPS | Progress: (480/1000) | 1304.84 s Done. +# Compile... +# Upload... +# Evaluate inference time cost... +# Mean inference time (std dev): 621.79 ms (0.14 ms) + +###################################################################### +# +# .. note:: **Experiencing Difficulties?** +# +# The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS", +# then there must be something wrong. +# +# First, make sure you set the correct configuration of your device. +# Then, you can print debug information by adding these lines in the beginning +# of the script. It will print every measurement result, where you can find useful +# error messages. +# +# .. code-block:: python +# +# import logging +# logging.getLogger('autotvm').setLevel(logging.DEBUG) +# +# Finally, always feel free to ask our community for help on https://discuss.tvm.ai diff --git a/tutorials/frontend/README.txt b/tutorials/frontend/README.txt new file mode 100644 index 000000000000..319506d21f8f --- /dev/null +++ b/tutorials/frontend/README.txt @@ -0,0 +1,4 @@ +.. _tutorial-frontend: + +Compile Deep Learning Models +---------------------------- diff --git a/tutorials/frontend/deploy_resnet_on_vta.py b/tutorials/frontend/deploy_resnet_on_vta.py new file mode 100644 index 000000000000..271630e69558 --- /dev/null +++ b/tutorials/frontend/deploy_resnet_on_vta.py @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy Pretrained ResNet Model from MxNet on VTA +================================================ +**Author**: `Thierry Moreau `_ + +This tutorial provides an end-to-end demo, on how to run ResNet-18 inference +onto the VTA accelerator design to perform ImageNet classification tasks. +It showcases Relay as a front end compiler that can perform quantization (VTA +only supports int8/32 inference) as well as graph packing (in order to enable +tensorization in the core) to massage the compute graph for the hardware target. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use the autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user mxnet requests pillow +# +# Now return to the python code. Import packages. + +from __future__ import absolute_import, print_function + +import argparse, json, os, requests, time +from io import BytesIO +from os.path import join, isfile +from PIL import Image + +from mxnet.gluon.model_zoo import vision +import numpy as np +from matplotlib import pyplot as plt + +import tvm +from tvm import rpc, autotvm, relay +from tvm.contrib import graph_runtime, util, download +from tvm.contrib.debugger import debug_runtime + +import vta +from vta.testing import simulator +from vta.top import graph_pack + +# Make sure that TVM was compiled with RPC=1 +assert tvm.module.enabled("rpc") + + +###################################################################### +# Define the platform and model targets +# ------------------------------------- +# Execute on CPU vs. VTA, and define the model. + +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() + +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu + +# Name of Gluon model to compile +# The ``start_pack`` and ``stop_pack`` labels indicate where +# to start and end the graph packing relay pass: in other words +# where to start and finish offloading to VTA. +model = "resnet18_v1" +start_pack="nn.max_pool2d" +stop_pack="nn.global_avg_pool2d" + +###################################################################### +# Obtain an execution remote +# --------------------------------- +# When target is 'pynq', reconfigure FPGA and runtime. +# Otherwise, if target is 'sim', execute locally. + +if env.TARGET != "sim": + + # Get remote from tracker node if environment variable is set. + # To set up the tracker, you'll need to follow the "Auto-tuning + # a convolutional network for VTA" tutorial. + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + # Otherwise if you have a device you want to program directly from + # the host, make sure you've set the variables below to the IP of + # your board. + device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") + device_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) + if not tracker_host or not tracker_port: + remote = rpc.connect(device_host, device_port) + else: + remote = autotvm.measure.request_remote(env.TARGET, tracker_host, tracker_port, timeout=10000) + + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + reconfig_start = time.time() + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + +# In simulation mode, host the RPC server locally. +else: + remote = rpc.LocalSession() + +# Get execution context from remote +ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) + +###################################################################### +# Build the inference graph runtime +# --------------------------------- +# Grab ResNet-18 model from Gluon model zoo and compile with Relay. +# The compilation steps are: +# 1) Front end translation from MxNet into Relay module. +# 2) Apply 8-bit quantization: here we skip the first conv layer, +# and dense layer which will both be executed in fp32 on the CPU. +# 3) Perform graph packing to alter the data layout for tensorization. +# 4) Perform constant folding to reduce number of operators (e.g. eliminate +# batch norm multiply). +# 5) Perform relay build to object file. +# 6) Load the object file onto remote (FPGA device). +# 7) Generate graph runtime, `m`. + +# Load pre-configured AutoTVM schedules +with autotvm.tophub.context(target): + + # Populate the shape and data type dictionary for ResNet input + dtype_dict = {"data": 'float32'} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(model, pretrained=True) + + # Measure build start time + build_start = time.time() + + # Start front end compilation + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod[mod.entry_func], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=start_pack, + stop_name=stop_pack) + relay_prog = relay.ir_pass.fold_constant(relay_prog) + + # Compile Relay program with AlterOpLayout disabled + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + if target.device_name != "vta": + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + + # Measure Relay build time + build_time = time.time() - build_start + print(model + " inference graph built in {0:.2f}s!".format(build_time)) + + # Send the inference library over to the remote RPC server + temp = util.tempdir() + lib.save(temp.relpath("graphlib.o")) + remote.upload(temp.relpath("graphlib.o")) + lib = remote.load_module("graphlib.o") + + # Graph runtime + m = graph_runtime.create(graph, lib, ctx) + +###################################################################### +# Perform ResNet-18 inference +# --------------------------- +# We run classification on an image sample from ImageNet +# We just need to download the categories files, `synset.txt` +# and an input test image. + +# Download ImageNet categories +categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" +categ_fn = "synset.txt" +download.download(join(categ_url, categ_fn), categ_fn) +synset = eval(open(categ_fn).read()) + +# Download test image +image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' +response = requests.get(image_url) + +# Prepare test image for inference +image = Image.open(BytesIO(response.content)).resize((224, 224)) +plt.imshow(image) +plt.show() +image = np.array(image) - np.array([123., 117., 104.]) +image /= np.array([58.395, 57.12, 57.375]) +image = image.transpose((2, 0, 1)) +image = image[np.newaxis, :] +image = np.repeat(image, env.BATCH, axis=0) + +# Set the network parameters and inputs +m.set_input(**params) +m.set_input('data', image) + +# Perform inference: we run the module 4 times, +# and repeat 3 times to get error bounds +timer = m.module.time_evaluator("run", ctx, number=4, repeat=3) +tcost = timer() + +# Get classification results +tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) +top_categories = np.argsort(tvm_output.asnumpy()[0]) + +# Report top-5 classification results +std = np.std(tcost.results) * 1000 / env.BATCH +mean = tcost.mean * 1000 / env.BATCH +print("%s prediction" % model) +print(" #1:", synset[top_categories[-1]]) +print(" #2:", synset[top_categories[-2]]) +print(" #3:", synset[top_categories[-3]]) +print(" #4:", synset[top_categories[-4]]) +print(" #5:", synset[top_categories[-5]]) +print("Performed inference in %.2fms/sample (std = %.2f)" % (mean, std)) + +# This just checks that one of the 5 top categories +# is one variety of cat; this is by no means an accurate +# assessment of how quantization affects classification +# accuracy but is meant to catch changes to the +# quantization pass that would accuracy in the CI. +cat_detected = False +for k in top_categories[-5:]: + if "cat" in synset[k]: + cat_detected = True +assert(cat_detected) diff --git a/tutorials/optimize/README.txt b/tutorials/optimize/README.txt new file mode 100644 index 000000000000..b051548c5351 --- /dev/null +++ b/tutorials/optimize/README.txt @@ -0,0 +1,2 @@ +Optimize Tensor Operators +------------------------- diff --git a/tutorials/convolution_opt.py b/tutorials/optimize/convolution_opt.py similarity index 100% rename from tutorials/convolution_opt.py rename to tutorials/optimize/convolution_opt.py diff --git a/tutorials/matrix_multiply_opt.py b/tutorials/optimize/matrix_multiply_opt.py similarity index 100% rename from tutorials/matrix_multiply_opt.py rename to tutorials/optimize/matrix_multiply_opt.py diff --git a/tutorials/resnet.py b/tutorials/resnet.py deleted file mode 100644 index df3bb0607284..000000000000 --- a/tutorials/resnet.py +++ /dev/null @@ -1,330 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -ResNet Inference Example -======================== -**Author**: `Thierry Moreau `_ - -This tutorial provides an end-to-end demo, on how to run ResNet-18 inference -onto the VTA accelerator design to perform ImageNet classification tasks. - -""" - -###################################################################### -# Import Libraries -# ---------------- -# We start by importing the tvm, vta, nnvm libraries to run this example. - -from __future__ import absolute_import, print_function - -import os -import time -from io import BytesIO - -import numpy as np -import requests -from matplotlib import pyplot as plt -from PIL import Image - -import tvm -from tvm import rpc, autotvm -from tvm.contrib import graph_runtime, util -from tvm.contrib.download import download -import nnvm.compiler -import vta -import vta.testing - -# Load VTA parameters from the vta/config/vta_config.json file -env = vta.get_env() - -# Helper to crop an image to a square (224, 224) -# Takes in an Image object, returns an Image object -def thumbnailify(image, pad=15): - w, h = image.size - crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) - image = image.crop(crop) - image = image.resize((224, 224)) - return image - -# Helper function to read in image -# Takes in Image object, returns an ND array -def process_image(image): - # Convert to neural network input format - image = np.array(image) - np.array([123., 117., 104.]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :] - - return tvm.nd.array(image.astype("float32")) - -# Classification helper function -# Takes in the graph runtime, and an image, and returns top result and time -def classify(m, image): - m.set_input('data', image) - timer = m.module.time_evaluator("run", ctx, number=1) - tcost = timer() - tvm_output = m.get_output(0) - top = np.argmax(tvm_output.asnumpy()[0]) - tcost = "t={0:.2f}s".format(tcost.mean) - return tcost + " {}".format(synset[top]) - -# Helper function to compile the NNVM graph -# Takes in a path to a graph file, params file, and device target -# Returns the NNVM graph object, a compiled library object, and the params dict -def generate_graph(graph_fn, params_fn, device="vta"): - # Measure build start time - build_start = time.time() - - # Derive the TVM target - target = tvm.target.create("llvm -device={}".format(device)) - - # Derive the LLVM compiler flags - # When targetting the Pynq, cross-compile to ARMv7 ISA - if env.TARGET == "sim": - target_host = "llvm" - elif env.TARGET == "pynq": - target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" - - # Load the ResNet-18 graph and parameters - sym = nnvm.graph.load_json(open(graph_fn).read()) - params = nnvm.compiler.load_param_dict(open(params_fn, 'rb').read()) - - # Populate the shape and data type dictionary - shape_dict = {"data": (1, 3, 224, 224)} - dtype_dict = {"data": 'float32'} - shape_dict.update({k: v.shape for k, v in params.items()}) - dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - - # Apply NNVM graph optimization passes - sym = vta.graph.clean_cast(sym) - sym = vta.graph.clean_conv_fuse(sym) - if target.device_name == "vta": - assert env.BLOCK_IN == env.BLOCK_OUT - sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) - - # Compile NNVM graph - with nnvm.compiler.build_config(opt_level=3): - if target.device_name != "vta": - graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, - params=params, target_host=target_host) - else: - with vta.build_config(): - graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, - params=params, target_host=target_host) - - # Save the compiled inference graph library - assert tvm.module.enabled("rpc") - temp = util.tempdir() - lib.save(temp.relpath("graphlib.o")) - - # Send the inference library over to the remote RPC server - remote.upload(temp.relpath("graphlib.o")) - lib = remote.load_module("graphlib.o") - - # Measure build time - build_time = time.time() - build_start - print("ResNet-18 inference graph built in {0:.2f}s!".format(build_time)) - - return graph, lib, params - - -###################################################################### -# Download ResNet Model -# -------------------------------------------- -# Download the necessary files to run ResNet-18. -# - -# Obtain ResNet model and download them into _data dir -url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" -categ_fn = 'synset.txt' -graph_fn = 'resnet18_qt8.json' -params_fn = 'resnet18_qt8.params' - -# Create data dir -data_dir = "_data/" -if not os.path.exists(data_dir): - os.makedirs(data_dir) - -# Download files -for file in [categ_fn, graph_fn, params_fn]: - download(os.path.join(url, file), os.path.join(data_dir, file)) - -# Read in ImageNet Categories -synset = eval(open(os.path.join(data_dir, categ_fn)).read()) - -# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA -autotvm.tophub.check_backend('vta') - - -###################################################################### -# Setup the Pynq Board's RPC Server -# --------------------------------- -# Build the RPC server's VTA runtime and program the Pynq FPGA. - -# Measure build start time -reconfig_start = time.time() - -# We read the Pynq RPC host IP address and port number from the OS environment -host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99") -port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091")) - -# We configure both the bitstream and the runtime system on the Pynq -# to match the VTA configuration specified by the vta_config.json file. -if env.TARGET == "pynq": - # Make sure that TVM was compiled with RPC=1 - assert tvm.module.enabled("rpc") - remote = rpc.connect(host, port) - - # Reconfigure the JIT runtime - vta.reconfig_runtime(remote) - - # Program the FPGA with a pre-compiled VTA bitstream. - # You can program the FPGA with your own custom bitstream - # by passing the path to the bitstream file instead of None. - vta.program_fpga(remote, bitstream=None) - - # Report on reconfiguration time - reconfig_time = time.time() - reconfig_start - print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) - -# In simulation mode, host the RPC server locally. -elif env.TARGET == "sim": - remote = rpc.LocalSession() - - -###################################################################### -# Build the ResNet Runtime -# ------------------------ -# Build the ResNet graph runtime, and configure the parameters. - -# Set ``device=vtacpu`` to run inference on the CPU -# or ``device=vta`` to run inference on the FPGA. -device = "vta" - -# Device context -ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) - -# Build the graph runtime -graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn), - os.path.join(data_dir, params_fn), - device) -m = graph_runtime.create(graph, lib, ctx) - -# Set the parameters -m.set_input(**params) - -###################################################################### -# Run ResNet-18 inference on a sample image -# ----------------------------------------- -# Perform image classification on test image. -# You can change the test image URL to any image of your choosing. - -# Read in test image -image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' -# Read in test image -response = requests.get(image_url) -image = Image.open(BytesIO(response.content)).resize((224, 224)) -# Show Image -plt.imshow(image) -plt.show() -# Set the input -image = process_image(image) -m.set_input('data', image) - -# Perform inference -timer = m.module.time_evaluator("run", ctx, number=1) -tcost = timer() - -# Get classification results -tvm_output = m.get_output(0) -top_categories = np.argsort(tvm_output.asnumpy()[0]) - -# Report top-5 classification results -print("ResNet-18 Prediction #1:", synset[top_categories[-1]]) -print(" #2:", synset[top_categories[-2]]) -print(" #3:", synset[top_categories[-3]]) -print(" #4:", synset[top_categories[-4]]) -print(" #5:", synset[top_categories[-5]]) -print("Performed inference in {0:.2f}s".format(tcost.mean)) - - -###################################################################### -# Run a Youtube Video Image Classifier -# ------------------------------------ -# Perform image classification on test stream on 1 frame every 48 frames. -# Comment the `if False:` out to run the demo - -# Early exit - remove for Demo -if False: - - import cv2 - import pafy - from IPython.display import clear_output - - # Helper to crop an image to a square (224, 224) - # Takes in an Image object, returns an Image object - def thumbnailify(image, pad=15): - w, h = image.size - crop = ((w-h)//2+pad, pad, h+(w-h)//2-pad, h-pad) - image = image.crop(crop) - image = image.resize((224, 224)) - return image - - # 16:16 inches - plt.rcParams['figure.figsize'] = [16, 16] - - # Stream the video in - url = "https://www.youtube.com/watch?v=PJlmYh27MHg&t=2s" - video = pafy.new(url) - best = video.getbest(preftype="mp4") - cap = cv2.VideoCapture(best.url) - - # Process one frame out of every 48 for variety - count = 0 - guess = "" - while(count<2400): - - # Capture frame-by-frame - ret, frame = cap.read() - - # Process one every 48 frames - if count % 48 == 1: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = Image.fromarray(frame) - # Crop and resize - thumb = np.array(thumbnailify(frame)) - image = process_image(thumb) - guess = classify(m, image) - - # Insert guess in frame - frame = cv2.rectangle(thumb,(0,0),(200,0),(0,0,0),50) - cv2.putText(frame, guess, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (256,256,256), 1, cv2.LINE_AA) - - plt.imshow(thumb) - plt.axis('off') - plt.show() - if cv2.waitKey(1) & 0xFF == ord('q'): - break - clear_output(wait=True) - - count += 1 - - # When everything done, release the capture - cap.release() - cv2.destroyAllWindows()