Skip to content

Commit

Permalink
[TE] Fix MakeLoopNest for warp memory (apache#5382)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored May 9, 2020
1 parent aded92d commit 06efa7f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/te/operation/op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,21 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = dom->min;
} else {
runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
if (stage->scope == "" || stage->scope == "warp" ||
if (stage->scope == "" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
value_map[iv] = var;
} else if (stage->scope == "warp" && ts.rank == 1) {
// To determine whether a thread index is inside or outside a warp, we need
// to know the thread extent. We leave a warning for now.
if (ts.dim_index == 0) {
value_map[iv] = var;
} else {
LOG(WARNING)
<< "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
<< "TVM assumes only threadIdx.x indicates threads inside a warp, "
<< "while threadIdx.y and threadIdx.z indicates different warps.";
value_map[iv] = dom->min;
}
} else {
value_map[iv] = dom->min;
}
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,42 @@ def test_lower_warp_memory_local_scope():
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)

def test_lower_warp_memory_correct_indices():
n = 32
A = te.placeholder((2, n, n), name='A', dtype="float32")
C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')

s = te.create_schedule(C.op)
bk_x = te.thread_axis("blockIdx.x")
th_y = te.thread_axis("threadIdx.y")
th_x = te.thread_axis("threadIdx.x")
B = s.cache_read(A, "warp", [C])
cx, ci, cj = C.op.axis
bx, bi, bj = B.op.axis
s[C].bind(cj, th_x)
s[C].bind(cx, bk_x)
s[B].compute_at(s[C], cx)
s[B].bind(bi, th_y)
s[B].bind(bj, th_x)

bounds = tvm.te.schedule.InferBound(s)
ir = tvm.te.schedule.ScheduleOps(s, bounds)
inner_func = ir.body.body.body.body
store_A_warp = inner_func.body.seq[0].body.body
indices = list(store_A_warp.args)

# A.warp is actually many buffers, one for each warp, although they are all called A.warp
# 1. If we are accessing from different threads within a same warp (different
# threadIdx.x), we need to distinguish between each elements using threadIdx.x,
# so threadIdx.x is one if the indices.
# 2. If we are accessing from different warps (different threadIdx.y), we are actually
# assessing different buffers, so there is no need to distinguish from elements,
# and therefore threadIdx.y is NOT a index.
idx_names = map(lambda x: x.name,
filter(lambda x: type(x) is tvm.tir.expr.Var, indices))
assert "threadIdx.x" in idx_names
assert "threadIdx.y" not in idx_names

def test_lower_warp_memory_cuda_end_to_end():
def check_cuda(dtype):
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
Expand Down Expand Up @@ -182,6 +218,7 @@ def check_cuda(dtype):

if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()

0 comments on commit 06efa7f

Please sign in to comment.