Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Update VerifyGPU #10405

Merged
merged 2 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,68 @@
#include "../utils.h"

namespace tvm {
namespace tir {
class ThreadExtentChecker : private StmtVisitor {
public:
static bool Check(const Stmt& stmt) {
try {
ThreadExtentChecker().VisitStmt(stmt);
return true;
} catch (const dmlc::Error& e) {
return false;
}
}

private:
void VisitStmt_(const ForNode* loop) {
if (IsThreadIdx(GetThreadScope(loop))) {
const std::string& thread_tag = loop->thread_binding.value()->thread_tag;
if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
auto it = thread_tag2extent_.find(thread_tag);
bool new_thread = it == thread_tag2extent_.end();
if (new_thread) {
thread_extent_product *= *p_ext;
thread_tag2extent_[thread_tag] = *p_ext;
} else {
CHECK_EQ(it->second, *p_ext)
<< "ValueError: All loops that are bound to `" << thread_tag
<< "` should have the same extent. However, there are two loops with extent "
<< it->second << " and " << p_ext << ", which are not equal";
}
StmtVisitor::VisitStmt_(loop);
if (new_thread) {
thread_extent_product /= *p_ext;
thread_tag2extent_.erase(thread_tag);
}
return;
} else {
throw dmlc::Error("Dynamic thread extent");
}
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockNode* block) {
if (Optional<Integer> low_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
if (Optional<Integer> high_inclusive =
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_high_inclusive)) {
int64_t low = low_inclusive.value()->value;
int64_t high = high_inclusive.value()->value;
if (!(low <= thread_extent_product && thread_extent_product <= high)) {
throw dmlc::Error("Thread extent");
}
}
}
StmtVisitor::VisitStmt_(block);
}

int64_t thread_extent_product = 1;
junrushao marked this conversation as resolved.
Show resolved Hide resolved

/*! \brief A mapping from a thread tag to its thread extent */
std::unordered_map<std::string, int64_t> thread_tag2extent_;
};
} // namespace tir
namespace meta_schedule {

/*! \brief Extract attribute from a target. */
Expand Down Expand Up @@ -66,6 +128,9 @@ class VerifyGPUCodeNode : public PostprocNode {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
return false;
}
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
Expand All @@ -81,6 +146,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
189 changes: 189 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,176 @@ def main(a: T.handle, b: T.handle) -> None:
for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on

@T.prim_func
def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]

@T.prim_func
def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.block_attr({
"meta_schedule.thread_extent_low_inclusive": 0,
"meta_schedule.thread_extent_high_inclusive": 32,
})
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]


@T.prim_func
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
for i1_3_init, i2_4_init in T.grid(4, 2):
with T.block("Z_init"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
T.reads()
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = T.float32(0)
for i3_0 in T.serial(4):
for ax0_ax1_ax2_fused_0 in T.serial(4):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in T.vectorized(2):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused_0 in T.serial(8):
for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
with T.block("Z_update"):
b = T.axis.spatial(1, 0)
i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
k = T.axis.reduce(128, i3_0 * 32 + i3_2)
T.block_attr({
"meta_schedule.thread_extent_low_inclusive": 1024,
"meta_schedule.thread_extent_high_inclusive": 1024,
})
T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 2):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
Expand Down Expand Up @@ -226,6 +396,25 @@ def test_postproc_verify_gpu_3():
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)

def test_postproc_verify_gpu_4():
mod = GmmCuda0
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_5():
mod = GmmCuda1
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)


def test_postproc_verify_gpu_6():
mod = GmmCuda2
ctx = _create_context(mod, target=_target())
sch = tir.Schedule(mod, debug_mask="all")
assert not ctx.postprocs[0].apply(sch)

if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))