diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index d1e67cae1192..963044957bc3 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for tensorization on NVIDIA GPU.""" -from .. import Cast +from .. import IntImm, Cast from ..._ffi import register_func from ...runtime import convert from .. import TensorIntrin @@ -315,6 +315,97 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: return mma_sync_desc, mma_sync_impl +def get_mma_fill_intrin(dtype, local_size): + zero = IntImm("int32", 0).astype(dtype) + + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(i, j) + T.reads() + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = zero + + @T.prim_func + def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype)) + + return mma_fill_desc, mma_fill_impl + + +def get_mma_store_intrin(dtype, local_size): + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global") + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(v0, v1) + T.reads(C_warp[thread_id, local_id]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[thread_id, local_id] + + @T.prim_func + def mma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1] + ) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.mma_store( + M_DIM, + N_DIM, + C.access_ptr("w"), + C_warp.data, + C_warp.elem_offset, + s0, + dtype=dtype, + ) + ) + + return mma_store_desc, mma_store_impl + + LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a" TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False)) @@ -352,3 +443,21 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans" TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True)) + +MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32" +TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8)) + +MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16" +TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8)) + +MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" +TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) + +MMA_store_16x16_f32_INTRIN = "mma_store_16x16_f32" +TensorIntrin.register(MMA_store_16x16_f32_INTRIN, *get_mma_store_intrin("float32", 8)) + +MMA_store_16x16_f16_INTRIN = "mma_store_16x16_f16" +TensorIntrin.register(MMA_store_16x16_f16_INTRIN, *get_mma_store_intrin("float16", 8)) + +MMA_store_16x16_i32_INTRIN = "mma_store_16x16_i32" +TensorIntrin.register(MMA_store_16x16_i32_INTRIN, *get_mma_store_intrin("int32", 8)) diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index 043ab4a345e5..fe866e5fea51 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -1,6 +1,4 @@ -import tempfile import tvm -from tvm.script import tir as T import tvm.meta_schedule.testing.te_workload as te_workload from tvm import te, tir from tvm import meta_schedule as ms @@ -8,84 +6,14 @@ LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_INTRIN, MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) import tvm.testing import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.float32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -214,8 +142,8 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN) - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN) ir_module = tvm.IRModule({"main": workload}) diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py index cc6032846825..408579d22044 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py @@ -1,12 +1,11 @@ -import tempfile import tvm -from tvm.script import tir as T -import tvm.meta_schedule.testing.te_workload as te_workload from tvm import te, tir from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_TRANS_INTRIN, MMA_f16f16f32_TRANS_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -14,78 +13,6 @@ import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.float32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -231,8 +158,8 @@ def index_map(i, j): sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN) ir_module = tvm.IRModule({"main": workload}) diff --git a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py index f0f59c0f5209..7572fc3def15 100644 --- a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune.py @@ -1,12 +1,12 @@ -import tempfile import tvm -from tvm.script import tir as T import tvm.meta_schedule.testing.te_workload as te_workload from tvm import te, tir from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_INTRIN, MMA_f16f16f16_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -14,78 +14,6 @@ import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float16", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float16", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float16" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.float16(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float16")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -229,8 +157,8 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f16_INTRIN) - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f16_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f16_INTRIN) ir_module = tvm.IRModule({"main": workload}) diff --git a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py index d716016a6130..7b9fdfcd9202 100644 --- a/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x16_fp16_4k_tune_trans.py @@ -1,12 +1,11 @@ -import tempfile import tvm -from tvm.script import tir as T -import tvm.meta_schedule.testing.te_workload as te_workload from tvm import te, tir from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_INTRIN, LDMATRIX_16x16_B_TRANS_INTRIN, MMA_f16f16f16_TRANS_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) from tvm import meta_schedule as ms @@ -14,78 +13,6 @@ import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="float16", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="float16", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float16" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.float16(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="float16", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float16")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -231,8 +158,8 @@ def index_map(i, j): sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f16_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f16_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f16_INTRIN) ir_module = tvm.IRModule({"main": workload}) diff --git a/tests/python/unittest/test_mma_16x8x32_4k_tune.py b/tests/python/unittest/test_mma_16x8x32_4k_tune.py index b504114872cf..c0663849cf9c 100644 --- a/tests/python/unittest/test_mma_16x8x32_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x32_4k_tune.py @@ -1,11 +1,11 @@ -import tempfile import tvm -from tvm.script import tir as T import tvm.meta_schedule.testing.te_workload as te_workload from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x32_A_INTRIN, LDMATRIX_32x16_B_INTRIN, MMA_i8i8i32_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, shared_32x16_to_ldmatrix_32x16_layout, shared_16x32_to_ldmatrix_32x16_layout, @@ -16,78 +16,6 @@ import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="int32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="int32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.int32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -249,8 +177,8 @@ def index_map_C(i, j): sch.tensorize(loop_a, LDMATRIX_16x32_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_32x16_B_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_i8i8i32_INTRIN) # "mma_sync") - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_i32_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_i32_INTRIN) ir_module = tvm.IRModule({"main": workload}) diff --git a/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py b/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py index d2b4c6b9cb26..9a90ce0f921c 100644 --- a/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py +++ b/tests/python/unittest/test_mma_16x8x32_4k_tune_trans.py @@ -1,12 +1,11 @@ -import tempfile import tvm -from tvm.script import tir as T -import tvm.meta_schedule.testing.te_workload as te_workload from tvm import te, tir from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x32_A_INTRIN, LDMATRIX_16x32_B_TRANS_INTRIN, MMA_i8i8i32_TRANS_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, shared_16x32_to_ldmatrix_32x16_layout, ) @@ -16,78 +15,6 @@ import numpy as np -@T.prim_func -def mma_store_desc(a: T.handle, c: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") - C = T.match_buffer(c, [16, 16], dtype="int32", scope="global") - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - v0, v1 = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) - T.reads(C_warp[thread_id, local_id]) - T.writes(C[v0, v1]) - C[v0, v1] = C_warp[thread_id, local_id] - - -@T.prim_func -def mma_store_impl(a: T.handle, c: T.handle) -> None: - s1 = T.var("int32") - s0 = T.var("int32") - - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) - C = T.match_buffer( - c, [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[s1, s0] - ) - - with T.block("root"): - T.reads(C_warp[0:32, 0:8]) - T.writes(C[0:16, 0:16]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate( - T.mma_store( - 16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="int32" - ) - ) - - -@T.prim_func -def mma_fill_desc(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp") - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - for i0, i1 in T.grid(16, 16): - with T.block("C_warp"): - i_init, j_init = T.axis.remap("SS", [i0, i1]) - thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init) - T.reads() - T.writes(C_warp[thread_id, local_id]) - C_warp[thread_id, local_id] = T.int32(0) - - -@T.prim_func -def mma_fill_impl(a: T.handle) -> None: - C_warp = T.match_buffer(a, [32, 8], dtype="int32", scope="warp", offset_factor=1) - - with T.block("root"): - T.reads() - T.writes(C_warp[0:32, 0:8]) - tx = T.env_thread("threadIdx.x") - T.launch_thread(tx, 32) - - T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="int32")) - - -tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl) -tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl) - N = 4096 M = 4096 K = 4096 @@ -240,8 +167,8 @@ def index_map_C(i, j): sch.tensorize(loop_a, LDMATRIX_16x32_A_INTRIN) sch.tensorize(loop_b, LDMATRIX_16x32_B_TRANS_INTRIN) sch.tensorize(sch.get_loops(block_inner)[-3], MMA_i8i8i32_TRANS_INTRIN) - sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill") - sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store") + sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_i32_INTRIN) + sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_i32_INTRIN) ir_module = tvm.IRModule({"main": workload})