From 3b873226b777365f24b4ab25175244ffa19fade5 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sat, 18 Apr 2020 15:56:27 +0800 Subject: [PATCH 1/2] fix recursion in lower_warp_memory --- src/tir/transforms/lower_warp_memory.cc | 2 +- .../test_tir_transform_lower_warp_memory.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 71e7cfaf4832..543d098a654c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -379,7 +379,7 @@ class WarpMemoryRewriter : private StmtMutator { Stmt VisitStmt_(const AllocateNode* op) { if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); - return rewriter.Rewrite(op); + return StmtMutator::VisitStmt_(rewriter.Rewrite(op).as()); } else { return StmtMutator::VisitStmt_(op); } diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index a761cf1a95d8..51be480a7cba 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -132,7 +132,56 @@ def check_cuda(dtype): check_cuda("float32") check_cuda("float16") +def test_lower_warp_memory_cuda_2_buffers(): + def check_cuda(dtype): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + m = 32 + A = te.placeholder((m,), name='A', dtype=dtype) + B = te.placeholder((m,), name='B', dtype=dtype) + C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C') + + cuda_target = tvm.target.create("cuda") + assert m <= cuda_target.thread_warp_size + with cuda_target: + s = te.create_schedule(C.op) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + + AA = s.cache_read(A, "warp", [C]) + BB = s.cache_read(B, "warp", [C]) + xo, xi = s[C].split(C.op.axis[0], nparts=1) + s[C].bind(xi, tx) + s[C].bind(xo, bx) + s[AA].compute_at(s[C], xo) + s[BB].compute_at(s[C], xo) + xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1) + s[AA].bind(xo, bx) + s[AA].bind(xi, tx) + xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1) + s[BB].bind(xo, bx) + s[BB].bind(xi, tx) + + ctx = tvm.gpu(0) + func = tvm.build(s, [A, B, C], "cuda") + AB_np = np.array(list(range(m)), dtype=dtype) + C_np = np.array(list(range(1, m)) + [0], dtype=dtype) * 2 + A_nd = tvm.nd.array(AB_np, ctx) + B_nd = tvm.nd.array(AB_np, ctx) + C_nd = tvm.nd.array(np.zeros(C_np.shape, dtype=C_np.dtype), ctx) + func(A_nd, B_nd, C_nd) + tvm.testing.assert_allclose(C_nd.asnumpy(), C_np, rtol=1e-3) + + check_cuda("float32") + check_cuda("float16") + if __name__ == "__main__": test_lower_warp_memory_local_scope() test_lower_warp_memory_cuda_end_to_end() test_lower_warp_memory_cuda_half_a_warp() + test_lower_warp_memory_cuda_2_buffers() From 0fdb6dc13fc4ca73d76323902f68d4b134fec9b7 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sun, 19 Apr 2020 08:54:21 +0800 Subject: [PATCH 2/2] post-order mutation --- src/tir/transforms/lower_warp_memory.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 543d098a654c..0aee3c284422 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -377,12 +377,13 @@ class WarpMemoryRewriter : private StmtMutator { private: Stmt VisitStmt_(const AllocateNode* op) { + auto ret = StmtMutator::VisitStmt_(op); + op = ret.as(); if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); - return StmtMutator::VisitStmt_(rewriter.Rewrite(op).as()); - } else { - return StmtMutator::VisitStmt_(op); + ret = rewriter.Rewrite(op); } + return ret; } Stmt VisitStmt_(const AttrStmtNode* op) {