Skip to content

Commit

Permalink
generate mma fill/store
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent bf23fc5 commit 2b05b5a
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 460 deletions.
111 changes: 110 additions & 1 deletion python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
from .. import Cast
from .. import IntImm, Cast
from ..._ffi import register_func
from ...runtime import convert
from .. import TensorIntrin
Expand Down Expand Up @@ -315,6 +315,97 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
return mma_sync_desc, mma_sync_impl


def get_mma_fill_intrin(dtype, local_size):
zero = IntImm("int32", 0).astype(dtype)

# Assume M = N = 16
index_map = shared_16x16_to_ldmatrix_32x8_layout

@T.prim_func
def mma_fill_desc(a: T.handle) -> None:
C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp")

with T.block("root"):
T.reads()
T.writes(C_warp[0:WARP_SIZE, 0:local_size])
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
i, j = T.axis.remap("SS", [i0, i1])
thread_id, local_id = index_map(i, j)
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = zero

@T.prim_func
def mma_fill_impl(a: T.handle) -> None:
C_warp = T.match_buffer(
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
)

with T.block("root"):
T.reads()
T.writes(C_warp[0:WARP_SIZE, 0:local_size])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, WARP_SIZE)

T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype))

return mma_fill_desc, mma_fill_impl


def get_mma_store_intrin(dtype, local_size):
# Assume M = N = 16
index_map = shared_16x16_to_ldmatrix_32x8_layout

@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp")
C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope="global")

with T.block("root"):
T.reads(C_warp[0:WARP_SIZE, 0:local_size])
T.writes(C[0:M_DIM, 0:N_DIM])
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = index_map(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]

@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s0 = T.var("int32")
s1 = T.var("int32")

C_warp = T.match_buffer(
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
)
C = T.match_buffer(
c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1]
)

with T.block("root"):
T.reads(C_warp[0:WARP_SIZE, 0:local_size])
T.writes(C[0:M_DIM, 0:N_DIM])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, WARP_SIZE)

T.evaluate(
T.mma_store(
M_DIM,
N_DIM,
C.access_ptr("w"),
C_warp.data,
C_warp.elem_offset,
s0,
dtype=dtype,
)
)

return mma_store_desc, mma_store_impl


LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))

Expand Down Expand Up @@ -352,3 +443,21 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:

MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True))

MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8))

MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8))

MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8))

MMA_store_16x16_f32_INTRIN = "mma_store_16x16_f32"
TensorIntrin.register(MMA_store_16x16_f32_INTRIN, *get_mma_store_intrin("float32", 8))

MMA_store_16x16_f16_INTRIN = "mma_store_16x16_f16"
TensorIntrin.register(MMA_store_16x16_f16_INTRIN, *get_mma_store_intrin("float16", 8))

MMA_store_16x16_i32_INTRIN = "mma_store_16x16_i32"
TensorIntrin.register(MMA_store_16x16_i32_INTRIN, *get_mma_store_intrin("int32", 8))
80 changes: 4 additions & 76 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,19 @@
import tempfile
import tvm
from tvm.script import tir as T
import tvm.meta_schedule.testing.te_workload as te_workload
from tvm import te, tir
from tvm import meta_schedule as ms
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_INTRIN,
LDMATRIX_16x16_B_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
)
import tvm.testing
import numpy as np


@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]


@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")

C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
C = T.match_buffer(
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
)

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.mma_store(
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
)
)


@T.prim_func
def mma_fill_desc(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = T.float32(0)


@T.prim_func
def mma_fill_impl(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))


tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)

N = 4096
M = 4096
K = 4096
Expand Down Expand Up @@ -214,8 +142,8 @@ def index_map(i, j):
sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN)
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN)
sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN)


ir_module = tvm.IRModule({"main": workload})
Expand Down
81 changes: 4 additions & 77 deletions tests/python/unittest/test_mma_16x8x16_4k_tune_trans.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,18 @@
import tempfile
import tvm
from tvm.script import tir as T
import tvm.meta_schedule.testing.te_workload as te_workload
from tvm import te, tir
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_INTRIN,
LDMATRIX_16x16_B_TRANS_INTRIN,
MMA_f16f16f32_TRANS_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
)
from tvm import meta_schedule as ms
import tvm.testing
import numpy as np


@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
C = T.match_buffer(c, [16, 16], dtype="float32", scope="global")

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
v0, v1 = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]


@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
s1 = T.var("int32")
s0 = T.var("int32")

C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)
C = T.match_buffer(
c, [16, 16], dtype="float32", scope="global", offset_factor=1, strides=[s1, s0]
)

with T.block("root"):
T.reads(C_warp[0:32, 0:8])
T.writes(C[0:16, 0:16])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.mma_store(
16, 16, C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"
)
)


@T.prim_func
def mma_fill_desc(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
for i0, i1 in T.grid(16, 16):
with T.block("C_warp"):
i_init, j_init = T.axis.remap("SS", [i0, i1])
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i_init, j_init)
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = T.float32(0)


@T.prim_func
def mma_fill_impl(a: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp", offset_factor=1)

with T.block("root"):
T.reads()
T.writes(C_warp[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))


tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)

N = 4096
M = 4096
K = 4096
Expand Down Expand Up @@ -231,8 +158,8 @@ def index_map(i, j):
sch.tensorize(loop_b, LDMATRIX_16x16_B_TRANS_INTRIN)
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_TRANS_INTRIN)

sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")
sch.tensorize(sch.get_loops(block_init_c)[-2], MMA_fill_16x16_f32_INTRIN)
sch.tensorize(sch.get_loops(C_warp)[-2], MMA_store_16x16_f32_INTRIN)


ir_module = tvm.IRModule({"main": workload})
Expand Down
Loading

0 comments on commit 2b05b5a

Please sign in to comment.