From 36d3a41e3dc710a46ac7c7567bc0d64775f93900 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Sun, 22 Jul 2018 21:56:11 -0700 Subject: [PATCH] [TOPI] Bitserial low-precision convolution (#1332) --- python/tvm/intrin.py | 25 ++ src/codegen/llvm/llvm_module.cc | 9 + tests/python/unittest/test_codegen_llvm.py | 11 + topi/python/topi/generic/nn.py | 35 ++ topi/python/topi/nn/__init__.py | 1 + topi/python/topi/nn/bitserial_conv2d.py | 341 ++++++++++++++++ topi/python/topi/rasp/__init__.py | 1 + topi/python/topi/rasp/bitserial_conv2d.py | 365 ++++++++++++++++++ topi/python/topi/x86/__init__.py | 1 + topi/python/topi/x86/bitserial_conv2d.py | 316 +++++++++++++++ .../python/test_topi_bitserial_conv2d.py | 112 ++++++ .../python/test_topi_bitserial_conv2d_rasp.py | 56 +++ 12 files changed, 1273 insertions(+) create mode 100644 topi/python/topi/nn/bitserial_conv2d.py create mode 100644 topi/python/topi/rasp/bitserial_conv2d.py create mode 100644 topi/python/topi/x86/bitserial_conv2d.py create mode 100644 topi/tests/python/test_topi_bitserial_conv2d.py create mode 100644 topi/tests/python/test_topi_bitserial_conv2d_rasp.py diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 422f2d682d2b..30da873b5dcf 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -154,6 +154,31 @@ def call_extern(dtype, func_name, *args): dtype, func_name, convert(args), _Call.Extern, None, 0) +def call_llvm_intrin(dtype, name, *args): + """Build expression by calling an llvm intrinsic function + + Parameters + ---------- + dtype : str + The data type of the result. + + name : str + The name of the llvm intrinsic function. + + args : list + Poistional arguments. + + Returns + ------- + call : Expr + The call expression. + """ + import tvm + llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) + assert llvm_id != 0, "%s is not an LLVM intrinsic" % name + return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) + + def exp(x): """Take exponetial of input x. diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 2bae52b194f5..99740b0dbdca 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -282,6 +282,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::shared_ptr ctx_; }; +unsigned LookupLLVMIntrinsic(const std::string& name) { + return llvm::Function::lookupIntrinsicID(name); +} + +TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(LookupLLVMIntrinsic(args[0])); + }); + TVM_REGISTER_API("codegen.build_llvm") .set_body([](TVMArgs args, TVMRetValue* rv) { std::shared_ptr n = std::make_shared(); diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index f05fad10d273..e07f4aa8f40c 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -17,6 +17,16 @@ def test_llvm_intrin(): func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) fcode = tvm.build(func, None, "llvm") +def test_llvm_lookup_intrin(): + ib = tvm.ir_builder.create() + m = tvm.var("m") + A = ib.pointer("uint8x8", name="A") + x = tvm.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.const(1, 'uint32'), A) + ib.emit(x) + body = ib.get() + func = tvm.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) + fcode = tvm.build(func, None, "llvm") + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) @@ -324,3 +334,4 @@ def test_alignment(): test_llvm_flip_pipeline() test_llvm_madd_pipeline() test_llvm_temp_space() + test_llvm_lookup_intrin() diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 08220c15bbf0..fe76b9715d59 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -143,6 +143,41 @@ def schedule_depthwise_conv2d_nhwc(outs): """ return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_bitserial_conv2d_nchw(outs): + """Schedule for bitserial_conv2d_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of bitserial_conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + +@tvm.target.generic_func +def schedule_bitserial_conv2d_nhwc(outs): + """Schedule for bitserial_conv2d_nhwc + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of bitserial_conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + @tvm.target.override_native_generic_func("schedule_reduce") def schedule_reduce(outs): diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index 7b6ee4a86836..690379135e06 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -16,4 +16,5 @@ from .bnn import * from .upsampling import * from .local_response_norm import * +from .bitserial_conv2d import * from .l2_normalize import * diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py new file mode 100644 index 000000000000..ca2efb0820c1 --- /dev/null +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -0,0 +1,341 @@ +# pylint: disable=invalid-name, unused-variable, too-many-locals, too-many-arguments, unused-argument +"""Bitserial Conv2D operators""" +from __future__ import absolute_import as _abs +from collections import namedtuple +import numpy as np +import tvm +from topi.transform import concatenate +from .pad import pad +from .util import get_pad_tuple +from ..util import get_const_tuple, get_const_int + +# workload description of conv2d +Workload = namedtuple('Workload', + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +SpatialPackNCHW = namedtuple('SpatialPack', + ['vh', 'vw', 'vc', 'ba', 'bc']) + +SpatialPackNHWC = namedtuple('SpatialPack', + ['vh', 'vw', 'vc', 'ba', 'bc']) + +_WORKLOADS = [ + # workloads of resnet18 on imagenet + # input_size, input_size, ic, oc, kh, kw, pad, pad, stride, stride + Workload('uint32', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + Workload('uint32', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + Workload('uint32', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + Workload('uint32', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + Workload('uint32', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + Workload('uint32', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + Workload('uint32', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + Workload('uint32', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + Workload('uint32', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + Workload('uint32', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('uint32', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + + # workload of alexnet on cifar10 + Workload('int32', 'int32', 27, 27, 96, 192, 5, 5, 2, 2, 1, 1), + Workload('int32', 'int32', 13, 13, 192, 384, 3, 3, 1, 1, 1, 1), + Workload('int32', 'int32', 13, 13, 384, 384, 3, 3, 1, 1, 1, 1), + Workload('int32', 'int32', 13, 13, 384, 256, 3, 3, 1, 1, 1, 1), +] + +@tvm.target.generic_func +def bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, + layout='NCHW', pack_dtype='uint32', out_dtype='int32', dorefa=True): + """Bitserial Conv2D operator. + + Parameters + ---------- + input : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] or + [batch, in_height, in_width, in_channel] + + filter : tvm.Tensor + 4-D with shape [num_filter, in_channel, filter_height, filter_width] or + [filter_height, filter_width, in_channel, num_filter] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of two ints + padding size, or [pad_height, pad_width] + + layout : str + layout of data + + activation_bits: int + number of bits used for activations/input elements + + weight_bits: int + number of bits used for weight elements + + out_dtype: str + return type of convolution + + pack_dtype: str + bit packing type + + dorefa: bool + preform the bitserial dot-product using 2 popcounts (required for DoReFa-Net) + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] or + [batch, out_height, out_width, out_channel] + """ + # search platform specific declaration first + # default declaration + if layout == 'NCHW': + return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) + elif layout == 'NHWC': + return spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) + raise ValueError("not support this layout {} yet".format(layout)) + +def _get_workload(data, kernel, stride, padding, out_dtype, layout): + """ Get the workload structure. """ + assert layout == "NCHW" or layout == "NHWC", \ + "Only support layouts NCHW and NHWC" + if layout == "NCHW": + _, CI, IH, IW = [x.value for x in data.shape] + CO, _, KH, KW = [x.value for x in kernel.shape] + else: # NHWC + IH, IW = data.shape[1].value, data.shape[2].value + KH, KW, CI, CO = [x for x in get_const_tuple(kernel.shape)] + + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + + return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + +@tvm.target.generic_func +def _get_schedule(wkl, layout): + # pylint: disable=unreachable + """ Get the platform specific schedule. """ + target = tvm.target.current_target() + raise RuntimeError( + "No schedule for current target:{}".format(target)) + # This return has no use, merely to supress pylint warning + return wkl + +def spatial_pack_nchw(data, kernel, stride, padding, in_bits, weight_bits, + pack_dtype, out_dtype, dorefa=False): + """ Compute convolution with pack on spatial axes. """ + assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" + data_q = bitpack(data, in_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) + kernel_q = bitpack(kernel, weight_bits, pack_axis=1, bit_axis=0, pack_type=pack_dtype) + IB, _, CI, H, W = data_q.shape + KB, CO, _, KH, KW = kernel_q.shape + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + HCAT, WCAT = KH-1, KW-1 + + wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NCHW") + sch = _get_schedule(wkl, "NCHW") + VH = sch.vh + VW = sch.vw + VC = sch.vc + + TH = H + 2*HPAD + TW = W + 2*WPAD + OH = (H + 2*HPAD - KH) // HSTR + 1 + OW = (W + 2*WPAD - KW) // WSTR + 1 + + dshape = (IB, 1, CI, H, W) + dpshape = (IB, 1, CI, TH, TW) + dvshape = (1, TH//(VH*HSTR), TW//(VW*WSTR), CI, VH*HSTR+HCAT, VW*WSTR+WCAT, IB) + + kshape = (KB, CO, CI, KH, KW) + kvshape = (CO//VC, CI, KH, KW, KB, VC) + + ovshape = (1, CO//VC, OH//VH, OW//VW, VH, VW, VC) + oshape = (1, CO, OH, OW) + + DOPAD = (HPAD != 0 and WPAD != 0) + if DOPAD: + data_pad = pad(data_q, (0, 0, 0, HPAD, WPAD), name="data_pad") + else: + data_pad = data_q + + data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw, b: \ + data_pad[b][n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') + + kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, b, vc: \ + kernel_q[b][co*VC+vc][ci][dh][dw], name='kernel_vec') + + ci = tvm.reduce_axis((0, CI), name='ci') + dh = tvm.reduce_axis((0, KH), name='dh') + dw = tvm.reduce_axis((0, KW), name='dw') + b1 = tvm.reduce_axis((0, IB), name='ib') + b2 = tvm.reduce_axis((0, KB), name='kb') + + def _conv(n, co, h, w, vh, vw, vc): + b1b2 = (b1+b2).astype(out_dtype) + if dorefa: + return tvm.sum((tvm.popcount( + data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) & + kernel_vec[co, ci, dh, dw, b2, vc].astype(out_dtype)) - + tvm.popcount( + data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1].astype(out_dtype) + & ~kernel_vec[co, ci, dh, dw, b2, vc]).astype(out_dtype)) << b1b2, + axis=[ci, dh, dw, b1, b2]) + + return tvm.sum((tvm.popcount( + data_vec[n, h, w, ci, vh*HSTR+dh, vw*WSTR+dw, b1] & + kernel_vec[co, ci, dh, dw, b2, vc])).astype(out_dtype) << b1b2, + axis=[ci, dh, dw, b1, b2]) + + conv = tvm.compute(ovshape, _conv, name='conv_out') + + return tvm.compute(oshape, lambda n, co, h, w: + conv[n][co//VC][h//VH][w//VW][h%VH][w%VW][co%VC], + name='conv_vec', tag='spatial_bitserial_conv_nchw') + +def spatial_pack_nhwc(data, kernel, stride, padding, in_bits, weight_bits, + pack_dtype, out_dtype, dorefa=False): + """ Compute convolution with pack on spatial axes. """ + assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" + data_q = bitpack(data, in_bits, pack_axis=3, bit_axis=4, pack_type=pack_dtype) + kernel_q = bitpack(kernel, weight_bits, pack_axis=2, bit_axis=4, pack_type=pack_dtype) + _, H, W, CI, IB = data_q.shape + KH, KW, _, CO, KB = kernel_q.shape + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + HCAT, WCAT = KH-1, KW-1 + + wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") + sch = _get_schedule(wkl, "NHWC") + VH = sch.vh + VW = sch.vw + VC = sch.vc + + PAD_H = H + 2*HPAD + PAD_W = W + 2*WPAD + OH = (H + 2*HPAD - KH) // HSTR + 1 + OW = (W + 2*WPAD - KW) // WSTR + 1 + + dvshape = (1, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, CI, IB) + kvshape = (CO, KH, KW, CI, VC, KB) + ovshape = (1, OH, OW, CO, VH, VW, VC) + oshape = (1, OH, OW, CO) + + if (HPAD != 0 and WPAD != 0): + data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") + else: + data_pad = data_q + + data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, ci, b: \ + data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][ci][b], name='data_vec') + + kernel_vec = tvm.compute(kvshape, lambda co, dh, dw, ci, vc, b: \ + kernel_q[dh][dw][ci][co*VC+vc][b], name='kernel_vec') + + ci = tvm.reduce_axis((0, CI), name='ci') + dh = tvm.reduce_axis((0, KH), name='dh') + dw = tvm.reduce_axis((0, KW), name='dw') + b1 = tvm.reduce_axis((0, IB), name='ib') + b2 = tvm.reduce_axis((0, KB), name='kb') + + def _conv(n, h, w, co, vh, vw, vc): + b1b2 = (b1+b2).astype(out_dtype) + if dorefa: + return tvm.sum( + (tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & + kernel_vec[co, dh, dw, ci, vc, b2].astype(out_dtype)) - + tvm.popcount(data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1].astype(out_dtype) & + ~kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype)) << b1b2, + axis=[dh, dw, ci, b1, b2]) + + return tvm.sum(tvm.popcount( + data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ci, b1] & + kernel_vec[co, dh, dw, ci, vc, b2]).astype(out_dtype) << b1b2, + axis=[dh, dw, ci, b1, b2]) + + conv = tvm.compute(ovshape, _conv, name='conv') + + return tvm.compute(oshape, lambda n, h, w, co: + conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], + name='output_unpack', tag='spatial_bitserial_conv_nhwc') + +def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): + """Packs data into format necessary for bitserial computation + pack_axis : int + index of the axis to pack in data + bit_axis : int + index of axis to place bit axis in resulting packed data""" + ishape = data.shape + n = len(ishape) + if pack_type == 'uint8': + data_width = 8 + elif pack_type == 'uint16': + data_width = 16 + elif pack_type == 'uint32': + data_width = 32 + elif pack_type == 'uint64': + data_width = 64 + + # Data must be in multiples of the data_width + assert get_const_int(ishape[pack_axis]) % data_width == 0, "Not a multiple of word size" + + shape_vec = list(ishape) + shape_vec[pack_axis] = (shape_vec[pack_axis] // data_width) + shape_vec.insert(bit_axis, 1) + bitserial_oshape = tuple(shape_vec) + masks = np.array([0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80]) + + # pack axis shifts if bit axis comes before + if bit_axis <= pack_axis: + pack_axis += 1 + + def _bitpack(*indices): + packed_data = [tvm.const(0, pack_type)] * bits + for k in range(data_width): + # Translate indices for packed data back to original + idx = [0] * n + j = 0 + for i in range(n+1): + if i == bit_axis: + continue + elif i == pack_axis: + idx[j] = indices[i] * data_width + k + else: + idx[j] = indices[i] + j += 1 + + element = data(*idx) + for b in range(bits): + extracted_bit = ((element & tvm.const(masks[b])) >> b).astype(pack_type) + packed_data[b] = (packed_data[b] | extracted_bit) + if k < data_width - 1: + packed_data[b] = packed_data[b] << 1 + + if k == data_width - 1: + return tuple(packed_data) + return tuple(packed_data) + + output_tuple = tvm.compute(bitserial_oshape, _bitpack, name=name, tag='bitpack') + + if bits > 1: + return concatenate(output_tuple, axis=bit_axis) + return output_tuple + +_SCH_TO_DECL_FUNC_QUANT = { + SpatialPackNCHW: spatial_pack_nchw, + SpatialPackNHWC: spatial_pack_nhwc, +} diff --git a/topi/python/topi/rasp/__init__.py b/topi/python/topi/rasp/__init__.py index 31ecea5aba4e..270a48504468 100644 --- a/topi/python/topi/rasp/__init__.py +++ b/topi/python/topi/rasp/__init__.py @@ -4,3 +4,4 @@ from .conv2d import schedule_conv2d_nchw from .depthwise_conv2d import schedule_depthwise_conv2d_nchw +from .bitserial_conv2d import schedule_bitserial_conv2d_nhwc diff --git a/topi/python/topi/rasp/bitserial_conv2d.py b/topi/python/topi/rasp/bitserial_conv2d.py new file mode 100644 index 000000000000..7d292db8d298 --- /dev/null +++ b/topi/python/topi/rasp/bitserial_conv2d.py @@ -0,0 +1,365 @@ +# pylint: disable=invalid-name,unused-variable,invalid-name +"""Bitserial conv2d schedule on raspberry pi""" +from __future__ import absolute_import as _abs +from collections import namedtuple +import tvm +from .. import tag +from ..nn.pad import pad +from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload, bitpack +from ..nn.bitserial_conv2d import SpatialPackNCHW, _WORKLOADS, spatial_pack_nchw +from ..nn.util import get_pad_tuple +from ..util import get_const_int +from .. import generic + +RaspSpatialPack = namedtuple('SpatialPack', + ['vh', 'vw', 'vc', 'ba', 'bc', 'split_ci', 'kfactor']) + +_QUANTIZED_SCHEDULES_NHWC = [ + RaspSpatialPack(2, 2, 8, 1, 1, False, 8), + RaspSpatialPack(1, 4, 8, 4, 1, False, 8), + RaspSpatialPack(1, 4, 8, 1, 16, False, 8), + RaspSpatialPack(1, 4, 8, 4, 8, False, 8), + RaspSpatialPack(1, 7, 8, 3, 8, False, 16), + RaspSpatialPack(1, 2, 8, 1, 8, False, 16), + RaspSpatialPack(2, 1, 8, 1, 4, False, 16), + RaspSpatialPack(1, 7, 8, 1, 1, True, 16), + RaspSpatialPack(1, 1, 8, 1, 16, True, 16), + RaspSpatialPack(1, 1, 8, 1, 8, True, 16), + RaspSpatialPack(1, 1, 8, 1, 16, True, 16), +] + +_QUANTIZED_SCHEDULES_NCHW = [ + # resnet + SpatialPackNCHW(2, 2, 8, 1, 1), + SpatialPackNCHW(1, 4, 8, 4, 1), + SpatialPackNCHW(1, 4, 8, 1, 16), + SpatialPackNCHW(1, 4, 8, 4, 8), + SpatialPackNCHW(1, 7, 8, 3, 8), + SpatialPackNCHW(1, 2, 8, 1, 8), + SpatialPackNCHW(2, 1, 8, 1, 4), + SpatialPackNCHW(1, 7, 8, 1, 1), + SpatialPackNCHW(1, 1, 8, 1, 16), + SpatialPackNCHW(1, 1, 8, 1, 8), + SpatialPackNCHW(1, 1, 8, 1, 16), +] + +@_get_schedule.register("rasp") +def _get_schedule_bitserial_conv2d(wkl, layout): + if wkl not in _WORKLOADS: + raise ValueError("no schedule for such workload: {}".format(wkl)) + idx = _WORKLOADS.index(wkl) + if layout == "NCHW": + sch = _QUANTIZED_SCHEDULES_NCHW[idx] + elif layout == "NHWC": + sch = _QUANTIZED_SCHEDULES_NHWC[idx] + return sch + + +@bitserial_conv2d.register("rasp") +def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, + layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False): + if out_dtype is None: + out_dtype = data.dtype + assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" + assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC" + if dorefa: + assert layout == "NCHW", "Cannot support dorea with NHWC layout yet" + wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) + sch = _get_schedule(wkl, layout) + if layout == "NCHW": + return spatial_pack_nchw(data, kernel, stride, padding, activation_bits, weight_bits, + pack_dtype=pack_dtype, out_dtype=out_dtype, dorefa=dorefa) + return _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, + weight_bits, out_dtype) + +def _kernel_vec_spatial_pack_nhwc(kernel, kernel_bits, VC): + kernel_q = bitpack(kernel, kernel_bits, pack_axis=2, bit_axis=2, pack_type='uint8') + KH, KW, KB, CI, CO = kernel_q.shape + kvshape = (CO//VC, KH, KW, KB, VC, CI) + return tvm.compute(kvshape, lambda co, dh, dw, b, vc, ci: \ + kernel_q[dh][dw][b][ci][co*VC+vc], name='kernel_vec') + +def _spatial_pack_nhwc(data, kernel, stride, padding, activation_bits, weight_bits, out_dtype): + """ Compute convolution with pack on spatial axes. """ + assert data.shape[0].value == 1, "spatial pack convolution only support batch size=1" + wkl = _get_workload(data, kernel, stride, padding, out_dtype, "NHWC") + sch = _get_schedule(wkl, "NHWC") + VH = sch.vh + VW = sch.vw + VC = sch.vc + + data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') + kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC) + N, H, W, IB, CI = data_q.shape + OCO, KH, KW, KB, VC, _ = kernel_vec.shape + + CO = OCO * VC + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + HCAT, WCAT = KH-1, KW-1 + + PAD_H = H + 2*HPAD + PAD_W = W + 2*WPAD + OH = (H + 2*HPAD - KH) // HSTR + 1 + OW = (W + 2*WPAD - KW) // WSTR + 1 + dvshape = (N, PAD_H//(VH*HSTR), PAD_W//(VW*WSTR), VH*HSTR+HCAT, VW*WSTR+WCAT, IB, CI) + ovshape = (1, OH // VH, OW // VW, CO // VC, VH, VW, VC) + oshape = (1, OH, OW, CO) + + if (HPAD != 0 and WPAD != 0): + data_pad = pad(data_q, (0, HPAD, WPAD, 0, 0), name="data_pad") + else: + data_pad = data_q + + data_vec = tvm.compute(dvshape, lambda n, h, w, vh, vw, b, ci: \ + data_pad[n][h*VH*HSTR+vh][w*VW*WSTR+vw][b][ci], name='data_vec') + + ci = tvm.reduce_axis((0, CI), name='ci') + dh = tvm.reduce_axis((0, KH), name='dh') + dw = tvm.reduce_axis((0, KW), name='dw') + ib = tvm.reduce_axis((0, IB), name='ib') + kb = tvm.reduce_axis((0, KB), name='kb') + + def _conv(n, h, w, co, vh, vw, vc): + return tvm.sum((tvm.popcount( + kernel_vec[co, dh, dw, kb, vc, ci].astype('uint16') & + data_vec[n, h, w, vh*HSTR+dh, vw*WSTR+dw, ib, ci].astype('uint16')) + << (kb + ib).astype('uint16')), axis=[dh, dw, kb, ib, ci]) + + conv = tvm.compute(ovshape, _conv, name='conv') + + return tvm.compute(oshape, lambda n, h, w, co: + conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype), + name='output_vec', tag='spatial_bitserial_conv_nhwc') + +def _intrin_popcount(m, k_i, w_b, x_b): + dtype = 'uint8' + w = tvm.placeholder((w_b, m, k_i), dtype=dtype, name='w') + x = tvm.placeholder((x_b, k_i,), dtype=dtype, name='x') + k = tvm.reduce_axis((0, k_i), name='k') + bw = tvm.reduce_axis((0, w_b), name='bw') + bx = tvm.reduce_axis((0, x_b), name='bx') + z = tvm.compute((m,), lambda i: + tvm.sum(tvm.popcount(w[bw, i, k].astype('uint16') & + x[bx, k].astype('uint16')) + << (bw+bx).astype('uint16'), axis=[bw, bx, k]), name='z') + + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=k_i, + strides=[tvm.var('ldw'), tvm.var('ldw'), 1]) + Xb = tvm.decl_buffer(x.shape, x.dtype, + name="X", + offset_factor=k_i, + strides=[tvm.var('ldw'), 1]) + + def _intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + vpadd = "llvm.arm.neon.vpadd.v8u8" + vpadalu = "llvm.arm.neon.vpadalu.v16u8.v8u16" + args_1 = tvm.const(1, 'uint32') + args_2 = tvm.const(2, 'uint32') + + def _instr(index): + irb = tvm.ir_builder.create() + if index == 1: + irb.emit(zz.vstore(0, tvm.const(0, 'uint16x8'))) + return irb.get() + + cnts8 = [None] * 8 + cnts4 = [None] * 4 + cnts2 = [None] * 2 + for bw in range(w_b): + for bx in range(x_b): + if k_i == 16: + for i in range(m): + ands = ww.vload([bw, i, 0], 'uint8x16') & xx.vload([bx, 0], 'uint8x16') + cnts = tvm.popcount(ands) + upper_half = tvm.call_pure_intrin('uint8x8', 'vectorhigh', cnts) + lower_half = tvm.call_pure_intrin('uint8x8', 'vectorlow', cnts) + cnts8[i] = upper_half + lower_half + for i in range(m//2): + cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + args_1, cnts8[i*2], cnts8[i*2+1]) + for i in range(m//4): + cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + args_1, cnts4[i*2], cnts4[i*2+1]) + cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) + shifted_cnts = cnts << tvm.const(bw+bx, dtype) + out = tvm.call_llvm_intrin('uint16x8', vpadalu, + args_2, zz.vload(0, 'uint16x8'), shifted_cnts) + else: # ki == 8 + for i in range(m): + ands = ww.vload([bw, i, 0], 'uint8x8') & xx.vload([bx, 0], 'uint8x8') + cnts8[i] = tvm.popcount(ands) + for i in range(m//2): + cnts4[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + args_1, cnts8[i*2], cnts8[i*2+1]) + for i in range(m//4): + cnts2[i] = tvm.call_llvm_intrin('uint8x8', vpadd, + args_1, cnts4[i*2], cnts4[i*2+1]) + cnts = tvm.call_pure_intrin('uint8x16', 'vectorcombine', cnts2[0], cnts2[1]) + shifted_cnts = cnts << tvm.const(bw+bx, dtype) + out = tvm.call_llvm_intrin('uint16x8', vpadalu, + args_2, zz.vload(0, 'uint16x8'), shifted_cnts) + irb.emit(zz.vstore(0, out)) + return irb.get() + # body, reset, update + return _instr(0), _instr(1), _instr(2) + with tvm.build_config(offset_factor=1, partition_const_loop=True): + return tvm.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb}) + +# ARM specific schedule that using custom microkernel +def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, + conv_out, output, last): + # no stride and padding info here + _, H, W, IB, CI = data_q.shape + KH, KW, KB, _, CO = kernel_q.shape + KB = get_const_int(KB) + IB = get_const_int(IB) + + if data_pad is None: + padding = (0, 0) + _, in_h, in_w, _, _ = data_q.shape + kern_h, kern_w, _, _ = kernel.shape + _, out_h, out_w, _ = output.shape + hstride = (in_h - kern_h) // (out_h - 1) + wstride = (in_w - kern_w) // (out_w - 1) + stride = get_const_int(hstride), get_const_int(wstride) + else: + _, in_h, in_w, _, _ = data_q.shape + _, pad_h, pad_w, _, _ = data_pad.shape + hpad = (pad_h - in_h) // 2 + wpad = (pad_w - in_w) // 2 + padding = get_const_int(hpad), get_const_int(wpad) + + _, in_h, in_w, _, _ = data_pad.shape + kern_h, kern_w, _, _ = kernel.shape + _, out_h, out_w, _ = output.shape + hstride = (in_h - kern_h) // (out_h - 1) + wstride = (in_w - kern_w) // (out_w - 1) + stride = get_const_int(hstride), get_const_int(wstride) + + wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NHWC") + sch = _get_schedule(wkl, "NHWC") + + VH = sch.vh + VW = sch.vw + VC = sch.vc + ba = sch.ba + bc = sch.bc + + ##### Schedule data packing + if data_pad is not None: + s[data_pad].compute_inline() + + _, h, _, _, _, _, _ = s[data_vec].op.axis + if ba == 1: + oaxis = h + paxis = h + else: + oh, ih = s[data_vec].split(h, ba) + oaxis = oh + paxis = ih + + s[data_vec].parallel(paxis) + s[data_vec].pragma(oaxis, "parallel_launch_point") + s[data_vec].pragma(paxis, "parallel_stride_pattern") + s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") + + ##### Schedule kernel packing + co, _, _, _, _, _ = s[kernel_vec].op.axis + if bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[kernel_vec].split(co, bc) + oaxis = oco + paxis = ico + + s[kernel_vec].parallel(paxis) + s[kernel_vec].pragma(oaxis, "parallel_launch_point") + s[kernel_vec].pragma(paxis, "parallel_stride_pattern") + s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") + + ##### Schedule Convolution + n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis + dh, dw, kb, ib, ci = s[conv_out].op.reduce_axis + + kfactor = sch.kfactor + if sch.split_ci: + oci, ici = s[conv_out].split(ci, kfactor) + s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, oci, kb, ib, vc, ici) + else: + s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, kb, ib, vc, ci) + + pc = _intrin_popcount(8, kfactor, KB, IB) + s[conv_out].tensorize(kb, pc) + + n, h, w, co = s[last].op.axis + co, vc = s[last].split(co, VC) + oh, ow, vh, vw = s[last].tile(h, w, VH, VW) + s[last].reorder(n, oh, ow, co, vc, vh, vw) + s[last].vectorize(vw) + if last != output: + s[last].compute_inline() + + s[conv_out].compute_at(s[last], ow) + if co == 1: + oaxis = oh + paxis = oh + else: + oho, iho = s[last].split(oh, bc) + oaxis = oho + paxis = iho + + s[last].parallel(paxis) + s = s.normalize() + return s + +@generic.schedule_bitserial_conv2d_nhwc.register(["rasp"]) +def schedule_bitserial_conv2d_nhwc(outs): + """Raspverry pi schedule for bitserial conv2d""" + s = tvm.create_schedule([x.op for x in outs]) + def traverse(op): + """Traverse operators from computation graph""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag): + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + if 'spatial_bitserial_conv_nhwc' in op.tag: + output = op.output(0) + conv_out = op.input_tensors[0] + kernel_vec = conv_out.op.input_tensors[0] + kernel_q = kernel_vec.op.input_tensors[0] + kernel = kernel_q.op.input_tensors[0] + if "QuantizeInput" in kernel.op.name: + # Need to go up 1 further, from the combine in bitpack + kernel = kernel.op.input_tensors[0] + data_vec = conv_out.op.input_tensors[1] + data_q = data_vec.op.input_tensors[0] + data = data_q.op.input_tensors[0] + data_pad = None + if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag: + data_pad = data_q + data_q = data + data = data_q.op.input_tensors[0] + if "QuantizeInput" in data.op.name: + # Need to go up 1 further, from the combine in bitpack + data = data.op.input_tensors[0] + + _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, conv_out, output, outs[0]) + + traverse(outs[0].op) + return s diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index d001b5fdca57..c146419fcec9 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -8,3 +8,4 @@ from .nn import * from .injective import * from .pooling import schedule_pool, schedule_global_pool +from .bitserial_conv2d import schedule_bitserial_conv2d diff --git a/topi/python/topi/x86/bitserial_conv2d.py b/topi/python/topi/x86/bitserial_conv2d.py new file mode 100644 index 000000000000..1c01b96f9c30 --- /dev/null +++ b/topi/python/topi/x86/bitserial_conv2d.py @@ -0,0 +1,316 @@ +# pylint: disable=invalid-name,unused-variable,invalid-name +"""Bitserial conv2d schedule on x86""" +import tvm +from topi.util import get_const_int +from .. import generic, tag +from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload +from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC +from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT + +_QUANTIZED_SCHEDULES_NCHW = [ + # resnet + SpatialPackNCHW(2, 2, 8, 1, 1), + SpatialPackNCHW(1, 4, 8, 4, 1), + SpatialPackNCHW(1, 4, 8, 1, 16), + SpatialPackNCHW(1, 4, 8, 4, 8), + SpatialPackNCHW(1, 7, 8, 3, 8), + SpatialPackNCHW(1, 2, 8, 1, 8), + SpatialPackNCHW(2, 1, 8, 1, 4), + SpatialPackNCHW(1, 7, 8, 1, 1), + SpatialPackNCHW(1, 1, 8, 1, 16), + SpatialPackNCHW(1, 1, 8, 1, 8), + SpatialPackNCHW(1, 1, 8, 1, 16), + + SpatialPackNCHW(3, 3, 16, 3, 16), + SpatialPackNCHW(1, 1, 16, 2, 16), + SpatialPackNCHW(1, 1, 8, 1, 16), + SpatialPackNCHW(1, 1, 8, 1, 16), +] + +_QUANTIZED_SCHEDULES_NHWC = [ + # resnet + SpatialPackNHWC(2, 2, 8, 1, 1), + SpatialPackNHWC(1, 4, 8, 4, 1), + SpatialPackNHWC(1, 4, 8, 1, 16), + SpatialPackNHWC(1, 4, 8, 4, 8), + SpatialPackNHWC(1, 7, 8, 3, 8), + SpatialPackNHWC(1, 2, 8, 1, 8), + SpatialPackNHWC(2, 1, 8, 1, 4), + SpatialPackNHWC(1, 7, 8, 1, 1), + SpatialPackNHWC(1, 1, 8, 1, 16), + SpatialPackNHWC(1, 1, 8, 1, 8), + SpatialPackNHWC(1, 1, 8, 1, 16), +] + +@_get_schedule.register("cpu") +def _get_schedule_bitserial_conv2d(wkl, layout): + if wkl not in _WORKLOADS: + raise ValueError("no schedule for such workload: {}".format(wkl)) + idx = _WORKLOADS.index(wkl) + if layout == "NCHW": + sch = _QUANTIZED_SCHEDULES_NCHW[idx] + elif layout == "NHWC": + sch = _QUANTIZED_SCHEDULES_NHWC[idx] + return sch + +@bitserial_conv2d.register("cpu") +def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits, + layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False): + if out_dtype is None: + out_dtype = data.dtype + assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp" + assert layout == "NCHW" or layout == "NHWC", "only support layouts NCHW and NHWC" + + wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout) + sch = _get_schedule(wkl, layout) + return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits, + weight_bits, pack_dtype, out_dtype, dorefa) + +@generic.schedule_bitserial_conv2d_nchw.register(["cpu"]) +@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"]) +def schedule_bitserial_conv2d(outs): + """CPU schedule for bitserial convolutions NCHW and NHWC""" + s = tvm.create_schedule([x.op for x in outs]) + + def traverse(op): + """Traverse operators from computation graph""" + output = op.output(0) + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(op.tag) or 'elemwise' in op.tag: + if op not in s.outputs: + s[op].compute_inline() + for tensor in op.input_tensors: + if tensor.op.input_tensors: + traverse(tensor.op) + + elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag: + conv_out = op.input_tensors[0] + kernel_vec = conv_out.op.input_tensors[1] + kernel_q = kernel_vec.op.input_tensors[0] + kernel = kernel_q.op.input_tensors[0] + data_vec = conv_out.op.input_tensors[0] + data_q = data_vec.op.input_tensors[0] + data = data_q.op.input_tensors[0] + data_pad = None + if isinstance(data_q.op, tvm.tensor.ComputeOp) and "pad" in data_q.op.tag: + data_pad = data_q + data_q = data + data = data_q.op.input_tensors[0] + if "QuantizeInput" in kernel.op.name: + # Need to go up 1 further, from the combine in bitpack + kernel = kernel.op.input_tensors[0] + if "QuantizeInput" in data.op.name: + # Need to go up 1 further, from the combine in bitpack + data = data.op.input_tensors[0] + + if 'spatial_bitserial_conv_nchw' in op.tag: + _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, + conv_out, output, outs[0]) + elif 'spatial_bitserial_conv_nhwc' in op.tag: + _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, + conv_out, output, outs[0]) + + traverse(outs[0].op) + return s + +def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, + conv_out, output, last): + IB, _, CI, IH, IW = data_q.shape + KB, CO, _, KH, KW = kernel_q.shape + _, _, OH, OW = output.shape + + # Infer padding and stride + if data_pad is None: + padding = (0, 0) + TH, TW = IH, IW + else: + _, _, _, TH, TW = data_pad.shape + hpad = get_const_int((TH - IH) // 2) + wpad = get_const_int((TW - IW) // 2) + padding = (hpad, wpad) + + hstride = get_const_int((TH - KH) // (OH - 1)) + wstride = get_const_int((TW - KW) // (OW - 1)) + stride = (hstride, wstride) + + wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW") + sch = _get_schedule(wkl, "NCHW") + VH = sch.vh + VW = sch.vw + VC = sch.vc + ba = sch.ba + bc = sch.bc + + CC = s.cache_write(conv_out, "global") + n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis + s[conv_out].vectorize(vc) + + s[CC].compute_at(s[conv_out], ow) + n, co, oh, ow, vh, vw, vc = s[CC].op.axis + ci, dh, dw, b1, b2 = s[CC].op.reduce_axis + s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc) + s[CC].unroll(b1) + s[CC].unroll(b2) + s[CC].vectorize(vc) + + ##### Schedule A + if data_pad is not None: + s[data_pad].compute_inline() + + _, h, _, _, _, _, vw = s[data_vec].op.axis + s[data_vec].vectorize(vw) + if ba == 1: + oaxis = h + paxis = h + else: + oh, ih = s[data_vec].split(h, ba) + oaxis = oh + paxis = ih + + s[data_vec].parallel(paxis) + s[data_vec].pragma(oaxis, "parallel_launch_point") + s[data_vec].pragma(paxis, "parallel_stride_pattern") + s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule B + co, _, _, _, _, vc = s[kernel_vec].op.axis + s[kernel_vec].vectorize(vc) + if bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[kernel_vec].split(co, bc) + oaxis = oco + paxis = ico + + s[kernel_vec].parallel(paxis) + s[kernel_vec].pragma(oaxis, "parallel_launch_point") + s[kernel_vec].pragma(paxis, "parallel_stride_pattern") + s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule C + n, co, h, w = s[last].op.axis + co, vc = s[last].split(co, VC) + oh, ow, vh, vw = s[last].tile(h, w, VH, VW) + s[last].reorder(n, co, oh, ow, vh, vw, vc) + if last != output: + s[output].compute_inline() + s[conv_out].compute_at(s[last], ow) + + if bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[last].split(co, bc) + oaxis = oco + paxis = ico + + s[last].parallel(paxis) + s[last].pragma(oaxis, "parallel_launch_point") + s[last].pragma(paxis, "parallel_stride_pattern") + s[last].pragma(oaxis, "parallel_barrier_when_finish") + + return s + +def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec, + kernel, kernel_q, kernel_vec, + conv_out, output, last): + # no stride and padding info here + _, IH, IW, CI, IB = data_q.shape + KH, KW, _, CO, KB = kernel_q.shape + _, OH, OW, _ = output.shape + # Infer padding and stride + if data_pad is None: + padding = (0, 0) + TH, TW = IH, IW + else: + _, TH, TW, _, _ = data_pad.shape + hpad = get_const_int((TH - IH) // 2) + wpad = get_const_int((TW - IW) // 2) + padding = (hpad, wpad) + + hstride = get_const_int((TH - KH) // (OH - 1)) + wstride = get_const_int((TW - KW) // (OW - 1)) + stride = (hstride, wstride) + + wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC") + sch = _get_schedule(wkl, "NHWC") + VH = sch.vh + VW = sch.vw + VC = sch.vc + ba = sch.ba + bc = sch.bc + + ##### Schedule data packing + if data_pad is not None: + s[data_pad].compute_inline() + + _, h, _, _, _, _, _ = s[data_vec].op.axis + if ba == 1: + oaxis = h + paxis = h + else: + oh, ih = s[data_vec].split(h, ba) + oaxis = oh + paxis = ih + s[data_vec].parallel(paxis) + s[data_vec].pragma(oaxis, "parallel_launch_point") + s[data_vec].pragma(paxis, "parallel_stride_pattern") + s[data_vec].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule kernel packing + co, _, _, _, _, _ = s[kernel_vec].op.axis + if bc == 1: + oaxis = co + paxis = co + else: + oco, ico = s[kernel_vec].split(co, bc) + oaxis = oco + paxis = ico + + s[kernel_vec].parallel(paxis) + s[kernel_vec].pragma(oaxis, "parallel_launch_point") + s[kernel_vec].pragma(paxis, "parallel_stride_pattern") + s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish") + + + ##### Schedule Convolution + n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis + dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis + + s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2) + + s[conv_out].unroll(b1) + s[conv_out].unroll(b2) + s[conv_out].vectorize(vc) + + # # Schedule output + n, h, w, co = s[last].op.axis + co, vc = s[last].split(co, VC) + oh, ow, vh, vw = s[last].tile(h, w, VH, VW) + s[last].reorder(n, oh, ow, co, vh, vw, vc) + s[last].vectorize(vc) + if last != output: + s[output].compute_inline() + s[conv_out].compute_at(s[last], ow) + + if bc == 1: + oaxis = oh + paxis = oh + else: + oho, iho = s[last].split(oh, bc) + oaxis = oho + paxis = iho + + s[last].parallel(paxis) + s[last].pragma(oaxis, "parallel_launch_point") + s[last].pragma(paxis, "parallel_stride_pattern") + s[last].pragma(oaxis, "parallel_barrier_when_finish") + + return s diff --git a/topi/tests/python/test_topi_bitserial_conv2d.py b/topi/tests/python/test_topi_bitserial_conv2d.py new file mode 100644 index 000000000000..6df18483a45f --- /dev/null +++ b/topi/tests/python/test_topi_bitserial_conv2d.py @@ -0,0 +1,112 @@ +import os +import numpy as np +import tvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize + +def generate_quantized_np(shape, bits, out_dtype): + min_val = 0 + max_val = 1 << bits + return np.random.randint(min_val, max_val, size=shape).astype(out_dtype) + +def verify_bitserial_conv2d_nchw(batch, in_size, in_channel, num_filter, kernel, stride, padding, + activation_bits, weight_bits, dorefa): + in_height = in_width = in_size + input_type='uint32' + out_dtype='int32' + + with tvm.target.create('llvm'): + A = tvm.placeholder((batch, in_channel, in_height, in_width), dtype=input_type, name='A') + W = tvm.placeholder((num_filter, in_channel, kernel, kernel), dtype=input_type, name='W') + B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, + out_dtype=out_dtype, layout="NCHW", dorefa=dorefa) + s = topi.generic.schedule_bitserial_conv2d_nchw([B]) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + def get_ref_data(): + a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) + w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) + if dorefa: + w_ = np.copy(w_np).astype(out_dtype) + for x in np.nditer(w_, op_flags=['readwrite']): + x[...] = 1 if x == 1 else -1 + b_np = topi.testing.conv2d_nchw_python(a_np.astype(out_dtype), w_, stride, padding) + else: + b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + ctx = tvm.cpu(0) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], "llvm") + func(a, w, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + +def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, + activation_bits, weight_bits, dorefa): + in_height = in_width = in_size + input_type='uint32' + out_dtype='int32' + + with tvm.target.create('llvm'): + A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') + B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, + layout="NHWC", dorefa=dorefa) + s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + def get_ref_data(): + a_np = generate_quantized_np(get_const_tuple(A.shape), activation_bits, input_type) + w_np = generate_quantized_np(get_const_tuple(W.shape), weight_bits, input_type) + if dorefa: + w_ = np.copy(w_np).astype(out_dtype) + for x in np.nditer(w_, op_flags=['readwrite']): + x[...] = 1 if x == 1 else -1 + b_np = topi.testing.conv2d_nhwc_python(a_np, w_, stride, padding).astype(out_dtype) + else: + b_np = topi.testing.conv2d_nhwc_python(a_np, w_np, stride, padding).astype(out_dtype) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + ctx = tvm.cpu(0) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], 'llvm') + + func(a, w, b) + np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + +def test_bitserial_conv2d(): + in_size = 56 + ic, oc = 64, 64 + k = 3 + stride = 1 + pad = 1 + verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, True) + verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, True) + verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 1, 1, False) + verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 1, False) + verify_bitserial_conv2d_nchw(1, in_size, ic, oc, k, stride, pad, 2, 2, False) + + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, True) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, True) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 2, False) + +if __name__ == "__main__": + test_bitserial_conv2d() \ No newline at end of file diff --git a/topi/tests/python/test_topi_bitserial_conv2d_rasp.py b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py new file mode 100644 index 000000000000..5789c5496205 --- /dev/null +++ b/topi/tests/python/test_topi_bitserial_conv2d_rasp.py @@ -0,0 +1,56 @@ +import os +import re +import numpy as np +import tvm +import topi +import topi.testing +from topi.util import get_const_tuple +from tvm.contrib import util + +target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon' + +def generate_quantized_np(shape, bits, out_dtype): + np.random.seed(0) + min_val = 0 + max_val = 1 << bits + return np.random.randint(min_val, max_val, size=shape).astype(out_dtype) + +# Verify that certain special instructions from the tensorize pass exist +def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel, stride, padding, + activation_bits, weight_bits, dorefa): + in_height = in_width = in_size + input_type='uint32' + out_dtype='int32' + + with tvm.target.rasp(): + A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W') + B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype, + layout="NHWC", dorefa=dorefa) + s = topi.generic.schedule_bitserial_conv2d_nhwc([B]) + + + func = tvm.build(s, [A, W, B], target) + + assembly = func.get_source('asm') + matches = re.findall("vpadal", assembly) + assert (len(matches) > 0) + matches = re.findall("vcnt", assembly) + assert (len(matches) > 0) + matches = re.findall("vpadd", assembly) + assert (len(matches) > 0) + +def test_bitserial_conv2d(): + in_size = 56 + ic, oc = 64, 64 + k = 3 + stride = 1 + pad = 1 + + + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 1, 1, False) + verify_bitserial_conv2d_nhwc(1, in_size, ic, oc, k, stride, pad, 2, 1, False) + +if __name__ == "__main__": + test_bitserial_conv2d() +