diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index edf13e36bef4..6b34f69bc0b1 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -21,6 +21,68 @@ #include "../utils.h" namespace tvm { +namespace tir { +class ThreadExtentChecker : private StmtVisitor { + public: + static bool Check(const Stmt& stmt) { + try { + ThreadExtentChecker().VisitStmt(stmt); + return true; + } catch (const dmlc::Error& e) { + return false; + } + } + + private: + void VisitStmt_(const ForNode* loop) { + if (IsThreadIdx(GetThreadScope(loop))) { + const std::string& thread_tag = loop->thread_binding.value()->thread_tag; + if (const int64_t* p_ext = GetLoopIntExtent(loop)) { + auto it = thread_tag2extent_.find(thread_tag); + bool new_thread = it == thread_tag2extent_.end(); + if (new_thread) { + thread_extent_product *= *p_ext; + thread_tag2extent_[thread_tag] = *p_ext; + } else { + CHECK_EQ(it->second, *p_ext) + << "ValueError: All loops that are bound to `" << thread_tag + << "` should have the same extent. However, there are two loops with extent " + << it->second << " and " << p_ext << ", which are not equal"; + } + StmtVisitor::VisitStmt_(loop); + if (new_thread) { + thread_extent_product /= *p_ext; + thread_tag2extent_.erase(thread_tag); + } + return; + } else { + throw dmlc::Error("Dynamic thread extent"); + } + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockNode* block) { + if (Optional low_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { + if (Optional high_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { + int64_t low = low_inclusive.value()->value; + int64_t high = high_inclusive.value()->value; + if (!(low <= thread_extent_product && thread_extent_product <= high)) { + throw dmlc::Error("Thread extent"); + } + } + } + StmtVisitor::VisitStmt_(block); + } + + int64_t thread_extent_product = 1; + + /*! \brief A mapping from a thread tag to its thread extent */ + std::unordered_map thread_tag2extent_; +}; +} // namespace tir namespace meta_schedule { /*! \brief Extract attribute from a target. */ @@ -66,6 +128,9 @@ class VerifyGPUCodeNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { + if (!tir::ThreadExtentChecker::Check(prim_func->body)) { + return false; + } IRModule lowered{nullptr}; try { auto pass_list = Array(); @@ -81,6 +146,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index bebfec6122b3..a4be6186d887 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -194,6 +194,176 @@ def main(a: T.handle, b: T.handle) -> None: for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on +@T.prim_func +def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + +@T.prim_func +def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 0, + "meta_schedule.thread_extent_high_inclusive": 32, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + + +@T.prim_func +def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 1024, + "meta_schedule.thread_extent_high_inclusive": 1024, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant @@ -226,6 +396,25 @@ def test_postproc_verify_gpu_3(): sch = tir.Schedule(mod, debug_mask="all") assert not ctx.postprocs[0].apply(sch) +def test_postproc_verify_gpu_4(): + mod = GmmCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_5(): + mod = GmmCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_6(): + mod = GmmCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))