From f1c4f5524a063191cd79bdc9f9a4b8bf1dcdec22 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 26 Jun 2020 22:19:47 +0800 Subject: [PATCH] [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching (#5924) --- src/arith/rewrite_simplify.cc | 13 +- .../unittest/test_arith_rewrite_simplify.py | 46 +++++ .../unittest/test_target_codegen_cuda.py | 184 ++++++++++++++++++ 3 files changed, 241 insertions(+), 2 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 6758c9b569a8..898eecc93845 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -722,8 +722,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { - return broadcast(floordiv(b1, c2), lanes).Eval(); + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + return broadcast(floordiv(b1, c2), lanes).Eval(); + } } } } @@ -847,6 +854,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } else { return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } + } else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); } } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 53ba93dc65e7..c01898635488 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -38,6 +38,8 @@ def test_vector_simplify(): tvm.tir.Ramp(y + x, 1, 2)) ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) + ck.verify(tvm.tir.Broadcast(0, 4) + y, + tvm.tir.Broadcast(y, 4)) # Sub rules ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4)) @@ -55,6 +57,8 @@ def test_vector_simplify(): tvm.tir.Ramp(x * 2, 8, 4)) ck.verify(2 * tvm.tir.Ramp(x, 4, 4), tvm.tir.Ramp(x * 2, 8, 4)) + ck.verify(tvm.tir.Broadcast(0, 4) * x, + tvm.tir.Broadcast(0, 4)) ## DivMod rules tdiv = tvm.tir.truncdiv @@ -69,6 +73,7 @@ def test_vector_simplify(): (x).astype("int32x4")) ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) + # truc mod ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")) ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2), @@ -90,6 +95,27 @@ def test_vector_simplify(): (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) + ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), + tvm.tir.Ramp(fld(x, 4), 2, 5)) + ck.verify(fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4))) + ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), + tvm.tir.Broadcast(x * 2, 4)) + ck.verify(fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4))) + ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4))) + ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), + tvm.tir.Broadcast(fld(x, 16), 4)) + ck.verify(fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), + tvm.tir.Broadcast(fld(x, 8), 4)) + ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5))) + ck.verify(fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4))) + ck.verify(fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), + fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4))) + # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), @@ -98,6 +124,26 @@ def test_vector_simplify(): tvm.tir.Ramp(1, 1, 4)) ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8)) + ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), + tvm.tir.Broadcast(flm(x, 4), 4)) + ck.verify(flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), + flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4))) + ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), + tvm.tir.Ramp(0, 1, 4)) + ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)), + flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5))) + ck.verify(flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)), + flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4))) + ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), + tvm.tir.Ramp(flm(x * 4, 64), 1, 4)) + ck.verify(flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), + tvm.tir.Ramp(flm(x * 8, 64), 2, 4)) + ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + tvm.tir.Ramp(flm(x * 4, 64), 1, 5)) + ck.verify(flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4)) + ck.verify(flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), + flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4))) # Min/Max rules vx = te.var("vx", dtype="int32x2") diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 1a7163ff129d..c977334ed25a 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -485,6 +485,33 @@ def test_cuda_floordiv_with_vectorization(): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) +def test_cuda_floormod_with_vectorization(): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + + with tvm.target.cuda(): + # B[i] = A[floormod(i, k)] + n = 256 + k = 37 + A = te.placeholder((n,), name='A') + B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name='B') + s = te.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], nparts=1) + xio, xii = s[B].split(xi, factor=4) + s[B].vectorize(xii) + s[B].bind(xo, bx) + s[B].bind(xio, tx) + func = tvm.build(s, [A, B], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=(n,)).astype(A.dtype) + b_np = np.array([a_np[i % k] for i in range(0, n)]) + a_nd = tvm.nd.array(a_np, ctx) + b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx) + func(a_nd, b_nd) + tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) + def test_vectorized_casts(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") @@ -693,6 +720,160 @@ def check_cuda(dtype, n, l, padding, lanes): check_cuda("float16", 64, 16, 3, 4) check_cuda("float32", 64, 16, 3, 4) +def vcf_check_common(s, args): + N = 512 + + # To check if every vectorize loop transforms to ramp expr successfully + stmt = tvm.lower(s, args) + # Use this as a stack flag to show whether this stmt is inside a BroadcastNode + inside_broadcast = [False] + + # Possible patterns: + # Reduce init: Store[Ramp] = Broadcast(0) + # Shared memory copy: Store[Ramp] = Load[Ramp] + # Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load] + + def pre_visit(stmt): + if isinstance(stmt, tvm.tir.Broadcast): + inside_broadcast[0] = True + # Check Broadcast[Imm numbers] or Broadcast[Load] patterns + assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load)) + if isinstance(stmt, tvm.tir.Store): + # Check Store[Ramp] pattern + assert isinstance(stmt.index, tvm.tir.Ramp) + if isinstance(stmt, tvm.tir.Load): + # Check Broadcast[Load] or Load[Ramp] patterns + assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp) + # Skip the rest + return stmt + return None + + def post_visit(stmt): + if isinstance(stmt, tvm.tir.Broadcast): + inside_broadcast[0] = False + return None + + tvm.tir.stmt_functor.ir_transform(stmt['main'].body, pre_visit, post_visit) + + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("CUDA device not found, skip the verification.") + return + else: + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + # print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + +def test_vectorized_cooperative_fetching_x(): + N = 512 + A = te.placeholder((N, N), name='A', dtype='float32') + B = te.placeholder((N, N), name='B', dtype='float32') + k = te.reduce_axis((0, N), name='k') + C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) + s = te.create_schedule(C.op) + i, j = s[C].op.axis + k = s[C].op.reduce_axis[0] + + AA = s.cache_read(A, "shared", [C]) + BB = s.cache_read(B, "shared", [C]) + + i3, i4 = s[C].split(i, factor=4) + i2, i3 = s[C].split(i3, factor=2) + i1, i2 = s[C].split(i2, factor=8) + i0, i1 = s[C].split(i1, factor=1) + j3, j4 = s[C].split(j, factor=4) + j2, j3 = s[C].split(j3, factor=2) + j1, j2 = s[C].split(j2, factor=8) + j0, j1 = s[C].split(j1, factor=2) + k1, k2 = s[C].split(k, factor=8) + k0, k1 = s[C].split(k1, factor=8) + s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) + block_it = s[C].fuse(i0, j0) + s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) + vthread_it = s[C].fuse(i1, j1) + s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) + thread_it = s[C].fuse(i2, j2) + s[C].bind(thread_it, tvm.te.thread_axis("threadIdx.x")) + s[C].vectorize(j4) + + s[AA].compute_at(s[C], k0) + iaa, jaa = s[AA].op.axis + s[BB].compute_at(s[C], k0) + ibb, jbb = s[BB].op.axis + aa_fused = s[AA].fuse(iaa, jaa) + bb_fused = s[BB].fuse(ibb, jbb) + aa1, aa2 = s[AA].split(aa_fused, factor=4) + aa0, aa1 = s[AA].split(aa1, factor=64) + bb1, bb2 = s[BB].split(bb_fused, factor=4) + bb0, bb1 = s[BB].split(bb1, factor=64) + s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.x")) + s[AA].vectorize(aa2) + s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.x")) + s[BB].vectorize(bb2) + + vcf_check_common(s, [A, B, C]) + +def test_vectorized_cooperative_fetching_xy(): + N = 512 + A = te.placeholder((N, N), name='A') + B = te.placeholder((N, N), name='B') + k = te.reduce_axis((0, N), name='k') + C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k)) + s = te.create_schedule(C.op) + i, j = s[C].op.axis + k = s[C].op.reduce_axis[0] + + AA = s.cache_read(A, "shared", [C]) + BB = s.cache_read(B, "shared", [C]) + + i3, i4 = s[C].split(i, factor=4) + i2, i3 = s[C].split(i3, factor=2) + i1, i2 = s[C].split(i2, factor=8) + i0, i1 = s[C].split(i1, factor=1) + j3, j4 = s[C].split(j, factor=4) + j2, j3 = s[C].split(j3, factor=2) + j1, j2 = s[C].split(j2, factor=8) + j0, j1 = s[C].split(j1, factor=2) + k1, k2 = s[C].split(k, factor=8) + k0, k1 = s[C].split(k1, factor=8) + s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4) + block_it = s[C].fuse(i0, j0) + s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x")) + vthread_it = s[C].fuse(i1, j1) + s[C].bind(vthread_it, tvm.te.thread_axis("vthread")) + s[C].bind(i2, tvm.te.thread_axis("threadIdx.y")) + s[C].bind(j2, tvm.te.thread_axis("threadIdx.x")) + s[C].vectorize(j4) + + s[AA].compute_at(s[C], k0) + iaa, jaa = s[AA].op.axis + s[BB].compute_at(s[C], k0) + ibb, jbb = s[BB].op.axis + aa_fused = s[AA].fuse(iaa, jaa) + bb_fused = s[BB].fuse(ibb, jbb) + aa2, aa3 = s[AA].split(aa_fused, factor=4) + aa1, aa2 = s[AA].split(aa2, factor=8) + aa0, aa1 = s[AA].split(aa1, factor=8) + bb2, bb3 = s[BB].split(bb_fused, factor=4) + bb1, bb2 = s[BB].split(bb2, factor=8) + bb0, bb1 = s[BB].split(bb1, factor=8) + s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.y")) + s[AA].bind(aa2, tvm.te.thread_axis("threadIdx.x")) + s[AA].vectorize(aa3) + s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.y")) + s[BB].bind(bb2, tvm.te.thread_axis("threadIdx.x")) + s[BB].vectorize(bb3) + + vcf_check_common(s, [A, B, C]) + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() @@ -709,7 +890,10 @@ def check_cuda(dtype, n, l, padding, lanes): test_cuda_reduction() test_cuda_mix_threaded_and_normal_reduction() test_cuda_floordiv_with_vectorization() + test_cuda_floormod_with_vectorization() test_vectorized_intrin1() test_vectorized_intrin2() test_vectorized_popcount() test_cuda_vectorize_load_permute_pad() + test_vectorized_cooperative_fetching_x() + test_vectorized_cooperative_fetching_xy()