From 215f656d757163b9c729a9b85bfcf82fa03ffabc Mon Sep 17 00:00:00 2001 From: "Tang, Shizhi" Date: Sat, 9 May 2020 08:52:30 +0800 Subject: [PATCH] [TE] Fix MakeLoopNest for warp memory (#5382) --- src/te/operation/op_util.cc | 14 ++++++- .../test_tir_transform_lower_warp_memory.py | 37 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index f7e0e51fd16a..bee573e854b9 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -163,9 +163,21 @@ MakeLoopNest(const Stage& stage, value_map[iv] = dom->min; } else { runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); - if (stage->scope == "" || stage->scope == "warp" || + if (stage->scope == "" || static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { value_map[iv] = var; + } else if (stage->scope == "warp" && ts.rank == 1) { + // To determine whether a thread index is inside or outside a warp, we need + // to know the thread extent. We leave a warning for now. + if (ts.dim_index == 0) { + value_map[iv] = var; + } else { + LOG(WARNING) + << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " + << "TVM assumes only threadIdx.x indicates threads inside a warp, " + << "while threadIdx.y and threadIdx.z indicates different warps."; + value_map[iv] = dom->min; + } } else { value_map[iv] = dom->min; } diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 51be480a7cba..bd553772e087 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,6 +47,42 @@ def test_lower_warp_memory_local_scope(): assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.body.extents[0].value == 2) +def test_lower_warp_memory_correct_indices(): + n = 32 + A = te.placeholder((2, n, n), name='A', dtype="float32") + C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C') + + s = te.create_schedule(C.op) + bk_x = te.thread_axis("blockIdx.x") + th_y = te.thread_axis("threadIdx.y") + th_x = te.thread_axis("threadIdx.x") + B = s.cache_read(A, "warp", [C]) + cx, ci, cj = C.op.axis + bx, bi, bj = B.op.axis + s[C].bind(cj, th_x) + s[C].bind(cx, bk_x) + s[B].compute_at(s[C], cx) + s[B].bind(bi, th_y) + s[B].bind(bj, th_x) + + bounds = tvm.te.schedule.InferBound(s) + ir = tvm.te.schedule.ScheduleOps(s, bounds) + inner_func = ir.body.body.body.body + store_A_warp = inner_func.body.seq[0].body.body + indices = list(store_A_warp.args) + + # A.warp is actually many buffers, one for each warp, although they are all called A.warp + # 1. If we are accessing from different threads within a same warp (different + # threadIdx.x), we need to distinguish between each elements using threadIdx.x, + # so threadIdx.x is one if the indices. + # 2. If we are accessing from different warps (different threadIdx.y), we are actually + # assessing different buffers, so there is no need to distinguish from elements, + # and therefore threadIdx.y is NOT a index. + idx_names = map(lambda x: x.name, + filter(lambda x: type(x) is tvm.tir.expr.Var, indices)) + assert "threadIdx.x" in idx_names + assert "threadIdx.y" not in idx_names + def test_lower_warp_memory_cuda_end_to_end(): def check_cuda(dtype): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): @@ -182,6 +218,7 @@ def check_cuda(dtype): if __name__ == "__main__": test_lower_warp_memory_local_scope() + test_lower_warp_memory_correct_indices() test_lower_warp_memory_cuda_end_to_end() test_lower_warp_memory_cuda_half_a_warp() test_lower_warp_memory_cuda_2_buffers()