Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] bitserial_conv2d move to autotvm template and updates #2819

Merged
merged 10 commits into from
Apr 4, 2019
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
Expand Down
429 changes: 203 additions & 226 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py

Large diffs are not rendered by default.

397 changes: 259 additions & 138 deletions topi/python/topi/nn/bitserial_conv2d.py

Large diffs are not rendered by default.

281 changes: 98 additions & 183 deletions topi/python/topi/x86/bitserial_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,33 @@
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Bitserial conv2d schedule on x86"""
import tvm
from tvm import autotvm
from tvm.autotvm.task.topi_integration import deserialize_args
from topi.util import get_const_int
from .. import generic, tag
from ..nn.bitserial_conv2d import bitserial_conv2d, _get_schedule, _get_workload
from ..nn.bitserial_conv2d import SpatialPackNCHW, SpatialPackNHWC
from ..nn.bitserial_conv2d import _WORKLOADS, _SCH_TO_DECL_FUNC_QUANT

_QUANTIZED_SCHEDULES_NCHW = [
# resnet
SpatialPackNCHW(2, 2, 8, 1, 1),
SpatialPackNCHW(1, 4, 8, 4, 1),
SpatialPackNCHW(1, 4, 8, 1, 16),
SpatialPackNCHW(1, 4, 8, 4, 8),
SpatialPackNCHW(1, 7, 8, 3, 8),
SpatialPackNCHW(1, 2, 8, 1, 8),
SpatialPackNCHW(2, 1, 8, 1, 4),
SpatialPackNCHW(1, 7, 8, 1, 1),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 8),
SpatialPackNCHW(1, 1, 8, 1, 16),

SpatialPackNCHW(3, 3, 16, 3, 16),
SpatialPackNCHW(1, 1, 16, 2, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
SpatialPackNCHW(1, 1, 8, 1, 16),
]

_QUANTIZED_SCHEDULES_NHWC = [
# resnet
SpatialPackNHWC(2, 2, 8, 1, 1),
SpatialPackNHWC(1, 4, 8, 4, 1),
SpatialPackNHWC(1, 4, 8, 1, 16),
SpatialPackNHWC(1, 4, 8, 4, 8),
SpatialPackNHWC(1, 7, 8, 3, 8),
SpatialPackNHWC(1, 2, 8, 1, 8),
SpatialPackNHWC(2, 1, 8, 1, 4),
SpatialPackNHWC(1, 7, 8, 1, 1),
SpatialPackNHWC(1, 1, 8, 1, 16),
SpatialPackNHWC(1, 1, 8, 1, 8),
SpatialPackNHWC(1, 1, 8, 1, 16),
]

@_get_schedule.register("cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
if layout == "NCHW":
sch = _QUANTIZED_SCHEDULES_NCHW[idx]
elif layout == "NHWC":
sch = _QUANTIZED_SCHEDULES_NHWC[idx]
return sch

@bitserial_conv2d.register("cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
out_dtype = data.dtype
assert data.shape[0].value == 1, "only support batch size=1 convolution on rasp"
assert layout in ("NCHW", "NHWC"), "only support layouts NCHW and NHWC"

wkl = _get_workload(data, kernel, stride, padding, out_dtype, layout)
sch = _get_schedule(wkl, layout)
return _SCH_TO_DECL_FUNC_QUANT[type(sch)](data, kernel, stride, padding, activation_bits,
weight_bits, pack_dtype, out_dtype, dorefa)

@generic.schedule_bitserial_conv2d_nchw.register(["cpu"])
@generic.schedule_bitserial_conv2d_nhwc.register(["cpu"])
def schedule_bitserial_conv2d(outs):
from ..nn.bitserial_conv2d import bitserial_conv2d_nhwc, bitserial_conv2d_nchw

@autotvm.task.register("topi_x86_bitserial_conv_nhwc")
def topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = bitserial_conv2d_nhwc(*args, **kwargs)
s = generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]

@autotvm.task.register("topi_arm_cpu_bitserial_conv_nchw")
@autotvm.task.register("topi_x86_bitserial_conv_nchw")
cowanmeg marked this conversation as resolved.
Show resolved Hide resolved
def topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = bitserial_conv2d_nchw(*args, **kwargs)
s = generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]

@autotvm.register_topi_schedule(generic.nn.schedule_bitserial_conv2d_nchw, ['cpu'], 'direct')
def schedule_bitserial_conv2d(cfg, outs):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
Expand All @@ -88,7 +47,6 @@ def traverse(op):
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel_q = kernel_vec.op.input_tensors[0]
kernel = kernel_q.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_q = data_vec.op.input_tensors[0]
data = data_q.op.input_tensors[0]
Expand All @@ -97,29 +55,27 @@ def traverse(op):
data_pad = data_q
data_q = data
data = data_q.op.input_tensors[0]
if "QuantizeInput" in kernel.op.name:
# Need to go up 1 further, from the combine in bitpack
kernel = kernel.op.input_tensors[0]

if "QuantizeInput" in data.op.name:
# Need to go up 1 further, from the combine in bitpack
data = data.op.input_tensors[0]

if 'spatial_bitserial_conv_nchw' in op.tag:
_schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
_schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, outs[0])
elif 'spatial_bitserial_conv_nhwc' in op.tag:
_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, outs[0])
_schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, outs[0])
scheduled_ops.append(op)

traverse(outs[0].op)
return s

def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
def _schedule_bitserial_conv2d_nchw(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, last):
IB, _, CI, IH, IW = data_q.shape
KB, CO, _, KH, KW = kernel_q.shape
_, _, OH, OW = output.shape
Expand All @@ -138,37 +94,21 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)

wkl = _get_workload(data, kernel, stride, padding, output.dtype, "NCHW")
sch = _get_schedule(wkl, "NCHW")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]

CC = s.cache_write(conv_out, "global")
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
s[conv_out].vectorize(vc)

s[CC].compute_at(s[conv_out], ow)
n, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw, b1, b2 = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, b1, b2, vc)
s[CC].unroll(b1)
s[CC].unroll(b2)
s[CC].vectorize(vc)

##### Schedule A
##### Schedule Data padding, and bitpacking
if data_pad is not None:
s[data_pad].compute_inline()

_, h, _, _, _, _, vw = s[data_vec].op.axis
s[data_vec].vectorize(vw)
if ba == 1:
oaxis = h
paxis = h
_, _, h, _, _, _, _ = s[data_vec].op.axis
cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
if cfg["tile_ah"].size[1] == 1:
oaxis = oh
paxis = oh
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih

Expand All @@ -178,14 +118,14 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")


##### Schedule B
co, _, _, _, _, vc = s[kernel_vec].op.axis
s[kernel_vec].vectorize(vc)
if bc == 1:
oaxis = co
paxis = co
##### Schedule Kenerl bitpacking
co, _, _, _, _, _ = s[kernel_vec].op.axis
cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
if cfg["tile_bco"].size[1] == 1:
oaxis = oco
paxis = oco
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico

Expand All @@ -195,7 +135,23 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")


##### Schedule C
##### Schedule Convolution
n, co, oh, ow, vh, vw, vc = s[conv_out].op.axis
ci, dh, dw, ib, kb = s[conv_out].op.reduce_axis

# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, co, oh, ow, vc, vh, vw, dh, dw, kb, ib, ci])
cfg["ann_reduce"].apply(s, conv_out, [kb, ib, dh, dw],
axis_lens=[get_const_int(kb.dom.extent),
get_const_int(ib.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)

s[conv_out].vectorize(vc)

# # Schedule output
n, co, h, w = s[last].op.axis
co, vc = s[last].split(co, VC)
oh, ow, vh, vw = s[last].tile(h, w, VH, VW)
Expand All @@ -204,89 +160,58 @@ def _schedule_spatial_conv2d_nchw(s, data, data_q, data_pad, data_vec,
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)

if bc == 1:
oaxis = co
paxis = co
oco, ico = cfg["tile_oh"].apply(s, last, co)
if cfg["tile_oh"].size[1] == 1:
oaxis = oco
paxis = oco
else:
oco, ico = s[last].split(co, bc)
oaxis = oco
paxis = ico

s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")

s[last].parallel(oco)
return s

def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec,
conv_out, output, last):
def _schedule_bitserial_conv2d_nhwc(cfg, s, data_q, data_pad, data_vec,
kernel_q, kernel_vec,
conv_out, output, last):
# no stride and padding info here
_, IH, IW, CI, IB = data_q.shape
KH, KW, _, CO, KB = kernel_q.shape
_, OH, OW, _ = output.shape
# Infer padding and stride
if data_pad is None:
padding = (0, 0)
TH, TW = IH, IW
else:
_, TH, TW, _, _ = data_pad.shape
hpad = get_const_int((TH - IH) // 2)
wpad = get_const_int((TW - IW) // 2)
padding = (hpad, wpad)

hstride = get_const_int((TH - KH) // (OH - 1))
wstride = get_const_int((TW - KW) // (OW - 1))
stride = (hstride, wstride)
VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1]
VW = cfg["tile_ow"].size[-1]

wkl = _get_workload(data, kernel, stride, padding, last.dtype, "NHWC")
sch = _get_schedule(wkl, "NHWC")
VH = sch.vh
VW = sch.vw
VC = sch.vc
ba = sch.ba
bc = sch.bc

##### Schedule data packing
##### Schedule data padding and packing
if data_pad is not None:
s[data_pad].compute_inline()

_, h, _, _, _, _, _ = s[data_vec].op.axis
if ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[data_vec].split(h, ba)
oaxis = oh
paxis = ih
s[data_vec].parallel(paxis)
s[data_vec].pragma(oaxis, "parallel_launch_point")
s[data_vec].pragma(paxis, "parallel_stride_pattern")
s[data_vec].pragma(oaxis, "parallel_barrier_when_finish")

cfg.define_split("tile_ah", cfg.axis(h), policy="all", num_outputs=2, max_factor=32)
oh, ih = cfg["tile_ah"].apply(s, data_vec, h)
s[data_vec].parallel(oh)

##### Schedule kernel packing
co, _, _, _, _, _ = s[kernel_vec].op.axis
if bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[kernel_vec].split(co, bc)
oaxis = oco
paxis = ico

s[kernel_vec].parallel(paxis)
s[kernel_vec].pragma(oaxis, "parallel_launch_point")
s[kernel_vec].pragma(paxis, "parallel_stride_pattern")
s[kernel_vec].pragma(oaxis, "parallel_barrier_when_finish")

cfg.define_split("tile_bco", cfg.axis(co), policy="all", num_outputs=2, max_factor=32)
oco, ico = cfg["tile_bco"].apply(s, kernel_vec, co)
s[kernel_vec].parallel(oco)

##### Schedule Convolution
n, oh, ow, co, vh, vw, vc = s[conv_out].op.axis
dh, dw, ci, b1, b2 = s[conv_out].op.reduce_axis

s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
# s[conv_out].reorder(n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2)
cfg["reorder_0"].apply(s, conv_out, [n, oh, ow, co, vh, vw, dh, dw, ci, vc, b1, b2])
cfg["ann_reduce"].apply(s, conv_out, [b1, b2, dh, dw],
axis_lens=[get_const_int(b1.dom.extent),
get_const_int(b2.dom.extent),
get_const_int(dh.dom.extent),
get_const_int(dw.dom.extent)],
max_unroll=16,
cfg=cfg)

s[conv_out].unroll(b1)
s[conv_out].unroll(b2)
Expand All @@ -302,17 +227,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s[output].compute_inline()
s[conv_out].compute_at(s[last], ow)

if bc == 1:
oaxis = oh
paxis = oh
else:
oho, iho = s[last].split(oh, bc)
oaxis = oho
paxis = iho

s[last].parallel(paxis)
s[last].pragma(oaxis, "parallel_launch_point")
s[last].pragma(paxis, "parallel_stride_pattern")
s[last].pragma(oaxis, "parallel_barrier_when_finish")
oho, iho = cfg["tile_oh"].apply(s, last, oh) # reuse parameter
s[last].parallel(oho)

return s
Loading