From b46fceb4665e43075ac56c81bc84f7fe890673f4 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 29 Apr 2022 14:16:08 +0800 Subject: [PATCH] move utils.cc to tir/schedule/analysis --- src/meta_schedule/schedule_rule/auto_bind.cc | 2 +- src/meta_schedule/utils.cc | 155 ------------------- src/meta_schedule/utils.h | 27 ---- src/tir/schedule/analysis/analysis.cc | 129 +++++++++++++++ src/tir/schedule/utils.h | 17 ++ 5 files changed, 147 insertions(+), 183 deletions(-) delete mode 100644 src/meta_schedule/utils.cc diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 8e63f5a4209f8..22cc82232646f 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -62,7 +62,7 @@ Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl return {BindThreadsForUnboundBlock(sch, block_rv, max_num_threads_, max_threadblock_, thread_extents)}; -}; +} ScheduleRule ScheduleRule::AutoBind(int max_threadblock, // Array thread_extents) { diff --git a/src/meta_schedule/utils.cc b/src/meta_schedule/utils.cc deleted file mode 100644 index fbf43dd53b95b..0000000000000 --- a/src/meta_schedule/utils.cc +++ /dev/null @@ -1,155 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "./utils.h" - -namespace tvm { -namespace tir { - -bool IsUnboundBlock(const StmtSRef& block_sref) { - for (const StmtSRefNode* p = block_sref->parent; p != nullptr; p = p->parent) { - if (p->stmt->IsInstance()) { - For loop = Downcast(GetRef(p->stmt)); - if (loop->kind == ForKind::kThreadBinding) return false; - } - } - return true; -} - -/*! - * \brief Check the combination of bindings to be added to the block - * \param block_sref The block to be checked - * \param fuse_first_num The number of loops to be fused - * \return The type of binding to be added to the block - */ -BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { - Array loops = tir::GetLoops(block_sref); - int n = loops.size(); - if (n == 0) { - return BindType::kNoBind; - } - int i_block_idx = -1; - int i_thread_idx = -1; - int i_multi_child = -1; - int i_spatial_loop = -1; - for (int i = 0; i < n; ++i) { - const StmtSRef& loop_sref = loops[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - runtime::ThreadScope thread_scope = GetThreadScope(loop); - if (IsBlockIdx(thread_scope)) { - if (i_block_idx == -1) { - i_block_idx = i; - } - } - if (IsThreadIdx(thread_scope)) { - if (i_thread_idx == -1) { - i_thread_idx = i; - } - } - if (loop->kind != tir::ForKind::kSerial) { - if (i_multi_child == -1) { - i_multi_child = i; - } - } - if (!IsSingleStmt(loop->body)) { - if (i_multi_child == -1) { - i_multi_child = i + 1; - } - } - if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) { - if (i_spatial_loop == i - 1) { - ++i_spatial_loop; - } - } - } - if (i_multi_child == -1) { - i_multi_child = n; - } - if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { - return BindType::kNoBind; - } else if (i_block_idx != -1 && i_thread_idx == -1) { - ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; - throw; - } else if (i_block_idx == -1 && i_thread_idx != -1) { - *fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); - return BindType::kBindBlock; - } else { // i_block_idx == -1 && i_thread_idx == -1 - *fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1); - return BindType::kBindBlockThread; - } -} - -Schedule BindThreadsForUnboundBlock(const Schedule& sch, // - const BlockRV& block_rv, // - int max_threadblock, // - int max_num_threads, // - Array thread_extents) { - tir::StmtSRef block_sref = sch->GetSRef(block_rv); - - int fuse_first_num = 0; - tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); - if (bind_type == tir::BindType::kNoBind) { - return {sch}; - } - - Array loop_rvs = sch->GetLoops(block_rv); - LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); - if (bind_type == tir::BindType::kBindBlock) { - sch->Bind(fused, "blockIdx.x"); - } else if (bind_type == tir::BindType::kBindBlockThread) { - int64_t extent_size = int64_t(1) << 60; - if (const int64_t* extent_ptr = tir::GetLoopIntExtent(sch->Get(fused).get())) { - extent_size = *extent_ptr; - } - - Array updated_extents; - for (const Integer extent : thread_extents) { - if (extent->value <= extent_size) updated_extents.push_back(extent); - } - - if (extent_size <= max_threadblock * max_num_threads) { - tir::ExprRV factor; - if (updated_extents.empty()) { - factor = Integer(std::min(static_cast(max_num_threads), extent_size)); - } else if (updated_extents.size() == 1) { - factor = updated_extents[0]; - } else { - // Sample a factor - int n = updated_extents.size(); - Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); - factor = sch->SampleCategorical(updated_extents, probs); - } - Array splits = sch->Split(fused, {NullOpt, factor}); - ICHECK_EQ(splits.size(), 2); - sch->Bind(splits[0], "blockIdx.x"); - sch->Bind(splits[1], "threadIdx.x"); - } else { - Array splits = - sch->Split(fused, {NullOpt, Integer(max_threadblock), Integer(max_num_threads)}); - ICHECK_EQ(splits.size(), 3); - sch->Reorder({splits[1], splits[2], splits[0]}); - sch->Bind(splits[1], "blockIdx.x"); - sch->Bind(splits[2], "threadIdx.x"); - } - } - return {sch}; -} -} // namespace tir -} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 62dc7a45b3532..4a3f4fae8eb09 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -54,33 +54,6 @@ #include "../tir/schedule/utils.h" namespace tvm { -namespace tir { - -/*! \brief The rewrite type for an unbound block */ -enum class BindType : int32_t { - /*! \brief No additional thread binding is needed */ - kNoBind = 0, - /*! \brief Need to bind to blockIdx */ - kBindBlock = 1, - /*! \brief Need to bind to both blockIdx and threadIdx */ - kBindBlockThread = 2, -}; - -/*! - * \brief Bind loops nesting to threadIdx for an unbound block. - * \param sch The input schedule. - * \param block_rv The input block. - * \param max_threadblock The max number of thread blocks - * \param max_num_threads The max number of threads per block - * \param thread_extents The candidates for thread extent - * \return The result schedule. - */ -Schedule BindThreadsForUnboundBlock(const Schedule& sch, // - const BlockRV& block_rv, // - int max_threadblock, // - int max_num_threads, // - Array thread_extents); -} // namespace tir namespace meta_schedule { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4777ee2657b39..33855b1e86947 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2191,5 +2191,134 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); }); +/*! \brief The rewrite type for an unbound block */ +enum class BindType : int32_t { + /*! \brief No additional thread binding is needed */ + kNoBind = 0, + /*! \brief Need to bind to blockIdx */ + kBindBlock = 1, + /*! \brief Need to bind to both blockIdx and threadIdx */ + kBindBlockThread = 2, +}; + +/*! + * \brief Check the combination of bindings to be added to the block + * \param block_sref The block to be checked + * \param fuse_first_num The number of loops to be fused + * \return The type of binding to be added to the block + */ +BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + if (n == 0) { + return BindType::kNoBind; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + int i_spatial_loop = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (loop->kind != tir::ForKind::kSerial) { + if (i_multi_child == -1) { + i_multi_child = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) { + if (i_spatial_loop == i - 1) { + ++i_spatial_loop; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { + return BindType::kNoBind; + } else if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } else if (i_block_idx == -1 && i_thread_idx != -1) { + *fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + return BindType::kBindBlock; + } else { // i_block_idx == -1 && i_thread_idx == -1 + *fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1); + return BindType::kBindBlockThread; + } +} + +Schedule BindThreadsForUnboundBlock(const Schedule& sch, // + const BlockRV& block_rv, // + int max_threadblock, // + int max_num_threads, // + Array thread_extents) { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + + int fuse_first_num = 0; + tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); + if (bind_type == tir::BindType::kNoBind) { + return {sch}; + } + + Array loop_rvs = sch->GetLoops(block_rv); + LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); + if (bind_type == tir::BindType::kBindBlock) { + sch->Bind(fused, "blockIdx.x"); + } else if (bind_type == tir::BindType::kBindBlockThread) { + int64_t extent_size = int64_t(1) << 60; + if (const int64_t* extent_ptr = tir::GetLoopIntExtent(sch->Get(fused).get())) { + extent_size = *extent_ptr; + } + + Array updated_extents; + for (const Integer extent : thread_extents) { + if (extent->value <= extent_size) updated_extents.push_back(extent); + } + + if (extent_size <= max_threadblock * max_num_threads) { + tir::ExprRV factor; + if (updated_extents.empty()) { + factor = Integer(std::min(static_cast(max_num_threads), extent_size)); + } else if (updated_extents.size() == 1) { + factor = updated_extents[0]; + } else { + // Sample a factor + int n = updated_extents.size(); + Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + factor = sch->SampleCategorical(updated_extents, probs); + } + Array splits = sch->Split(fused, {NullOpt, factor}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } else { + Array splits = + sch->Split(fused, {NullOpt, Integer(max_threadblock), Integer(max_num_threads)}); + ICHECK_EQ(splits.size(), 3); + sch->Reorder({splits[1], splits[2], splits[0]}); + sch->Bind(splits[1], "blockIdx.x"); + sch->Bind(splits[2], "threadIdx.x"); + } + } + return {sch}; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 53cafa798b548..e298a6e29906d 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -431,6 +431,23 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } +/********** Helper Functions for AutoBind and RewriteUnBoundBlock **********/ + +/*! + * \brief Bind loops nesting to threadIdx for an unbound block. + * \param sch The input schedule. + * \param block_rv The input block. + * \param max_threadblock The max number of thread blocks + * \param max_num_threads The max number of threads per block + * \param thread_extents The candidates for thread extent + * \return The result schedule. + */ +Schedule BindThreadsForUnboundBlock(const Schedule& sch, // + const BlockRV& block_rv, // + int max_threadblock, // + int max_num_threads, // + Array thread_extents); + /******** Helper functions for enum conversion ********/ /*!