Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conv2d data re-layout fix out of threads bug #514

Merged
merged 3 commits into from
Oct 6, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions topi/python/topi/cuda/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
ofactor = 16
hfactor = 2
ow_size = util.get_const_int(Out.shape[3])
num_thread = ow_size*hfactor
vthread = hfactor
num_thread = ow_size * hfactor
vthread = ofactor
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")

i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor)
ooc, ioc = s[Out].split(oc, factor=vthread)
oh, ih = s[Out].split(h, factor=hfactor)
s[Out].reorder(ooc, oh, ioc, ih, w)
oc = s[Out].fuse(ooc, oh)
w = s[Out].fuse(w, ih)

s[Out].bind(w, thread_x)
s[Out].bind(ioc, thread_xz)
s[Out].bind(oc, block_x)
Expand Down Expand Up @@ -261,9 +260,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):

else:
# scheduler params
vthread_x = util.get_const_int(Out.shape[2])
vthread_x = min(8, util.get_const_int(Out.shape[2]))
num_thread_x = 16
num_thread_y = util.get_const_int(Out.shape[3])
num_thread_y = min(8, util.get_const_int(Out.shape[3]))
ofactor = 8
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
Expand All @@ -272,10 +271,12 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):

i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=num_thread_x)
s[Out].reorder(i, ooc, h, w, ioc)
oh, ih = s[Out].split(h, factor=vthread_x)
ow, iw = s[Out].split(w, factor=num_thread_y)
s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc)
s[Out].bind(ioc, thread_x)
s[Out].bind(w, thread_y)
s[Out].bind(h, thread_xz)
s[Out].bind(iw, thread_y)
s[Out].bind(ih, thread_xz)
s[Out].bind(ooc, block_x)

s[Out_L].compute_at(s[Out], ioc)
Expand All @@ -289,21 +290,19 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)

rfactor = util.get_const_int(Filter.shape[1])
thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x")
num_thread = 512
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

i, ic, h, w = s[temp].op.axis
ic = s[temp].fuse(ic, h, w)
oic, iic = s[temp].split(ic, factor=rfactor)
s[temp].bind(iic, thread_xx)
s[temp].bind(oic, block_xx)

i, h, w, oic, iic = s[temp_R].op.axis
ic = s[temp_R].fuse(oic, iic)
s[temp_R].bind(ic, thread_xx)
h = s[temp_R].fuse(h, w)
s[temp_R].bind(h, block_xx)
i = s[temp].fuse(*s[temp].op.axis)
bx, tx = s[temp].split(i, factor=num_thread)
s[temp].bind(tx, thread_xx)
s[temp].bind(bx, block_xx)

i = s[temp_R].fuse(*s[temp_R].op.axis)
bx, tx = s[temp_R].split(i, factor=num_thread)
s[temp_R].bind(tx, thread_xx)
s[temp_R].bind(bx, block_xx)

#schedule temp_S shared mem load
i, h, w, oc, ic = s[temp_S].op.axis
Expand Down