-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR][Transform]Block scope hoisting added #6238
Changes from 9 commits
3986c93
69a15fa
fbbdd4d
8e23c24
94aced5
80c6dd1
c73d39d
cd54666
33ea539
4461094
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,25 @@ | |
namespace tvm { | ||
namespace tir { | ||
|
||
struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> { | ||
bool support_block_scope_hosting; | ||
|
||
TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") { | ||
TVM_ATTR_FIELD(support_block_scope_hosting) | ||
.describe("Hoist if cond with block scope variables") | ||
.set_default(false); | ||
} | ||
}; | ||
|
||
class HoistIfThenElseConfig : public Attrs { | ||
public: | ||
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, | ||
HoistIfThenElseConfigNode); | ||
}; | ||
|
||
TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); | ||
TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig); | ||
|
||
using VarForMap = std::unordered_map<const VarNode*, const ForNode*>; | ||
using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>; | ||
|
||
|
@@ -93,11 +112,33 @@ using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>; | |
* if (likely(j > 2)) | ||
* A[i+j+k] = B[i+j+k] | ||
* | ||
* | ||
* This pass do hoisting for Block scope variables also. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The term "Block scope variables" is a little bit confusing to me. Is it equivalent to say variables defined in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. It is referring to Attr nodes. |
||
* As below: | ||
* Attr(IterVar: threadIdx.x) | ||
* for (i = 0; i < 3; i++) | ||
* for (j = 0; j < 4; j++) | ||
* for (k = 0; k < 5; k++) | ||
* if (likely(threadIdx.x < 3)) | ||
* A[3*i+2j+k] = B[7*i+3j+k] | ||
* | ||
* Will be transformed to as below: | ||
* Attr(IterVar: threadIdx.x) | ||
* if (likely(threadIdx.x < 3)) | ||
* for (i = 0; i < 3; i++) | ||
* for (j = 0; j < 4; j++) | ||
* for (k = 0; k < 5; k++) | ||
* A[3*i+2j+k] = B[7*i+3j+k] | ||
* | ||
*/ | ||
|
||
// Select potential candidate IRs that can be hoisted. | ||
class HoistCandidateSelector final : public StmtExprVisitor { | ||
public: | ||
explicit HoistCandidateSelector(bool support_block_scope_hosting) | ||
: support_block_scope_hosting_(support_block_scope_hosting) { | ||
InitRecorder(); | ||
} | ||
HoistCandidateSelector() { InitRecorder(); } | ||
|
||
void VisitStmt_(const ForNode* op) final { | ||
|
@@ -108,16 +149,16 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
} | ||
|
||
// Check if it is first for loop, then start the recorder | ||
StartOrAddRecord(op); | ||
StartOrAddRecord(GetRef<ObjectRef>(op)); | ||
StmtExprVisitor::VisitStmt_(op); | ||
RemoveRecord(op); | ||
RemoveRecord(GetRef<ObjectRef>(op)); | ||
} | ||
|
||
void VisitStmt_(const SeqStmtNode* op) final { | ||
// If SeqStmt is encountered in the middle of recording | ||
// then need to purge all, as it can not be hoisted | ||
if (IsRecordingOn()) { | ||
ResetRecorder(); | ||
ResetRecorderInternal(); | ||
} | ||
StmtExprVisitor::VisitStmt_(op); | ||
} | ||
|
@@ -126,10 +167,19 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
// Maintain list of all vars in AttrStmt | ||
// To stop hoisting if any of the block variables are used. | ||
// | ||
// NOTE: If in future | ||
// hoisting is required for any specific case, | ||
// then add exception to only those case | ||
// rather than allowing for all. | ||
// In case we want to use hoisting in between certain passes | ||
// which have interdependencies of the postioning of if nodes with scope var | ||
// it is better to disable this section | ||
if (support_block_scope_hosting_) { | ||
if (IsRecordingOn()) { | ||
StartOrAddRecord(GetRef<ObjectRef>(op)); | ||
StmtExprVisitor::VisitStmt_(op); | ||
RemoveRecord(GetRef<ObjectRef>(op)); | ||
return; | ||
} else { | ||
return StmtExprVisitor::VisitStmt_(op); | ||
} | ||
} | ||
UpdateAttrVarList(op); | ||
StmtExprVisitor::VisitStmt_(op); | ||
RemoveAttrVarList(op); | ||
|
@@ -147,26 +197,23 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
|
||
if (CheckValidIf()) { | ||
// Check corresponding for loop | ||
bool match_found = false; | ||
size_t match_for_loop_pos = 0; | ||
int match_for_loop_pos = -1; | ||
for (auto var : if_var_list_) { | ||
for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) { | ||
if (ordered_for_list_[i] == var_for_map_[var]) { | ||
for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) { | ||
ANSHUMAN87 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) { | ||
if (match_for_loop_pos < i) { | ||
match_for_loop_pos = i; | ||
} | ||
match_found = true; | ||
break; | ||
} | ||
} | ||
} | ||
// If none of the for loop has the matching loop variable as if condition, | ||
// then the if node need to be hoisted on top of all, provided no parent loop exists. | ||
int target_for_pos = match_found ? match_for_loop_pos + 1 : 0; | ||
int target_for_pos = GetNextLoopPos(match_for_loop_pos); | ||
|
||
// Check if target for loop is not the parent of current if node | ||
if (!IsParentForLoop(target_for_pos)) { | ||
StopAndAddRecord(ordered_for_list_[target_for_pos], op); | ||
// Check if valid position | ||
if (target_for_pos >= 0) { | ||
StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op); | ||
if_var_list_.clear(); | ||
return; | ||
} | ||
|
@@ -186,13 +233,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
HoistForIfTuple hoist_for_if_recorder; | ||
|
||
void ResetRecorder() { | ||
if (is_recorder_on_) { | ||
CHECK_GT(ordered_for_list_.size(), 0); | ||
is_recorder_on_ = false; | ||
} | ||
ordered_for_list_.clear(); | ||
var_for_map_.clear(); | ||
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); | ||
ResetRecorderInternal(); | ||
|
||
// Reset Block scope vars also here | ||
attr_var_list_.clear(); | ||
} | ||
|
||
bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } | ||
|
@@ -202,25 +246,24 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } | ||
|
||
private: | ||
void ResetRecorderInternal() { | ||
if (is_recorder_on_) { | ||
CHECK_GT(ordered_list_.size(), 0); | ||
is_recorder_on_ = false; | ||
} | ||
ordered_list_.clear(); | ||
var_for_map_.clear(); | ||
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); | ||
} | ||
bool CheckValidIf() { | ||
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop | ||
// hoisting | ||
return ((!if_var_list_.empty()) && (!CheckAttrVar())); | ||
} | ||
|
||
bool IsParentForLoop(int loop_pos) { | ||
// Check if the loop position is higher than the parent loop position | ||
for (auto var : if_var_list_) { | ||
if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
int GetParentLoopPos(const Object* node) { | ||
for (size_t i = 0; i < ordered_for_list_.size(); ++i) { | ||
if (ordered_for_list_[i] == node) { | ||
int GetNextLoopPos(int cur_pos) { | ||
for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { | ||
if (ordered_list_[i]->IsInstance<ForNode>()) { | ||
return i; | ||
} | ||
} | ||
|
@@ -233,18 +276,25 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
|
||
bool IsRecordingOn() { return is_recorder_on_; } | ||
|
||
void StartOrAddRecord(const ForNode* op) { | ||
void StartOrAddRecord(const ObjectRef& op) { | ||
is_recorder_on_ = true; | ||
if (!var_for_map_.count(op->loop_var.get())) { | ||
var_for_map_.insert({op->loop_var.get(), op}); | ||
if (const auto* node = op.as<ForNode>()) { | ||
if (!var_for_map_.count(node->loop_var.get())) | ||
var_for_map_.insert({node->loop_var.get(), node}); | ||
ordered_list_.emplace_back(op.get()); | ||
} else if (const auto* node = op.as<AttrStmtNode>()) { | ||
if (const auto* iv = node->node.as<IterVarNode>()) { | ||
ordered_list_.emplace_back(iv->var.get()); | ||
} else if (const auto* iv = node->node.as<VarNode>()) { | ||
ordered_list_.emplace_back(iv); | ||
} | ||
} | ||
ordered_for_list_.emplace_back(op); | ||
} | ||
|
||
void RemoveRecord(const ForNode* op) { | ||
void RemoveRecord(const ObjectRef& op) { | ||
StopRecording(); | ||
var_for_map_.erase(op->loop_var.get()); | ||
if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back(); | ||
if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get()); | ||
if (ordered_list_.size() > 0) ordered_list_.pop_back(); | ||
} | ||
|
||
void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { | ||
|
@@ -277,18 +327,22 @@ class HoistCandidateSelector final : public StmtExprVisitor { | |
return false; | ||
} | ||
|
||
std::vector<const ForNode*> ordered_for_list_; | ||
// Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence | ||
std::vector<const Object*> ordered_list_; | ||
ANSHUMAN87 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<const VarNode*> if_var_list_; | ||
std::unordered_set<const VarNode*> attr_var_list_; | ||
VarForMap var_for_map_; | ||
|
||
bool is_if_cond_{false}; | ||
bool is_recorder_on_{false}; | ||
bool support_block_scope_hosting_{false}; | ||
}; | ||
|
||
class IfThenElseHoister : public StmtMutator { | ||
public: | ||
IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} | ||
explicit IfThenElseHoister(bool support_block_scope_hosting) | ||
: hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {} | ||
|
||
Stmt VisitAndMutate(Stmt stmt) { | ||
hoist_selector_(stmt); | ||
|
@@ -344,21 +398,40 @@ class IfThenElseHoister : public StmtMutator { | |
const IfThenElseNode* target_if_; | ||
}; | ||
|
||
Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) { | ||
return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt); | ||
} | ||
Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); } | ||
|
||
namespace transform { | ||
|
||
Pass HoistIfThenElse() { | ||
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
auto* n = f.CopyOnWrite(); | ||
n->body = HoistIfThenElse(std::move(n->body)); | ||
auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse"); | ||
|
||
if (!cfg.defined()) { | ||
cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>(); | ||
} | ||
n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting); | ||
return f; | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); | ||
} | ||
|
||
Pass HoistIfThenElseBasic() { | ||
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
auto* n = f.CopyOnWrite(); | ||
n->body = HoistIfThenElse(std::move(n->body)); | ||
return f; | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this "basic" pass? Can we directly perform a complete pass in Phase 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can combine both also and place the pass in End phase.
I understand the dilemma here. Even i also had the same :)
The reasoning i came up later is as below.
Current PR supports the feature for "Block scope vars" or "Attr nodes" which happens to be more applicable in specific cases(For example in Cuda Kernels). Also there is slight increase in time complexity(As linear).
So to sum up, we have 2 cases :
Case 1: "Basic" or "Default": The scenarios covered here should be more general(simpler version) across.
Case 2: "Advanced" : The scenarios covered here should be enabled in case of particular settings.
Please let me know your thought on above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a little more compile-time complexity is a real problem. Perhaps the pass can even be faster than those Python bindings. We can always perform the "advanced" version as long as it won't break the correctness or do negative optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right! The compile-time complexity is not the real point here. I just shared the info :)
I am okay for the both the approach.
However i think it would be good to have one additional segregation inside Pass to have more control on the different scenarios it has covered. Which can provide more user friendly experience, when user wants to club the Pass with only specific Passes without needing to write special config parameters.
Let us have some more opinion on this point, to conclude better.
@MarisaKirisame : Would please help share your thoughts on above point. TIA!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use the advanced option. It is more general and the compile time is light juding from the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MarisaKirisame for sharing your thoughts here.
With this i think we can conclude the comment here, by removing the "Basic" pass embedding in the lower sequence of passes, and place only one hoisting pass at the third phase as per latest change.
If my understanding is wrong, please let me know. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is good for me.