From 3d5b34751b22e7f5e7e821b1751a989bbce131e9 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 15 Aug 2020 13:44:21 +0530 Subject: [PATCH] CI Failure resolved --- src/tir/transforms/hoist_if_then_else.cc | 85 +++++++++++++++++------- 1 file changed, 60 insertions(+), 25 deletions(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index e7f61bd4c6c42..cd0083bdae25a 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -131,16 +131,16 @@ class HoistCandidateSelector final : public StmtExprVisitor { } // Check if it is first for loop, then start the recorder - StartOrAddRecord(op); + StartOrAddRecord(GetRef(op)); StmtExprVisitor::VisitStmt_(op); - RemoveRecord(op); + RemoveRecord(GetRef(op)); } void VisitStmt_(const SeqStmtNode* op) final { // If SeqStmt is encountered in the middle of recording // then need to purge all, as it can not be hoisted if (IsRecordingOn()) { - ResetRecorder(); + ResetRecorderInternal(); } StmtExprVisitor::VisitStmt_(op); } @@ -153,6 +153,13 @@ class HoistCandidateSelector final : public StmtExprVisitor { // which have interdependencies of the postioning of if nodes with scope var // it is better to disable this section if (support_block_scope_hosting_) { + if (IsRecordingOn()) { + StartOrAddRecord(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + RemoveRecord(GetRef(op)); + return; + } + return StmtExprVisitor::VisitStmt_(op); } UpdateAttrVarList(op); @@ -173,25 +180,30 @@ class HoistCandidateSelector final : public StmtExprVisitor { if (CheckValidIf()) { // Check corresponding for loop bool match_found = false; - size_t match_for_loop_pos = 0; + int match_for_loop_pos = -1; for (auto var : if_var_list_) { - for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) { - if (ordered_for_list_[i] == var_for_map_[var]) { + for (int i = 0; i < static_cast(ordered_list_.size() - 1); ++i) { + if (ordered_list_[i] == var_for_map_[var]) { if (match_for_loop_pos < i) { match_for_loop_pos = i; } match_found = true; break; + } else if (ordered_list_[i] == var) { + if (match_for_loop_pos < i) { + match_for_loop_pos = i; + } } } } // If none of the for loop has the matching loop variable as if condition, // then the if node need to be hoisted on top of all, provided no parent loop exists. - int target_for_pos = match_found ? match_for_loop_pos + 1 : 0; + int target_for_pos = + match_found ? match_for_loop_pos + 1 : GetNextLoopPos(match_for_loop_pos); // Check if target for loop is not the parent of current if node if (!IsParentForLoop(target_for_pos)) { - StopAndAddRecord(ordered_for_list_[target_for_pos], op); + StopAndAddRecord(static_cast(ordered_list_[target_for_pos]), op); if_var_list_.clear(); return; } @@ -211,13 +223,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { HoistForIfTuple hoist_for_if_recorder; void ResetRecorder() { - if (is_recorder_on_) { - CHECK_GT(ordered_for_list_.size(), 0); - is_recorder_on_ = false; - } - ordered_for_list_.clear(); - var_for_map_.clear(); - hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + ResetRecorderInternal(); + + // Reset Block scope vars also here + attr_var_list_.clear(); } bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } @@ -227,6 +236,15 @@ class HoistCandidateSelector final : public StmtExprVisitor { const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } private: + void ResetRecorderInternal() { + if (is_recorder_on_) { + CHECK_GT(ordered_list_.size(), 0); + is_recorder_on_ = false; + } + ordered_list_.clear(); + var_for_map_.clear(); + hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + } bool CheckValidIf() { // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop // hoisting @@ -244,8 +262,17 @@ class HoistCandidateSelector final : public StmtExprVisitor { } int GetParentLoopPos(const Object* node) { - for (size_t i = 0; i < ordered_for_list_.size(); ++i) { - if (ordered_for_list_[i] == node) { + for (size_t i = 0; i < ordered_list_.size(); ++i) { + if (ordered_list_[i] == node) { + return i; + } + } + return -1; + } + + int GetNextLoopPos(int cur_pos) { + for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { + if (ordered_list_[i]->IsInstance()) { return i; } } @@ -258,18 +285,25 @@ class HoistCandidateSelector final : public StmtExprVisitor { bool IsRecordingOn() { return is_recorder_on_; } - void StartOrAddRecord(const ForNode* op) { + void StartOrAddRecord(const ObjectRef& op) { is_recorder_on_ = true; - if (!var_for_map_.count(op->loop_var.get())) { - var_for_map_.insert({op->loop_var.get(), op}); + if (const auto* node = op.as()) { + if (!var_for_map_.count(node->loop_var.get())) + var_for_map_.insert({node->loop_var.get(), node}); + ordered_list_.emplace_back(op.get()); + } else if (const auto* node = op.as()) { + if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv->var.get()); + } else if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv); + } } - ordered_for_list_.emplace_back(op); } - void RemoveRecord(const ForNode* op) { + void RemoveRecord(const ObjectRef& op) { StopRecording(); - var_for_map_.erase(op->loop_var.get()); - if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back(); + if (const auto* node = op.as()) var_for_map_.erase(node->loop_var.get()); + if (ordered_list_.size() > 0) ordered_list_.pop_back(); } void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { @@ -302,7 +336,7 @@ class HoistCandidateSelector final : public StmtExprVisitor { return false; } - std::vector ordered_for_list_; + std::vector ordered_list_; std::vector if_var_list_; std::unordered_set attr_var_list_; VarForMap var_for_map_; @@ -383,6 +417,7 @@ Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); + if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); }