From fc766e7b9f39ae1e5bae90df64fb797fa00c823a Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 13:55:28 -0700 Subject: [PATCH 01/14] Added arm_cpu NHWC schedules. --- topi/python/topi/arm_cpu/conv2d.py | 147 ++++++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 77b37ed5a1e2..628baf5dadb1 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -75,8 +75,12 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - num_tile=2) + if layout == "NCHW": + return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, + num_tile=2) + else: + return _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, + out_dtype, num_tile=2) @autotvm.register_topi_schedule( @@ -118,8 +122,12 @@ def _callback(op): kernel = kernel_vec if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - - _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + + # TODO: move to schedule_nhwc later + if 'nhwc' in op.tag: + _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + else: + _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) if 'winograd_conv2d_output' in op.tag: output = op.output(0) @@ -243,6 +251,90 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou name='output_unpack', tag='spatial_conv2d_output') return output +def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile): + assert layout == "NHWC", "Only support NHWC" + # create workload according to raw arguments + out_dtype = out_dtype or data.dtype + N, IH, IW, CI = get_const_tuple(data.shape) + + # TODO? Dilation + + if len(kernel.shape) == 4: + pre_packed = False + KH, KW, _, CO = get_const_tuple(kernel.shape) + else: + pre_packed = True + CO, _, KH, KW, VC = get_const_tuple(kernel.shape) + CO = CO * VC + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (KH, KW)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + OW = (IW + pad_left + pad_right - KW) // WSTR + 1 + data_pad = pad(data, [0, pad_top, pad_left, 0], [0, pad_bottom, pad_right, 0]) + + # ==================== define configuration space ==================== + n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) + ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + + if num_tile == 2: # for arm cpu + co, vc = cfg.define_split('tile_co', co, num_outputs=2) + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) + elif num_tile == 3: # for mali gpu + co, _, vc = cfg.define_split('tile_co', co, num_outputs=3) + oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3) + ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3) + else: + raise RuntimeError("Invalid num_tile") + + cfg.define_reorder("reorder_0", + [n, oh, ow, co, ci, kh, kw, vh, vc, vw], + policy='candidate', candidate=[ + [n, oh, ow, co, ci, kh, kw, vh, vc, vw], + [n, oh, ow, co, ci, kh, kw, vc, vh, vw]]) + + cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') + cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') + + # fallback support + # if cfg.is_fallback: + # if num_tile == 2: # arm cpu + # ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') + # cfg.fallback_with_reference_log(ref_log) + + VC = cfg["tile_co"].size[-1] + VH = cfg["tile_oh"].size[-1] + VW = cfg["tile_ow"].size[-1] + + kvshape = (CO // VC, CI, KH, KW, VC) + ovshape = (N, OH // VH, OW // VW, CO // VC, VH, VW, VC) + oshape = (N, OH, OW, CO) + + # undilate input data + dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1) + data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: + data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], + name='data_vec') + + kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc: + kernel[co*VC+vc][ci][kh][kw], + name='kernel_vec') + + ci = tvm.reduce_axis((0, CI), name='ci') + kh = tvm.reduce_axis((0, KH), name='kh') + kw = tvm.reduce_axis((0, KW), name='kw') + + conv = tvm.compute(ovshape, lambda n, h, w, co, vh, vw, vc: \ + tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) * + kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), + axis=[ci, kh, kw]), name='conv') + + output = tvm.compute(oshape, lambda n, h, w, co: + conv[n][h//VH][w//VW][CO // VC][h%VH][w%VW][co%VC], + name='output_unpack', tag='spatial_conv2d_output_nhwc') + return output + def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): """schedule implementation""" @@ -303,6 +395,53 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, return s +def _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, + conv, output, last): + """schedule implementation""" + n, oh, ow, co, vh, vw, vc = s[conv].op.axis + ci, kh, kw = s[conv].op.reduce_axis + + # schedule conv + cfg["reorder_0"].apply(s, conv, [n, oh, ow, co, ci, kh, kw, vh, vw, vc]) + cfg["ann_reduce"].apply(s, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=16, + cfg=cfg) + cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + + # schedule fusion + n, h, w, co = s[last].op.axis + co, vc = cfg['tile_co'].apply(s, last, co) + oh, vh = cfg['tile_oh'].apply(s, last, h) + ow, vw = cfg['tile_ow'].apply(s, last, w) + s[last].reorder(n, oh, ow, co, vh, vw, vc) + if last != output: + s[output].compute_inline() + cfg["ann_spatial"].apply(s, last, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=16, + cfg=cfg) + s[conv].compute_at(s[last], co) + + # mark parallel + s[last].parallel(oh) + + _, h, _, _, _, _ = s[data_vec].op.axis + s[data_vec].parallel(h) + + co, _, _, _, _ = s[kernel_vec].op.axis + s[kernel_vec].parallel(co) + return s + + @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """ TOPI compute callback. Use winograd template """ From a41a6ca7061383e47457f4246d7734a336cd435e Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 14:16:12 -0700 Subject: [PATCH 02/14] Fixed kernel shape legalization. --- topi/python/topi/arm_cpu/conv2d.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 628baf5dadb1..56d8e7755995 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -949,25 +949,19 @@ def _conv2d_legalize(attrs, inputs, arg_types): if attrs['data_layout'] == 'NHWC': data, kernel = inputs if attrs['kernel_layout'] == 'HWIO': - # Handle HWIO layout. This is common in TF graph. - kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) + # HWIO layout is expected for NHWC input. + return None elif attrs['kernel_layout'] == 'HWOI': # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) - elif attrs['kernel_layout'] != 'OIHW': - return None + kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) + elif attrs['kernel_layout'] == 'OIHW': + kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) - logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " - + "fallback to NCHW. This can result in performance degradation.") - # Set new attrs for the tranposed conv. + ## Set new attrs for the tranposed conv. new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs['data_layout'] = 'NCHW' - new_attrs['kernel_layout'] = 'OIHW' + new_attrs['data_layout'] = 'NHWC' + new_attrs['kernel_layout'] = 'HWIO' - # Convert from NHWC to NCHW. - data = relay.transpose(data, axes=(0, 3, 1, 2)) conv = relay.nn.conv2d(data, kernel, **new_attrs) - # Convert back to original NHWC layout. - out = relay.transpose(conv, axes=(0, 2, 3, 1)) - return out + return conv return None From 20cba659b3e5d9f5e3c8883498442707edf720e2 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 15:20:28 -0700 Subject: [PATCH 03/14] Added bitserial ops to relay. --- python/tvm/relay/op/nn/_nn.py | 97 ++++++++++++ python/tvm/relay/op/nn/nn.py | 146 +++++++++++++++++++ python/tvm/relay/op/op_attrs.py | 15 ++ topi/python/topi/arm_cpu/bitserial_conv2d.py | 41 +++++- topi/python/topi/nn/bitserial_conv2d.py | 21 +++ topi/python/topi/nn/bitserial_util.py | 23 ++- 6 files changed, 341 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 03a04c951d59..216010d088f4 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -600,3 +600,100 @@ def schedule_deformable_conv2d(attrs, outs, target): reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +@reg.register_compute("nn.bitpack") +def compute_bitpack(attrs, inputs, out_dtype, target): + """Compute definition for bitpack""" + bits = attrs.bits + pack_axis = attrs.pack_axis + bit_axis = attrs.bit_axis + pack_type = attrs.pack_type + name = attrs.name + with target: + out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, + name) + return [out] + +@reg.register_schedule("nn.bitpack") +def schedule_bitpack(attrs, outs, target): + with target: + return topi.generic.schedule_bitpack(outs) + +reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE) + + +@reg.register_compute("nn.bitserial_conv2d") +def compute_bitserial_conv2d(attrs, inputs, out_dtype, target): + """Compute definition for bitserial conv2d.""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + activation_bits = attrs.activation_bits + weight_bits = attrs.weight_bits + layout = attrs.data_layout + pack_dtype = attrs.pack_dtype + out_dtype = attrs.out_dtype + unipolar = attrs.unipolar + if layout == 'NCHW': + with target: + out = topi.nn.bitserial_conv2d_nchw( + inputs[0], inputs[1], strides, padding, activation_bits, + weight_bits, pack_dtype, out_dtype, unipolar) + elif layout == 'NHWC': + with target: + out = topi.nn.bitserial_conv2d_nhwc( + inputs[0], inputs[1], strides, padding, activation_bits, + weight_bits, pack_dtype, out_dtype, unipolar) + else: + raise ValueError("Data layout not supported.") + + return [out] + + +@reg.register_schedule("nn.bitserial_conv2d") +def schedule_bitserial_conv2d(attrs, outs, target): + """Schedule definition for bitserial conv2d.""" + layout = attrs.data_layout + if layout == 'NCHW': + with target: + return topi.generic.schedule_bitserial_conv2d_nchw(outs) + elif layout == 'NHWC': + with target: + return topi.generic.schedule_bitserial_conv2d_nhwc(outs) + else: + raise ValueError("Data layout not supported.") + + +reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# bitserial_dense +@reg.register_compute("nn.bitserial_dense") +def compute_bitserial_dense(attrs, inputs, out_type, target): + """Compute definition of bitserial_dense""" + data_bits = attrs.data_bits + weight_bits = attrs.weight_bits + pack_dtype = attrs.pack_dtype + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + unipolar = attrs.unipolar + return [ + topi.nn.bitserial_dense( + inputs[0], + inputs[1], + data_bits, + weight_bits, + pack_dtype, + out_dtype, + unipolar) + ] + + +@reg.register_schedule("nn.bitserial_dense") +def schedule_bitserial_dense(attrs, outputs, target): + """Schedule definition of bitserial_dense""" + with target: + return topi.generic.schedule_bitserial_dense(outputs) + + +reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 946ea335e0db..74b2a5373377 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1459,3 +1459,149 @@ def deformable_conv2d(data, return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) + + +def bitpack(data, + bits=1, + pack_axis=1, + bit_axis=2, + pack_type="uint32", + name="BitPack"): + r"""Tensor packing for bitserial operations. + + Parameters + ---------- + data : tvm.relay.expr + The incoming tensor to be packed. + + bits : int + Number of bits that should be packed. + + pack_axis : int + Axis that should be decomposed and packed. + + bit_axis : int + New axis containing bitplane. + + pack_type : str + Datatype to pack bits into. + + name : str, optional + Name of the operation. + + Returns + ------- + result : tvm.relay.Expr + The packed tensor. + """ + return _make.bitpack(data, bits, pack_axis, bit_axis, pack_type, name) + + +def bitserial_conv2d(data, + weight, + strides=(1, 1), + padding=(0, 0), + channels=None, + kernel_size=(3, 3), + activation_bits=1, + weight_bits=1, + data_layout='NCHW', + pack_dtype='uint32', + out_dtype='int16', + unipolar=True): + r"""2D convolution using bitserial computation. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + activation_bits : int + Number of bits to pack for activations. + + weight_bits : int + Number of bits to pack for weights. + + data_layout : str, optional + Layout of the input. + + pack_dtype: str, optional + Datatype to pack bits into. + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + return _make.bitserial_conv2d(data, weight, strides, padding, channels, + kernel_size, activation_bits, weight_bits, + data_layout, pack_dtype, out_dtype, unipolar) + + +def bitserial_dense(data, + weight, + units=None, + data_bits=1, + weight_bits=1, + pack_dtype='uint32', + out_dtype='int16', + unipolar=True): + """Bitserial Dense operator. + Applies a linear transformation + + .. math:: + + `Y = X * W` + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + units : int, optional + Number of hidden units of the dense transformation. + + data_bits : int + Number of bits incoming tensor should be packed with. + + weight_bits : int + Number of bits weight tensor should be packed with. + + pack_dtype : str, optional + Datatype to pack individual bits into before computation. + + out_dtype : str, optional + Specifies the output data type for mixed precision dense. + + unipolar : bool, optional + Whether to use unipolar or bipolar quantization for inputs. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.bitserial_dense(data, weight, units, data_bits, weight_bits, + pack_dtype, out_dtype, unipolar) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 48d3d2032f80..11f8ad1611cd 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -264,3 +264,18 @@ class MaxPool2DAttrs(Attrs): @register_relay_attr_node class AvgPool2DAttrs(Attrs): """Attributes used in avg_pool2d operators""" + + +@register_relay_attr_node +class BitPackAttrs(Attrs): + """Attributes used in bitpack operator""" + + +@register_relay_attr_node +class BinaryConv2DAttrs(Attrs): + """Attributes used in bitserial conv2d operators""" + + +@register_relay_attr_node +class BinaryDenseAttrs(Attrs): + """Attributes used in bitserial dense operators""" diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index 4198267cac60..c4b2125dd748 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -21,7 +21,7 @@ from tvm import autotvm from .. import tag from ..nn.pad import pad -from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc +from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize from ..nn.bitserial_util import bitpack, binary_op_multiplier from ..nn.util import get_pad_tuple from ..util import get_const_int, get_const_tuple @@ -350,3 +350,42 @@ def traverse(op): traverse(outs[0].op) return s + +@bitserial_conv2d_legalize.register("arm_cpu") +def _bitserial_conv2d_legalize(attrs, inputs, arg_types): + """Legalizes Bitserial Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + + if attrs['data_layout'] == 'NHWC': + data, kernel = inputs + if attrs['kernel_layout'] == 'HWIO': + # HWIO layout is expected for NHWC input. + return None + elif attrs['kernel_layout'] == 'HWOI': + # Handle HWOI layout. This is common in TF depthwise conv2d graph. + kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) + elif attrs['kernel_layout'] == 'OIHW': + kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) + + ## Set new attrs for the tranposed conv. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['data_layout'] = 'NHWC' + new_attrs['kernel_layout'] = 'HWIO' + + conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs) + return conv + return None diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 99cac889deea..a7e62c7c1ede 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -414,3 +414,24 @@ def _conv(n, h, w, co, vh, vw, vc): return tvm.compute(oshape, lambda n, h, w, co: conv[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC], name='output_unpack', tag='spatial_bitserial_conv_nhwc') + +@tvm.target.generic_func +def bitserial_conv2d_legalize(attrs, inputs, types): + """Legalizes Bitserial Conv2D op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + return None diff --git a/topi/python/topi/nn/bitserial_util.py b/topi/python/topi/nn/bitserial_util.py index 09a301f7c962..1fd10d35ed6a 100644 --- a/topi/python/topi/nn/bitserial_util.py +++ b/topi/python/topi/nn/bitserial_util.py @@ -88,4 +88,25 @@ def binary_op_multiplier(pack_dtype): pack_dtype: string pack type for the operator (must be a uint)""" return int(pack_dtype[4:]) - \ No newline at end of file + + +@tvm.target.generic_func +def bitpack_legalize(attrs, inputs, types): + """Legalizes Bitpack op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + return None From 27497942f93abf3aac3ae379b4e068768532570c Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 19:33:53 -0700 Subject: [PATCH 04/14] Snapshot and more missing files. --- include/tvm/relay/attrs/bitserial.h | 113 +++++++++ python/tvm/relay/op/nn/_nn.py | 20 ++ python/tvm/relay/op/nn/nn.py | 7 +- src/relay/op/nn/bitserial.cc | 231 +++++++++++++++++++ topi/python/topi/arm_cpu/bitserial_conv2d.py | 29 ++- topi/python/topi/generic/nn.py | 17 ++ topi/python/topi/nn/bitserial_conv2d.py | 5 +- topi/python/topi/nn/bitserial_util.py | 22 -- 8 files changed, 405 insertions(+), 39 deletions(-) create mode 100644 include/tvm/relay/attrs/bitserial.h create mode 100644 src/relay/op/nn/bitserial.cc diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h new file mode 100644 index 000000000000..36d429d987f4 --- /dev/null +++ b/include/tvm/relay/attrs/bitserial.h @@ -0,0 +1,113 @@ +#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_ +#define TVM_RELAY_ATTRS_BITSERIAL_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attributes used in bitpack operators */ +struct BitPackAttrs : public tvm::AttrsNode { + int bits; + int pack_axis; + int bit_axis; + DataType pack_type; + std::string name; + + TVM_DECLARE_ATTRS(BitPackAttrs, "relay.attrs.BitPackAttrs") { + TVM_ATTR_FIELD(bits).set_default(1).describe("Number of bits to quantize with."); + TVM_ATTR_FIELD(pack_axis).set_default(1).describe( + "Axis that should be compressed, typically channels."); + TVM_ATTR_FIELD(bit_axis).set_default(-1).describe("New axis for packed bits."); + TVM_ATTR_FIELD(pack_type) + .set_default(NullValue()) + .describe("Type of int to pack bits into."); + TVM_ATTR_FIELD(name).set_default("BitPack").describe("Name of operation."); + } +}; + +/*! \brief Attribues used in bitserial convolution operators */ +struct BinaryConv2DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + IndexExpr channels; + Array kernel_size; + int activation_bits; + int weight_bits; + std::string data_layout; + std::string kernel_layout; + DataType pack_dtype; + DataType out_dtype; + bool unipolar; + + TVM_DECLARE_ATTRS(BinaryConv2DAttrs, "relay.attrs.BinaryConv2DAttrs") { + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero the input is implicitly zero-padded" + "on both sides for padding number of points."); + TVM_ATTR_FIELD(kernel_size) + .set_default(Array({3, 3})) + .describe("Specifies the dimensions of the convolution window."); + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe("Number of output channels, needed for shape inference."); + TVM_ATTR_FIELD(activation_bits) + .set_default(1) + .describe("Number of bits activation should be packed with."); + TVM_ATTR_FIELD(weight_bits) + .set_default(1) + .describe("Number of bits kernel should be packed with."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe("Dimension ordering of input data, can be 'NCHW' or NHWC'."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe("Dimension ordering of kernel data, can be 'OIHW' or HWIO'."); + TVM_ATTR_FIELD(pack_dtype) + .set_default(NullValue()) + .describe("Datatype to pack bits into."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output datatype."); + TVM_ATTR_FIELD(unipolar).set_default(true).describe( + "Whether to use unipolar or bipolar quantization."); + } +}; + +/*~ \brief Attributes for bitserial dense operator */ +struct BinaryDenseAttrs : public tvm::AttrsNode { + IndexExpr units; + int data_bits; + int weight_bits; + DataType pack_dtype; + DataType out_dtype; + bool unipolar; + + TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") { + TVM_ATTR_FIELD(units) + .describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(data_bits) + .set_default(1) + .describe("Number of bits to pack for incoming tensor."); + TVM_ATTR_FIELD(weight_bits) + .set_default(1) + .describe("Number of bits to pack for weight tensor."); + TVM_ATTR_FIELD(pack_dtype) + .set_default(NullValue()) + .describe("Datatype to pack bits into before computation."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type."); + TVM_ATTR_FIELD(unipolar) + .set_default(true) + .describe("Whether to use unipolar or bipolar quantization for inputs."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_BITSERIAL_H_ diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 216010d088f4..d652977924ca 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -663,6 +663,26 @@ def schedule_bitserial_conv2d(attrs, outs, target): else: raise ValueError("Data layout not supported.") +@reg.register_legalize("nn.bitserial_conv2d") +def legalize_bitserial_conv2d(attrs, inputs, types): + """Legalize bitserial_conv2d op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types) + reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 74b2a5373377..bfdf1fd3dfe4 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1506,6 +1506,7 @@ def bitserial_conv2d(data, activation_bits=1, weight_bits=1, data_layout='NCHW', + kernel_layout='OIHW', pack_dtype='uint32', out_dtype='int16', unipolar=True): @@ -1540,6 +1541,9 @@ def bitserial_conv2d(data, data_layout : str, optional Layout of the input. + kernel_layout : str, optional + Layout of the kernel + pack_dtype: str, optional Datatype to pack bits into. @@ -1554,7 +1558,8 @@ def bitserial_conv2d(data, return _make.bitserial_conv2d(data, weight, strides, padding, channels, kernel_size, activation_bits, weight_bits, - data_layout, pack_dtype, out_dtype, unipolar) + data_layout, kernel_layout, pack_dtype, + out_dtype, unipolar) def bitserial_dense(data, diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc new file mode 100644 index 000000000000..af96f01b18c1 --- /dev/null +++ b/src/relay/op/nn/bitserial.cc @@ -0,0 +1,231 @@ +#include +#include +#include + +#include "../../pass/alter_op_layout.h" + +namespace tvm { +namespace relay { + +// relay.nn.bitpack +TVM_REGISTER_NODE_TYPE(BitPackAttrs); + +template +Array> BinaryConv2DInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + const T* params = attrs.as(); + + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs + return Array>{{params->data_layout, params->kernel_layout}, {params->data_layout}}; +} + +bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const BitPackAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data); + int ndim = data->shape.size(); + int bits = param->bits; + int pack_axis = param->pack_axis; + int bit_axis = param->bit_axis; + DataType pack_type = param->pack_type; + + int pack_bits = pack_type.bits(); + + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i == bit_axis) { + out_shape.push_back(bits); + if (i == pack_axis) { + out_shape.push_back(data->shape[i] / pack_bits); + } else { + out_shape.push_back(data->shape[i]); + } + } else if (i == pack_axis) { + out_shape.push_back(data->shape[i] / pack_bits); + } else { + out_shape.push_back(data->shape[i]); + } + } + // Add extra check for last axis expansion. + if (bit_axis == ndim) { + out_shape.push_back(bits); + } + + reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type)); + return true; +} + +Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, + std::string name) { + auto attrs = make_node(); + attrs->bits = bits; + attrs->pack_axis = pack_axis; + attrs->bit_axis = bit_axis; + attrs->pack_type = pack_type; + attrs->name = name; + static const Op& op = Op::Get("nn.bitpack"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack); + +RELAY_REGISTER_OP("nn.bitpack") + .describe(R"code(Bitpack layer that prepares data for bitserial operations. + +This layer backs the bits of an input into a single datatype, allowing +efficient implementation of bitserial operations. + +- **data**: Input tensor of any shape, dimension that is to be + packed must be divisible by number of bits. +- **out**: Packed tensor with shape appropriately compressed. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_attrs_type_key("relay.attrs.BitPackAttrs") + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("BitPack", BitPackRel); + +// relay.nn.bitserial_conv2d +TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); + +bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const BinaryConv2DAttrs* param = attrs.as(); + CHECK(param != nullptr); + + static const Layout kNCHW("NCHW"); + + const Layout in_layout(param->data_layout); + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + CHECK(param->channels.defined()); + CHECK(param->kernel_size.defined()); + Array oshape({dshape_nchw[0], param->channels, 0, 0}); + oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); + DataType out_dtype = param->out_dtype; + oshape = trans_in_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; + +} + +// Positional relay function to create binaryconv2d operator +// used by frontend FFI. +Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array padding, + IndexExpr channels, Array kernel_size, int activation_bits, + int weight_bits, std::string data_layout, std::string kernel_layout, + DataType pack_dtype, DataType out_dtype, bool unipolar) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->activation_bits = activation_bits; + attrs->weight_bits = weight_bits; + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->pack_dtype = std::move(pack_dtype); + attrs->out_dtype = std::move(out_dtype); + attrs->unipolar = unipolar; + static const Op& op = Op::Get("nn.bitserial_conv2d"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D); + +RELAY_REGISTER_OP("nn.bitserial_conv2d") + .describe(R"code(2D convolution using packed binary computation. + +This layer creates a convolution kernel that is convolved with the +layer input using bitserial computation. This enables faster processing +on some platforms. + +- **data**: 4D input tensor that can be either `NCHW` or `NHWC` layout. + +- **weight**: Weight tensor that can either be prepacked (5D) or unpacked (4D). + When data is NCHW, weight is expected to be OIHW or OIHWi. + When data is NHWC weight is expected to be HWIO or HWIOi. + +- **out**: Output with same layout as input. +)code" TVM_ADD_FILELINE) + .set_attrs_type_key("relay.attrs.BinaryConv2DAttrs") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("BinaryConv2D", BinaryConv2DRel) + .set_attr("FInferCorrectLayout", + BinaryConv2DInferCorrectLayout); + +// relay.nn.bitserial_dense +TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); + +bool BinaryDenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + const BinaryDenseAttrs* param = attrs.as(); + CHECK(param != nullptr); + + CHECK(static_cast(data->shape.size()) != 0); + CHECK(param->units.defined()); + + Array oshape = data->shape; + oshape.Set((oshape.size() - 1), param->units); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // Assign output type. + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + +// Positional relay function to create bitserial dense operator used by frontend FFI. +Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int weight_bits, + DataType pack_dtype, DataType out_dtype, bool unipolar) { + auto attrs = make_node(); + attrs->units = units; + attrs->data_bits = data_bits; + attrs->weight_bits = weight_bits; + attrs->pack_dtype = pack_dtype; + attrs->out_dtype = out_dtype; + attrs->unipolar = unipolar; + static const Op& op = Op::Get("nn.bitserial_dense"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense); + +RELAY_REGISTER_OP("nn.bitserial_dense") + .describe(R"code(Applies a quantized linear transformation: :math:`Y = XW^T`. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units, input_dim)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type_key("relay.attrs.BinaryDenseAttrs") + .set_num_inputs(2) + .add_argument("data", "2D Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("BinaryDense", BinaryDenseRel); + +} // namespace relay +} // namespace tvm diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index c4b2125dd748..b34301f71016 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs import tvm from tvm import autotvm +from tvm import relay from .. import tag from ..nn.pad import pad from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_legalize @@ -370,22 +371,20 @@ def _bitserial_conv2d_legalize(attrs, inputs, arg_types): The legalized expr """ + # Fix different kernel layouts where possible. if attrs['data_layout'] == 'NHWC': data, kernel = inputs - if attrs['kernel_layout'] == 'HWIO': + if len(kernel.data.shape) == 4: # HWIO layout is expected for NHWC input. - return None - elif attrs['kernel_layout'] == 'HWOI': - # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) - elif attrs['kernel_layout'] == 'OIHW': - kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) - - ## Set new attrs for the tranposed conv. - new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs['data_layout'] = 'NHWC' - new_attrs['kernel_layout'] = 'HWIO' - - conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs) - return conv + if attrs['kernel_layout'] == 'HWOI': + # Handle HWOI layout. This is common in TF depthwise conv2d graph. + kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) + elif attrs['kernel_layout'] == 'OIHW': + kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) + ## Set new attrs for the tranposed conv. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['kernel_layout'] = 'HWIO' + + conv = relay.nn.bitserial_conv2d(data, kernel, **new_attrs) + return conv return None diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 38b66320b428..8fbedec3fef1 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -470,6 +470,23 @@ def schedule_binarize_pack(outs): return _default_schedule(outs, False) +@tvm.target.override_native_generic_func("schedule_bitpack") +def schedule_bitpack(outs): + """Schedule for bitpack + Parameters + ---------- + outs: Array of Tensor + The computation graph description of bitpack + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.override_native_generic_func("schedule_binary_dense") def schedule_binary_dense(outs): """Schedule for binary_dense diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index a7e62c7c1ede..007e7c14ee24 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -65,7 +65,10 @@ def bitserial_conv2d_nchw(data, kernel, stride, padding, activation_bits, weight """ assert isinstance(stride, int) or len(stride) == 2 Input_q = bitpack(data, activation_bits, pack_axis=1, bit_axis=2, pack_type=pack_dtype) - Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + if len(filter.shape) == 4: + Filter_q = bitpack(filter, weight_bits, pack_axis=1, bit_axis=4, pack_type=pack_dtype) + else: + Filter_q = filter batch, in_channel, activation_bits, in_height, in_width = Input_q.shape num_filter, _, kernel_h, kernel_w, weight_bits = Filter_q.shape diff --git a/topi/python/topi/nn/bitserial_util.py b/topi/python/topi/nn/bitserial_util.py index 1fd10d35ed6a..35a03a5b1e64 100644 --- a/topi/python/topi/nn/bitserial_util.py +++ b/topi/python/topi/nn/bitserial_util.py @@ -88,25 +88,3 @@ def binary_op_multiplier(pack_dtype): pack_dtype: string pack type for the operator (must be a uint)""" return int(pack_dtype[4:]) - - -@tvm.target.generic_func -def bitpack_legalize(attrs, inputs, types): - """Legalizes Bitpack op. - - Parameters - ---------- - attrs : tvm.attrs.Attrs - Attributes of current convolution - inputs : list of tvm.relay.Expr - The args of the Relay expr to be legalized - types : list of types - List of input and output types - - Returns - ------- - result : tvm.relay.Expr - The legalized expr - """ - # not to change by default - return None From a7e11d89016a4f3e26be027f6434a6cf9a98543c Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 20:55:36 -0700 Subject: [PATCH 05/14] Added dense testing. --- tests/python/relay/test_op_level1.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 66e65c5fd409..1b7225294753 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -337,6 +337,34 @@ def test_dense(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +def test_bitserial_dense(): + m, k = tvm.var("m"), tvm.var("k") + x = relay.var("x", relay.TensorType((m, k), "int16")) + w = relay.var("w", relay.TensorType((k, 32), "int16")) + y = relay.nn.bitserial_dense(x, w, units=32) + "units=8" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((m, 32), "int16") + + x = relay.var("x", shape=(32, 32), dtype="int16") + w = relay.var("w", shape=(32, 32), dtype="int16") + z = relay.nn.dense(x, w) + + # Check result. + func = relay.Function([x, w], z) + x_data = np.random.randint(0, 10, size=(32, 32)).astype("int16") + w_data = np.random.randint(0, 10, size=(32, 32)).astype("int16") + ref_res = np.dot(x_data, w_data.T) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) + op_res2 = intrp2.evaluate(func)(x_data, w_data) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) + + if __name__ == "__main__": test_concatenate() test_bias_add() @@ -349,3 +377,4 @@ def test_dense(): test_dropout() test_batch_norm() test_dense() + test_bitserial_dense() From e24233ea7ae0fbe768b49da18a1997a09ed91095 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:00:11 -0700 Subject: [PATCH 06/14] Added tests --- tests/python/relay/test_op_level1.py | 18 ------------------ tests/python/relay/test_op_level2.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 1b7225294753..c25393cf4026 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -346,24 +346,6 @@ def test_bitserial_dense(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((m, 32), "int16") - x = relay.var("x", shape=(32, 32), dtype="int16") - w = relay.var("w", shape=(32, 32), dtype="int16") - z = relay.nn.dense(x, w) - - # Check result. - func = relay.Function([x, w], z) - x_data = np.random.randint(0, 10, size=(32, 32)).astype("int16") - w_data = np.random.randint(0, 10, size=(32, 32)).astype("int16") - ref_res = np.dot(x_data, w_data.T) - - for target, ctx in ctx_list(): - intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) - tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - if __name__ == "__main__": test_concatenate() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 5e9abdf0faf4..a94a203f4d79 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -105,8 +105,8 @@ def run_test_conv2d(dtype, out_dtype, scale, dshape, kshape, except_targets=None, **attrs): if except_targets is None: - except_targets = [] - + except_targets = [] + x = relay.var("x", shape=dshape, dtype=dtype) w = relay.var("w", dtype=dtype) y = relay.nn.conv2d(x, w, @@ -599,12 +599,35 @@ def _compile(input_dtype, weight_dtype, output_dtype, target): assert "vpmulld" in asm and "vpadd" in asm +def test_bitserial_conv2d_infer_type(): + # Basic shape test with ambiguous batch. + n, c, h, w = tvm.var("n"), 32, 224, 224 + x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16")) + w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16")) + y = relay.nn.bitserial_conv2d( + x, w, kernel_size=(3, 3), padding=(0, 0), channels=32) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 32, 222, 222), "int16") + + +def test_bitpack_infer_type(): + # Test axis packing shape inference. + o, i, h, w = 32, 32, 128, 128 + x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16")) + y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type='uint16', bits=1) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (32, 2, 128, 128, 1), "uint16") + + if __name__ == "__main__": test_pool2d() test_avg_pool2d_no_count_pad() test_lrn() test_l2_normalize() test_conv2d_infer_type() + test_bitpack_infer_type() test_upsampling_infer_type() test_flatten_infer_type() test_pad_infer_type() @@ -612,6 +635,7 @@ def _compile(input_dtype, weight_dtype, output_dtype, target): test_conv2d_transpose_infer_type() test_conv2d_transpose_run() test_conv2d_run() + test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() test_conv2d_int8_intrinsics() From 7ce8d7b981b8a9035d3eedbee73fe9bfe0e445f6 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:07:25 -0700 Subject: [PATCH 07/14] Added ASF header to new files. --- include/tvm/relay/attrs/bitserial.h | 24 ++++++++++++++++++++++++ src/relay/op/nn/bitserial.cc | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h index 36d429d987f4..2a7376b72e64 100644 --- a/include/tvm/relay/attrs/bitserial.h +++ b/include/tvm/relay/attrs/bitserial.h @@ -1,3 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/bitserial.h + * \brief Auxiliary attributes for bitserial operators. + */ + #ifndef TVM_RELAY_ATTRS_BITSERIAL_H_ #define TVM_RELAY_ATTRS_BITSERIAL_H_ diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index af96f01b18c1..c9cdb8fb898d 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -1,3 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file bitserial.cc + * \brief Property def of bitserial operators. + */ + #include #include #include From 252f4d5c3a6c4c74d17fbe0a9938311837c6b046 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:10:14 -0700 Subject: [PATCH 08/14] cc lint --- src/relay/op/nn/bitserial.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index c9cdb8fb898d..6ee1ee675c06 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -131,18 +131,19 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr const Layout in_layout(param->data_layout); const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); CHECK(param->channels.defined()); CHECK(param->kernel_size.defined()); Array oshape({dshape_nchw[0], param->channels, 0, 0}); - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); + oshape.Set( + 2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set( + 3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); return true; - } // Positional relay function to create binaryconv2d operator From f48917af58bda9bdc493c1cafd4b5e2741ebb6d0 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:12:21 -0700 Subject: [PATCH 09/14] Pylint change. --- topi/python/topi/nn/bitserial_conv2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 007e7c14ee24..21abdf0de1ec 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, too-many-locals, too-many-arguments +# pylint: disable=unused-argument, redefined-builtin """Bitserial Conv2D operators""" from __future__ import absolute_import as _abs import tvm From 444191bc6513798f83aafb05959edfd94dff6cfe Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:17:05 -0700 Subject: [PATCH 10/14] pylint fixes. --- topi/python/topi/arm_cpu/bitserial_conv2d.py | 2 +- topi/python/topi/arm_cpu/conv2d.py | 30 +++++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index b34301f71016..af9c5bebb998 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,unused-variable,invalid-name +# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument """Bitserial conv2d schedule on arm cpu""" from __future__ import absolute_import as _abs import tvm diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 56d8e7755995..9fb157c3129b 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -122,8 +122,8 @@ def _callback(op): kernel = kernel_vec if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - - # TODO: move to schedule_nhwc later + + # TODO: move to schedule_nhwc later if 'nhwc' in op.tag: _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) else: @@ -251,12 +251,14 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou name='output_unpack', tag='spatial_conv2d_output') return output -def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, num_tile): + +def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, + layout, out_dtype, num_tile): assert layout == "NHWC", "Only support NHWC" # create workload according to raw arguments out_dtype = out_dtype or data.dtype N, IH, IW, CI = get_const_tuple(data.shape) - + # TODO? Dilation if len(kernel.shape) == 4: @@ -267,7 +269,7 @@ def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layou CO, _, KH, KW, VC = get_const_tuple(kernel.shape) CO = CO * VC - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (KH, KW)) + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1 OW = (IW + pad_left + pad_right - KW) // WSTR + 1 @@ -313,13 +315,15 @@ def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layou # undilate input data dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1) - data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: - data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], - name='data_vec') + data_vec = tvm.compute( + dvshape, + lambda n, h, w, ci, vh, vw: data_pad[n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw], + name='data_vec') - kernel_vec = tvm.compute(kvshape, lambda co, ci, kh, kw, vc: - kernel[co*VC+vc][ci][kh][kw], - name='kernel_vec') + kernel_vec = tvm.compute( + kvshape, + lambda co, ci, kh, kw, vc: kernel[co * VC + vc][ci][kh][kw], + name='kernel_vec') ci = tvm.reduce_axis((0, CI), name='ci') kh = tvm.reduce_axis((0, KH), name='kh') @@ -395,8 +399,8 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, return s -def _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, - conv, output, last): +def _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, + last): """schedule implementation""" n, oh, ow, co, vh, vw, vc = s[conv].op.axis ci, kh, kw = s[conv].op.reduce_axis From 99026dc525691450841ec43b86097be3da2b6b33 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:45:58 -0700 Subject: [PATCH 11/14] Change arm legalize test. --- tests/python/relay/test_pass_legalize.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index c5303ef3c4e9..e7e3bbc5d59d 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -183,7 +183,7 @@ def get_output(func, data_val, parameters): out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy() return out - def before(): + def ref(): n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32')) @@ -198,6 +198,21 @@ def before(): func = relay.Function([data, kernel], y) return func + def before(): + n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 + data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) + kernel = relay.var("kernel", relay.TensorType((oc, ic, kh, kw), 'float32')) + y = relay.nn.conv2d(data, kernel, + kernel_size=(kh, kw), + channels=oc, + padding=(1, 1), + dilation=(1, 1), + data_layout='NHWC', + kernel_layout='OIHW', + out_dtype='float32') + func = relay.Function([data, kernel], y) + return func + @register_legalize("nn.conv2d", level=105) def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize @@ -207,10 +222,12 @@ def legalize_conv2d(attrs, inputs, types): b = run_opt_pass(a, transform.Legalize()) assert b.astext().count('transpose') == 3 - wdata = np.random.rand(3, 3, 16, 32) * 10 + wdata = np.random.rand(32, 16, 3, 3) * 10 + wref = wdata.transpose([2, 3, 1, 0]) parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))} + ref_parameters = {"kernel": tvm.nd.array(wref.astype('float32'))} data_val = np.random.rand(1, 224, 224, 16).astype('float32') - ref_out = get_output(a, data_val, parameters) + ref_out = get_output(ref(), data_val, ref_parameters) legalized_out = get_output(b, data_val, parameters) np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01) From 341e2aa0563ad833315f2b6af15d92447b9d94c3 Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Tue, 27 Aug 2019 21:49:04 -0700 Subject: [PATCH 12/14] Added assert check to arm legalize. --- tests/python/relay/test_pass_legalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index e7e3bbc5d59d..fc43b121cff1 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -220,7 +220,7 @@ def legalize_conv2d(attrs, inputs, types): a = before() b = run_opt_pass(a, transform.Legalize()) - assert b.astext().count('transpose') == 3 + assert b.astext().count('transpose') == 1 wdata = np.random.rand(32, 16, 3, 3) * 10 wref = wdata.transpose([2, 3, 1, 0]) From a9d53db09c6458d3a2d44dc11b9b195a659b681a Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Fri, 30 Aug 2019 11:32:48 -0700 Subject: [PATCH 13/14] Added better documentation, fixed some bad style --- python/tvm/contrib/graph_runtime.py | 2 +- python/tvm/relay/op/nn/nn.py | 13 ++++++++++++- topi/python/topi/arm_cpu/conv2d.py | 11 ++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 0c9ce404c48e..ddb8de048f9b 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -129,7 +129,7 @@ def __init__(self, module): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._load_params = module["load_params"] - self._share_params = module["share_params"] + #self._share_params = module["share_params"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index bfdf1fd3dfe4..19c50d6dc700 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1468,6 +1468,16 @@ def bitpack(data, pack_type="uint32", name="BitPack"): r"""Tensor packing for bitserial operations. + The values along the input tensor's pack_axis are quantized + and packed together into the specified pack_type in a new + bit axis. + + For example, consider bitpacking with data to be a tensor with shape [1, 64, 128, 128], + pack_axis=1, bit_axis=4, pack_type=uint8, and bits=2. The output in this case will + be of shape [1, 8, 128, 128, 2]. The dimension of axis 1 has been reduced by a factor + of 8 since each value is packed into an 8-bit uint8. Axis 4 is now two bitplanes + representing the quantized value of the incoming data. The output tensor is now + ready to be used in a bitserial operation. Parameters ---------- @@ -1571,7 +1581,8 @@ def bitserial_dense(data, out_dtype='int16', unipolar=True): """Bitserial Dense operator. - Applies a linear transformation + Applies matrix multiplication of two quantized matrices + using a fast bitserial algorithm. .. math:: diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 9fb157c3129b..f15bbd60e798 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -259,7 +259,8 @@ def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype = out_dtype or data.dtype N, IH, IW, CI = get_const_tuple(data.shape) - # TODO? Dilation + # TODO dilation not currently supported + assert dilation == 1 or tuple(dilation) == (1, 1), "Does not support dilation" if len(kernel.shape) == 4: pre_packed = False @@ -300,10 +301,10 @@ def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') # fallback support - # if cfg.is_fallback: - # if num_tile == 2: # arm cpu - # ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') - # cfg.fallback_with_reference_log(ref_log) + if cfg.is_fallback: + if num_tile == 2: # arm cpu + ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') + cfg.fallback_with_reference_log(ref_log) VC = cfg["tile_co"].size[-1] VH = cfg["tile_oh"].size[-1] From 491e4f28b7a965eda9a05a9c2dacd58be1d4b49c Mon Sep 17 00:00:00 2001 From: Joshua Fromm Date: Fri, 30 Aug 2019 12:07:06 -0700 Subject: [PATCH 14/14] Reverted arm conv2d nhwc changes. --- python/tvm/contrib/graph_runtime.py | 2 +- tests/python/relay/test_pass_legalize.py | 25 +--- topi/python/topi/arm_cpu/conv2d.py | 174 +++-------------------- topi/python/topi/nn/bitserial_util.py | 1 + 4 files changed, 24 insertions(+), 178 deletions(-) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index ddb8de048f9b..0c9ce404c48e 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -129,7 +129,7 @@ def __init__(self, module): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._load_params = module["load_params"] - #self._share_params = module["share_params"] + self._share_params = module["share_params"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index fc43b121cff1..c5303ef3c4e9 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -183,7 +183,7 @@ def get_output(func, data_val, parameters): out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy() return out - def ref(): + def before(): n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32')) @@ -198,21 +198,6 @@ def ref(): func = relay.Function([data, kernel], y) return func - def before(): - n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 - data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) - kernel = relay.var("kernel", relay.TensorType((oc, ic, kh, kw), 'float32')) - y = relay.nn.conv2d(data, kernel, - kernel_size=(kh, kw), - channels=oc, - padding=(1, 1), - dilation=(1, 1), - data_layout='NHWC', - kernel_layout='OIHW', - out_dtype='float32') - func = relay.Function([data, kernel], y) - return func - @register_legalize("nn.conv2d", level=105) def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize @@ -220,14 +205,12 @@ def legalize_conv2d(attrs, inputs, types): a = before() b = run_opt_pass(a, transform.Legalize()) - assert b.astext().count('transpose') == 1 + assert b.astext().count('transpose') == 3 - wdata = np.random.rand(32, 16, 3, 3) * 10 - wref = wdata.transpose([2, 3, 1, 0]) + wdata = np.random.rand(3, 3, 16, 32) * 10 parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))} - ref_parameters = {"kernel": tvm.nd.array(wref.astype('float32'))} data_val = np.random.rand(1, 224, 224, 16).astype('float32') - ref_out = get_output(ref(), data_val, ref_parameters) + ref_out = get_output(a, data_val, parameters) legalized_out = get_output(b, data_val, parameters) np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index f15bbd60e798..77b37ed5a1e2 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -75,12 +75,8 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - if layout == "NCHW": - return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, - num_tile=2) - else: - return _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, layout, - out_dtype, num_tile=2) + return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, + num_tile=2) @autotvm.register_topi_schedule( @@ -123,11 +119,7 @@ def _callback(op): if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - # TODO: move to schedule_nhwc later - if 'nhwc' in op.tag: - _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) - else: - _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) + _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0]) if 'winograd_conv2d_output' in op.tag: output = op.output(0) @@ -251,95 +243,6 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, layout, ou name='output_unpack', tag='spatial_conv2d_output') return output - -def _decl_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, - layout, out_dtype, num_tile): - assert layout == "NHWC", "Only support NHWC" - # create workload according to raw arguments - out_dtype = out_dtype or data.dtype - N, IH, IW, CI = get_const_tuple(data.shape) - - # TODO dilation not currently supported - assert dilation == 1 or tuple(dilation) == (1, 1), "Does not support dilation" - - if len(kernel.shape) == 4: - pre_packed = False - KH, KW, _, CO = get_const_tuple(kernel.shape) - else: - pre_packed = True - CO, _, KH, KW, VC = get_const_tuple(kernel.shape) - CO = CO * VC - - pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) - OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1 - OW = (IW + pad_left + pad_right - KW) // WSTR + 1 - data_pad = pad(data, [0, pad_top, pad_left, 0], [0, pad_bottom, pad_right, 0]) - - # ==================== define configuration space ==================== - n, co, oh, ow = cfg.axis(N), cfg.axis(CO), cfg.axis(OH), cfg.axis(OW) - ci, kh, kw = cfg.reduce_axis(CI), cfg.reduce_axis(KH), cfg.reduce_axis(KW) - - if num_tile == 2: # for arm cpu - co, vc = cfg.define_split('tile_co', co, num_outputs=2) - oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) - ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) - elif num_tile == 3: # for mali gpu - co, _, vc = cfg.define_split('tile_co', co, num_outputs=3) - oh, _, vh = cfg.define_split('tile_oh', oh, num_outputs=3) - ow, _, vw = cfg.define_split('tile_ow', ow, num_outputs=3) - else: - raise RuntimeError("Invalid num_tile") - - cfg.define_reorder("reorder_0", - [n, oh, ow, co, ci, kh, kw, vh, vc, vw], - policy='candidate', candidate=[ - [n, oh, ow, co, ci, kh, kw, vh, vc, vw], - [n, oh, ow, co, ci, kh, kw, vc, vh, vw]]) - - cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') - cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') - - # fallback support - if cfg.is_fallback: - if num_tile == 2: # arm cpu - ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct') - cfg.fallback_with_reference_log(ref_log) - - VC = cfg["tile_co"].size[-1] - VH = cfg["tile_oh"].size[-1] - VW = cfg["tile_ow"].size[-1] - - kvshape = (CO // VC, CI, KH, KW, VC) - ovshape = (N, OH // VH, OW // VW, CO // VC, VH, VW, VC) - oshape = (N, OH, OW, CO) - - # undilate input data - dvshape = (N, OH // VH, OW // VW, CI, VH*HSTR + KH-1, VW*WSTR + KW-1) - data_vec = tvm.compute( - dvshape, - lambda n, h, w, ci, vh, vw: data_pad[n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw], - name='data_vec') - - kernel_vec = tvm.compute( - kvshape, - lambda co, ci, kh, kw, vc: kernel[co * VC + vc][ci][kh][kw], - name='kernel_vec') - - ci = tvm.reduce_axis((0, CI), name='ci') - kh = tvm.reduce_axis((0, KH), name='kh') - kw = tvm.reduce_axis((0, KW), name='kw') - - conv = tvm.compute(ovshape, lambda n, h, w, co, vh, vw, vc: \ - tvm.sum(data_vec[n, h, w, ci, vh*HSTR+kh, vw*WSTR+kw].astype(out_dtype) * - kernel_vec[co, ci, kh, kw, vc].astype(out_dtype), - axis=[ci, kh, kw]), name='conv') - - output = tvm.compute(oshape, lambda n, h, w, co: - conv[n][h//VH][w//VW][CO // VC][h%VH][w%VW][co%VC], - name='output_unpack', tag='spatial_conv2d_output_nhwc') - return output - def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): """schedule implementation""" @@ -400,53 +303,6 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, return s -def _schedule_spatial_pack_nhwc(cfg, s, data_vec, kernel_vec, conv, output, - last): - """schedule implementation""" - n, oh, ow, co, vh, vw, vc = s[conv].op.axis - ci, kh, kw = s[conv].op.reduce_axis - - # schedule conv - cfg["reorder_0"].apply(s, conv, [n, oh, ow, co, ci, kh, kw, vh, vw, vc]) - cfg["ann_reduce"].apply(s, conv, [kh, kw], - axis_lens=[get_const_int(kh.dom.extent), - get_const_int(kw.dom.extent)], - max_unroll=16, - cfg=cfg) - cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], - axis_lens=[cfg['tile_oh'].size[-1], - cfg['tile_ow'].size[-1], - cfg['tile_co'].size[-1]], - max_unroll=16, - cfg=cfg) - - # schedule fusion - n, h, w, co = s[last].op.axis - co, vc = cfg['tile_co'].apply(s, last, co) - oh, vh = cfg['tile_oh'].apply(s, last, h) - ow, vw = cfg['tile_ow'].apply(s, last, w) - s[last].reorder(n, oh, ow, co, vh, vw, vc) - if last != output: - s[output].compute_inline() - cfg["ann_spatial"].apply(s, last, [vh, vw, vc], - axis_lens=[cfg['tile_oh'].size[-1], - cfg['tile_ow'].size[-1], - cfg['tile_co'].size[-1]], - max_unroll=16, - cfg=cfg) - s[conv].compute_at(s[last], co) - - # mark parallel - s[last].parallel(oh) - - _, h, _, _, _, _ = s[data_vec].op.axis - s[data_vec].parallel(h) - - co, _, _, _, _ = s[kernel_vec].op.axis - s[kernel_vec].parallel(co) - return s - - @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """ TOPI compute callback. Use winograd template """ @@ -954,19 +810,25 @@ def _conv2d_legalize(attrs, inputs, arg_types): if attrs['data_layout'] == 'NHWC': data, kernel = inputs if attrs['kernel_layout'] == 'HWIO': - # HWIO layout is expected for NHWC input. - return None + # Handle HWIO layout. This is common in TF graph. + kernel = relay.transpose(kernel, axes=(3, 2, 0, 1)) elif attrs['kernel_layout'] == 'HWOI': # Handle HWOI layout. This is common in TF depthwise conv2d graph. - kernel = relay.transpose(kernel, axes=(0, 1, 3, 2)) - elif attrs['kernel_layout'] == 'OIHW': - kernel = relay.transpose(kernel, axes=(2, 3, 1, 0)) + kernel = relay.transpose(kernel, axes=(2, 3, 0, 1)) + elif attrs['kernel_layout'] != 'OIHW': + return None - ## Set new attrs for the tranposed conv. + logger.warning("Legalize arm_cpu - NHWC schedule absent. Inserting layout transforms to " + + "fallback to NCHW. This can result in performance degradation.") + # Set new attrs for the tranposed conv. new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs['data_layout'] = 'NHWC' - new_attrs['kernel_layout'] = 'HWIO' + new_attrs['data_layout'] = 'NCHW' + new_attrs['kernel_layout'] = 'OIHW' + # Convert from NHWC to NCHW. + data = relay.transpose(data, axes=(0, 3, 1, 2)) conv = relay.nn.conv2d(data, kernel, **new_attrs) - return conv + # Convert back to original NHWC layout. + out = relay.transpose(conv, axes=(0, 2, 3, 1)) + return out return None diff --git a/topi/python/topi/nn/bitserial_util.py b/topi/python/topi/nn/bitserial_util.py index 35a03a5b1e64..09a301f7c962 100644 --- a/topi/python/topi/nn/bitserial_util.py +++ b/topi/python/topi/nn/bitserial_util.py @@ -88,3 +88,4 @@ def binary_op_multiplier(pack_dtype): pack_dtype: string pack type for the operator (must be a uint)""" return int(pack_dtype[4:]) + \ No newline at end of file