From 3986c935cc4f97f434140634b918aa1492e5f308 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 8 Aug 2020 13:28:00 +0530 Subject: [PATCH 01/10] Block scope hoisting added --- src/tir/transforms/hoist_if_then_else.cc | 56 +++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index f58eb965584d..26e1ad4aa340 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -37,6 +37,25 @@ namespace tvm { namespace tir { +struct HoistIfThenElseConfigNode : public tvm::AttrsNode { + bool support_block_scope_hosting; + + TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") { + TVM_ATTR_FIELD(support_block_scope_hosting) + .describe("Hoist if cond with block scope variables") + .set_default(false); + } +}; + +class HoistIfThenElseConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, + HoistIfThenElseConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig); + using VarForMap = std::unordered_map; using HoistForIfTuple = std::tuple; @@ -98,6 +117,10 @@ using HoistForIfTuple = std::tuple; // Select potential candidate IRs that can be hoisted. class HoistCandidateSelector final : public StmtExprVisitor { public: + HoistCandidateSelector(bool support_block_scope_hosting) + : support_block_scope_hosting_(support_block_scope_hosting) { + InitRecorder(); + } HoistCandidateSelector() { InitRecorder(); } void VisitStmt_(const ForNode* op) final { @@ -126,10 +149,12 @@ class HoistCandidateSelector final : public StmtExprVisitor { // Maintain list of all vars in AttrStmt // To stop hoisting if any of the block variables are used. // - // NOTE: If in future - // hoisting is required for any specific case, - // then add exception to only those case - // rather than allowing for all. + // In case we want to use hoisting in beetween certain passes + // which interdependencies of the postioning of if nodes with scope var + // it is better to disable this section + if (!support_block_scope_hosting_) { + return StmtExprVisitor::VisitStmt_(op); + } UpdateAttrVarList(op); StmtExprVisitor::VisitStmt_(op); RemoveAttrVarList(op); @@ -284,11 +309,14 @@ class HoistCandidateSelector final : public StmtExprVisitor { bool is_if_cond_{false}; bool is_recorder_on_{false}; + bool support_block_scope_hosting_{false}; }; class IfThenElseHoister : public StmtMutator { public: IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} + IfThenElseHoister(bool support_block_scope_hosting) + : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {} Stmt VisitAndMutate(Stmt stmt) { hoist_selector_(stmt); @@ -344,6 +372,9 @@ class IfThenElseHoister : public StmtMutator { const IfThenElseNode* target_if_; }; +Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) { + return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt); +} Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); } namespace transform { @@ -351,14 +382,29 @@ namespace transform { Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = HoistIfThenElse(std::move(n->body)); + auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); } +Pass HoistIfThenElseBasic() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = HoistIfThenElse(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {}); +} + TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); + } // namespace transform } // namespace tir From 69a15fa18887ce85c7ddf79842641498c6d66586 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 8 Aug 2020 21:51:51 +0530 Subject: [PATCH 02/10] lowering flow added with 2 variants --- python/tvm/driver/build_module.py | 3 ++- python/tvm/tir/transform/transform.py | 12 ++++++++++-- src/tir/transforms/hoist_if_then_else.cc | 10 +++++----- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index ff4b56bdba1d..fdb3b5ce7ee5 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -181,7 +181,7 @@ def lower(sch, tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), - tvm.tir.transform.HoistIfThenElse(), + tvm.tir.transform.HoistIfThenElse("basic"), ] pass_list += lower_phase1 @@ -205,6 +205,7 @@ def lower(sch, ] pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + pass_list += [tvm.tir.transform.HoistIfThenElse()] pass_list += lower_phase3 # Instrument BoundCheckers diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d2f5acd199e6..467dd7c91330 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -500,11 +500,19 @@ def VerifyMemory(): """ return _ffi_api.VerifyMemory() -def HoistIfThenElse(): +def HoistIfThenElse(variant=None): """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. + + Parameters + ---------- + variant : str + The variant of the pass. + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.HoistIfThenElse() + if variant is None: + return _ffi_api.HoistIfThenElse() + return _ffi_api.HoistIfThenElseBasic() diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 26e1ad4aa340..d7e8b359a3f6 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -117,7 +117,7 @@ using HoistForIfTuple = std::tuple; // Select potential candidate IRs that can be hoisted. class HoistCandidateSelector final : public StmtExprVisitor { public: - HoistCandidateSelector(bool support_block_scope_hosting) + explicit HoistCandidateSelector(bool support_block_scope_hosting) : support_block_scope_hosting_(support_block_scope_hosting) { InitRecorder(); } @@ -149,10 +149,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { // Maintain list of all vars in AttrStmt // To stop hoisting if any of the block variables are used. // - // In case we want to use hoisting in beetween certain passes - // which interdependencies of the postioning of if nodes with scope var + // In case we want to use hoisting in between certain passes + // which have interdependencies of the postioning of if nodes with scope var // it is better to disable this section - if (!support_block_scope_hosting_) { + if (support_block_scope_hosting_) { return StmtExprVisitor::VisitStmt_(op); } UpdateAttrVarList(op); @@ -315,7 +315,7 @@ class HoistCandidateSelector final : public StmtExprVisitor { class IfThenElseHoister : public StmtMutator { public: IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} - IfThenElseHoister(bool support_block_scope_hosting) + explicit IfThenElseHoister(bool support_block_scope_hosting) : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {} Stmt VisitAndMutate(Stmt stmt) { From fbbdd4d48fdb6d0fc0a7cc5f26fc01ef7d1fe21a Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Mon, 10 Aug 2020 21:32:49 +0530 Subject: [PATCH 03/10] Fake commit to trigger ci with pass default enabled --- src/tir/transforms/hoist_if_then_else.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index d7e8b359a3f6..e7f61bd4c6c4 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -43,7 +43,7 @@ struct HoistIfThenElseConfigNode : public tvm::AttrsNode Date: Sat, 15 Aug 2020 13:44:21 +0530 Subject: [PATCH 04/10] CI Failure resolved --- src/tir/transforms/hoist_if_then_else.cc | 86 ++++++++++++++++-------- 1 file changed, 59 insertions(+), 27 deletions(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index e7f61bd4c6c4..6e4b7d1d3b5b 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -131,16 +131,16 @@ class HoistCandidateSelector final : public StmtExprVisitor { } // Check if it is first for loop, then start the recorder - StartOrAddRecord(op); + StartOrAddRecord(GetRef(op)); StmtExprVisitor::VisitStmt_(op); - RemoveRecord(op); + RemoveRecord(GetRef(op)); } void VisitStmt_(const SeqStmtNode* op) final { // If SeqStmt is encountered in the middle of recording // then need to purge all, as it can not be hoisted if (IsRecordingOn()) { - ResetRecorder(); + ResetRecorderInternal(); } StmtExprVisitor::VisitStmt_(op); } @@ -153,6 +153,13 @@ class HoistCandidateSelector final : public StmtExprVisitor { // which have interdependencies of the postioning of if nodes with scope var // it is better to disable this section if (support_block_scope_hosting_) { + if (IsRecordingOn()) { + StartOrAddRecord(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + RemoveRecord(GetRef(op)); + return; + } + return StmtExprVisitor::VisitStmt_(op); } UpdateAttrVarList(op); @@ -172,26 +179,28 @@ class HoistCandidateSelector final : public StmtExprVisitor { if (CheckValidIf()) { // Check corresponding for loop - bool match_found = false; - size_t match_for_loop_pos = 0; + int match_for_loop_pos = -1; for (auto var : if_var_list_) { - for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) { - if (ordered_for_list_[i] == var_for_map_[var]) { + for (int i = 0; i < static_cast(ordered_list_.size() - 1); ++i) { + if (ordered_list_[i] == var_for_map_[var]) { if (match_for_loop_pos < i) { match_for_loop_pos = i; } - match_found = true; break; + } else if (ordered_list_[i] == var) { + if (match_for_loop_pos < i) { + match_for_loop_pos = i; + } } } } // If none of the for loop has the matching loop variable as if condition, // then the if node need to be hoisted on top of all, provided no parent loop exists. - int target_for_pos = match_found ? match_for_loop_pos + 1 : 0; + int target_for_pos = GetNextLoopPos(match_for_loop_pos); // Check if target for loop is not the parent of current if node if (!IsParentForLoop(target_for_pos)) { - StopAndAddRecord(ordered_for_list_[target_for_pos], op); + StopAndAddRecord(static_cast(ordered_list_[target_for_pos]), op); if_var_list_.clear(); return; } @@ -211,13 +220,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { HoistForIfTuple hoist_for_if_recorder; void ResetRecorder() { - if (is_recorder_on_) { - CHECK_GT(ordered_for_list_.size(), 0); - is_recorder_on_ = false; - } - ordered_for_list_.clear(); - var_for_map_.clear(); - hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + ResetRecorderInternal(); + + // Reset Block scope vars also here + attr_var_list_.clear(); } bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } @@ -227,6 +233,15 @@ class HoistCandidateSelector final : public StmtExprVisitor { const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } private: + void ResetRecorderInternal() { + if (is_recorder_on_) { + CHECK_GT(ordered_list_.size(), 0); + is_recorder_on_ = false; + } + ordered_list_.clear(); + var_for_map_.clear(); + hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + } bool CheckValidIf() { // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop // hoisting @@ -244,8 +259,17 @@ class HoistCandidateSelector final : public StmtExprVisitor { } int GetParentLoopPos(const Object* node) { - for (size_t i = 0; i < ordered_for_list_.size(); ++i) { - if (ordered_for_list_[i] == node) { + for (size_t i = 0; i < ordered_list_.size(); ++i) { + if (ordered_list_[i] == node) { + return i; + } + } + return -1; + } + + int GetNextLoopPos(int cur_pos) { + for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { + if (ordered_list_[i]->IsInstance()) { return i; } } @@ -258,18 +282,25 @@ class HoistCandidateSelector final : public StmtExprVisitor { bool IsRecordingOn() { return is_recorder_on_; } - void StartOrAddRecord(const ForNode* op) { + void StartOrAddRecord(const ObjectRef& op) { is_recorder_on_ = true; - if (!var_for_map_.count(op->loop_var.get())) { - var_for_map_.insert({op->loop_var.get(), op}); + if (const auto* node = op.as()) { + if (!var_for_map_.count(node->loop_var.get())) + var_for_map_.insert({node->loop_var.get(), node}); + ordered_list_.emplace_back(op.get()); + } else if (const auto* node = op.as()) { + if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv->var.get()); + } else if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv); + } } - ordered_for_list_.emplace_back(op); } - void RemoveRecord(const ForNode* op) { + void RemoveRecord(const ObjectRef& op) { StopRecording(); - var_for_map_.erase(op->loop_var.get()); - if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back(); + if (const auto* node = op.as()) var_for_map_.erase(node->loop_var.get()); + if (ordered_list_.size() > 0) ordered_list_.pop_back(); } void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { @@ -302,7 +333,7 @@ class HoistCandidateSelector final : public StmtExprVisitor { return false; } - std::vector ordered_for_list_; + std::vector ordered_list_; std::vector if_var_list_; std::unordered_set attr_var_list_; VarForMap var_for_map_; @@ -383,6 +414,7 @@ Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); + if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } From 94aced53093fe1a58491110d72b87da58c333227 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 19 Aug 2020 20:52:49 +0530 Subject: [PATCH 05/10] Optimize for if var list iteration --- src/tir/transforms/hoist_if_then_else.cc | 50 +++++++++++------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 6e4b7d1d3b5b..4d0eb4ab6f3b 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -112,6 +112,24 @@ using HoistForIfTuple = std::tuple; * if (likely(j > 2)) * A[i+j+k] = B[i+j+k] * + * + * This pass do hoisting for Block scope variables also. + * As below: + * Attr(IterVar: threadIdx.x) + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(threadIdx.x < 3)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Will be transformed to as below: + * Attr(IterVar: threadIdx.x) + * if (likely(threadIdx.x < 3)) + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * */ // Select potential candidate IRs that can be hoisted. @@ -181,13 +199,8 @@ class HoistCandidateSelector final : public StmtExprVisitor { // Check corresponding for loop int match_for_loop_pos = -1; for (auto var : if_var_list_) { - for (int i = 0; i < static_cast(ordered_list_.size() - 1); ++i) { - if (ordered_list_[i] == var_for_map_[var]) { - if (match_for_loop_pos < i) { - match_for_loop_pos = i; - } - break; - } else if (ordered_list_[i] == var) { + for (int i = 0; i < static_cast(ordered_list_.size()); ++i) { + if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) { if (match_for_loop_pos < i) { match_for_loop_pos = i; } @@ -198,8 +211,8 @@ class HoistCandidateSelector final : public StmtExprVisitor { // then the if node need to be hoisted on top of all, provided no parent loop exists. int target_for_pos = GetNextLoopPos(match_for_loop_pos); - // Check if target for loop is not the parent of current if node - if (!IsParentForLoop(target_for_pos)) { + // Check if valid position + if (target_for_pos >= 0) { StopAndAddRecord(static_cast(ordered_list_[target_for_pos]), op); if_var_list_.clear(); return; @@ -248,25 +261,6 @@ class HoistCandidateSelector final : public StmtExprVisitor { return ((!if_var_list_.empty()) && (!CheckAttrVar())); } - bool IsParentForLoop(int loop_pos) { - // Check if the loop position is higher than the parent loop position - for (auto var : if_var_list_) { - if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) { - return true; - } - } - return false; - } - - int GetParentLoopPos(const Object* node) { - for (size_t i = 0; i < ordered_list_.size(); ++i) { - if (ordered_list_[i] == node) { - return i; - } - } - return -1; - } - int GetNextLoopPos(int cur_pos) { for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { if (ordered_list_[i]->IsInstance()) { From 80c6dd17a715b271665e4b0d615d1d170e97d8a1 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 19 Aug 2020 20:54:24 +0530 Subject: [PATCH 06/10] More test case added --- .../unittest/test_tir_transform_hoist_if.py | 502 +++++++++++++++++- 1 file changed, 501 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 4ca952af00d4..b4ffee702f4b 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -16,6 +16,9 @@ # under the License. import tvm from tvm import te +from tvm import relay +import numpy as np +from tvm.relay.testing import ctx_list var_list = [] @@ -255,6 +258,487 @@ def test_multi_if(): ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) +def test_no_hoisting_1(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(k >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_2(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + x = te.var("x") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(i >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3 + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_4(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_5(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.if_scope((tx + j) < 9): + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_1(): + n = te.size_var("n") + m = te.size_var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "k") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + s = te.create_schedule(B.op) + ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) + BF = s.rfactor(B, ki) + xo, xi = s[B].split(s[B].op.axis[0], factor=32) + s[B.op].bind(xo, te.thread_axis("blockIdx.x")) + s[B.op].bind(xi, te.thread_axis("threadIdx.y")) + s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) + s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) + func = tvm.driver.build_module.form_irmodule( + s, [A, B], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_2(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + #ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_4(): + nn = 1024 + n = tvm.runtime.convert(nn) + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + AA = te.compute((n,), lambda *i: A(*i), name='A') + BB = te.compute((n,), lambda *i: B(*i), name='B') + T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T') + C = te.compute(A.shape, lambda *i: T(*i), name='C') + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + xo1, xo2 = s[C].split(xo, factor=13) + s[C].parallel(xo2) + s[C].pragma(xo1, "parallel_launch_point") + s[C].pragma(xo2, "parallel_stride_pattern") + s[C].pragma(xo2, "parallel_barrier_when_finish") + s[C].vectorize(xi) + func = tvm.driver.build_module.form_irmodule( + s, [A, B, C], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_5(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + g = te.var('g') + + ib.scope_attr(data, "storage_scope", "global") + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(data[g] < 3): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3 + with ib.else_scope(): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + + stmt = new_stmt + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + n) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + i) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_op_conv(): + dtype = "float32" + dshape = (1, 80, 73, 73) + kshape = (192, 80, 3, 3) + padding=(1, 1) + groups=1 + dilation=(1, 1) + kernel_size=(3, 3) + channels=192 + scale=1 + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d(x, w, padding=padding, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_size) + + func = relay.Function([x, w], y) + mod = tvm.IRModule() + mod['main'] = func + mod = relay.transform.InferType()(mod) + + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + + params = {'w': tvm.nd.array(kernel)} + for target, ctx in ctx_list(): + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t1 = e(data_tvm).results + t1 = np.array(t1) * 1000 + print('{} ms'.format(t1.mean())) + + with tvm.transform.PassContext(opt_level=3, config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t2 = e(data_tvm).results + t2 = np.array(t2) * 1000 + + print('{} ms'.format(t2.mean())) + tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": test_hoist_top_for() @@ -265,4 +749,20 @@ def test_multi_if(): test_nested_for() test_if_block() test_multi_if() - + test_no_hoisting_1() + test_no_hoisting_2() + test_no_hoisting_3() + test_no_hoisting_4() + test_no_hoisting_5() + test_no_hoisting_6() + test_no_hoisting_7() + test_hoisting_block_scope_1() + test_hoisting_block_scope_2() + test_hoisting_block_scope_3() + test_hoisting_block_scope_4() + test_hoisting_block_scope_5() + test_hoisting_block_scope_6() + test_hoisting_block_scope_7() + + # Test with Conv Op + test_hoisting_op_conv() From c73d39d42c569c2c7e24e775ec10471a09e98abe Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Wed, 19 Aug 2020 21:29:32 +0530 Subject: [PATCH 07/10] Fake commit to disable failed test cases --- tests/python/unittest/test_tir_transform_hoist_if.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index b4ffee702f4b..e0bf8228f4a3 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -18,6 +18,7 @@ from tvm import te from tvm import relay import numpy as np +import pytest from tvm.relay.testing import ctx_list var_list = [] @@ -465,6 +466,7 @@ def test_no_hoisting_7(): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) +@pytest.mark.xfail() def test_hoisting_block_scope_1(): n = te.size_var("n") m = te.size_var("m") @@ -491,6 +493,7 @@ def test_hoisting_block_scope_1(): new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body assert(not tvm.ir.structural_equal(new_stmt, stmt)) +@pytest.mark.xfail() def test_hoisting_block_scope_2(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -525,6 +528,7 @@ def test_hoisting_block_scope_2(): #tvm.ir.assert_structural_equal(new_stmt, stmt) assert(not tvm.ir.structural_equal(new_stmt, stmt)) +@pytest.mark.xfail() def test_hoisting_block_scope_3(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -560,6 +564,7 @@ def test_hoisting_block_scope_3(): #tvm.ir.assert_structural_equal(new_stmt, stmt) assert(not tvm.ir.structural_equal(new_stmt, stmt)) +@pytest.mark.xfail() def test_hoisting_block_scope_4(): nn = 1024 n = tvm.runtime.convert(nn) @@ -620,6 +625,7 @@ def test_hoisting_block_scope_5(): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body tvm.ir.assert_structural_equal(new_stmt, stmt) +@pytest.mark.xfail() def test_hoisting_block_scope_6(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -651,6 +657,7 @@ def test_hoisting_block_scope_6(): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert(not tvm.ir.structural_equal(new_stmt, stmt)) +@pytest.mark.xfail() def test_hoisting_block_scope_7(): ib = tvm.tir.ir_builder.create() dshape = (32, 64) @@ -682,6 +689,7 @@ def test_hoisting_block_scope_7(): new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body assert(not tvm.ir.structural_equal(new_stmt, stmt)) +@pytest.mark.skip() def test_hoisting_op_conv(): dtype = "float32" dshape = (1, 80, 73, 73) @@ -738,7 +746,7 @@ def test_hoisting_op_conv(): t2 = np.array(t2) * 1000 print('{} ms'.format(t2.mean())) - tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) + #tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": test_hoist_top_for() From cd546662fb91a91c3038841e9e54a3fc27ba9bba Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 22 Aug 2020 12:30:07 +0530 Subject: [PATCH 08/10] Pass default value restored --- src/tir/transforms/hoist_if_then_else.cc | 2 +- tests/python/unittest/test_tir_transform_hoist_if.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 4d0eb4ab6f3b..1ac5e1ce701a 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -43,7 +43,7 @@ struct HoistIfThenElseConfigNode : public tvm::AttrsNode Date: Tue, 25 Aug 2020 14:26:31 +0530 Subject: [PATCH 09/10] [1] Review comment handled --- python/tvm/tir/transform/transform.py | 18 ++++++++++--- src/tir/transforms/hoist_if_then_else.cc | 5 ++-- .../unittest/test_tir_transform_hoist_if.py | 26 +------------------ 3 files changed, 18 insertions(+), 31 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 467dd7c91330..55dc98c72462 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -505,14 +505,24 @@ def HoistIfThenElse(variant=None): Parameters ---------- - variant : str + variant : str, optional The variant of the pass. + variant can have any one of following values ["basic", None(Default)]. + + The basic variant supports basic hoisting scenarios where it exepects + the For & If Nodes are in place consecutively and does not involve + global scope variables or more advanced scenarios. + + Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"} + supported with control with PassContext configs like below: + + config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} Returns ------- fpass : tvm.transform.Pass The result pass """ - if variant is None: - return _ffi_api.HoistIfThenElse() - return _ffi_api.HoistIfThenElseBasic() + if variant == "basic": + return _ffi_api.HoistIfThenElseBasic() + return _ffi_api.HoistIfThenElse() diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index 1ac5e1ce701a..4e7589c3a795 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -176,9 +176,9 @@ class HoistCandidateSelector final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); RemoveRecord(GetRef(op)); return; + } else { + return StmtExprVisitor::VisitStmt_(op); } - - return StmtExprVisitor::VisitStmt_(op); } UpdateAttrVarList(op); StmtExprVisitor::VisitStmt_(op); @@ -327,6 +327,7 @@ class HoistCandidateSelector final : public StmtExprVisitor { return false; } + // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence std::vector ordered_list_; std::vector if_var_list_; std::unordered_set attr_var_list_; diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index b2e23dc1b7f0..186a52d12da1 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -743,28 +743,4 @@ def test_hoisting_op_conv(): tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) if __name__ == "__main__": - test_hoist_top_for() - test_hoist_multi_var_if() - test_hoist_no_match_for() - test_no_else() - test_attr_stmt() - test_nested_for() - test_if_block() - test_multi_if() - test_no_hoisting_1() - test_no_hoisting_2() - test_no_hoisting_3() - test_no_hoisting_4() - test_no_hoisting_5() - test_no_hoisting_6() - test_no_hoisting_7() - test_hoisting_block_scope_1() - test_hoisting_block_scope_2() - test_hoisting_block_scope_3() - test_hoisting_block_scope_4() - test_hoisting_block_scope_5() - test_hoisting_block_scope_6() - test_hoisting_block_scope_7() - - # Test with Conv Op - test_hoisting_op_conv() + pytest.main([__file__]) From 446109472aa58a61327533b4ebc386ac4303fab0 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Mon, 31 Aug 2020 21:39:37 +0530 Subject: [PATCH 10/10] [2] Review comments handled --- python/tvm/driver/build_module.py | 1 - python/tvm/tir/transform/transform.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index fdb3b5ce7ee5..9a3c473737c3 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -181,7 +181,6 @@ def lower(sch, tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), - tvm.tir.transform.HoistIfThenElse("basic"), ] pass_list += lower_phase1 diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 55dc98c72462..3f7fb41e7ff4 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -500,12 +500,13 @@ def VerifyMemory(): """ return _ffi_api.VerifyMemory() +#pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant=None): """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. Parameters ---------- - variant : str, optional + variant : Optional[String] The variant of the pass. variant can have any one of following values ["basic", None(Default)]. @@ -525,4 +526,5 @@ def HoistIfThenElse(variant=None): """ if variant == "basic": return _ffi_api.HoistIfThenElseBasic() - return _ffi_api.HoistIfThenElse() + elif variant is None: + return _ffi_api.HoistIfThenElse()