diff --git a/tests/python/unittest/test_mma_16x8x32_4k_tune.py b/tests/python/unittest/test_mma_16x8x32_4k_tune.py new file mode 100644 index 000000000000..b05f6fedd6fa --- /dev/null +++ b/tests/python/unittest/test_mma_16x8x32_4k_tune.py @@ -0,0 +1,503 @@ +import tempfile +import tvm +from tvm.script import tir as T +import tvm.meta_schedule.testing.te_workload as te_workload +from tvm import te, tir +from tvm import meta_schedule as ms +import tvm.testing +import numpy as np + + +@T.prim_func +def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: + A_shared = T.match_buffer(a, (16, 32), "int8", align=128, offset_factor=16, scope="shared") + A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(A_shared[0:16, 0:32]) + T.writes(A_warp[0:32, 0:16]) + + for ax0, ax1 in T.grid(16, 32): + with T.block("A_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A_shared[v0, v1]) + T.writes(A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4]) + A_warp[v0 % 8 * 4 + v1 % 16 // 4, v1 // 16 * 8 + v0 // 8 * 4 + v1 % 4] = A_shared[v0, v1] + + +@T.prim_func +def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A_shared = T.match_buffer( + a, + (16, 32), + "int8", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + A_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + with T.block("root"): + T.reads(A_shared[0:16, 0:32]) + T.writes(A_warp[0:32, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 0, + 4, + ".b16", + A_warp.data, + 16 * tx, + A_shared.data, + s1 * (tx % 16) + 16 * (tx // 16), + dtype="int8", + ) + ) + + +@T.prim_func +def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: + B_shared = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="shared") + B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(B_shared[0:32, 0:16]) + T.writes(B_warp[0:32, 0:16]) + + for ax0, ax1 in T.grid(32, 16): + with T.block("B_shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(B_shared[v0, v1]) + T.writes(B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4]) + B_warp[v1 % 8 * 4 + v0 % 4, v1 // 8 * 8 + v0 // 16 * 4 + v0 % 4] = B_shared[v0, v1] + + +@T.prim_func +def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + B_shared = T.match_buffer( + a, + (32, 16), + "int8", + align=128, + offset_factor=16, + scope="shared", + strides=[s1, s0], + ) + B_warp = T.match_buffer(c, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + with T.block("root"): + T.reads(B_shared[0:32, 0:16]) + T.writes(B_warp[0:32, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_ldmatrix( + 1, + 4, + ".b16", + B_warp.data, + 16 * tx, + B_shared.data, + s1, + dtype="int8", + ) + ) + + +@T.prim_func +def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) + T.writes(C[0:32, 0:8]) + for i, j, k in T.grid(16, 16, 32): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + T.reads(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4]) + T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) + C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + T.cast(A[i % 8 * 4 + k % 16 // 4, k // 16 * 8 + i // 8 * 4 + k % 4], "int32") * T.cast(B[j % 8 * 4 + k % 4, j // 8 * 8 + k // 16 * 4 + k % 4], "int32") + + +@T.prim_func +def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + B = T.match_buffer(b, (32, 16), "int8", align=128, offset_factor=16, scope="warp") + C = T.match_buffer(c, (32, 8), "int32", align=128, offset_factor=16, scope="warp") + + with T.block("root"): + T.reads(C[0:32, 0:8], A[0:32, 0:16], B[0:32, 0:16]) + T.writes(C[0:32, 0:8]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.ptx_mma( + "m16n8k32", + "row", + "col", + "int8", + "int8", + "int32", + A.data, + A.elem_offset + tx * 16, + B.data, + B.elem_offset + tx * 16, + C.data, + C.elem_offset + tx * 8, + False, + dtype="int32", + ) + ) + + T.evaluate( + T.ptx_mma( + "m16n8k32", + "row", + "col", + "int8", + "int8", + "int32", + A.data, + A.elem_offset + tx * 16, + B.data, + B.elem_offset + tx * 16 + 8, + C.data, + C.elem_offset + tx * 8 + 4, + False, + dtype="int32", + ) + ) + + +@T.prim_func +def mma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") + C = T.match_buffer(c, [16, 16], dtype="int32", scope="global") + + with T.block("root"): + T.reads(C_warp[0:32, 0:8]) + T.writes(C[0:16, 0:16]) + for ax1_0, i0, i1 in T.grid(2, 32, 4): + with T.block("C_warp"): + v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4) + v1 = T.axis.spatial(16, ax1_0 * 8 + i0 % 4 * 2 + i1 % 2) + + T.reads(C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] + + +@T.prim_func +def mma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + + C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) + C = T.match_buffer( + c, [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[s1, s0] + ) + + with T.block("root"): + T.reads(C_warp[0:32, 0:8]) + T.writes(C[0:16, 0:16]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate( + T.mma_store( + 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="int32" + ) + ) + + +@T.prim_func +def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:32, 0:8]) + for i0, i1 in T.grid(32, 8): + with T.block("C_warp"): + i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4) + j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4) + T.reads() + T.writes( + C_warp[ + i_init % 8 * 4 + j_init % 8 // 2, + j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2, + ] + ) + C_warp[ + i_init % 8 * 4 + j_init % 8 // 2, + j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2, + ] = T.int32(0) + + +@T.prim_func +def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:32, 0:8]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, 32) + + T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) + + +tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl) +tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl) +tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl) +tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) +tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) + +N = 4096 +M = 4096 +K = 4096 + +def matmul_int8(n, m, k): + a = te.placeholder((n, k), name="A", dtype="int8") + b = te.placeholder((k, m), name="B", dtype="int8") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="int32", value=a[i, k]) + v_b = tir.Cast(dtype="int32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + return (a, b, c) + + +workload = te.create_prim_func(matmul_int8(n=N, m=M, k=K)) + +tune = False + + +def schedule(sch: tir.Schedule): + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 32]) + k, k_tc = sch.split(k, factors=[None, 16]) + + sch.reorder( + i, + j, + k, + i_tc, + j_tc, + k_tc, + ) + block_inner = sch.blockize(i_tc) + + block_outer, block_inner = block_inner, block + + if tune: + i_factors = sch.sample_perfect_tile(i, n=5) + j_factors = sch.sample_perfect_tile(j, n=5) + k_factors = sch.sample_perfect_tile(k, n=3) + num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2]) + else: + i_factors = [4, 8, 2, 4, 1] + j_factors = [1, 32, 2, 1, 2] + k_factors = [128, 2, 1] + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder( + i0, + j0, # S => blockIdx.x + i1, + j1, # S => blockIdx.y + j2, + i2, # S => threadIdx.y + # cache_write here + k0, # R + # vectorized cooperative fetching here + k1, # R + i3, + j3, # S + k2, # R + i4, + j4, + # S + ) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k0) + vector_size = 16 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + f_0, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) + + return block_read + + A_sh = fetch_to_shared(block_outer, 0, 2) + B_sh = fetch_to_shared(block_outer, 1, 2) + + loop = sch.get_loops(block_outer)[-1] + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + + def tile_wmma_fragment(block_read, height, is_b=False): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + if is_b: + j0, j1 = sch.split(j, factors=[32, None]) + else: + j0, j1 = sch.split(j, factors=[None, 32]) + sch.reorder(i0, j0, i1, j1) + return i1 + + def shared_16x16_to_ldmatrix_32x8_layout(i, j): + i_0 = i // 16 + j_0 = j // 16 + + i = i % 16 + j = j % 16 + + thread_id = 4 * (i % 8) + (j % 8) // 2 + return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2 + + def shared_16x32_to_ldmatrix_32x16_layout(i, j): + i_0 = i // 16 + j_0 = j // 32 + + i = i % 16 + j = j % 32 + + thread_id = 4 * (i % 8) + (j % 16) // 4 + return i_0, j_0, thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 + + + def shared_32x16_to_ldmatrix_32x16_layout(i, j): + i_0 = i // 32 + j_0 = j // 16 + + i = i % 32 + j = j % 16 + + thread_id = (i % 4) + 4 * (j % 8) + return i_0, j_0, thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 + + loop_a = tile_wmma_fragment(A_warp, 16) + loop_b = tile_wmma_fragment(B_warp, 16, True) + + sch.transform_layout(A_warp, 0, "write", index_map=shared_16x32_to_ldmatrix_32x16_layout) + sch.transform_layout(B_warp, 0, "write", index_map=shared_32x16_to_ldmatrix_32x16_layout) + sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout) + + sch.tensorize(loop_a, "mma.ldmatrix_a") + sch.tensorize(loop_b, "mma.ldmatrix_b") + + mma_loop = sch.get_loops(block_inner)[-3] + sch.tensorize(mma_loop, "mma_sync") + + block_init_c = sch.get_block("C_init") + init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:] + f_0, f_1 = sch.split(init_loop1, factors=[None, 8]) + f_2, f_3 = sch.split(init_loop2, factors=[None, 4]) + sch.reorder(f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.tensorize(fused_1, "mma_fill") + + warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:] + f_0, f_1 = sch.split(warp_loop1, factors=[None, 8]) + outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2]) + sch.reorder(outer, f_1, f_2, f_0, f_3) + fused_1 = sch.fuse(f_1, f_2) + fused_2 = sch.fuse(f_0, f_3) + sch.tensorize(outer, "mma_store") + # print(sch.mod.script()) + # return + + +ir_module = tvm.IRModule({"main": workload}) +sch = tvm.tir.Schedule(ir_module) +schedule(sch) +print(sch.mod.script()) + +# if tune: +# with tempfile.TemporaryDirectory() as work_dir: +# sch = ms.tune_tir( +# mod=workload, +# target=tvm.target.Target("nvidia/geforce-rtx-3070"), +# config=ms.TuneConfig( +# strategy="evolutionary", +# num_trials_per_iter=32, +# max_trials_per_task=128, +# max_trials_global=128, +# ), +# work_dir=work_dir, +# space=ms.space_generator.ScheduleFn(schedule), +# ) +# if sch is None: +# print("No valid schedule found!") +# else: +# print(sch.mod.script()) +# print(sch.trace) +# else: +# target = "cuda" +# f = tvm.build(sch.mod["main"], target=target, name="dense") + +# dev = tvm.device("cuda", 0) +# a_np = np.random.uniform(size=(N, K)).astype("int8") +# b_np = np.random.uniform(size=(K, M)).astype("int8") +# c_np = np.dot(a_np.astype("int32"), b_np.astype("int32")) +# a = tvm.nd.array(a_np, dev) +# b = tvm.nd.array(b_np, dev) +# c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev) +# f = tvm.build(sch.mod["main"], target="cuda", name="dense") + +# print(f.imported_modules[0].get_source()) +# f(a, b, c) +# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +# print("ok") + +# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +# gflops = (N * M * K) * 2 / 1e9 +# time_ms = evaluator(a, b, c).mean * 1e3 +# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))