Skip to content

Commit

Permalink
[TOPI] Update depthwise conv2d schedule on rasp (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Sep 30, 2017
1 parent 9e7a667 commit a7f0125
Showing 1 changed file with 129 additions and 7 deletions.
136 changes: 129 additions & 7 deletions topi/python/topi/rasp/depthwise_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,147 @@
# pylint: disable=invalid-name,unused-variable
"""Schedule for depthwise_conv2d with auto fusion"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple


_Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])

# workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [
_Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 1024, 1, 3, 3, 1, 1, 1, 1),
]

_SCHEDULES = [
_Schedule(2, 1, 4, 1, True),
_Schedule(2, 4, 4, 2, True),
_Schedule(2, 1, 4, 2, False),
_Schedule(2, 4, 4, 1, True),
_Schedule(4, 1, 4, 8, True),
_Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False),
_Schedule(2, 1, 4, 16, False),
]

def _get_workload(data, kernel, stride, padding):
_, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)


def _schedule(s, data, data_pad, kernel, output, last):
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)

if wkl not in _WORKLOADS:
return s

# use specified schedule
sch = _SCHEDULES[_WORKLOADS.index(wkl)]

H, W = wkl.height, wkl.width
CN = wkl.channel
MT = wkl.multiplier

HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride

VH, VW = sch.vh, sch.vw
BC = sch.bc
VC = sch.vc

TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - HK) / HSTR + 1
OW = (W + 2*WPAD - WK) / WSTR + 1


A, B, C = data, kernel, output
A0 = data_pad
C0 = last

A1 = s.cache_read(A0, "global", C)
_, c, h, w = s[A1].op.axis
c, vc = s[A1].split(c, VC)
s[A1].reorder(c, h, w, vc)

A2 = s.cache_write(A1, 'global')
s[A0].compute_inline()
s[A1].compute_inline()

B0 = s.cache_read(B, "global", C)
c, m, h, w = s[B0].op.axis
c, vc = s[B0].split(c, VC)
s[B0].reorder(c, m, h, w, vc)

B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()

_, c, h, w = s[C].op.axis
dh, dw = s[C].op.reduce_axis
c, vc = s[C].split(c, VC)
s[C].reorder(c, h, w, vc)


C0 = s.cache_write(C, 'global')
_, c, h, w, vc = s[C0].op.axis
dh, dw = s[C0].op.reduce_axis
oh, ow, ih, iw = s[C0].tile(h, w, VH, VW)
s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
if sch.unroll:
s[C0].unroll(iw)
s[C0].vectorize(vc)

oh, ow, ih, iw = s[C].tile(h, w, 2, 4)
s[C].reorder(oh, ow, dh, dw, ih, iw)
s[C].unroll(ih)
s[C].vectorize(iw)

# # s[C0].compute_at(s[C0], ow)
launch, c, _, _ = s[C].op.axis
s[C].pragma(launch, "parallel_launch_point")

s[C].parallel(c)
s[C].pragma(c, "parallel_launch_point")
s[C].pragma(c, "parallel_stride_pattern")
s[C].pragma(c, "parallel_barrier_when_finish")


s[C0].compute_at(s[C], launch)
_, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
s[C0].pragma(c, "parallel_stride_pattern")
s[C0].pragma(c, "parallel_barrier_when_finish")


s[A2].compute_at(s[C0], oh)
# parallel(s[A2], s[A2].op.axis[1], BC)

# # s[B0].compute_at(s[C0], ow)
s[B1].compute_at(s[C], launch)
c, m, h, w, vc = s[B1].op.axis
s[B1].parallel(c)
s[B1].pragma(c, "parallel_stride_pattern")
s[B1].pragma(c, "parallel_barrier_when_finish")

return s


Expand Down

0 comments on commit a7f0125

Please sign in to comment.