Skip to content

Commit

Permalink
[Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 authored and Trevor Morris committed Jun 30, 2020
1 parent d2d88b8 commit 4187057
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(floordiv(b1, c2), lanes).Eval();
if (ramp_min == ramp_max) {
// If b1 can devide c2
if (bmod->coeff % c2val == 0) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
// If all indices can be guaranteed to settle inside a coeff range
if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
}
}
}
Expand Down Expand Up @@ -847,6 +854,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
} else {
return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
} else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) {
return ramp(floormod(b1, c2), c1, lanes).Eval();
}
}
}
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_vector_simplify():
tvm.tir.Ramp(y + x, 1, 2))
ck.verify(y.astype("int32x2") + x.astype("int32x2"),
(y + x).astype("int32x2"))
ck.verify(tvm.tir.Broadcast(0, 4) + y,
tvm.tir.Broadcast(y, 4))
# Sub rules
ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4),
tvm.tir.Ramp(x - y, 2, 4))
Expand All @@ -55,6 +57,8 @@ def test_vector_simplify():
tvm.tir.Ramp(x * 2, 8, 4))
ck.verify(2 * tvm.tir.Ramp(x, 4, 4),
tvm.tir.Ramp(x * 2, 8, 4))
ck.verify(tvm.tir.Broadcast(0, 4) * x,
tvm.tir.Broadcast(0, 4))

## DivMod rules
tdiv = tvm.tir.truncdiv
Expand All @@ -69,6 +73,7 @@ def test_vector_simplify():
(x).astype("int32x4"))
ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
# truc mod
ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
tmod(y, x).astype("int32x2"))
ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2),
Expand All @@ -90,6 +95,27 @@ def test_vector_simplify():
(x).astype("int32x4"))
ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)),
tvm.tir.Ramp(fld(x, 4), 2, 5))
ck.verify(fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(x * 2, 4))
ck.verify(fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Broadcast(fld(x, 16), 4))
ck.verify(fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Broadcast(fld(x, 8), 4))
ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)))
ck.verify(fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))
# floor mod
ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")),
flm(y, x).astype("int32x2"))
ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2),
Expand All @@ -98,6 +124,26 @@ def test_vector_simplify():
tvm.tir.Ramp(1, 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
flm(tvm.tir.Ramp(1, 15, 4), 8))
ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(flm(x, 4), 4))
ck.verify(flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Ramp(0, 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)),
flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)))
ck.verify(flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)),
flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 4, 64), 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 8, 64), 2, 4))
ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
tvm.tir.Ramp(flm(x * 4, 64), 1, 5))
ck.verify(flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))

# Min/Max rules
vx = te.var("vx", dtype="int32x2")
Expand Down
184 changes: 184 additions & 0 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,33 @@ def test_cuda_floordiv_with_vectorization():
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

def test_cuda_floormod_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

with tvm.target.cuda():
# B[i] = A[floormod(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name='B')
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
xio, xii = s[B].split(xi, factor=4)
s[B].vectorize(xii)
s[B].bind(xo, bx)
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], 'cuda')

ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i % k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

def test_vectorized_casts():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
Expand Down Expand Up @@ -693,6 +720,160 @@ def check_cuda(dtype, n, l, padding, lanes):
check_cuda("float16", 64, 16, 3, 4)
check_cuda("float32", 64, 16, 3, 4)

def vcf_check_common(s, args):
N = 512

# To check if every vectorize loop transforms to ramp expr successfully
stmt = tvm.lower(s, args)
# Use this as a stack flag to show whether this stmt is inside a BroadcastNode
inside_broadcast = [False]

# Possible patterns:
# Reduce init: Store[Ramp] = Broadcast(0)
# Shared memory copy: Store[Ramp] = Load[Ramp]
# Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load]

def pre_visit(stmt):
if isinstance(stmt, tvm.tir.Broadcast):
inside_broadcast[0] = True
# Check Broadcast[Imm numbers] or Broadcast[Load] patterns
assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load))
if isinstance(stmt, tvm.tir.Store):
# Check Store[Ramp] pattern
assert isinstance(stmt.index, tvm.tir.Ramp)
if isinstance(stmt, tvm.tir.Load):
# Check Broadcast[Load] or Load[Ramp] patterns
assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp)
# Skip the rest
return stmt
return None

def post_visit(stmt):
if isinstance(stmt, tvm.tir.Broadcast):
inside_broadcast[0] = False
return None

tvm.tir.stmt_functor.ir_transform(stmt['main'].body, pre_visit, post_visit)

if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("CUDA device not found, skip the verification.")
return
else:
tgt = tvm.target.cuda()
mod = tvm.build(s, args, tgt)
# To check if every vectorize loop transforms to correct instruction
# print(mod.imported_modules[0].get_source())

ctx = tvm.context("cuda", 0)
a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx)
b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx)
c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx)
mod(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(
a.asnumpy(), b.asnumpy()), rtol=1e-5)

def test_vectorized_cooperative_fetching_x():
N = 512
A = te.placeholder((N, N), name='A', dtype='float32')
B = te.placeholder((N, N), name='B', dtype='float32')
k = te.reduce_axis((0, N), name='k')
C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
s = te.create_schedule(C.op)
i, j = s[C].op.axis
k = s[C].op.reduce_axis[0]

AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])

i3, i4 = s[C].split(i, factor=4)
i2, i3 = s[C].split(i3, factor=2)
i1, i2 = s[C].split(i2, factor=8)
i0, i1 = s[C].split(i1, factor=1)
j3, j4 = s[C].split(j, factor=4)
j2, j3 = s[C].split(j3, factor=2)
j1, j2 = s[C].split(j2, factor=8)
j0, j1 = s[C].split(j1, factor=2)
k1, k2 = s[C].split(k, factor=8)
k0, k1 = s[C].split(k1, factor=8)
s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4)
block_it = s[C].fuse(i0, j0)
s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x"))
vthread_it = s[C].fuse(i1, j1)
s[C].bind(vthread_it, tvm.te.thread_axis("vthread"))
thread_it = s[C].fuse(i2, j2)
s[C].bind(thread_it, tvm.te.thread_axis("threadIdx.x"))
s[C].vectorize(j4)

s[AA].compute_at(s[C], k0)
iaa, jaa = s[AA].op.axis
s[BB].compute_at(s[C], k0)
ibb, jbb = s[BB].op.axis
aa_fused = s[AA].fuse(iaa, jaa)
bb_fused = s[BB].fuse(ibb, jbb)
aa1, aa2 = s[AA].split(aa_fused, factor=4)
aa0, aa1 = s[AA].split(aa1, factor=64)
bb1, bb2 = s[BB].split(bb_fused, factor=4)
bb0, bb1 = s[BB].split(bb1, factor=64)
s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.x"))
s[AA].vectorize(aa2)
s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.x"))
s[BB].vectorize(bb2)

vcf_check_common(s, [A, B, C])

def test_vectorized_cooperative_fetching_xy():
N = 512
A = te.placeholder((N, N), name='A')
B = te.placeholder((N, N), name='B')
k = te.reduce_axis((0, N), name='k')
C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
s = te.create_schedule(C.op)
i, j = s[C].op.axis
k = s[C].op.reduce_axis[0]

AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])

i3, i4 = s[C].split(i, factor=4)
i2, i3 = s[C].split(i3, factor=2)
i1, i2 = s[C].split(i2, factor=8)
i0, i1 = s[C].split(i1, factor=1)
j3, j4 = s[C].split(j, factor=4)
j2, j3 = s[C].split(j3, factor=2)
j1, j2 = s[C].split(j2, factor=8)
j0, j1 = s[C].split(j1, factor=2)
k1, k2 = s[C].split(k, factor=8)
k0, k1 = s[C].split(k1, factor=8)
s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4)
block_it = s[C].fuse(i0, j0)
s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x"))
vthread_it = s[C].fuse(i1, j1)
s[C].bind(vthread_it, tvm.te.thread_axis("vthread"))
s[C].bind(i2, tvm.te.thread_axis("threadIdx.y"))
s[C].bind(j2, tvm.te.thread_axis("threadIdx.x"))
s[C].vectorize(j4)

s[AA].compute_at(s[C], k0)
iaa, jaa = s[AA].op.axis
s[BB].compute_at(s[C], k0)
ibb, jbb = s[BB].op.axis
aa_fused = s[AA].fuse(iaa, jaa)
bb_fused = s[BB].fuse(ibb, jbb)
aa2, aa3 = s[AA].split(aa_fused, factor=4)
aa1, aa2 = s[AA].split(aa2, factor=8)
aa0, aa1 = s[AA].split(aa1, factor=8)
bb2, bb3 = s[BB].split(bb_fused, factor=4)
bb1, bb2 = s[BB].split(bb2, factor=8)
bb0, bb1 = s[BB].split(bb1, factor=8)
s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.y"))
s[AA].bind(aa2, tvm.te.thread_axis("threadIdx.x"))
s[AA].vectorize(aa3)
s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.y"))
s[BB].bind(bb2, tvm.te.thread_axis("threadIdx.x"))
s[BB].vectorize(bb3)

vcf_check_common(s, [A, B, C])

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
Expand All @@ -709,7 +890,10 @@ def check_cuda(dtype, n, l, padding, lanes):
test_cuda_reduction()
test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_cuda_floormod_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
test_vectorized_popcount()
test_cuda_vectorize_load_permute_pad()
test_vectorized_cooperative_fetching_x()
test_vectorized_cooperative_fetching_xy()

0 comments on commit 4187057

Please sign in to comment.