Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Mar 1, 2022
1 parent 8c0139c commit 8d76443
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,24 @@ class ThreadExtentChecker : private StmtVisitor {
private:
void VisitStmt_(const ForNode* loop) {
if (IsThreadIdx(GetThreadScope(loop))) {
const std::string& thread_tag = loop->thread_binding.value()->thread_tag;
if (const int64_t* p_ext = GetLoopIntExtent(loop)) {
thread_extent_product *= *p_ext;
auto it = thread_tag2extent_.find(thread_tag);
bool new_thread = it == thread_tag2extent_.end();
if (new_thread) {
thread_extent_product *= *p_ext;
thread_tag2extent_[thread_tag] = *p_ext;
} else {
CHECK_EQ(it->second, *p_ext)
<< "ValueError: All loops that are bound to `" << thread_tag
<< "` should have the same extent. However, there are two loops with extent "
<< it->second << " and " << p_ext << ", which are not equal";
}
StmtVisitor::VisitStmt_(loop);
thread_extent_product /= *p_ext;
if (new_thread) {
thread_extent_product /= *p_ext;
thread_tag2extent_.erase(thread_tag);
}
return;
} else {
throw dmlc::Error("Dynamic thread extent");
Expand All @@ -64,6 +78,9 @@ class ThreadExtentChecker : private StmtVisitor {
}

int64_t thread_extent_product = 1;

/*! \brief A mapping from a thread tag to its thread extent */
std::unordered_map<std::string, int64_t> thread_tag2extent_;
};
} // namespace tir
namespace meta_schedule {
Expand Down Expand Up @@ -126,9 +143,8 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::FlattenBuffer());
Expand Down

0 comments on commit 8d76443

Please sign in to comment.