Skip to content

Commit

Permalink
[MetaSchedule] Use shared.dyn for Tensor Core Schedule Rules (#13891)
Browse files Browse the repository at this point in the history
This PR adds Tensor Core intrinsics with `shared.dyn` scope and changes the default rules to use `shared.dyn`.

Here are the performance improvement of GEMM 1024x1024x1024 on my device (RTX-3080)

|                     |      Use `shared`         |  Use `shared.dyn`.   | Speedup  |
| ----------- | --------------------- | -------------------- | ---------- |
fp 16-16-16 | 66399.8766 GFLOPs | 71778.3808 GFLOPs |      8.1%    |
fp 16-16-32 | 44292.5893 GFLOPs | 49070.2514 GFLOPS |  10.8%    |

cc @vinx13 @junrushao @masahi
  • Loading branch information
Hzfengsy authored Feb 1, 2023
1 parent d6f78b1 commit 9bbc2c0
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 119 deletions.
166 changes: 112 additions & 54 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""Intrinsics for tensorization on NVIDIA GPU."""
from typing import Dict, Tuple

from typing_extensions import Literal

from tvm.script import tir as T
from tvm.tir.function import PrimFunc

Expand Down Expand Up @@ -815,54 +817,101 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a"
WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, False),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b"
WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN = "wmma_load_16x16x16_f16_a_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, False),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
)

WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN = "wmma_load_16x16x16_f16_b_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, False),
)

WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
)

WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, True),
)

WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a"
WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, True),
)

WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False),
)

WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b"
WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN = "wmma_load_16x16x16_s8_a_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, False),
)

WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False),
)

WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans"
WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "wmma_load_16x16x16_s8_b_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, False),
)

WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans"
WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True),
)

WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared_dyn"
TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True),
)

WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32"))
Expand All @@ -878,16 +927,34 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared")
)

WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f32_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "float32", "shared.dyn"),
)

WMMA_STORE_16x16x16_F16_SHARED_INTRIN = "wmma_store_16x16x16_f16_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "shared")
)

WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN = "wmma_store_16x16x16_f16_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "float16", "shared.dyn"),
)

WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "shared")
)

WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN = "wmma_store_16x16x16_s32_shared_dyn"
TensorIntrin.register(
WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN,
*get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"),
)

WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global")
Expand All @@ -905,14 +972,21 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:


def get_wmma_intrin_group(
store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
load_scope: Literal["shared", "shared.dyn"],
store_scope: Literal["global", "shared", "shared.dyn"],
in_dtype: str,
out_dtype: str,
trans_b: bool,
) -> Dict[str, str]:
"""Get a group of intrinsics for wmma tensor core with the given configurations
Parameters
----------
store_scope : str
Must be one of ["global", "shared"]. The memory scope of the result buffer.
load_scope : Literal["shared", "shared.dyn"]
The memory scope of the input buffer.
store_scope : Literal["global", "shared", "shared.dyn"]
The memory scope of the result buffer.
in_dtype : str
The input data type.
Expand All @@ -928,51 +1002,35 @@ def get_wmma_intrin_group(
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert store_scope in ["global", "shared"]
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
assert in_dtype in ["float16", "int8"]
assert out_dtype in ["float16", "float32", "int32"]

load_a_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_A_INTRIN,
}
load_b_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
if trans_b
else WMMA_LOAD_16x16x16_F16_B_INTRIN,
"int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else WMMA_LOAD_16x16x16_S8_B_INTRIN,
}
compute_intrins = {
"float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
"float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
}
init_intrins = {
"float16": WMMA_FILL_16x16x16_F16_INTRIN,
"float32": WMMA_FILL_16x16x16_F32_INTRIN,
"int32": WMMA_FILL_16x16x16_S32_INTRIN,
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
"int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN,
}
shape = "16x16x16"
in_dtype = "f16" if in_dtype == "float16" else "s8"
out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == "float32" else "s32"
# convert "shared.dyn" to "shared_dyn"
load_scope = load_scope.replace(".", "_")
store_scope = store_scope.replace(".", "_")
trans_a = ""
trans_b = "_trans" if trans_b else ""

# e.g. wmma_load_16x16x16_f16_a_shared
load_a_intrin = f"wmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}"
# e.g. wmma_load_16x16x16_f16_b_trans_shared_dyn
load_b_intrin = f"wmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}"
# e.g. wmma_sync_16x16x16_f16f16f32_trans
compute_intrin = f"wmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}"
# e.g. wmma_fill_16x16x16_f16
init_intrin = f"wmma_fill_{shape}_{out_dtype}"
# e.g. wmma_store_16x16x16_f16_shared_dyn
store_intrin = f"wmma_store_{shape}_{out_dtype}_{store_scope}"

return {
"init": init_intrins[out_dtype],
"load_a": load_a_intrins[in_dtype],
"load_b": load_b_intrins[in_dtype],
"compute": compute_intrins[out_dtype],
"store": store_intrins[out_dtype],
"init": init_intrin,
"load_a": load_a_intrin,
"load_b": load_b_intrin,
"compute": compute_intrin,
"store": store_intrin,
}
40 changes: 20 additions & 20 deletions src/meta_schedule/schedule_rule/schedule_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,47 +175,47 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
// Tensor Cores f32 += f16 * f16
{
{"init", "wmma_fill_16x16x16_f32"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f32"},
{"store", "wmma_store_16x16x16_f32_shared"},
{"store", "wmma_store_16x16x16_f32_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_f32"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b_trans"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f32_trans"},
{"store", "wmma_store_16x16x16_f32_shared"},
{"store", "wmma_store_16x16x16_f32_shared_dyn"},
},
// Tensor Cores f16 += f16 * f16
{
{"init", "wmma_fill_16x16x16_f16"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f16"},
{"store", "wmma_store_16x16x16_f16_shared"},
{"store", "wmma_store_16x16x16_f16_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_f16"},
{"load_a", "wmma_load_16x16x16_f16_a"},
{"load_b", "wmma_load_16x16x16_f16_b_trans"},
{"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_f16f16f16_trans"},
{"store", "wmma_store_16x16x16_f16_shared"},
{"store", "wmma_store_16x16x16_f16_shared_dyn"},
},
// Tensor Cores s32 += s8 * s8
{
{"init", "wmma_fill_16x16x16_s32"},
{"load_a", "wmma_load_16x16x16_s8_a"},
{"load_b", "wmma_load_16x16x16_s8_b"},
{"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_s8_b_shared_dyn"},
{"compute", "wmma_sync_16x16x16_s8s8s32"},
{"store", "wmma_store_16x16x16_s32_shared"},
{"store", "wmma_store_16x16x16_s32_shared_dyn"},
},
{
{"init", "wmma_fill_16x16x16_s32"},
{"load_a", "wmma_load_16x16x16_s8_a"},
{"load_b", "wmma_load_16x16x16_s8_b_trans"},
{"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
{"load_b", "wmma_load_16x16x16_s8_b_trans_shared_dyn"},
{"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
{"store", "wmma_store_16x16x16_s32_shared"},
{"store", "wmma_store_16x16x16_s32_shared_dyn"},
},
};
Array<ScheduleRule> results{
Expand All @@ -229,11 +229,11 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
/*reuse_read=*/
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{4}}, //
{"scope", String("shared")}},
{"scope", String("shared.dyn")}},
/*reuse_write=*/
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{2}}, //
{"scope", String("shared")}},
{"scope", String("shared.dyn")}},
/*use_software_pipeline=*/false) //
};
Array<ScheduleRule> append = ScheduleRule::DefaultCUDA();
Expand Down
Loading

0 comments on commit 9bbc2c0

Please sign in to comment.