Skip to content

Commit

Permalink
Add vectorization to cuda conv2d_nhwc schedule
Browse files Browse the repository at this point in the history
Adding vectorization significantly improved performance. About 6-7x
boost.
  • Loading branch information
echuraev committed Aug 4, 2021
1 parent d38bef5 commit 1a931a1
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/tvm/topi/cuda/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2])
cfg.define_knob("vthread_c", [1, 2])
cfg.define_knob("step", [16, 3, 32, 64])
cfg.define_knob("vectorize", [4, 2, 8, 16])

# fallback support
target = tvm.target.Target.current()
Expand All @@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
vthread_n = cfg["vthread_n"].val
vthread_c = cfg["vthread_c"].val
step = cfg["step"].val
vec_factor = cfg["vectorize"].val
block_factor_c = tile_c * num_thread_c * vthread_c

offset = 8
Expand All @@ -85,8 +87,10 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy")

# Schedule for output
ni, hi, wi, fi = s[output].op.axis
bz = s[output].fuse(hi, wi)
ni, _, wi, fi = s[output].op.axis
bz = wi
fi, vec = s[output].split(fi, factor=vec_factor)
s[output].vectorize(vec)
tx, fi = s[output].split(fi, factor=tile_c)
txz, tx = s[output].split(tx, factor=num_thread_c)
bx, txz = s[output].split(txz, factor=vthread_c)
Expand Down Expand Up @@ -125,6 +129,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):
_, _, ic, o = s[WW].op.axis
t = s[WW].fuse(ic, o)
s[WW].storage_align(ic, W_align - 1, W_align)
t, vec = s[WW].split(t, factor=vec_factor)
s[WW].vectorize(vec)
ty, tx = s[WW].split(t, factor=num_thread_c)
_, ty = s[WW].split(ty, factor=num_thread_n)
s[WW].bind(tx, thread_x)
Expand Down

0 comments on commit 1a931a1

Please sign in to comment.