Skip to content

Commit

Permalink
[TIR][Transform]Block scope hoisting added (#6238)
Browse files Browse the repository at this point in the history
* Block scope hoisting added

* lowering flow added with 2 variants

* Fake commit to trigger ci with pass default enabled

* CI Failure resolved

* Optimize for if var list iteration

* More test case added

* Fake commit to disable failed test cases

* Pass default value restored

* [1] Review comment handled

* [2] Review comments handled
  • Loading branch information
ANSHUMAN TRIPATHY authored Aug 31, 2020
1 parent 1767b08 commit 66b7ddb
Show file tree
Hide file tree
Showing 4 changed files with 629 additions and 58 deletions.
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def lower(sch,
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
tvm.tir.transform.HoistIfThenElse(),
]
pass_list += lower_phase1

Expand All @@ -205,6 +204,7 @@ def lower(sch,
]

pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [tvm.tir.transform.HoistIfThenElse()]
pass_list += lower_phase3

# Instrument BoundCheckers
Expand Down
24 changes: 22 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,31 @@ def VerifyMemory():
"""
return _ffi_api.VerifyMemory()

def HoistIfThenElse():
#pylint: disable=no-else-return,inconsistent-return-statements
def HoistIfThenElse(variant=None):
"""Hoist loop-invariant IfThenElse nodes to outside the elligible loops.
Parameters
----------
variant : Optional[String]
The variant of the pass.
variant can have any one of following values ["basic", None(Default)].
The basic variant supports basic hoisting scenarios where it exepects
the For & If Nodes are in place consecutively and does not involve
global scope variables or more advanced scenarios.
Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"}
supported with control with PassContext configs like below:
config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}}
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.HoistIfThenElse()
if variant == "basic":
return _ffi_api.HoistIfThenElseBasic()
elif variant is None:
return _ffi_api.HoistIfThenElse()
165 changes: 119 additions & 46 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@
namespace tvm {
namespace tir {

struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> {
bool support_block_scope_hosting;

TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") {
TVM_ATTR_FIELD(support_block_scope_hosting)
.describe("Hoist if cond with block scope variables")
.set_default(false);
}
};

class HoistIfThenElseConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs,
HoistIfThenElseConfigNode);
};

TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);

using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;

Expand Down Expand Up @@ -93,11 +112,33 @@ using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
* if (likely(j > 2))
* A[i+j+k] = B[i+j+k]
*
*
* This pass do hoisting for Block scope variables also.
* As below:
* Attr(IterVar: threadIdx.x)
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* if (likely(threadIdx.x < 3))
* A[3*i+2j+k] = B[7*i+3j+k]
*
* Will be transformed to as below:
* Attr(IterVar: threadIdx.x)
* if (likely(threadIdx.x < 3))
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* A[3*i+2j+k] = B[7*i+3j+k]
*
*/

// Select potential candidate IRs that can be hoisted.
class HoistCandidateSelector final : public StmtExprVisitor {
public:
explicit HoistCandidateSelector(bool support_block_scope_hosting)
: support_block_scope_hosting_(support_block_scope_hosting) {
InitRecorder();
}
HoistCandidateSelector() { InitRecorder(); }

void VisitStmt_(const ForNode* op) final {
Expand All @@ -108,16 +149,16 @@ class HoistCandidateSelector final : public StmtExprVisitor {
}

// Check if it is first for loop, then start the recorder
StartOrAddRecord(op);
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(op);
RemoveRecord(GetRef<ObjectRef>(op));
}

void VisitStmt_(const SeqStmtNode* op) final {
// If SeqStmt is encountered in the middle of recording
// then need to purge all, as it can not be hoisted
if (IsRecordingOn()) {
ResetRecorder();
ResetRecorderInternal();
}
StmtExprVisitor::VisitStmt_(op);
}
Expand All @@ -126,10 +167,19 @@ class HoistCandidateSelector final : public StmtExprVisitor {
// Maintain list of all vars in AttrStmt
// To stop hoisting if any of the block variables are used.
//
// NOTE: If in future
// hoisting is required for any specific case,
// then add exception to only those case
// rather than allowing for all.
// In case we want to use hoisting in between certain passes
// which have interdependencies of the postioning of if nodes with scope var
// it is better to disable this section
if (support_block_scope_hosting_) {
if (IsRecordingOn()) {
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(GetRef<ObjectRef>(op));
return;
} else {
return StmtExprVisitor::VisitStmt_(op);
}
}
UpdateAttrVarList(op);
StmtExprVisitor::VisitStmt_(op);
RemoveAttrVarList(op);
Expand All @@ -147,26 +197,23 @@ class HoistCandidateSelector final : public StmtExprVisitor {

if (CheckValidIf()) {
// Check corresponding for loop
bool match_found = false;
size_t match_for_loop_pos = 0;
int match_for_loop_pos = -1;
for (auto var : if_var_list_) {
for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
if (ordered_for_list_[i] == var_for_map_[var]) {
for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) {
if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
match_found = true;
break;
}
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
int target_for_pos = GetNextLoopPos(match_for_loop_pos);

// Check if target for loop is not the parent of current if node
if (!IsParentForLoop(target_for_pos)) {
StopAndAddRecord(ordered_for_list_[target_for_pos], op);
// Check if valid position
if (target_for_pos >= 0) {
StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op);
if_var_list_.clear();
return;
}
Expand All @@ -186,13 +233,10 @@ class HoistCandidateSelector final : public StmtExprVisitor {
HoistForIfTuple hoist_for_if_recorder;

void ResetRecorder() {
if (is_recorder_on_) {
CHECK_GT(ordered_for_list_.size(), 0);
is_recorder_on_ = false;
}
ordered_for_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
ResetRecorderInternal();

// Reset Block scope vars also here
attr_var_list_.clear();
}

bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
Expand All @@ -202,25 +246,24 @@ class HoistCandidateSelector final : public StmtExprVisitor {
const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }

private:
void ResetRecorderInternal() {
if (is_recorder_on_) {
CHECK_GT(ordered_list_.size(), 0);
is_recorder_on_ = false;
}
ordered_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
}
bool CheckValidIf() {
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
// hoisting
return ((!if_var_list_.empty()) && (!CheckAttrVar()));
}

bool IsParentForLoop(int loop_pos) {
// Check if the loop position is higher than the parent loop position
for (auto var : if_var_list_) {
if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
return true;
}
}
return false;
}

int GetParentLoopPos(const Object* node) {
for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
if (ordered_for_list_[i] == node) {
int GetNextLoopPos(int cur_pos) {
for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) {
if (ordered_list_[i]->IsInstance<ForNode>()) {
return i;
}
}
Expand All @@ -233,18 +276,25 @@ class HoistCandidateSelector final : public StmtExprVisitor {

bool IsRecordingOn() { return is_recorder_on_; }

void StartOrAddRecord(const ForNode* op) {
void StartOrAddRecord(const ObjectRef& op) {
is_recorder_on_ = true;
if (!var_for_map_.count(op->loop_var.get())) {
var_for_map_.insert({op->loop_var.get(), op});
if (const auto* node = op.as<ForNode>()) {
if (!var_for_map_.count(node->loop_var.get()))
var_for_map_.insert({node->loop_var.get(), node});
ordered_list_.emplace_back(op.get());
} else if (const auto* node = op.as<AttrStmtNode>()) {
if (const auto* iv = node->node.as<IterVarNode>()) {
ordered_list_.emplace_back(iv->var.get());
} else if (const auto* iv = node->node.as<VarNode>()) {
ordered_list_.emplace_back(iv);
}
}
ordered_for_list_.emplace_back(op);
}

void RemoveRecord(const ForNode* op) {
void RemoveRecord(const ObjectRef& op) {
StopRecording();
var_for_map_.erase(op->loop_var.get());
if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get());
if (ordered_list_.size() > 0) ordered_list_.pop_back();
}

void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
Expand Down Expand Up @@ -277,18 +327,22 @@ class HoistCandidateSelector final : public StmtExprVisitor {
return false;
}

std::vector<const ForNode*> ordered_for_list_;
// Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence
std::vector<const Object*> ordered_list_;
std::vector<const VarNode*> if_var_list_;
std::unordered_set<const VarNode*> attr_var_list_;
VarForMap var_for_map_;

bool is_if_cond_{false};
bool is_recorder_on_{false};
bool support_block_scope_hosting_{false};
};

class IfThenElseHoister : public StmtMutator {
public:
IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}
explicit IfThenElseHoister(bool support_block_scope_hosting)
: hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {}

Stmt VisitAndMutate(Stmt stmt) {
hoist_selector_(stmt);
Expand Down Expand Up @@ -344,21 +398,40 @@ class IfThenElseHoister : public StmtMutator {
const IfThenElseNode* target_if_;
};

Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) {
return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt);
}
Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }

namespace transform {

Pass HoistIfThenElse() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = HoistIfThenElse(std::move(n->body));
auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");

if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
}
n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
}

Pass HoistIfThenElseBasic() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = HoistIfThenElse(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {});
}

TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);

TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic);

} // namespace transform

} // namespace tir
Expand Down
Loading

0 comments on commit 66b7ddb

Please sign in to comment.