Skip to content

Commit

Permalink
update strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jan 8, 2020
1 parent 3d3984d commit a9300fa
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 114 deletions.
30 changes: 1 addition & 29 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,35 +717,7 @@ def legalize_bitserial_conv2d(attrs, inputs, types):
reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# bitserial_dense
@reg.register_compute("nn.bitserial_dense")
def compute_bitserial_dense(attrs, inputs, out_type):
"""Compute definition of bitserial_dense"""
data_bits = attrs.data_bits
weight_bits = attrs.weight_bits
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
unipolar = attrs.unipolar
return [
topi.nn.bitserial_dense(
inputs[0],
inputs[1],
data_bits,
weight_bits,
pack_dtype,
out_dtype,
unipolar)
]


# @reg.register_schedule("nn.bitserial_dense")
# def schedule_bitserial_dense(attrs, outputs, target):
# """Schedule definition of bitserial_dense"""
# with target:
# return topi.generic.schedule_bitserial_dense(outputs)


reg.register_strategy("nn.bitserial_dense", strategy.bitserial_dense_strategy)
reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def schedule_injective_arm_cpu(_, outs, target):
def schedule_concatenate_arm_cpu(_, outs, target):
with target:
return topi.arm_cpu.schedule_concatenate(outs)

@bitserial_dense_strategy.register("arm_cpu")
def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implement(
wrap_compute_bitserial_dense(topi.arm_cpu.bitserial_dense_default),
wrap_topi_schedule(topi.arm_cpu.schedule_bitserial_dense))
return strategy
25 changes: 25 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,28 @@ def proposal_strategy(attrs, inputs, out_type, target):
def schedule_argwhere(attrs, outs, target):
with target:
return topi.generic.schedule_argwhere(outs)

# bitserial_dense
def wrap_compute_bitserial_dense(topi_func):
def compute_bitserial_dense(attrs, inputs, out_type):
"""Compute definition of bitserial dense"""
data_bits = attrs.data_bits
weight_bits = attrs.weight_bits
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
unipolar = attrs.unipolar
return [
topi_func(inputs[0], inputs[1], data_bits, weight_bits, pack_dtype,
out_dtype, unipolar)
]
return compute_bitserial_dense


@override_native_generic_func("bitserial_dense_strategy")
def bitserial_dense_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implement(
wrap_compute_bitserial_dense(topi.nn.bitserial_dense),
wrap_topi_schedule(topi.generic.schedule_bitserial_dense))
return strategy
9 changes: 9 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,12 @@ def roi_align_strategy_cpu(attrs, inputs, out_type, target):
strategy.add_implement(wrap_compute_roi_align(topi.x86.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align))
return strategy

@bitserial_dense_strategy.register("cpu")
def bitserial_dense_strategy_cpu(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implement(
wrap_compute_bitserial_dense(
topi.x86.bitserial_dense.bitserial_dense_default),
wrap_topi_schedule(topi.x86.schedule_bitserial_dense))
return strategy
10 changes: 5 additions & 5 deletions topi/python/topi/arm_cpu/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from tvm import autotvm
from topi.util import get_const_tuple
from .. import tag
from .. import generic
from .bitserial_conv2d import _intrin_popcount
from ..nn.pad import pad
from ..nn.bitserial_dense import bitserial_dense
from ..nn.bitserial_util import bitpack, binary_op_multiplier

@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
#@autotvm.register_topi_compute(bitserial_dense, ['arm_cpu'], 'direct')
@autotvm.register_topi_compute2('bitserial_dense.arm_cpu')
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype, out_dtype,
unipolar):
"""The default implementation of bitserial dense in topi.
Expand Down Expand Up @@ -111,7 +110,8 @@ def bitserial_dense_generic(cfg, data, weight, data_bits, weight_bits, pack_dtyp
return matmul


@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
#@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['arm_cpu'], 'direct')
@autotvm.register_topi_schedule2('bitserial_dense.arm_cpu')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for binary_dense.
Expand Down
80 changes: 2 additions & 78 deletions topi/python/topi/nn/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
"""Bitserial Dense operator."""
from __future__ import absolute_import
import tvm
from tvm import autotvm
from topi.util import get_const_tuple
from .bitserial_util import bitpack, binary_op_multiplier

@tvm.target.generic_func

@tvm.target.override_native_generic_func("bitserial_dense")
def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""The default implementation of bitserial dense in topi.
Expand Down Expand Up @@ -66,78 +65,3 @@ def bitserial_dense(data, weight, data_bits, weight_bits, pack_dtype='uint32',
if unipolar:
return matmul_unipolar
return matmul


@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct')
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""Bitserial dense implementation. TODO: Why are these separate
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)

cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
[yo, xo, ko, yi, wb, db, ki, xi],
[yo, xo, yi, ko, wb, db, ki, xi]])

cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')

###### Compute rule
VX = cfg['tile_x'].size[-1]

wvshape = (X//VX, WB, VX, K)
oshape = (Y, X)

k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')

# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec')

idxdiv = tvm.indexdiv
idxmod = tvm.indexmod

matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')

matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')

# binary ops
cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))

if unipolar:
return matmul_unipolar
return matmul
80 changes: 78 additions & 2 deletions topi/python/topi/x86/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,86 @@
import tvm
from tvm import autotvm
from topi.util import get_const_int
from .bitserial_util import bitpack, binary_op_multiplier
from .. import tag
from .. import generic

@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct')
#@autotvm.register_topi_compute(bitserial_dense, ['cpu'], 'direct')
@autotvm.register_topi_compute2('bitserial_dense.x86')
def bitserial_dense_default(cfg, data, weight, data_bits, weight_bits, pack_dtype='uint32',
out_dtype='int16', unipolar=True):
"""Bitserial dense implementation. TODO: Why are these separate
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split('tile_k', k, num_outputs=2)
yo, yi = cfg.define_split('tile_y', y, num_outputs=2)
xo, xi = cfg.define_split('tile_x', x, num_outputs=2)

cfg.define_reorder('reorder_0', [yo, xo, ko, yi, wb, db, ki, xi],
policy='candidate', candidate=[
[yo, xo, ko, yi, wb, db, ki, xi],
[yo, xo, yi, ko, wb, db, ki, xi]])

cfg.define_annotate('ann_reduce', [db, wb], policy='try_unroll')
cfg.define_annotate('ann_spatial', [yi, xi], policy='try_unroll_vec')

###### Compute rule
VX = cfg['tile_x'].size[-1]

wvshape = (X//VX, WB, VX, K)
oshape = (Y, X)

k = tvm.reduce_axis((0, K), name='k')
db = tvm.reduce_axis((0, DB), name='db')
wb = tvm.reduce_axis((0, WB), name='wb')

# Tile data and weights
weight_vec = tvm.compute(wvshape, lambda xo, wb, vx, k:
weight_packed[xo*VX+vx][wb][k], name='weight_vec')

idxdiv = tvm.indexdiv
idxmod = tvm.indexmod

matmul_unipolar = tvm.compute(oshape, lambda i, j: tvm.sum(
(tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]) -
tvm.popcount(~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k])
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense_unipolar')

matmul = tvm.compute(oshape, lambda i, j: tvm.sum(
tvm.popcount(weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
).astype(out_dtype)
<< (db+wb).astype(out_dtype), axis=[wb, db, k]), tag='bitserial_dense')

# binary ops
cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))

if unipolar:
return matmul_unipolar
return matmul

#@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_dense, ['cpu'], 'direct')
@autotvm.register_topi_schedule2('biserial_dense.x86')
def schedule_bitserial_dense(cfg, outs):
"""Schedule for bitserial_dense.
Expand Down

0 comments on commit a9300fa

Please sign in to comment.