-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Utility function to decide loop mapping for auto tensorization #11050
Changes from all commits
062d8c2
e0c3337
51df94d
84801b6
fcd7917
65682c2
fcca9fb
4eb5845
f759f43
f0caa77
0df73cb
46eed2a
ecb3ebc
0860abc
ec39b62
9ec0974
2909a06
f474003
8750b4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -16,6 +16,9 @@ | |||
* specific language governing permissions and limitations | ||||
* under the License. | ||||
*/ | ||||
#include <tvm/runtime/container/optional.h> | ||||
#include <tvm/tir/expr.h> | ||||
|
||||
#include "../utils.h" | ||||
|
||||
namespace tvm { | ||||
|
@@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, | |||
} | ||||
} | ||||
|
||||
std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) { | ||||
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); | ||||
std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) { | ||||
std::vector<IterVarType> results; | ||||
results.reserve(block->iter_vars.size()); | ||||
for (const IterVar& iter_var : block->iter_vars) { | ||||
|
@@ -502,6 +504,11 @@ std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) { | |||
return results; | ||||
} | ||||
|
||||
std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) { | ||||
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); | ||||
return GetBlockVarTypes(block); | ||||
} | ||||
|
||||
bool IsWriteCache(const StmtSRef& block_sref) { | ||||
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); | ||||
if (block->writes.size() != 1) { | ||||
|
@@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // | |||
} | ||||
} | ||||
|
||||
TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); | ||||
|
||||
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self, | ||||
const tir::StmtSRef& block_sref, | ||||
const tir::PrimFunc& desc_func) { | ||||
arith::Analyzer analyzer; | ||||
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); | ||||
// Step 1. Analyze desc_func, extract its block, loops and loop vars | ||||
const tir::BlockRealizeNode* desc_block = nullptr; | ||||
std::vector<const tir::ForNode*> desc_loops; | ||||
std::unordered_set<const tir::VarNode*> desc_loop_vars; | ||||
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>(); | ||||
ICHECK(desc_scope_realize); | ||||
{ | ||||
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, | ||||
&analyzer](const ObjectRef& obj) -> bool { | ||||
// Extract the block | ||||
if (const auto* block = obj.as<tir::BlockRealizeNode>()) { | ||||
desc_block = block; | ||||
return false; | ||||
} | ||||
// Extract loops | ||||
if (const auto* loop = obj.as<tir::ForNode>()) { | ||||
desc_loops.push_back(loop); | ||||
desc_loop_vars.insert(loop->loop_var.get()); | ||||
if (!analyzer.CanProve(loop->min == 0)) { | ||||
return false; | ||||
} | ||||
} | ||||
return true; | ||||
}; | ||||
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); | ||||
std::reverse(desc_loops.begin(), desc_loops.end()); | ||||
ICHECK(desc_block); | ||||
} | ||||
// Step 2. Collect loops from block_sref | ||||
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); | ||||
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); | ||||
std::vector<const tir::ForNode*> block_loops; | ||||
std::unordered_set<const tir::VarNode*> block_loop_vars; | ||||
{ | ||||
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { | ||||
const auto* loop = loop_sref->StmtAs<tir::ForNode>(); | ||||
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) { | ||||
break; | ||||
} | ||||
block_loops.push_back(loop); | ||||
block_loop_vars.insert(loop->loop_var.get()); | ||||
if (!analyzer.CanProve(loop->min == 0)) { | ||||
return NullOpt; | ||||
} | ||||
} | ||||
std::reverse(block_loops.begin(), block_loops.end()); | ||||
} | ||||
// Step 3. Map from block loops to desc block loops | ||||
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>(); | ||||
const int n_block_vars = block->iter_values.size(); | ||||
const int n_desc_vars = desc_block->iter_values.size(); | ||||
const int offset = n_block_vars - n_desc_vars; | ||||
|
||||
if (offset < 0) { | ||||
return NullOpt; | ||||
} | ||||
|
||||
const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref); | ||||
const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get()); | ||||
|
||||
ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars)); | ||||
ICHECK(block_loops.size() == iter_types_block.size()); | ||||
|
||||
// We assume that the orders of iter_vars in the target and the desc block are consistent. | ||||
// Based on that assumption, the following logic supports arbitrary permutations of a loop order, | ||||
// such as | ||||
|
||||
// for k: | ||||
// for i: | ||||
// for j: | ||||
// C[i, j] += A[i, k] * B[k, j] | ||||
|
||||
// or | ||||
|
||||
// for i: | ||||
// for j: | ||||
// for k: | ||||
// C[i, j] += A[i, k] * B[k, j] | ||||
|
||||
int next_block_ind = block_loops.size() - 1; | ||||
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { | ||||
// Step 3.1. Find the corresponding loop of the i_desc-th block var of desc | ||||
const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; | ||||
const tir::ForNode* desc_loop = nullptr; | ||||
IterVarType iter_type_desc = iter_types_desc[i_desc]; | ||||
for (int i = 0, n = desc_loops.size(); i < n; ++i) { | ||||
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars | ||||
PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); | ||||
if (!UsesVar(residual, | ||||
[&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { | ||||
desc_loop = desc_loops[i]; | ||||
iter_type_desc = iter_types_desc[i]; | ||||
break; | ||||
} | ||||
} | ||||
if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) { | ||||
return NullOpt; | ||||
} | ||||
|
||||
const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>(); | ||||
|
||||
// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type | ||||
PrimExpr block_bind; | ||||
for (int i = next_block_ind; i >= 0; --i) { | ||||
if (iter_types_block[i] == iter_type_desc) { | ||||
next_block_ind = i - 1; | ||||
block_bind = block->iter_values[i]; | ||||
break; | ||||
} | ||||
} | ||||
|
||||
if (!block_bind.defined()) return NullOpt; | ||||
|
||||
// Step 3.3. Find the corresponding loop of the target block | ||||
for (int i = 0, n = block_loops.size(); i < n; ++i) { | ||||
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars | ||||
const tir::ForNode* block_loop = block_loops[i]; | ||||
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; | ||||
// Skip i-th loop if it has already been mapped | ||||
if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; | ||||
|
||||
PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); | ||||
if (UsesVar(residual, | ||||
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) | ||||
continue; | ||||
|
||||
const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>(); | ||||
|
||||
// Check divisibility | ||||
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { | ||||
return NullOpt; | ||||
} | ||||
|
||||
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop)); | ||||
break; | ||||
} | ||||
} | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic here is very different from the one in the original code https://github.com/spectrometerHBH/tvm/blob/auto-tensorization/src/tir/schedule/analysis/analysis.cc#L1246. I was not able to understand why the original code has been written that way and it didn't work for the case where matching loops in the target block are not in the innermost positions (conv2d NCHWc on CPU, a test in
I think my change is simple and obvious. The condition for a match is (1) divisibility of loop extent and (2) matching iterator types (reduction vs spatial). Mapping is determined starting from the innermost axis. Please have a look at this change carefully, and let me know if I need to bring back some logic in the original code @spectrometerHBH @vinx13 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would love to have @spectrometerHBH review this change before merging There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The goal of the original mapping is to support for k:
for i:
for j:
C[i, j] += A[i, k] * B[k, j] where loops are not in the same order as the tensor intrinsic description function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But it also makes sense if we don't support such cases for this PR. So I approve it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @spectrometerHBH, I now understand the original code and was able to integrate the original logic to support loop permutations. Please have a look at the current diff, also cc @vinx13 @Hzfengsy @MasterJH5574 The key difference between the original code and the code I submitted yesterday was that, my code was looking at only the loop nest ( |
||||
|
||||
for (int i = 0, n = desc_loops.size(); i < n; ++i) { | ||||
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i)); | ||||
} | ||||
return TensorizeInfo(ret); | ||||
} | ||||
|
||||
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") | ||||
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { | ||||
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); | ||||
}); | ||||
|
||||
} // namespace tir | ||||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e., no matter what the permutation of loop is, we should always have
for GEMM.
I think this is a reasonable assumption. Correct me if I'm wrong @spectrometerHBH @junrushao1994 @vinx13
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this is a reasonable assumption. Though there might be corner cases, it covers all of the current use cases