From ac70e08fe2582cb112063fba3f6254aa97094cd9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 6 Mar 2023 22:42:11 -0500 Subject: [PATCH] [TIR] Enhance loop unroll with unroll local access This PR enhances the unroller with an unroll local access option. This option will detect loop variables that access local memories and unroll them independent of other options. A test case is added. This option is by default turned off and can be useful in certain cases to improve unroller as these local memory access have to be unrolled at some time pt to be lifted as registers --- src/tir/transforms/unroll_loop.cc | 64 +++++++++++++++++-- .../test_tir_transform_unroll_loop.py | 42 ++++++++++++ 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 1e55cb22ee26..dc14e4512f1e 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -33,6 +33,7 @@ #include #include +#include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" namespace tvm { @@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int unroll_local_access; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(unroll_local_access) + .describe("Whether to always unroll local access") + .set_default(false); } }; @@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs { TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); +class VarLocalAccessMarker : public ExprVisitor { + public: + explicit VarLocalAccessMarker( + std::unordered_set* var_touched_local) + : var_touched_local_(var_touched_local) {} + + void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } + + private: + std::unordered_set* var_touched_local_; +}; + +// The Visitor is used to check whether var is used as write index in a local memory +// If a loop var is used as indices to a local memory, it must be unrolled so +// the local memory access can be turned into register access. class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, bool unroll_local_access) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + unroll_local_access_(unroll_local_access) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) { + // Post order so we can collect more information Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); int value = GetExtent(op); @@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator { auto_unroll = true; } + // If a loop var is used as indices to a local memory, it must be unrolled so + // the local memory access can be turned into register access. + if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) { + auto_unroll = true; + } + if (auto_unroll) { step_count_ *= value; unroll_depth_ += 1; @@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } + return GetRef(op); + } + Stmt VisitStmt_(const BufferStoreNode* op) final { ++step_count_; + if (unroll_local_access_) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } return StmtExprMutator::VisitStmt_(op); } @@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator { unroll_depth_ = std::max(unroll_depth_, unroll_depth); return ret; }; - return StmtMutator::VisitSeqStmt_(op, false, fmutate); + return StmtExprMutator::VisitSeqStmt_(op, false, fmutate); } Stmt Unroll(const ForNode* op) { @@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator { // this not not count the total steps, only count the number of loops int auto_max_extent_; bool explicit_unroll_; + // Wether to unroll loops to local access. + bool unroll_local_access_{false}; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // set of indices touched during visit local memory + std::unordered_set var_touched_local_; // analyzer arith::Analyzer analyzer_; }; Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->unroll_local_access)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index a76e6135b3c4..a05a085eeb64 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -134,7 +134,49 @@ def main(): tvm.ir.assert_structural_equal(after, expected) +def test_unroll_local_access(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + for i in T.serial(4): + A_local[i] = T.float32(i) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(4, thread="blockIdx.x"): + for tx in T.thread_binding(4, thread="threadIdx.x"): + A_local_data = T.allocate([4], dtype="float32", scope="local") + A_local = T.Buffer([4], dtype="float32", data=A_local_data) + A_local[0] = T.float32(0) + A_local[1] = T.float32(1) + A_local[2] = T.float32(2) + A_local[3] = T.float32(3) + + with tvm.transform.PassContext( + config={ + "tir.UnrollLoop": { + "auto_max_depth": 0, + "auto_max_extent": 1, + "explicit_unroll": True, + "unroll_local_access": True, + } + } + ): + after = tvm.tir.transform.UnrollLoop()(Before) + after = tvm.tir.transform.Simplify()(after) + + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": + test_unroll_local_access() test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops()