Skip to content

Commit

Permalink
move utils.cc to tir/schedule/analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Apr 29, 2022
1 parent d4736fb commit 44029b3
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 184 deletions.
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::Bl

return {BindThreadsForUnboundBlock(sch, block_rv, max_num_threads_, max_threadblock_,
thread_extents)};
};
}

ScheduleRule ScheduleRule::AutoBind(int max_threadblock, //
Array<Integer> thread_extents) {
Expand Down
155 changes: 0 additions & 155 deletions src/meta_schedule/utils.cc

This file was deleted.

28 changes: 0 additions & 28 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,6 @@
#include "../tir/schedule/utils.h"

namespace tvm {
namespace tir {

/*! \brief The rewrite type for an unbound block */
enum class BindType : int32_t {
/*! \brief No additional thread binding is needed */
kNoBind = 0,
/*! \brief Need to bind to blockIdx */
kBindBlock = 1,
/*! \brief Need to bind to both blockIdx and threadIdx */
kBindBlockThread = 2,
};

/*!
* \brief Bind loops nesting to threadIdx for an unbound block.
* \param sch The input schedule.
* \param block_rv The input block.
* \param max_threadblock The max number of thread blocks
* \param max_num_threads The max number of threads per block
* \param thread_extents The candidates for thread extent
* \return The result schedule.
*/
Schedule BindThreadsForUnboundBlock(const Schedule& sch, //
const BlockRV& block_rv, //
int max_threadblock, //
int max_num_threads, //
Array<Integer> thread_extents);
} // namespace tir

namespace meta_schedule {

/*! \brief The type of the random state */
Expand Down
129 changes: 129 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2191,5 +2191,134 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
});

/*! \brief The rewrite type for an unbound block */
enum class BindType : int32_t {
/*! \brief No additional thread binding is needed */
kNoBind = 0,
/*! \brief Need to bind to blockIdx */
kBindBlock = 1,
/*! \brief Need to bind to both blockIdx and threadIdx */
kBindBlockThread = 2,
};

/*!
* \brief Check the combination of bindings to be added to the block
* \param block_sref The block to be checked
* \param fuse_first_num The number of loops to be fused
* \return The type of binding to be added to the block
*/
BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
Array<StmtSRef> loops = tir::GetLoops(block_sref);
int n = loops.size();
if (n == 0) {
return BindType::kNoBind;
}
int i_block_idx = -1;
int i_thread_idx = -1;
int i_multi_child = -1;
int i_spatial_loop = -1;
for (int i = 0; i < n; ++i) {
const StmtSRef& loop_sref = loops[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsBlockIdx(thread_scope)) {
if (i_block_idx == -1) {
i_block_idx = i;
}
}
if (IsThreadIdx(thread_scope)) {
if (i_thread_idx == -1) {
i_thread_idx = i;
}
}
if (loop->kind != tir::ForKind::kSerial) {
if (i_multi_child == -1) {
i_multi_child = i;
}
}
if (!IsSingleStmt(loop->body)) {
if (i_multi_child == -1) {
i_multi_child = i + 1;
}
}
if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) {
if (i_spatial_loop == i - 1) {
++i_spatial_loop;
}
}
}
if (i_multi_child == -1) {
i_multi_child = n;
}
if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
return BindType::kNoBind;
} else if (i_block_idx != -1 && i_thread_idx == -1) {
ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not";
throw;
} else if (i_block_idx == -1 && i_thread_idx != -1) {
*fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1);
return BindType::kBindBlock;
} else { // i_block_idx == -1 && i_thread_idx == -1
*fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1);
return BindType::kBindBlockThread;
}
}

Schedule BindThreadsForUnboundBlock(const Schedule& sch, //
const BlockRV& block_rv, //
int max_threadblock, //
int max_num_threads, //
Array<Integer> thread_extents) {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);

int fuse_first_num = 0;
tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num);
if (bind_type == tir::BindType::kNoBind) {
return {sch};
}

Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num});
if (bind_type == tir::BindType::kBindBlock) {
sch->Bind(fused, "blockIdx.x");
} else if (bind_type == tir::BindType::kBindBlockThread) {
int64_t extent_size = int64_t(1) << 60;
if (const int64_t* extent_ptr = tir::GetLoopIntExtent(sch->Get(fused).get())) {
extent_size = *extent_ptr;
}

Array<Integer> updated_extents;
for (const Integer extent : thread_extents) {
if (extent->value <= extent_size) updated_extents.push_back(extent);
}

if (extent_size <= max_threadblock * max_num_threads) {
tir::ExprRV factor;
if (updated_extents.empty()) {
factor = Integer(std::min(static_cast<int64_t>(max_num_threads), extent_size));
} else if (updated_extents.size() == 1) {
factor = updated_extents[0];
} else {
// Sample a factor
int n = updated_extents.size();
Array<FloatImm> probs(n, FloatImm(DataType::Float(64), 1.0 / n));
factor = sch->SampleCategorical(updated_extents, probs);
}
Array<LoopRV> splits = sch->Split(fused, {NullOpt, factor});
ICHECK_EQ(splits.size(), 2);
sch->Bind(splits[0], "blockIdx.x");
sch->Bind(splits[1], "threadIdx.x");
} else {
Array<LoopRV> splits =
sch->Split(fused, {NullOpt, Integer(max_threadblock), Integer(max_num_threads)});
ICHECK_EQ(splits.size(), 3);
sch->Reorder({splits[1], splits[2], splits[0]});
sch->Bind(splits[1], "blockIdx.x");
sch->Bind(splits[2], "threadIdx.x");
}
}
return {sch};
}

} // namespace tir
} // namespace tvm
17 changes: 17 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,23 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl
}
}

/********** Helper Functions for AutoBind and RewriteUnBoundBlock **********/

/*!
* \brief Bind loops nesting to threadIdx for an unbound block.
* \param sch The input schedule.
* \param block_rv The input block.
* \param max_threadblock The max number of thread blocks
* \param max_num_threads The max number of threads per block
* \param thread_extents The candidates for thread extent
* \return The result schedule.
*/
Schedule BindThreadsForUnboundBlock(const Schedule& sch, //
const BlockRV& block_rv, //
int max_threadblock, //
int max_num_threads, //
Array<Integer> thread_extents);

/******** Helper functions for enum conversion ********/

/*!
Expand Down

0 comments on commit 44029b3

Please sign in to comment.