From a7f01253d51cc08a7d5e68eb39d8af7ddafeac99 Mon Sep 17 00:00:00 2001 From: ziheng Date: Fri, 29 Sep 2017 17:52:34 -0700 Subject: [PATCH] [TOPI] Update depthwise conv2d schedule on rasp (#500) --- topi/python/topi/rasp/depthwise_conv2d.py | 136 ++++++++++++++++++++-- 1 file changed, 129 insertions(+), 7 deletions(-) diff --git a/topi/python/topi/rasp/depthwise_conv2d.py b/topi/python/topi/rasp/depthwise_conv2d.py index 1446556dc207..00bab8c1b174 100644 --- a/topi/python/topi/rasp/depthwise_conv2d.py +++ b/topi/python/topi/rasp/depthwise_conv2d.py @@ -1,25 +1,147 @@ # pylint: disable=invalid-name,unused-variable """Schedule for depthwise_conv2d with auto fusion""" +from __future__ import absolute_import as _abs +from collections import namedtuple import tvm from .. import tag +from ..nn.util import infer_pad, infer_stride, get_pad_tuple + + +_Workload = namedtuple('Workload', + ['height', 'width', 'channel', 'multiplier', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + +_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll']) + +# workloads of depthwise conv mobile net on imagenet +_WORKLOADS = [ + _Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1), + _Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2), + _Workload( 56, 56, 128, 1, 3, 3, 1, 1, 1, 1), + _Workload( 56, 56, 128, 1, 3, 3, 1, 1, 2, 2), + _Workload( 28, 28, 256, 1, 3, 3, 1, 1, 1, 1), + _Workload( 28, 28, 256, 1, 3, 3, 1, 1, 2, 2), + _Workload( 14, 14, 512, 1, 3, 3, 1, 1, 1, 1), + _Workload( 14, 14, 512, 1, 3, 3, 1, 1, 2, 2), + _Workload( 14, 14, 1024, 1, 3, 3, 1, 1, 1, 1), +] + +_SCHEDULES = [ + _Schedule(2, 1, 4, 1, True), + _Schedule(2, 4, 4, 2, True), + _Schedule(2, 1, 4, 2, False), + _Schedule(2, 4, 4, 1, True), + _Schedule(4, 1, 4, 8, True), + _Schedule(1, 1, 4, 2, True), + _Schedule(1, 1, 8, 8, True), + _Schedule(1, 1, 4, 1, False), + _Schedule(2, 1, 4, 16, False), +] + +def _get_workload(data, kernel, stride, padding): + _, C, IH, IW = [x.value for x in data.shape] + _, MT, KH, KW = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + if isinstance(stride, (tuple, list)): + HSTR, WSTR = stride + else: + HSTR, WSTR = stride, stride + return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR) + def _schedule(s, data, data_pad, kernel, output, last): + padding = infer_pad(data, data_pad) + if data_pad is None: + stride = infer_stride(data, kernel, output) + else: + stride = infer_stride(data_pad, kernel, output) + wkl = _get_workload(data, kernel, stride, padding) + + if wkl not in _WORKLOADS: + return s + + # use specified schedule + sch = _SCHEDULES[_WORKLOADS.index(wkl)] + + H, W = wkl.height, wkl.width + CN = wkl.channel + MT = wkl.multiplier + + HK, WK = wkl.hkernel, wkl.wkernel + HPAD, WPAD = wkl.hpad, wkl.wpad + HSTR, WSTR = wkl.hstride, wkl.wstride + + VH, VW = sch.vh, sch.vw + BC = sch.bc + VC = sch.vc + + TH = H + 2*HPAD + TW = W + 2*WPAD + OH = (H + 2*HPAD - HK) / HSTR + 1 + OW = (W + 2*WPAD - WK) / WSTR + 1 + + A, B, C = data, kernel, output A0 = data_pad - C0 = last + + A1 = s.cache_read(A0, "global", C) + _, c, h, w = s[A1].op.axis + c, vc = s[A1].split(c, VC) + s[A1].reorder(c, h, w, vc) + + A2 = s.cache_write(A1, 'global') + s[A0].compute_inline() + s[A1].compute_inline() + + B0 = s.cache_read(B, "global", C) + c, m, h, w = s[B0].op.axis + c, vc = s[B0].split(c, VC) + s[B0].reorder(c, m, h, w, vc) + + B1 = s.cache_write(B0, 'global') + s[B0].compute_inline() _, c, h, w = s[C].op.axis - dh, dw = s[C].op.reduce_axis + c, vc = s[C].split(c, VC) + s[C].reorder(c, h, w, vc) + + + C0 = s.cache_write(C, 'global') + _, c, h, w, vc = s[C0].op.axis + dh, dw = s[C0].op.reduce_axis + oh, ow, ih, iw = s[C0].tile(h, w, VH, VW) + s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc) + if sch.unroll: + s[C0].unroll(iw) + s[C0].vectorize(vc) - oh, ow, ih, iw = s[C].tile(h, w, 2, 4) - s[C].reorder(oh, ow, dh, dw, ih, iw) - s[C].unroll(ih) - s[C].vectorize(iw) + + # # s[C0].compute_at(s[C0], ow) + launch, c, _, _ = s[C].op.axis + s[C].pragma(launch, "parallel_launch_point") s[C].parallel(c) - s[C].pragma(c, "parallel_launch_point") s[C].pragma(c, "parallel_stride_pattern") s[C].pragma(c, "parallel_barrier_when_finish") + + + s[C0].compute_at(s[C], launch) + _, c, h, w, vc = s[C0].op.axis + s[C0].parallel(c) + s[C0].pragma(c, "parallel_stride_pattern") + s[C0].pragma(c, "parallel_barrier_when_finish") + + + s[A2].compute_at(s[C0], oh) + # parallel(s[A2], s[A2].op.axis[1], BC) + + # # s[B0].compute_at(s[C0], ow) + s[B1].compute_at(s[C], launch) + c, m, h, w, vc = s[B1].op.axis + s[B1].parallel(c) + s[B1].pragma(c, "parallel_stride_pattern") + s[B1].pragma(c, "parallel_barrier_when_finish") + return s