Skip to content

Commit

Permalink
[TensorIR] Update VerifyGPU (apache#10405)
Browse files Browse the repository at this point in the history
* update VerifyGPU

* address comments
  • Loading branch information
Hzfengsy authored and pfk-beta committed Apr 11, 2022
1 parent 84a61a7 commit 436bc79
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,68 @@
#include "../utils.h"

namespace tvm {
namespace tir {
class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
try {
ThreadExtentChecker().VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
void VisitStmt_(const ForNode* loop) {
if (IsThreadIdx(GetThreadScope(loop))) {
const std::string& thread_tag = loop->thread_binding.value()->thread_tag;
if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
auto it = thread_tag2extent_.find(thread_tag);
bool new_thread = it == thread_tag2extent_.end();
if (new_thread) {
thread_extent_product *= *p_ext;
thread_tag2extent_[thread_tag] = *p_ext;
} else {
CHECK_EQ(it->second, *p_ext)
<< "ValueError: All loops that are bound to `" << thread_tag
<< "` should have the same extent. However, there are two loops with extent "
<< it->second << " and " << p_ext << ", which are not equal";
}
StmtVisitor::VisitStmt_(loop);
if (new_thread) {
thread_extent_product /= *p_ext;
thread_tag2extent_.erase(thread_tag);
}
return;
} else {
throw dmlc::Error("Dynamic thread extent");
}
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockNode* block) {
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) {
int64_t low = low_inclusive.value()->value;
int64_t high = high_inclusive.value()->value;
if (!(low <= thread_extent_product && thread_extent_product <= high)) {
throw dmlc::Error("Thread extent");
}
}
}
StmtVisitor::VisitStmt_(block);
}

int64_t thread_extent_product = 1;

/*! \brief A mapping from a thread tag to its thread extent */
std::unordered_map<std::string, int64_t> thread_tag2extent_;
};
} // namespace tir
namespace meta_schedule {

/*! \brief Extract attribute from a target. */
Expand Down Expand Up @@ -66,6 +128,9 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
return false;
}
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
Expand All @@ -81,6 +146,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
189 changes: 189 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,176 @@ def main(a: T.handle, b: T.handle) -> None:
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on

@T.prim_func
def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]

@T.prim_func
def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.block_attr({
"meta_schedule.thread_extent_low_inclusive": 0,
"meta_schedule.thread_extent_high_inclusive": 32,
})
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]


@T.prim_func
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.block_attr({
"meta_schedule.thread_extent_low_inclusive": 1024,
"meta_schedule.thread_extent_high_inclusive": 1024,
})
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
Expand Down Expand Up @@ -226,6 +396,25 @@ def test_postproc_verify_gpu_3():
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)

def test_postproc_verify_gpu_4():
mod = GmmCuda0
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_5():
mod = GmmCuda1
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_6():
mod = GmmCuda2
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)

if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 436bc79

Please sign in to comment.