diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index cce0823ca048..5ef755a1b5a1 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -35,9 +35,10 @@ namespace tir { class GPUCodeVerifier : public StmtExprVisitor { public: - bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, - int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) { + std::vector Verify(Stmt stmt, int64_t max_local_memory_per_block, + int64_t max_shared_memory_per_block, int64_t max_threads_per_block, + int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z, + int64_t max_vthread, int64_t max_vector_bytes) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); @@ -52,7 +53,7 @@ class GPUCodeVerifier : public StmtExprVisitor { // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); - return valid_; + return errors_; } void VisitStmt_(const AllocateNode* op) final { @@ -66,7 +67,13 @@ class GPUCodeVerifier : public StmtExprVisitor { shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } if (op->dtype.lanes() > 1) { - valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; + if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + std::stringstream s; + s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" + << op->dtype.bytes() << ") for dtype " << op->dtype + << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + errors_.push_back(s.str()); + } } } @@ -98,27 +105,39 @@ class GPUCodeVerifier : public StmtExprVisitor { visited_threads_.insert(name); thread_per_block_ *= length; + auto err = [this](std::string id, size_t ext, size_t m) { + if (ext > m) { + std::stringstream s; + s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m + << ");"; + errors_.push_back(s.str()); + } + }; + if (name == "threadIdx.x") { - valid_ &= length <= max_thread_x_; + err("threadIdx.x", length, max_thread_x_); thread_x_extent_ = length; } else if (name == "threadIdx.y") { - valid_ &= length <= max_thread_y_; + err("threadIdx.y", length, max_thread_y_); thread_y_extent_ = length; } else if (name == "threadIdx.z") { - valid_ &= length <= max_thread_z_; + err("threadIdx.z", length, max_thread_z_); thread_z_extent_ = length; } else if (name == "vthread") { - valid_ &= length <= max_vthread_; + err("vthread", length, max_vthread_); } } else { // the thread should be bound to axes with the same length - if (name == "threadIdx.x") { - valid_ &= length == thread_x_extent_; - } else if (name == "threadIdx.y") { - valid_ &= length == thread_y_extent_; - } else if (name == "threadIdx.z") { - valid_ &= length == thread_z_extent_; - } + auto err = [this, name](std::string id, size_t ext, size_t m) { + if (name == id && ext != m) { + std::stringstream s; + s << "Extent of " << id << " (" << ext << ") does not match the bound " << m; + errors_.push_back(s.str()); + } + }; + err("threadIdx.x", length, thread_x_extent_); + err("threadIdx.y", length, thread_y_extent_); + err("threadIdx.z", length, thread_z_extent_); } } @@ -128,10 +147,17 @@ class GPUCodeVerifier : public StmtExprVisitor { if (nest_level_ == 0) { // exit a kernel, check the validity - valid_ &= thread_per_block_ <= max_threads_per_block_; - - valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; - valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; + auto err = [this](std::string id, size_t num, size_t m) { + if (num > m) { + std::stringstream s; + s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m + << ")"; + errors_.push_back(s.str()); + } + }; + err("threads per block", thread_per_block_, max_threads_per_block_); + err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); + err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); } } else { StmtVisitor::VisitStmt_(op); @@ -143,7 +169,13 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->extent.as(); CHECK(extent); - valid_ &= static_cast(extent->value) <= max_vthread_; + size_t num_vthread = static_cast(extent->value); + if (num_vthread > max_vthread_) { + std::stringstream s; + s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum (" + << max_vthread_ << ")"; + errors_.push_back(s.str()); + } } StmtVisitor::VisitStmt_(op); @@ -151,15 +183,27 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitExpr_(const LoadNode* op) { if (op->dtype.lanes() > 1) { - valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; + if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + std::stringstream s; + s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" + << op->dtype.bytes() << ") for dtype " << op->dtype + << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + errors_.push_back(s.str()); + } } ExprVisitor::VisitExpr_(op); } void VisitStmt_(const StoreNode* op) { if (op->index->dtype.lanes() > 1) { - valid_ &= static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) <= - max_vector_bytes_; + if (static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) > + max_vector_bytes_) { + std::stringstream s; + s << "Number of lanes (" << op->index->dtype.lanes() << ") times number of bytes (" + << op->index->dtype.bytes() << ") for dtype " << op->index->dtype + << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + errors_.push_back(s.str()); + } } StmtVisitor::VisitStmt_(op); } @@ -183,7 +227,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_; size_t max_vector_bytes_; - bool valid_{true}; + std::vector errors_; void Reset_() { visited_local_buffers_.clear(); @@ -196,7 +240,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } }; -bool VerifyGPUCode(const PrimFunc& func, Map constraints) { +std::vector VerifyGPUCode_(const PrimFunc& func, Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -236,6 +280,11 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { max_vthread, max_vector_bytes); } +bool VerifyGPUCode(const PrimFunc& func, Map constraints) { + auto errs = VerifyGPUCode_(func, constraints); + return errs.size() == 0; +} + TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); namespace transform { @@ -245,7 +294,16 @@ Pass VerifyGPUCode(Map constraints) { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); - CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func; + auto errs = VerifyGPUCode_(func, constraints); + if (errs.size() != 0) { + std::stringstream s; + for (auto& err : errs) { + s << " " << err << std::endl; + } + LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n" + << s.str() << " In function\n" + << func; + } } } return mod; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index dfad5494ad75..64097e1d343a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -62,20 +62,14 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Verification result - bool Failed() const { return failure_; } + std::vector Errors() const { return errs_; } protected: /// Visitor implementation //@{ - void VisitExpr(const PrimExpr& n) final { - if (Failed()) return; - StmtExprVisitor::VisitExpr(n); - } + void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt& n) final { - if (Failed()) return; - StmtExprVisitor::VisitStmt(n); - } + void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); } void VisitStmt_(const LetStmtNode* op) final { // Book keep definitions @@ -139,7 +133,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { if (!IsFromFunctionArgs(var.get())) return; // The verification fails in this case. - SetFailure(); + std::stringstream s; + s << "Variable `" << var + << "` is directly accessed by host memory (it is not contained in a thread environment or in " + "the function arguments."; + errs_.push_back(s.str()); } /// Status getter/setter @@ -147,7 +145,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool InThreadEnv() const { return in_thread_env_; } void EnterThreadEnv() { in_thread_env_ = true; } void ExitThreadEnv() { in_thread_env_ = false; } - void SetFailure() { failure_ = true; } //@} /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. @@ -162,7 +159,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Status of visitor //@{ bool in_thread_env_{false}; - bool failure_{false}; ///< If the verification fails (i.e. has illegal access) + std::vector errs_; //@} tir::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type @@ -171,7 +168,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } // namespace /// Interface of VerifyMemory pass -bool VerifyMemory(const PrimFunc& func) { +std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; @@ -179,30 +176,37 @@ bool VerifyMemory(const PrimFunc& func) { CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); v.Run(); - return !v.Failed(); + return v.Errors(); } else { - return true; + return {}; } } +bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } + TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { Pass VerifyMemory() { - auto pass_func = - [=](IRModule mod, PassContext ctx) { - for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - auto func = GetRef(n); - CHECK(VerifyMemory(func)) - << "RuntimeError: Direct host side access to device memory is detected." - << " Did you forget to bind?\n" - << func; + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + auto errs = VerifyMemory_(func); + if (errs.size() > 0) { + std::stringstream s; + for (auto& err : errs) { + s << " " << err << "\n"; } + LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n" + << s.str() << " Did you forget to bind?\n" + << func; } - return mod; - }; + } + } + return mod; + }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); }