Skip to content

Commit

Permalink
CI Failure resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Aug 15, 2020
1 parent 44a52e1 commit 3d5b347
Showing 1 changed file with 60 additions and 25 deletions.
85 changes: 60 additions & 25 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ class HoistCandidateSelector final : public StmtExprVisitor {
}

// Check if it is first for loop, then start the recorder
StartOrAddRecord(op);
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(op);
RemoveRecord(GetRef<ObjectRef>(op));
}

void VisitStmt_(const SeqStmtNode* op) final {
// If SeqStmt is encountered in the middle of recording
// then need to purge all, as it can not be hoisted
if (IsRecordingOn()) {
ResetRecorder();
ResetRecorderInternal();
}
StmtExprVisitor::VisitStmt_(op);
}
Expand All @@ -153,6 +153,13 @@ class HoistCandidateSelector final : public StmtExprVisitor {
// which have interdependencies of the postioning of if nodes with scope var
// it is better to disable this section
if (support_block_scope_hosting_) {
if (IsRecordingOn()) {
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(GetRef<ObjectRef>(op));
return;
}

return StmtExprVisitor::VisitStmt_(op);
}
UpdateAttrVarList(op);
Expand All @@ -173,25 +180,30 @@ class HoistCandidateSelector final : public StmtExprVisitor {
if (CheckValidIf()) {
// Check corresponding for loop
bool match_found = false;
size_t match_for_loop_pos = 0;
int match_for_loop_pos = -1;
for (auto var : if_var_list_) {
for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
if (ordered_for_list_[i] == var_for_map_[var]) {
for (int i = 0; i < static_cast<int>(ordered_list_.size() - 1); ++i) {
if (ordered_list_[i] == var_for_map_[var]) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
match_found = true;
break;
} else if (ordered_list_[i] == var) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
}
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
int target_for_pos =
match_found ? match_for_loop_pos + 1 : GetNextLoopPos(match_for_loop_pos);

// Check if target for loop is not the parent of current if node
if (!IsParentForLoop(target_for_pos)) {
StopAndAddRecord(ordered_for_list_[target_for_pos], op);
StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op);
if_var_list_.clear();
return;
}
Expand All @@ -211,13 +223,10 @@ class HoistCandidateSelector final : public StmtExprVisitor {
HoistForIfTuple hoist_for_if_recorder;

void ResetRecorder() {
if (is_recorder_on_) {
CHECK_GT(ordered_for_list_.size(), 0);
is_recorder_on_ = false;
}
ordered_for_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
ResetRecorderInternal();

// Reset Block scope vars also here
attr_var_list_.clear();
}

bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
Expand All @@ -227,6 +236,15 @@ class HoistCandidateSelector final : public StmtExprVisitor {
const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }

private:
void ResetRecorderInternal() {
if (is_recorder_on_) {
CHECK_GT(ordered_list_.size(), 0);
is_recorder_on_ = false;
}
ordered_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
}
bool CheckValidIf() {
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
// hoisting
Expand All @@ -244,8 +262,17 @@ class HoistCandidateSelector final : public StmtExprVisitor {
}

int GetParentLoopPos(const Object* node) {
for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
if (ordered_for_list_[i] == node) {
for (size_t i = 0; i < ordered_list_.size(); ++i) {
if (ordered_list_[i] == node) {
return i;
}
}
return -1;
}

int GetNextLoopPos(int cur_pos) {
for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) {
if (ordered_list_[i]->IsInstance<ForNode>()) {
return i;
}
}
Expand All @@ -258,18 +285,25 @@ class HoistCandidateSelector final : public StmtExprVisitor {

bool IsRecordingOn() { return is_recorder_on_; }

void StartOrAddRecord(const ForNode* op) {
void StartOrAddRecord(const ObjectRef& op) {
is_recorder_on_ = true;
if (!var_for_map_.count(op->loop_var.get())) {
var_for_map_.insert({op->loop_var.get(), op});
if (const auto* node = op.as<ForNode>()) {
if (!var_for_map_.count(node->loop_var.get()))
var_for_map_.insert({node->loop_var.get(), node});
ordered_list_.emplace_back(op.get());
} else if (const auto* node = op.as<AttrStmtNode>()) {
if (const auto* iv = node->node.as<IterVarNode>()) {
ordered_list_.emplace_back(iv->var.get());
} else if (const auto* iv = node->node.as<VarNode>()) {
ordered_list_.emplace_back(iv);
}
}
ordered_for_list_.emplace_back(op);
}

void RemoveRecord(const ForNode* op) {
void RemoveRecord(const ObjectRef& op) {
StopRecording();
var_for_map_.erase(op->loop_var.get());
if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get());
if (ordered_list_.size() > 0) ordered_list_.pop_back();
}

void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
Expand Down Expand Up @@ -302,7 +336,7 @@ class HoistCandidateSelector final : public StmtExprVisitor {
return false;
}

std::vector<const ForNode*> ordered_for_list_;
std::vector<const Object*> ordered_list_;
std::vector<const VarNode*> if_var_list_;
std::unordered_set<const VarNode*> attr_var_list_;
VarForMap var_for_map_;
Expand Down Expand Up @@ -383,6 +417,7 @@ Pass HoistIfThenElse() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");

if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
}
Expand Down

0 comments on commit 3d5b347

Please sign in to comment.