Skip to content

Commit

Permalink
Fix bias supprot. Add testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Sep 18, 2019
1 parent 5ab19d4 commit d175104
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions topi/python/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _schedule(C):
batch, _ = get_const_tuple(A.shape)
if batch < 32:
return schedule_dense_small_batch(cfg, s, C, outs)
return schedule_dense_large_batch(cfg, s, C)
return schedule_dense_large_batch(cfg, s, C, outs)

scheduled_ops = []

Expand Down Expand Up @@ -157,7 +157,7 @@ def schedule_dense_small_batch(cfg, s, C, outs):
s[C].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))

def schedule_dense_large_batch(cfg, s, C):
def schedule_dense_large_batch(cfg, s, C, outs):
"""Schedule float32/64 dense with large batch size"""
A, B = C.op.input_tensors
batch, in_dim = get_const_tuple(A.shape)
Expand Down Expand Up @@ -201,29 +201,28 @@ def schedule_dense_large_batch(cfg, s, C):
else:
cfg['tile_k'] = SplitEntity([-1, 1, 1])

# scheduling template
# memory access
# Explicit memory access
AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")

# split and reorder computation
# Split and reorder computation
bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
s[CC].compute_at(s[C], tx)

# binding
# Binding
s[C].bind(by, tvm.thread_axis("blockIdx.y"))
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tyz, tvm.thread_axis("vthread"))
s[C].bind(txz, tvm.thread_axis("vthread"))
s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

# split reduction
# Split reduction
yo, xo = CC.op.axis
ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
s[CC].reorder(ko, kt, ki, yo, xo)
Expand Down Expand Up @@ -251,6 +250,11 @@ def schedule_dense_large_batch(cfg, s, C):
s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
s[BB].double_buffer()

# Deal with bias
if C.op not in s.outputs:
Out = outs[0].op.output(0)
s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))

@autotvm.register_topi_compute(dense, ['cuda'], ['int8'])
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
Expand Down
3 changes: 2 additions & 1 deletion topi/tests/python/test_topi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ def check_device(device):
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)

verify_dense(2, 1024, 1000, use_bias=True)
verify_dense(128, 1024, 1000, use_bias=False)
verify_dense(128, 1024, 1000, use_bias=True)


def test_dense_int8():
Expand Down

0 comments on commit d175104

Please sign in to comment.