diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 1e55cb22ee26..dc14e4512f1e 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -33,6 +33,7 @@ #include #include +#include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" namespace tvm { @@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int unroll_local_access; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(unroll_local_access) + .describe("Whether to always unroll local access") + .set_default(false); } }; @@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs { TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); +class VarLocalAccessMarker : public ExprVisitor { + public: + explicit VarLocalAccessMarker( + std::unordered_set* var_touched_local) + : var_touched_local_(var_touched_local) {} + + void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } + + private: + std::unordered_set* var_touched_local_; +}; + +// The Visitor is used to check whether var is used as write index in a local memory +// If a loop var is used as indices to a local memory, it must be unrolled so +// the local memory access can be turned into register access. class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, bool unroll_local_access) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + unroll_local_access_(unroll_local_access) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) { + // Post order so we can collect more information Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); int value = GetExtent(op); @@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator { auto_unroll = true; } + // If a loop var is used as indices to a local memory, it must be unrolled so + // the local memory access can be turned into register access. + if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) { + auto_unroll = true; + } + if (auto_unroll) { step_count_ *= value; unroll_depth_ += 1; @@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } + return GetRef(op); + } + Stmt VisitStmt_(const BufferStoreNode* op) final { ++step_count_; + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } return StmtExprMutator::VisitStmt_(op); } @@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator { unroll_depth_ = std::max(unroll_depth_, unroll_depth); return ret; }; - return StmtMutator::VisitSeqStmt_(op, false, fmutate); + return StmtExprMutator::VisitSeqStmt_(op, false, fmutate); } Stmt Unroll(const ForNode* op) { @@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator { // this not not count the total steps, only count the number of loops int auto_max_extent_; bool explicit_unroll_; + // Wether to unroll loops to local access. + bool unroll_local_access_{false}; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // set of indices touched during visit local memory + std::unordered_set var_touched_local_; // analyzer arith::Analyzer analyzer_; }; Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->unroll_local_access)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index a76e6135b3c4..a05a085eeb64 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -134,7 +134,49 @@ def main(): tvm.ir.assert_structural_equal(after, expected) +def test_unroll_local_access(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + for i in T.serial(4): + A_local[i] = T.float32(i) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + A_local[0] = T.float32(0) + A_local[1] = T.float32(1) + A_local[2] = T.float32(2) + A_local[3] = T.float32(3) + + with tvm.transform.PassContext( + config={ + "tir.UnrollLoop": { + "auto_max_depth": 0, + "auto_max_extent": 1, + "explicit_unroll": True, + "unroll_local_access": True, + } + } + ): + after = tvm.tir.transform.UnrollLoop()(Before) + after = tvm.tir.transform.Simplify()(after) + + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": + test_unroll_local_access() test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops()