From 8d76443d9d449c27df1e143146efe3faf57dd1c0 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 1 Mar 2022 11:33:52 +0000 Subject: [PATCH] address comments --- src/meta_schedule/postproc/verify_gpu_code.cc | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 18b08e938a34..6b34f69bc0b1 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -36,10 +36,24 @@ class ThreadExtentChecker : private StmtVisitor { private: void VisitStmt_(const ForNode* loop) { if (IsThreadIdx(GetThreadScope(loop))) { + const std::string& thread_tag = loop->thread_binding.value()->thread_tag; if (const int64_t* p_ext = GetLoopIntExtent(loop)) { - thread_extent_product *= *p_ext; + auto it = thread_tag2extent_.find(thread_tag); + bool new_thread = it == thread_tag2extent_.end(); + if (new_thread) { + thread_extent_product *= *p_ext; + thread_tag2extent_[thread_tag] = *p_ext; + } else { + CHECK_EQ(it->second, *p_ext) + << "ValueError: All loops that are bound to `" << thread_tag + << "` should have the same extent. However, there are two loops with extent " + << it->second << " and " << p_ext << ", which are not equal"; + } StmtVisitor::VisitStmt_(loop); - thread_extent_product /= *p_ext; + if (new_thread) { + thread_extent_product /= *p_ext; + thread_tag2extent_.erase(thread_tag); + } return; } else { throw dmlc::Error("Dynamic thread extent"); @@ -64,6 +78,9 @@ class ThreadExtentChecker : private StmtVisitor { } int64_t thread_extent_product = 1; + + /*! \brief A mapping from a thread tag to its thread extent */ + std::unordered_map thread_tag2extent_; }; } // namespace tir namespace meta_schedule { @@ -126,9 +143,8 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer());