diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e6e85d8e2589..950fb26f9226 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -716,36 +716,8 @@ def legalize_bitserial_conv2d(attrs, inputs, types): reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) - # bitserial_dense -@reg.register_compute("nn.bitserial_dense") -def compute_bitserial_dense(attrs, inputs, out_type): - """Compute definition of bitserial_dense""" - data_bits = attrs.data_bits - weight_bits = attrs.weight_bits - pack_dtype = attrs.pack_dtype - out_dtype = attrs.out_dtype - out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype - unipolar = attrs.unipolar - return [ - topi.nn.bitserial_dense( - inputs[0], - inputs[1], - data_bits, - weight_bits, - pack_dtype, - out_dtype, - unipolar) - ] - - -# @reg.register_schedule("nn.bitserial_dense") -# def schedule_bitserial_dense(attrs, outputs, target): -# """Schedule definition of bitserial_dense""" -# with target: -# return topi.generic.schedule_bitserial_dense(outputs) - - +reg.register_strategy("nn.bitserial_dense", strategy.bitserial_dense_strategy) reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 138af85b15b7..c2d6aab58e64 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -32,3 +32,11 @@ def schedule_concatenate_arm_cpu(_, outs, target): """schedule concatenate for arm cpu""" with target: return topi.arm_cpu.schedule_concatenate(outs) + +@bitserial_dense_strategy.register("arm_cpu") +def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implement( + wrap_compute_bitserial_dense(topi.arm_cpu.bitserial_dense_default), + wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense)) + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f91613ad1eb3..df6c548f8b2f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -445,3 +445,28 @@ def schedule_argwhere(attrs, outs, target): """schedule argwhere""" with target: return topi.generic.schedule_argwhere(outs) + +# bitserial_dense +def wrap_compute_bitserial_dense(topi_func): + def compute_bitserial_dense(attrs, inputs, out_type): + """Compute definition of bitserial dense""" + data_bits = attrs.data_bits + weight_bits = attrs.weight_bits + pack_dtype = attrs.pack_dtype + out_dtype = attrs.out_dtype + out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype + unipolar = attrs.unipolar + return [ + topi_func(inputs[0], inputs[1], data_bits, weight_bits, pack_dtype, + out_dtype, unipolar) + ] + return compute_bitserial_dense + + +@override_native_generic_func("bitserial_dense_strategy") +def bitserial_dense_strategy(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implement( + wrap_compute_bitserial_dense(topi.nn.bitserial_dense), + wrap_topi_schedule(topi.generic.schedule_bitserial_dense)) + return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 02ba428829c3..e05ef3b941b4 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -125,3 +125,12 @@ def roi_align_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implement(wrap_compute_roi_align(topi.x86.roi_align_nchw), wrap_topi_schedule(topi.generic.schedule_roi_align)) return strategy + +@bitserial_dense_strategy.register("cpu") +def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implement( + wrap_compute_bitserial_dense( + topi.x86.bitserial_dense.bitserial_dense_default), + wrap_topi_schedule(topi.x86.schedule_bitserial_dense)) + return strategy diff --git a/topi/python/topi/arm_cpu/bitserial_dense.py b/topi/python/topi/arm_cpu/bitserial_dense.py index 8bd6c5d15f8c..9418393aab00 100644 --- a/topi/python/topi/arm_cpu/bitserial_dense.py +++ b/topi/python/topi/arm_cpu/bitserial_dense.py @@ -21,14 +21,12 @@ from tvm import autotvm from topi.util import get_const_tuple from .. import tag -from .. import generic from .bitserial_conv2d import _intrin_popcount from ..nn.pad import pad -from ..nn.bitserial_dense import bitserial_dense from ..nn.bitserial_util import bitpack, binary_op_multiplier -@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct') -def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype, +@autotvm.register_topi_compute2('bitserial_dense.arm_cpu') +def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype, unipolar): """The default implementation of bitserial dense in topi. @@ -111,7 +109,7 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp return matmul -@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct') +@autotvm.register_topi_schedule2('bitserial_dense.arm_cpu') def schedule_bitserial_dense(cfg, outs): """Schedule for binary_dense. diff --git a/topi/python/topi/nn/bitserial_dense.py b/topi/python/topi/nn/bitserial_dense.py index d77a1b7b0fc2..f2234f5d479d 100644 --- a/topi/python/topi/nn/bitserial_dense.py +++ b/topi/python/topi/nn/bitserial_dense.py @@ -18,11 +18,9 @@ """Bitserial Dense operator.""" from __future__ import absolute_import import tvm -from tvm import autotvm from topi.util import get_const_tuple -from .bitserial_util import bitpack, binary_op_multiplier -@tvm.target.generic_func + def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32', out_dtype='int16', unipolar=True): """The default implementation of bitserial dense in topi. @@ -66,78 +64,3 @@ def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32', if unipolar: return matmul_unipolar return matmul - - -@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct') -def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32', - out_dtype='int16', unipolar=True): - """Bitserial dense implementation. TODO: Why are these separate - - Parameters - ---------- - data : tvm.Tensor - 2-D with shape [batch, in_dim] - weight : tvm.Tensor - 2-D with shape [out_dim, in_dim] or - 3-D with shape [out_dim, weight_bits, in_dim] - Returns - ------- - output : tvm.Tensor - 2-D with shape [batch, out_dim] - """ - data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) - if len(weight.shape) == 2: - weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) - else: - weight_packed = weight - Y, DB, K = get_const_tuple(data_packed.shape) - X, WB, _ = get_const_tuple(weight_packed.shape) - ######## Search space - x, y = cfg.axis(X), cfg.axis(Y) - db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) - ko, ki = cfg.define_split('tile_k', k, num_outputs=2) - yo, yi = cfg.define_split('tile_y', y, num_outputs=2) - xo, xi = cfg.define_split('tile_x', x, num_outputs=2) - - cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], - policy='candidate', candidate=[ - [yo, xo, ko, yi, wb, db, ki, xi], - [yo, xo, yi, ko, wb, db, ki, xi]]) - - cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll') - cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec') - - ###### Compute rule - VX = cfg['tile_x'].size[-1] - - wvshape = (X//VX, WB, VX, K) - oshape = (Y, X) - - k = tvm.reduce_axis((0, K), name='k') - db = tvm.reduce_axis((0, DB), name='db') - wb = tvm.reduce_axis((0, WB), name='wb') - - # Tile data and weights - weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k: - weight_packed[xo*VX+vx][wb][k], name='weight_vec') - - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod - - matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( - (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) - - tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) - ).astype(out_dtype) - << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') - - matmul = tvm.compute(oshape, lambda i, j: tvm.sum( - tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k] - ).astype(out_dtype) - << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') - - # binary ops - cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype)) - - if unipolar: - return matmul_unipolar - return matmul diff --git a/topi/python/topi/x86/bitserial_dense.py b/topi/python/topi/x86/bitserial_dense.py index 47b972fa1319..e9751345633b 100644 --- a/topi/python/topi/x86/bitserial_dense.py +++ b/topi/python/topi/x86/bitserial_dense.py @@ -20,10 +20,84 @@ import tvm from tvm import autotvm from topi.util import get_const_int +from .bitserial_util import bitpack, binary_op_multiplier from .. import tag -from .. import generic -@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct') +@autotvm.register_topi_compute2('bitserial_dense.x86') +def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32', + out_dtype='int16', unipolar=True): + """Bitserial dense implementation. TODO: Why are these separate + + Parameters + ---------- + data : tvm.Tensor + 2-D with shape [batch, in_dim] + weight : tvm.Tensor + 2-D with shape [out_dim, in_dim] or + 3-D with shape [out_dim, weight_bits, in_dim] + Returns + ------- + output : tvm.Tensor + 2-D with shape [batch, out_dim] + """ + data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) + if len(weight.shape) == 2: + weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype) + else: + weight_packed = weight + Y, DB, K = get_const_tuple(data_packed.shape) + X, WB, _ = get_const_tuple(weight_packed.shape) + ######## Search space + x, y = cfg.axis(X), cfg.axis(Y) + db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K) + ko, ki = cfg.define_split('tile_k', k, num_outputs=2) + yo, yi = cfg.define_split('tile_y', y, num_outputs=2) + xo, xi = cfg.define_split('tile_x', x, num_outputs=2) + + cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi], + policy='candidate', candidate=[ + [yo, xo, ko, yi, wb, db, ki, xi], + [yo, xo, yi, ko, wb, db, ki, xi]]) + + cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll') + cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec') + + ###### Compute rule + VX = cfg['tile_x'].size[-1] + + wvshape = (X//VX, WB, VX, K) + oshape = (Y, X) + + k = tvm.reduce_axis((0, K), name='k') + db = tvm.reduce_axis((0, DB), name='db') + wb = tvm.reduce_axis((0, WB), name='wb') + + # Tile data and weights + weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k: + weight_packed[xo*VX+vx][wb][k], name='weight_vec') + + idxdiv = tvm.indexdiv + idxmod = tvm.indexmod + + matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum( + (tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) - + tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) + ).astype(out_dtype) + << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar') + + matmul = tvm.compute(oshape, lambda i, j: tvm.sum( + tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k] + ).astype(out_dtype) + << (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense') + + # binary ops + cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype)) + + if unipolar: + return matmul_unipolar + return matmul + +@autotvm.register_topi_schedule2('biserial_dense.x86') def schedule_bitserial_dense(cfg, outs): """Schedule for bitserial_dense.