diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index e25f0b12210d..938d39377f1c 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -172,6 +172,22 @@ Array ScheduleRule::DefaultCUDA() { Array ScheduleRule::DefaultCUDATensorCore() { Array> intrin_groups = { + // Tensor Cores f32 += f16 * f16 + { + {"init", "wmma_fill_16x16x16_f32"}, + {"load_a", "wmma_load_16x16x16_f16_a"}, + {"load_b", "wmma_load_16x16x16_f16_b"}, + {"compute", "wmma_sync_16x16x16_f16f16f32"}, + {"store", "wmma_store_16x16x16_f32_shared"}, + }, + { + {"init", "wmma_fill_16x16x16_f32"}, + {"load_a", "wmma_load_16x16x16_f16_a"}, + {"load_b", "wmma_load_16x16x16_f16_b_trans"}, + {"compute", "wmma_sync_16x16x16_f16f16f32_trans"}, + {"store", "wmma_store_16x16x16_f32_shared"}, + }, + // Tensor Cores f16 += f16 * f16 { {"init", "wmma_fill_16x16x16_f16"}, {"load_a", "wmma_load_16x16x16_f16_a"}, @@ -186,6 +202,7 @@ Array ScheduleRule::DefaultCUDATensorCore() { {"compute", "wmma_sync_16x16x16_f16f16f16_trans"}, {"store", "wmma_store_16x16x16_f16_shared"}, }, + // Tensor Cores s32 += s8 * s8 { {"init", "wmma_fill_16x16x16_s32"}, {"load_a", "wmma_load_16x16x16_s8_a"},