From ad8d6bc9cd10c867da37ca533d583c2e5276737e Mon Sep 17 00:00:00 2001 From: "Tang, Shizhi" Date: Mon, 6 Apr 2020 23:43:38 +0800 Subject: [PATCH] fix lower_warp_memory (#5247) --- src/tir/transforms/lower_warp_memory.cc | 6 +-- .../test_tir_transform_lower_warp_memory.py | 51 ++++++++++++++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0361100f1f57..1921db53cb06 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: - PrimExpr Mutate_(const VarNode* op) { + PrimExpr VisitExpr_(const VarNode* op) override { CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const StoreNode* op) { + Stmt VisitStmt_(const StoreNode* op) override { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); @@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator { } } - PrimExpr Mutate_(const LoadNode* op) { + PrimExpr VisitExpr_(const LoadNode* op) override { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index cf6ef721fcc5..25204eb1d906 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -16,8 +16,11 @@ # under the License. import tvm from tvm import te +from tvm.contrib.nvcc import have_fp16 -def test_lower_warp_mem(): +import numpy as np + +def test_lower_warp_memory_local_scope(): m = 128 A = te.placeholder((m,), name='A') B = te.compute((m,), lambda i: A[i] + 3, name='B') @@ -44,6 +47,50 @@ def test_lower_warp_mem(): assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) +def test_lower_warp_memory_cuda_end_to_end(): + def check_cuda(dtype): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + m = 128 + A = te.placeholder((m,), name='A', dtype=dtype) + B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B') + + cuda_target = tvm.target.create("cuda") + assert cuda_target.thread_warp_size == 32 + with cuda_target: + s = te.create_schedule(B.op) + AA = s.cache_read(A, "warp", [B]) + xo, xi = s[B].split(B.op.axis[0], 64) + xi0, xi1 = s[B].split(xi, factor=32) + tx = te.thread_axis("threadIdx.x") + s[B].bind(xi1, tx) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + s[AA].compute_at(s[B], xo) + xo, xi = s[AA].split(s[AA].op.axis[0], 32) + s[AA].bind(xi, tx) + + ctx = tvm.gpu(0) + func = tvm.build(s, [A, B], "cuda") + A_np = np.array(list(range(m)), dtype=dtype) + B_np = np.array( + list(range(1, 32)) + [0] + + list(range(33, 64)) + [32] + + list(range(65, 96)) + [64] + + list(range(97, 128)) + [96], + dtype=dtype) + A_nd = tvm.nd.array(A_np, ctx) + B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx) + func(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3) + + check_cuda("float32") + check_cuda("float16") if __name__ == "__main__": - test_lower_warp_mem() + test_lower_warp_memory_local_scope() + test_lower_warp_memory_cuda_end_to_end()