Skip to content

Commit

Permalink
[TIR] Add int8 CUDA tensor core intrinsics (apache#12354)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and Mikael Sevenier committed Aug 12, 2022
1 parent 2eec5bc commit d8c904a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,11 @@ def schedule_rules():
intrin_groups=[
get_wmma_intrin_group(
store_scope="shared",
in_dtype="float16",
out_dtype="float16",
in_dtype=in_dtype,
out_dtype=out_dtype,
trans_b=trans_b,
)
for (in_dtype, out_dtype) in [("float16", "float16"), ("int8", "int32")]
for trans_b in [False, True]
],
structure="SSSRRSRS",
Expand Down
65 changes: 63 additions & 2 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,18 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_sync_intrin(16, 16, 16, "float16", "float16", True),
)

WMMA_SYNC_16x16x16_s8s8s32_INTRIN = "wmma_sync_16x16x16_s8s8s32"
TensorIntrin.register(
WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
*get_wmma_sync_intrin(16, 16, 16, "int8", "int32", False),
)

WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN = "wmma_sync_16x16x16_s8s8s32_trans"
TensorIntrin.register(
WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN,
*get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
Expand All @@ -781,12 +793,40 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False),
)

WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False),
)

WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True),
)


WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32"))

WMMA_FILL_16x16x16_F16_INTRIN = "wmma_fill_16x16x16_f16"
TensorIntrin.register(WMMA_FILL_16x16x16_F16_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float16"))

WMMA_FILL_16x16x16_S32_INTRIN = "wmma_fill_16x16x16_s32"
TensorIntrin.register(WMMA_FILL_16x16x16_S32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "int32"))

WMMA_STORE_16x16x16_F32_SHARED_INTRIN = "wmma_store_16x16x16_f32_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared")
Expand All @@ -797,6 +837,11 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "shared")
)

WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "shared")
)

WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global")
Expand All @@ -807,6 +852,11 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "global")
)

WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN = "wmma_store_16x16x16_s32_global"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "global")
)


def get_wmma_intrin_group(
store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
Expand Down Expand Up @@ -836,11 +886,15 @@ def get_wmma_intrin_group(
assert in_dtype in ["float16"]
assert out_dtype in ["float16", "float32"]

load_a_intrins = {"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN}
load_a_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_A_INTRIN,
}
load_b_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
if trans_b
else WMMA_LOAD_16x16x16_F16_B_INTRIN
else WMMA_LOAD_16x16x16_F16_B_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else WMMA_LOAD_16x16x16_S8_B_INTRIN,
}
compute_intrins = {
"float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
Expand All @@ -849,10 +903,14 @@ def get_wmma_intrin_group(
"float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
}
init_intrins = {
"float16": WMMA_FILL_16x16x16_F16_INTRIN,
"float32": WMMA_FILL_16x16x16_F32_INTRIN,
"int32": WMMA_FILL_16x16x16_S32_INTRIN,
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
Expand All @@ -861,6 +919,9 @@ def get_wmma_intrin_group(
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
"int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN,
}
return {
"init": init_intrins[out_dtype],
Expand Down

0 comments on commit d8c904a

Please sign in to comment.