Skip to content

Commit

Permalink
tensorizing mma store
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent e80a1f1 commit 71fe5fe
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 55 deletions.
2 changes: 2 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ TVM_DLL const Op& ptx_mma_sp();
*/
TVM_DLL const Op& ptx_ldmatrix();

TVM_DLL const Op& mma_store();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(mma_store)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

Expand Down
143 changes: 88 additions & 55 deletions tests/python/unittest/test_mma_16x8x8_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@

@T.prim_func
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
A_shared = T.match_buffer(
a, (16, 8), "float16", align=128, offset_factor=1, scope="shared"
)
A_warp = T.match_buffer(
c, (32, 4), "float16", align=128, offset_factor=1, scope="warp"
)
A_shared = T.match_buffer(a, (16, 8), "float16", align=128, offset_factor=1, scope="shared")
A_warp = T.match_buffer(c, (32, 4), "float16", align=128, offset_factor=1, scope="warp")

with T.block("root"):
T.reads(A_shared[0:16, 0:8])
Expand All @@ -42,9 +38,7 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
scope="shared",
strides=[s1, s0],
)
A_warp = T.match_buffer(
c, (32, 4), "float16", align=128, offset_factor=1, scope="warp"
)
A_warp = T.match_buffer(c, (32, 4), "float16", align=128, offset_factor=1, scope="warp")
with T.block("root"):
T.reads(A_shared[0:16, 0:8])
T.writes(A_warp[0:32, 0:4])
Expand All @@ -67,12 +61,8 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:

@T.prim_func
def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
B_shared = T.match_buffer(
a, (8, 8), "float16", align=128, offset_factor=1, scope="shared"
)
B_shared_warp = T.match_buffer(
c, (32, 2), "float16", align=128, offset_factor=1, scope="warp"
)
B_shared = T.match_buffer(a, (8, 8), "float16", align=128, offset_factor=1, scope="shared")
B_shared_warp = T.match_buffer(c, (32, 2), "float16", align=128, offset_factor=1, scope="warp")

with T.block("root"):
T.reads(B_shared[0:8, 0:8])
Expand All @@ -99,9 +89,7 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
scope="shared",
strides=[s1, s0],
)
B_warp = T.match_buffer(
c, (32, 2), "float16", align=128, offset_factor=1, scope="warp"
)
B_warp = T.match_buffer(c, (32, 2), "float16", align=128, offset_factor=1, scope="warp")
with T.block("root"):
T.reads(B_shared[0:8, 0:8])
T.writes(B_warp[0:32, 0:2])
Expand Down Expand Up @@ -141,9 +129,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
T.writes(C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2] = C[
i % 8 * 4 + j % 8 // 2, i % 16 // 8 * 2 + j % 2
] + T.cast(
A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2], "float32"
) * T.cast(
] + T.cast(A[i % 8 * 4 + k % 8 // 2, i % 16 // 8 * 2 + k % 2], "float32") * T.cast(
B[j % 8 * 4 + k % 8 // 2, k % 2], "float32"
)

Expand Down Expand Up @@ -179,9 +165,41 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
)


@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp")
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global")

with T.block("root"):
T.reads(C_warp[0:32, 0:4])
T.writes(C[0:16, 0:8])
for i0, i1 in T.grid(32, 4):
with T.block("C_warp"):
v0 = T.axis.spatial(16, i1 // 2 * 8 + i0 // 4)
v1 = T.axis.spatial(8, (i0 % 4) * 2 + i1 % 2)
T.reads(C_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[v0 % 8 * 4 + v1 // 2, v0 // 8 * 2 + v1 % 2]


@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp")
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global")

with T.block("root"):
T.reads(C_warp[0:32, 0:4])
T.writes(C[0:16, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(T.mma_store("m16n8", C.data, C.elem_offset, C_warp.access_ptr("r"), tx, dtype="float32"))


tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)

N = 4096
M = 4096
Expand All @@ -201,8 +219,12 @@ def schedule(sch: tir.Schedule):
k, k_tc = sch.split(k, factors=[None, 8])

sch.reorder(
i, j, k,
i_tc, j_tc, k_tc,
i,
j,
k,
i_tc,
j_tc,
k_tc,
)
block_inner = sch.blockize(i_tc)

Expand All @@ -224,16 +246,21 @@ def schedule(sch: tir.Schedule):
k0, k1, k2 = sch.split(k, k_factors)

sch.reorder(
i0, j0, # S => blockIdx.x
i1, j1, # S => blockIdx.y
j2, i2, # S => threadIdx.y
i0,
j0, # S => blockIdx.x
i1,
j1, # S => blockIdx.y
j2,
i2, # S => threadIdx.y
# cache_write here
k0, # R
k0, # R
# vectorized cooperative fetching here
k1, # R
i3, j3, # S
k2, # R
i4, j4,
k1, # R
i3,
j3, # S
k2, # R
i4,
j4,
# S
)

Expand All @@ -250,9 +277,7 @@ def fetch_to_shared(block, idx, ndim):
vector_size = 8
warp_size = 32
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
f_0, f_1, f_2, f_3 = sch.split(
fused, factors=[None, num_ty, warp_size, vector_size]
)
f_0, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
Expand Down Expand Up @@ -326,8 +351,9 @@ def lambda_b(i, j):
)

if use_ldmatrix:
sch.tensorize(loop_a, "mma.ldmatrix_a")
sch.tensorize(loop_b, "mma.ldmatrix_b")
# sch.tensorize(loop_a, "mma.ldmatrix_a")
# sch.tensorize(loop_b, "mma.ldmatrix_b")
pass
else:
warp_loop1, warp_loop2 = sch.get_loops(A_warp)[-2:]
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
Expand Down Expand Up @@ -361,13 +387,19 @@ def lambda_b(i, j):
sch.reorder(f_1, f_2, f_0, f_3)
fused_1 = sch.fuse(f_1, f_2)
fused_2 = sch.fuse(f_0, f_3)
sch.bind(fused_1, "threadIdx.x")

# print(sch.mod.script())

# return

sch.tensorize(fused_1, "mma_store")
# sch.bind(fused_1, "threadIdx.x")


ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)
schedule(sch)
# print(sch.mod.script())
print(sch.mod.script())

if tune:
with tempfile.TemporaryDirectory() as work_dir:
Expand All @@ -389,24 +421,25 @@ def lambda_b(i, j):
print(sch.mod.script())
print(sch.trace)
else:
print(sch.mod.script())
target = "cuda"
f = tvm.build(sch.mod["main"], target=target, name="dense")
print(f.imported_modules[0].get_source())

dev = tvm.device("cuda", 0)
a_np = np.random.uniform(size=(N, K)).astype("float16")
b_np = np.random.uniform(size=(K, M)).astype("float16")
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
f = tvm.build(sch.mod["main"], target="cuda", name="dense")

f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
print("ok")

evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
gflops = (N * M * K) * 2 / 1e9
time_ms = evaluator(a, b, c).mean * 1e3
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
# dev = tvm.device("cuda", 0)
# a_np = np.random.uniform(size=(N, K)).astype("float16")
# b_np = np.random.uniform(size=(K, M)).astype("float16")
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
# a = tvm.nd.array(a_np, dev)
# b = tvm.nd.array(b_np, dev)
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")

# f(a, b, c)
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
# print("ok")

# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
# gflops = (N * M * K) * 2 / 1e9
# time_ms = evaluator(a, b, c).mean * 1e3
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 comments on commit 71fe5fe

Please sign in to comment.