Skip to content

Commit

Permalink
[TIR] Enhance loop unroll with unroll local access (#14224)
Browse files Browse the repository at this point in the history
This PR enhances the unroller with an unroll local access option.
This option will detect loop variables that access local memories
and unroll them independent of other options.

A test case is added. This option is by default turned off and
can be useful in certain cases to improve unroller as these
local memory access have to be unrolled at some time pt to be
lifted as registers
  • Loading branch information
tqchen authored Mar 7, 2023
1 parent ca48caf commit 56ddd37
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 4 deletions.
64 changes: 60 additions & 4 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <unordered_set>
#include <vector>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"

namespace tvm {
Expand All @@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
int unroll_local_access;

TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
TVM_ATTR_FIELD(auto_max_step)
Expand All @@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
TVM_ATTR_FIELD(explicit_unroll)
.describe("Whether to explicitly unroll the loop instead of setting a pragma")
.set_default(true);
TVM_ATTR_FIELD(unroll_local_access)
.describe("Whether to always unroll local access")
.set_default(false);
}
};

Expand All @@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

class VarLocalAccessMarker : public ExprVisitor {
public:
explicit VarLocalAccessMarker(
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local)
: var_touched_local_(var_touched_local) {}

void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef<Var>(op)); }

private:
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local_;
};

// The Visitor is used to check whether var is used as write index in a local memory
// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll)
bool explicit_unroll, bool unroll_local_access)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {}
explicit_unroll_(explicit_unroll),
unroll_local_access_(unroll_local_access) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
Expand All @@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator {
}

Stmt VisitStmt_(const ForNode* op) {
// Post order so we can collect more information
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
int value = GetExtent(op);
Expand All @@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator {
auto_unroll = true;
}

// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) {
auto_unroll = true;
}

if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
Expand All @@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return GetRef<PrimExpr>(op);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
++step_count_;
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return StmtExprMutator::VisitStmt_(op);
}

Expand All @@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator {
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
return StmtExprMutator::VisitSeqStmt_(op, false, fmutate);
}

Stmt Unroll(const ForNode* op) {
Expand Down Expand Up @@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator {
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Wether to unroll loops to local access.
bool unroll_local_access_{false};
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
// set of indices touched during visit local memory
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_touched_local_;
// analyzer
arith::Analyzer analyzer_;
};

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
cfg->explicit_unroll)(stmt);
cfg->explicit_unroll, cfg->unroll_local_access)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_unroll_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,49 @@ def main():
tvm.ir.assert_structural_equal(after, expected)


def test_unroll_local_access():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
for i in T.serial(4):
A_local[i] = T.float32(i)

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
A_local[0] = T.float32(0)
A_local[1] = T.float32(1)
A_local[2] = T.float32(2)
A_local[3] = T.float32(3)

with tvm.transform.PassContext(
config={
"tir.UnrollLoop": {
"auto_max_depth": 0,
"auto_max_extent": 1,
"explicit_unroll": True,
"unroll_local_access": True,
}
}
):
after = tvm.tir.transform.UnrollLoop()(Before)
after = tvm.tir.transform.Simplify()(after)

tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
test_unroll_local_access()
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()
Expand Down

0 comments on commit 56ddd37

Please sign in to comment.