Skip to content

Commit

Permalink
Add tuning knobs in depthwise schedule
Browse files Browse the repository at this point in the history
Change-Id: I15080e7f12b16e6c6aba99a04e42023845eeabf1
  • Loading branch information
Giuseppe Rossini committed Jul 21, 2020
1 parent e2168e5 commit cc43c78
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions topi/python/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, AnnotateEntity

from .. import nn
from ..util import traverse_inline, get_const_tuple, get_const_int
Expand Down Expand Up @@ -259,46 +260,70 @@ def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, o
idxmod(c, channel_multiplier)].astype(out_dtype),
axis=[reduce_h, reduce_w]),
name='depthwise_conv2d_nhwc_output')

return out

@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu")
def schedule_depthwise_conv2d_nhwc(_, outs):
def schedule_depthwise_conv2d_nhwc(cfg, outs):
"""Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
out = outs[0]

##### space definition begin #####
n, h, w, c = s[out].op.axis
cfg.define_split('tile_c', c, num_outputs=2)
_, hi = cfg.define_split('tile_h', h, num_outputs=2)
_, wi = cfg.define_split('tile_w', w, num_outputs=2)
cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1)

# fallback support
if cfg.is_fallback:
cfg['tile_c'] = SplitEntity([-1, 8])
cfg['tile_h'] = SplitEntity([-1, 2])
cfg['tile_w'] = SplitEntity([-1, 2])
cfg['locate_output'] = AnnotateEntity([1])
##### space definition end #####

def schedule_conv(conv):
conv_data = conv.op.input_tensors[0]
if conv_data.name == "data_pad":
s[conv_data].compute_inline()

n, w, h, c = conv.op.axis
r_h, r_w = conv.op.reduce_axis
co, ci = s[conv].split(c, 8)
wo, wi = s[conv].split(w, 2)
ho, hi = s[conv].split(h, 2)
ho, hi = cfg['tile_h'].apply(s, conv, h)
wo, wi = cfg['tile_w'].apply(s, conv, w)
co, ci = cfg['tile_c'].apply(s, conv, c)

s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci)
s[conv].parallel(wo)
s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
fused_n_ho = s[conv].fuse(n, ho)
s[conv].parallel(fused_n_ho)
s[conv].vectorize(ci)

def schedule_conv_out(out):
n, h, w, c = out.op.axis
co, ci = s[out].split(c, 8)
wo, wi = s[out].split(w, 2)
ho, hi = s[out].split(h, 2)
ci_outer, ci_inner = s[out].split(ci, 4)
s[out].reorder(n, wo, ho, co, wi, hi)
s[out].vectorize(ci_inner)
compute_at_axis = hi
s[out].parallel(wo)
return compute_at_axis
co, ci = cfg['tile_c'].apply(s, out, c)
wo, wi = cfg['tile_w'].apply(s, out, w)
ho, hi = cfg['tile_h'].apply(s, out, h)

if out.dtype in ['int8', 'uint8']:
# In case of quantized convolution further split the channel in batches of 4 elements
# so that we can use arm intrinsics to run fixed_point_multiplication
ci_outer, ci_inner = s[out].split(ci, 4)
s[out].reorder(n, ho, wo, co, hi, wi)
s[out].vectorize(ci_inner)

fused_n_ho = s[out].fuse(n, ho)
s[out].parallel(fused_n_ho)
return hi, wi

def _callback(op):
if op.name == 'depthwise_conv2d_nhwc_output':
conv = op.output(0)
if conv != out:
compute_at_axis = schedule_conv_out(out)
hi, wi = schedule_conv_out(out)
schedule_conv(conv)
s[conv].compute_at(s[out], compute_at_axis)
cfg['locate_output'].apply(s, out, [hi, wi], source=[[conv]])
else:
schedule_conv(out)

Expand Down

0 comments on commit cc43c78

Please sign in to comment.