-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PASS] Add GPU IR verifier #1296
Changes from 4 commits
459e35b
9d24748
a573596
630cc39
a3aa9db
7546c20
576ab67
6420092
f50f040
7ff9f3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file verify_gpu_code.cc | ||
* \brief Verify the correctness of a GPU IR. | ||
* It will check the whether the amount of shared memory or | ||
* the number of threads in a block exceeds the limit | ||
*/ | ||
|
||
#include <tvm/api_registry.h> | ||
#include <tvm/ir.h> | ||
#include <tvm/ir_visitor.h> | ||
|
||
namespace tvm { | ||
namespace ir { | ||
|
||
class GPUCodeVerifier : public IRVisitor { | ||
public: | ||
bool verify(tvm::Stmt stmt, int max_shared_memory_per_block, int max_thread_per_block) { | ||
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block); | ||
max_thread_per_block_ = static_cast<size_t>(max_thread_per_block); | ||
|
||
this->Visit(stmt); | ||
|
||
return valid; | ||
} | ||
|
||
void Visit_(const ProducerConsumer *op) { | ||
if (nest_level_ == 0) { | ||
// enter a new kernel, reset statistics | ||
reset_(); | ||
} | ||
|
||
if (op->is_producer) { | ||
nest_level_++; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer --nest_+level |
||
IRVisitor::Visit_(op); | ||
nest_level_--; | ||
} else { | ||
IRVisitor::Visit_(op); | ||
} | ||
|
||
if (nest_level_ == 0) { | ||
// exit a kernel, check the validity | ||
if (thread_per_block_ > max_thread_per_block_) { | ||
valid = false; | ||
} | ||
if (shared_memory_per_block_ > max_shared_memory_per_block_) { | ||
valid = false; | ||
} | ||
} | ||
} | ||
|
||
void Visit_(const Allocate *op) { | ||
IRVisitor::Visit_(op); | ||
// visit an allocation of a buffer in shared memory, record its size | ||
if (shared_buffers_.count(op->buffer_var.get()) != 0) { | ||
int64_t size = op->type.bytes(); | ||
for (auto dim : op->extents) { | ||
size *= dim.as<IntImm>()->value; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op-> constant_allocation_size() |
||
} | ||
shared_memory_per_block_ += size; | ||
} | ||
} | ||
|
||
void Visit_(const AttrStmt *op) { | ||
if (op->attr_key == attr::storage_scope) { | ||
if (op->value.as<StringImm>()->value == "shared") { | ||
shared_buffers_.insert(op->node.as<tvm::Variable>()); | ||
} | ||
} else if (op->attr_key == attr::thread_extent) { | ||
VarExpr var = op->node.as<tvm::IterVarNode>()->var; | ||
const auto *extent = op->value.as<IntImm>(); | ||
CHECK(extent); | ||
|
||
// record the number of threads in a block | ||
std::string name = var.get()->name_hint; | ||
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") { | ||
if (visited_threads_.find(name) == visited_threads_.end()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. !count(name) |
||
visited_threads_.insert(name); | ||
thread_per_block_ *= extent->value; | ||
} | ||
} | ||
} | ||
IRVisitor::Visit_(op); | ||
} | ||
|
||
private: | ||
int nest_level_{0}; | ||
|
||
std::unordered_set<const tvm::Variable *> shared_buffers_; | ||
std::unordered_set<std::string> visited_threads_; | ||
size_t shared_memory_per_block_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
size_t thread_per_block_; | ||
|
||
size_t max_shared_memory_per_block_; | ||
size_t max_thread_per_block_; | ||
|
||
bool valid{true}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. valid_ |
||
|
||
void reset_() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reset() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this visitor is only used once, reset is not necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reset is needed because there might be several gpu kernels in one Stmt. |
||
shared_buffers_.clear(); | ||
shared_memory_per_block_ = 0; | ||
thread_per_block_ = 1; | ||
visited_threads_.clear(); | ||
} | ||
}; | ||
|
||
bool VerifyGPUCode(Stmt stmt, | ||
int max_shared_memory_per_block, | ||
int max_thread_per_block) { | ||
GPUCodeVerifier verifier; | ||
return verifier.verify(stmt, max_shared_memory_per_block, max_thread_per_block); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Test gpu code verifier""" | ||
import tvm | ||
|
||
global valid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always avoid use of global variable, to carry state, you can use a closure to capture a list |
||
|
||
def cuda_verify_pass(max_shared_memory, max_num_thread): | ||
def verify_pass(stmt): | ||
global valid | ||
valid = tvm.ir_pass.VerifyGPUCode(stmt, max_shared_memory, max_num_thread) | ||
return stmt | ||
return verify_pass | ||
|
||
def test_shared_memory(): | ||
N = 1024 | ||
M = 128 | ||
|
||
A = tvm.placeholder((N,), name='A', dtype='float32') | ||
B = tvm.compute((N, ), lambda i: A[i], name='B') | ||
|
||
s = tvm.create_schedule([B.op]) | ||
AA = s.cache_read(A, "shared", [B]) | ||
o, i = s[B].split(s[B].op.axis[0], M) | ||
s[AA].compute_at(s[B], o) | ||
s[B].bind(o, tvm.thread_axis("blockIdx.x")) | ||
s[B].bind(i, tvm.thread_axis("threadIdx.x")) | ||
|
||
# shared memory usage: M * 4B | ||
# thread usage: M | ||
|
||
for target in ['opencl', 'cuda']: | ||
if not tvm.context(target).exist: | ||
continue | ||
global valid | ||
with tvm.build_config(**{"add_lower_pass": [(2, cuda_verify_pass(4 * M - 1, M))]}): | ||
tvm.build(s, [A, B], target) | ||
assert not valid | ||
|
||
with tvm.build_config(**{"add_lower_pass": [(2, cuda_verify_pass(4 * M, M))]}): | ||
tvm.build(s, [A, B], target) | ||
assert valid | ||
|
||
|
||
def test_num_thread(): | ||
N = 1024 | ||
M = 128 | ||
|
||
A = tvm.placeholder((N,), name='A', dtype='float32') | ||
B = tvm.compute((N, ), lambda i: A[i], name='B') | ||
|
||
s = tvm.create_schedule([B.op]) | ||
o, i = s[B].split(s[B].op.axis[0], M) | ||
|
||
s[B].bind(o, tvm.thread_axis('threadIdx.x')) | ||
s[B].bind(i, tvm.thread_axis("threadIdx.y")) | ||
|
||
# shared memory usage: 0 | ||
# thread usage: N | ||
|
||
for target in ['opencl', 'cuda']: | ||
if not tvm.context(target).exist: | ||
continue | ||
global valid | ||
with tvm.build_config(**{"add_lower_pass": [(2, cuda_verify_pass(0, N - 1))]}): | ||
tvm.build(s, [A, B], target) | ||
assert not valid | ||
|
||
with tvm.build_config(**{"add_lower_pass": [(2, cuda_verify_pass(0, N))]}): | ||
tvm.build(s, [A, B], target) | ||
assert valid | ||
|
||
|
||
if __name__ == "__main__": | ||
test_shared_memory() | ||
test_num_thread() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CamelCase for fucnctions