From d1751044bdfeea38753605251c90fbaeb4a02aff Mon Sep 17 00:00:00 2001 From: Cody Hao Yu Date: Wed, 18 Sep 2019 04:24:31 +0000 Subject: [PATCH] Fix bias supprot. Add testcase --- topi/python/topi/cuda/dense.py | 18 +++++++++++------- topi/tests/python/test_topi_dense.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/topi/python/topi/cuda/dense.py b/topi/python/topi/cuda/dense.py index 2cada84f79042..5136ee694783d 100644 --- a/topi/python/topi/cuda/dense.py +++ b/topi/python/topi/cuda/dense.py @@ -105,7 +105,7 @@ def _schedule(C): batch, _ = get_const_tuple(A.shape) if batch < 32: return schedule_dense_small_batch(cfg, s, C, outs) - return schedule_dense_large_batch(cfg, s, C) + return schedule_dense_large_batch(cfg, s, C, outs) scheduled_ops = [] @@ -157,7 +157,7 @@ def schedule_dense_small_batch(cfg, s, C, outs): s[C].set_store_predicate(thread_x.var.equal(0)) s[Out].set_store_predicate(thread_x.var.equal(0)) -def schedule_dense_large_batch(cfg, s, C): +def schedule_dense_large_batch(cfg, s, C, outs): """Schedule float32/64 dense with large batch size""" A, B = C.op.input_tensors batch, in_dim = get_const_tuple(A.shape) @@ -201,21 +201,20 @@ def schedule_dense_large_batch(cfg, s, C): else: cfg['tile_k'] = SplitEntity([-1, 1, 1]) - # scheduling template - # memory access + # Explicit memory access AA = s.cache_read(A, "shared", [C]) BB = s.cache_read(B, "shared", [C]) AL = s.cache_read(AA, "local", [C]) BL = s.cache_read(BB, "local", [C]) CC = s.cache_write(C, "local") - # split and reorder computation + # Split and reorder computation bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0]) by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1]) s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi) s[CC].compute_at(s[C], tx) - # binding + # Binding s[C].bind(by, tvm.thread_axis("blockIdx.y")) s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tyz, tvm.thread_axis("vthread")) @@ -223,7 +222,7 @@ def schedule_dense_large_batch(cfg, s, C): s[C].bind(ty, tvm.thread_axis("threadIdx.y")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) - # split reduction + # Split reduction yo, xo = CC.op.axis ko, kt, ki = cfg['tile_k'].apply(s, CC, k) s[CC].reorder(ko, kt, ki, yo, xo) @@ -251,6 +250,11 @@ def schedule_dense_large_batch(cfg, s, C): s[BB].bind(tx, tvm.thread_axis("threadIdx.x")) s[BB].double_buffer() + # Deal with bias + if C.op not in s.outputs: + Out = outs[0].op.output(0) + s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y")) + s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x")) @autotvm.register_topi_compute(dense, ['cuda'], ['int8']) def dense_int8(cfg, data, weight, bias=None, out_dtype=None): diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py index 412eb30501bda..3b747712a173a 100644 --- a/topi/tests/python/test_topi_dense.py +++ b/topi/tests/python/test_topi_dense.py @@ -117,8 +117,9 @@ def check_device(device): def test_dense(): verify_dense(1, 1024, 1000, use_bias=True) verify_dense(1, 1024, 1000, use_bias=False) - verify_dense(2, 1024, 1000, use_bias=True) + verify_dense(128, 1024, 1000, use_bias=False) + verify_dense(128, 1024, 1000, use_bias=True) def test_dense_int8():