diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index ff4b56bdba1d..9a3c473737c3 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -181,7 +181,6 @@ def lower(sch, tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), - tvm.tir.transform.HoistIfThenElse(), ] pass_list += lower_phase1 @@ -205,6 +204,7 @@ def lower(sch, ] pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + pass_list += [tvm.tir.transform.HoistIfThenElse()] pass_list += lower_phase3 # Instrument BoundCheckers diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d2f5acd199e6..3f7fb41e7ff4 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -500,11 +500,31 @@ def VerifyMemory(): """ return _ffi_api.VerifyMemory() -def HoistIfThenElse(): +#pylint: disable=no-else-return,inconsistent-return-statements +def HoistIfThenElse(variant=None): """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. + + Parameters + ---------- + variant : Optional[String] + The variant of the pass. + variant can have any one of following values ["basic", None(Default)]. + + The basic variant supports basic hoisting scenarios where it exepects + the For & If Nodes are in place consecutively and does not involve + global scope variables or more advanced scenarios. + + Default variant supports all hoisting scenarios,i.e., {"Basic" + "Advanced"} + supported with control with PassContext configs like below: + + config={"tir.HoistIfThenElse": {"support_block_scope_hosting": True}} + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.HoistIfThenElse() + if variant == "basic": + return _ffi_api.HoistIfThenElseBasic() + elif variant is None: + return _ffi_api.HoistIfThenElse() diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc index f58eb965584d..4e7589c3a795 100644 --- a/src/tir/transforms/hoist_if_then_else.cc +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -37,6 +37,25 @@ namespace tvm { namespace tir { +struct HoistIfThenElseConfigNode : public tvm::AttrsNode { + bool support_block_scope_hosting; + + TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") { + TVM_ATTR_FIELD(support_block_scope_hosting) + .describe("Hoist if cond with block scope variables") + .set_default(false); + } +}; + +class HoistIfThenElseConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, + HoistIfThenElseConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig); + using VarForMap = std::unordered_map; using HoistForIfTuple = std::tuple; @@ -93,11 +112,33 @@ using HoistForIfTuple = std::tuple; * if (likely(j > 2)) * A[i+j+k] = B[i+j+k] * + * + * This pass do hoisting for Block scope variables also. + * As below: + * Attr(IterVar: threadIdx.x) + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(threadIdx.x < 3)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Will be transformed to as below: + * Attr(IterVar: threadIdx.x) + * if (likely(threadIdx.x < 3)) + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * */ // Select potential candidate IRs that can be hoisted. class HoistCandidateSelector final : public StmtExprVisitor { public: + explicit HoistCandidateSelector(bool support_block_scope_hosting) + : support_block_scope_hosting_(support_block_scope_hosting) { + InitRecorder(); + } HoistCandidateSelector() { InitRecorder(); } void VisitStmt_(const ForNode* op) final { @@ -108,16 +149,16 @@ class HoistCandidateSelector final : public StmtExprVisitor { } // Check if it is first for loop, then start the recorder - StartOrAddRecord(op); + StartOrAddRecord(GetRef(op)); StmtExprVisitor::VisitStmt_(op); - RemoveRecord(op); + RemoveRecord(GetRef(op)); } void VisitStmt_(const SeqStmtNode* op) final { // If SeqStmt is encountered in the middle of recording // then need to purge all, as it can not be hoisted if (IsRecordingOn()) { - ResetRecorder(); + ResetRecorderInternal(); } StmtExprVisitor::VisitStmt_(op); } @@ -126,10 +167,19 @@ class HoistCandidateSelector final : public StmtExprVisitor { // Maintain list of all vars in AttrStmt // To stop hoisting if any of the block variables are used. // - // NOTE: If in future - // hoisting is required for any specific case, - // then add exception to only those case - // rather than allowing for all. + // In case we want to use hoisting in between certain passes + // which have interdependencies of the postioning of if nodes with scope var + // it is better to disable this section + if (support_block_scope_hosting_) { + if (IsRecordingOn()) { + StartOrAddRecord(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + RemoveRecord(GetRef(op)); + return; + } else { + return StmtExprVisitor::VisitStmt_(op); + } + } UpdateAttrVarList(op); StmtExprVisitor::VisitStmt_(op); RemoveAttrVarList(op); @@ -147,26 +197,23 @@ class HoistCandidateSelector final : public StmtExprVisitor { if (CheckValidIf()) { // Check corresponding for loop - bool match_found = false; - size_t match_for_loop_pos = 0; + int match_for_loop_pos = -1; for (auto var : if_var_list_) { - for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) { - if (ordered_for_list_[i] == var_for_map_[var]) { + for (int i = 0; i < static_cast(ordered_list_.size()); ++i) { + if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) { if (match_for_loop_pos < i) { match_for_loop_pos = i; } - match_found = true; - break; } } } // If none of the for loop has the matching loop variable as if condition, // then the if node need to be hoisted on top of all, provided no parent loop exists. - int target_for_pos = match_found ? match_for_loop_pos + 1 : 0; + int target_for_pos = GetNextLoopPos(match_for_loop_pos); - // Check if target for loop is not the parent of current if node - if (!IsParentForLoop(target_for_pos)) { - StopAndAddRecord(ordered_for_list_[target_for_pos], op); + // Check if valid position + if (target_for_pos >= 0) { + StopAndAddRecord(static_cast(ordered_list_[target_for_pos]), op); if_var_list_.clear(); return; } @@ -186,13 +233,10 @@ class HoistCandidateSelector final : public StmtExprVisitor { HoistForIfTuple hoist_for_if_recorder; void ResetRecorder() { - if (is_recorder_on_) { - CHECK_GT(ordered_for_list_.size(), 0); - is_recorder_on_ = false; - } - ordered_for_list_.clear(); - var_for_map_.clear(); - hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + ResetRecorderInternal(); + + // Reset Block scope vars also here + attr_var_list_.clear(); } bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } @@ -202,25 +246,24 @@ class HoistCandidateSelector final : public StmtExprVisitor { const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } private: + void ResetRecorderInternal() { + if (is_recorder_on_) { + CHECK_GT(ordered_list_.size(), 0); + is_recorder_on_ = false; + } + ordered_list_.clear(); + var_for_map_.clear(); + hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + } bool CheckValidIf() { // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop // hoisting return ((!if_var_list_.empty()) && (!CheckAttrVar())); } - bool IsParentForLoop(int loop_pos) { - // Check if the loop position is higher than the parent loop position - for (auto var : if_var_list_) { - if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) { - return true; - } - } - return false; - } - - int GetParentLoopPos(const Object* node) { - for (size_t i = 0; i < ordered_for_list_.size(); ++i) { - if (ordered_for_list_[i] == node) { + int GetNextLoopPos(int cur_pos) { + for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) { + if (ordered_list_[i]->IsInstance()) { return i; } } @@ -233,18 +276,25 @@ class HoistCandidateSelector final : public StmtExprVisitor { bool IsRecordingOn() { return is_recorder_on_; } - void StartOrAddRecord(const ForNode* op) { + void StartOrAddRecord(const ObjectRef& op) { is_recorder_on_ = true; - if (!var_for_map_.count(op->loop_var.get())) { - var_for_map_.insert({op->loop_var.get(), op}); + if (const auto* node = op.as()) { + if (!var_for_map_.count(node->loop_var.get())) + var_for_map_.insert({node->loop_var.get(), node}); + ordered_list_.emplace_back(op.get()); + } else if (const auto* node = op.as()) { + if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv->var.get()); + } else if (const auto* iv = node->node.as()) { + ordered_list_.emplace_back(iv); + } } - ordered_for_list_.emplace_back(op); } - void RemoveRecord(const ForNode* op) { + void RemoveRecord(const ObjectRef& op) { StopRecording(); - var_for_map_.erase(op->loop_var.get()); - if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back(); + if (const auto* node = op.as()) var_for_map_.erase(node->loop_var.get()); + if (ordered_list_.size() > 0) ordered_list_.pop_back(); } void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { @@ -277,18 +327,22 @@ class HoistCandidateSelector final : public StmtExprVisitor { return false; } - std::vector ordered_for_list_; + // Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence + std::vector ordered_list_; std::vector if_var_list_; std::unordered_set attr_var_list_; VarForMap var_for_map_; bool is_if_cond_{false}; bool is_recorder_on_{false}; + bool support_block_scope_hosting_{false}; }; class IfThenElseHoister : public StmtMutator { public: IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} + explicit IfThenElseHoister(bool support_block_scope_hosting) + : hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {} Stmt VisitAndMutate(Stmt stmt) { hoist_selector_(stmt); @@ -344,6 +398,9 @@ class IfThenElseHoister : public StmtMutator { const IfThenElseNode* target_if_; }; +Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) { + return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt); +} Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); } namespace transform { @@ -351,14 +408,30 @@ namespace transform { Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = HoistIfThenElse(std::move(n->body)); + auto cfg = ctx->GetConfig("tir.HoistIfThenElse"); + + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); } +Pass HoistIfThenElseBasic() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = HoistIfThenElse(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {}); +} + TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); + } // namespace transform } // namespace tir diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 4ca952af00d4..186a52d12da1 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -16,6 +16,10 @@ # under the License. import tvm from tvm import te +from tvm import relay +import numpy as np +import pytest +from tvm.relay.testing import ctx_list var_list = [] @@ -255,14 +259,488 @@ def test_multi_if(): ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) +def test_no_hoisting_1(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") -if __name__ == "__main__": - test_hoist_top_for() - test_hoist_multi_var_if() - test_hoist_no_match_for() - test_no_else() - test_attr_stmt() - test_nested_for() - test_if_block() - test_multi_if() + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(k >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_2(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + x = te.var("x") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(i >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.3 + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_4(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_5(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_no_hoisting_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.if_scope((tx + j) < 9): + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + k) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_1(): + n = te.size_var("n") + m = te.size_var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "k") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + s = te.create_schedule(B.op) + ko, ki = s[B].split(B.op.reduce_axis[0], factor=16) + BF = s.rfactor(B, ki) + xo, xi = s[B].split(s[B].op.axis[0], factor=32) + s[B.op].bind(xo, te.thread_axis("blockIdx.x")) + s[B.op].bind(xi, te.thread_axis("threadIdx.y")) + s[B].bind(s[B].op.reduce_axis[0], te.thread_axis("threadIdx.x")) + s[BF].compute_at(s[B], s[B].op.reduce_axis[0]) + func = tvm.driver.build_module.form_irmodule( + s, [A, B], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_2(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + #ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_3(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + dshape_inner = (33, 63) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + ib.scope_attr(tx, "thread_extent", dshape_inner[0]) + ib.scope_attr(bx, "thread_extent", dshape_inner[1]) + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tx < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + #tvm.ir.assert_structural_equal(new_stmt, stmt) + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_4(): + nn = 1024 + n = tvm.runtime.convert(nn) + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + AA = te.compute((n,), lambda *i: A(*i), name='A') + BB = te.compute((n,), lambda *i: B(*i), name='B') + T = te.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T') + C = te.compute(A.shape, lambda *i: T(*i), name='C') + s = te.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + xo1, xo2 = s[C].split(xo, factor=13) + s[C].parallel(xo2) + s[C].pragma(xo1, "parallel_launch_point") + s[C].pragma(xo2, "parallel_stride_pattern") + s[C].pragma(xo2, "parallel_barrier_when_finish") + s[C].vectorize(xi) + func = tvm.driver.build_module.form_irmodule( + s, [A, B, C], "main", None)["main"] + stmt = func.body + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(tvm.IRModule.from_expr(func))["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_5(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + g = te.var('g') + + ib.scope_attr(data, "storage_scope", "global") + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(data[g] < 3): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 0.3 + with ib.else_scope(): + data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + + stmt = new_stmt + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + +def test_hoisting_block_scope_6(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + n) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +def test_hoisting_block_scope_7(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope((tx + i) < 3): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.3 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.3 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + tvm.ir.assert_structural_equal(new_stmt, stmt) + + with tvm.transform.PassContext(config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + assert(not tvm.ir.structural_equal(new_stmt, stmt)) + +@pytest.mark.skip() +def test_hoisting_op_conv(): + dtype = "float32" + dshape = (1, 80, 73, 73) + kshape = (192, 80, 3, 3) + padding=(1, 1) + groups=1 + dilation=(1, 1) + kernel_size=(3, 3) + channels=192 + scale=1 + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d(x, w, padding=padding, + dilation=dilation, + groups=groups, + channels=channels, + kernel_size=kernel_size) + + func = relay.Function([x, w], y) + mod = tvm.IRModule() + mod['main'] = func + mod = relay.transform.InferType()(mod) + + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + + params = {'w': tvm.nd.array(kernel)} + for target, ctx in ctx_list(): + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t1 = e(data_tvm).results + t1 = np.array(t1) * 1000 + print('{} ms'.format(t1.mean())) + + with tvm.transform.PassContext(opt_level=3, config={ + "tir.HoistIfThenElse": {"support_block_scope_hosting": True} + }): + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + x = np.random.uniform(size=dshape) + data_tvm = tvm.nd.array(data) + m.set_input('x', data_tvm) + m.set_input(**params) + m.run() + e = m.module.time_evaluator("run", ctx, number=300, repeat=3) + t2 = e(data_tvm).results + t2 = np.array(t2) * 1000 + + print('{} ms'.format(t2.mean())) + tvm.testing.assert_allclose(t1.mean(), t2.mean(), atol=1, rtol=1e-1) + +if __name__ == "__main__": + pytest.main([__file__])