diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index dc021e17316c..e27b6ad4b4ab 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -364,10 +364,11 @@ def schedule_rules(): intrin_groups=[ get_wmma_intrin_group( store_scope="shared", - in_dtype="float16", - out_dtype="float16", + in_dtype=in_dtype, + out_dtype=out_dtype, trans_b=trans_b, ) + for (in_dtype, out_dtype) in [("float16", "float16"), ("int8", "int32")] for trans_b in [False, True] ], structure="SSSRRSRS", diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index e7d5defcf321..4ac9338ba86c 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -757,6 +757,18 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_sync_intrin(16, 16, 16, "float16", "float16", True), ) +WMMA_SYNC_16x16x16_s8s8s32_INTRIN = "wmma_sync_16x16x16_s8s8s32" +TensorIntrin.register( + WMMA_SYNC_16x16x16_s8s8s32_INTRIN, + *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", False), +) + +WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN = "wmma_sync_16x16x16_s8s8s32_trans" +TensorIntrin.register( + WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN, + *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True), +) + WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a" TensorIntrin.register( WMMA_LOAD_16x16x16_F16_A_INTRIN, @@ -781,12 +793,40 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True), ) +WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_A_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False), +) + +WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_B_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False), +) + +WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True), +) + +WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans" +TensorIntrin.register( + WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN, + *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True), +) + + WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32" TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float32")) WMMA_FILL_16x16x16_F16_INTRIN = "wmma_fill_16x16x16_f16" TensorIntrin.register(WMMA_FILL_16x16x16_F16_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "float16")) +WMMA_FILL_16x16x16_S32_INTRIN = "wmma_fill_16x16x16_s32" +TensorIntrin.register(WMMA_FILL_16x16x16_S32_INTRIN, *get_wmma_fill_intrin(16, 16, 16, "int32")) + WMMA_STORE_16x16x16_F32_SHARED_INTRIN = "wmma_store_16x16x16_f32_shared" TensorIntrin.register( WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "shared") @@ -797,6 +837,11 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "shared") ) +WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared" +TensorIntrin.register( + WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "shared") +) + WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global" TensorIntrin.register( WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float32", "global") @@ -807,6 +852,11 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "global") ) +WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN = "wmma_store_16x16x16_s32_global" +TensorIntrin.register( + WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "int32", "global") +) + def get_wmma_intrin_group( store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool @@ -836,11 +886,15 @@ def get_wmma_intrin_group( assert in_dtype in ["float16"] assert out_dtype in ["float16", "float32"] - load_a_intrins = {"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN} + load_a_intrins = { + "float16": WMMA_LOAD_16x16x16_F16_A_INTRIN, + "int8": WMMA_LOAD_16x16x16_S8_A_INTRIN, + } load_b_intrins = { "float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN if trans_b - else WMMA_LOAD_16x16x16_F16_B_INTRIN + else WMMA_LOAD_16x16x16_F16_B_INTRIN, + "int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else WMMA_LOAD_16x16x16_S8_B_INTRIN, } compute_intrins = { "float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN @@ -849,10 +903,14 @@ def get_wmma_intrin_group( "float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN if trans_b else WMMA_SYNC_16x16x16_f16f16f32_INTRIN, + "int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN + if trans_b + else WMMA_SYNC_16x16x16_s8s8s32_INTRIN, } init_intrins = { "float16": WMMA_FILL_16x16x16_F16_INTRIN, "float32": WMMA_FILL_16x16x16_F32_INTRIN, + "int32": WMMA_FILL_16x16x16_S32_INTRIN, } store_intrins = { "float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN @@ -861,6 +919,9 @@ def get_wmma_intrin_group( "float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN if store_scope == "shared" else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, + "int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN + if store_scope == "shared" + else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, } return { "init": init_intrins[out_dtype],