Skip to content

Commit

Permalink
Improve error messages for memory verifier and gpu memory verifier (a…
Browse files Browse the repository at this point in the history
…pache#6281)

* [FIX] Print exactly what issues the GPU memory verifier encountered.

* [FIX] Print exactly why memory verifier failed.
  • Loading branch information
Tristan Konolige authored and Trevor Morris committed Sep 2, 2020
1 parent cbae698 commit 80aa387
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 53 deletions.
112 changes: 85 additions & 27 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ namespace tir {

class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block, int64_t max_threads_per_block,
int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z,
int64_t max_vthread, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
Expand All @@ -52,7 +53,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);

return valid_;
return errors_;
}

void VisitStmt_(const AllocateNode* op) final {
Expand All @@ -66,7 +67,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
}

Expand Down Expand Up @@ -98,27 +105,39 @@ class GPUCodeVerifier : public StmtExprVisitor {
visited_threads_.insert(name);
thread_per_block_ *= length;

auto err = [this](std::string id, size_t ext, size_t m) {
if (ext > m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m
<< ");";
errors_.push_back(s.str());
}
};

if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
err("threadIdx.x", length, max_thread_x_);
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
err("threadIdx.y", length, max_thread_y_);
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
err("threadIdx.z", length, max_thread_z_);
thread_z_extent_ = length;
} else if (name == "vthread") {
valid_ &= length <= max_vthread_;
err("vthread", length, max_vthread_);
}
} else {
// the thread should be bound to axes with the same length
if (name == "threadIdx.x") {
valid_ &= length == thread_x_extent_;
} else if (name == "threadIdx.y") {
valid_ &= length == thread_y_extent_;
} else if (name == "threadIdx.z") {
valid_ &= length == thread_z_extent_;
}
auto err = [this, name](std::string id, size_t ext, size_t m) {
if (name == id && ext != m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") does not match the bound " << m;
errors_.push_back(s.str());
}
};
err("threadIdx.x", length, thread_x_extent_);
err("threadIdx.y", length, thread_y_extent_);
err("threadIdx.z", length, thread_z_extent_);
}
}

Expand All @@ -128,10 +147,17 @@ class GPUCodeVerifier : public StmtExprVisitor {

if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;

valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
auto err = [this](std::string id, size_t num, size_t m) {
if (num > m) {
std::stringstream s;
s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m
<< ")";
errors_.push_back(s.str());
}
};
err("threads per block", thread_per_block_, max_threads_per_block_);
err("local memory per block", local_memory_per_block_, max_local_memory_per_block_);
err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_);
}
} else {
StmtVisitor::VisitStmt_(op);
Expand All @@ -143,23 +169,41 @@ class GPUCodeVerifier : public StmtExprVisitor {
const auto* extent = op->extent.as<IntImmNode>();
CHECK(extent);

valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
size_t num_vthread = static_cast<size_t>(extent->value);
if (num_vthread > max_vthread_) {
std::stringstream s;
s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum ("
<< max_vthread_ << ")";
errors_.push_back(s.str());
}
}

StmtVisitor::VisitStmt_(op);
}

void VisitExpr_(const LoadNode* op) {
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
max_vector_bytes_;
if (static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->index->dtype.lanes() << ") times number of bytes ("
<< op->index->dtype.bytes() << ") for dtype " << op->index->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
StmtVisitor::VisitStmt_(op);
}
Expand All @@ -183,7 +227,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;

bool valid_{true};
std::vector<String> errors_;

void Reset_() {
visited_local_buffers_.clear();
Expand All @@ -196,7 +240,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
};

bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) {
GPUCodeVerifier verifier;

int64_t max_local_memory_per_block = INT64_MAX;
Expand Down Expand Up @@ -236,6 +280,11 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
max_vthread, max_vector_bytes);
}

bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
auto errs = VerifyGPUCode_(func, constraints);
return errs.size() == 0;
}

TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);

namespace transform {
Expand All @@ -245,7 +294,16 @@ Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func;
auto errs = VerifyGPUCode_(func, constraints);
if (errs.size() != 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << std::endl;
}
LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
<< s.str() << " In function\n"
<< func;
}
}
}
return mod;
Expand Down
56 changes: 30 additions & 26 deletions src/tir/analysis/verify_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,14 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
}

/// Verification result
bool Failed() const { return failure_; }
std::vector<String> Errors() const { return errs_; }

protected:
/// Visitor implementation
//@{
void VisitExpr(const PrimExpr& n) final {
if (Failed()) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }

void VisitStmt(const Stmt& n) final {
if (Failed()) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }

void VisitStmt_(const LetStmtNode* op) final {
// Book keep definitions
Expand Down Expand Up @@ -139,15 +133,18 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
if (!IsFromFunctionArgs(var.get())) return;

// The verification fails in this case.
SetFailure();
std::stringstream s;
s << "Variable `" << var
<< "` is directly accessed by host memory (it is not contained in a thread environment or in "
"the function arguments.";
errs_.push_back(s.str());
}

/// Status getter/setter
//@{
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
void SetFailure() { failure_ = true; }
//@}

/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
Expand All @@ -162,7 +159,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Status of visitor
//@{
bool in_thread_env_{false};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
std::vector<String> errs_;
//@}
tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
Expand All @@ -171,38 +168,45 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
} // namespace

/// Interface of VerifyMemory pass
bool VerifyMemory(const PrimFunc& func) {
std::vector<String> VerifyMemory_(const PrimFunc& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";

if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->kind->device_type);
v.Run();
return !v.Failed();
return v.Errors();
} else {
return true;
return {};
}
}

bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; }

TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);

namespace transform {

Pass VerifyMemory() {
auto pass_func =
[=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
CHECK(VerifyMemory(func))
<< "RuntimeError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n"
<< func;
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
auto errs = VerifyMemory_(func);
if (errs.size() > 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << "\n";
}
LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
<< s.str() << " Did you forget to bind?\n"
<< func;
}
return mod;
};
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
}

Expand Down

0 comments on commit 80aa387

Please sign in to comment.