Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enhance loop unroll with unroll local access #14224

Merged
merged 1 commit into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <unordered_set>
#include <vector>

#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"

namespace tvm {
Expand All @@ -43,6 +44,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
int unroll_local_access;

TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") {
TVM_ATTR_FIELD(auto_max_step)
Expand All @@ -57,6 +59,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> {
TVM_ATTR_FIELD(explicit_unroll)
.describe("Whether to explicitly unroll the loop instead of setting a pragma")
.set_default(true);
TVM_ATTR_FIELD(unroll_local_access)
.describe("Whether to always unroll local access")
.set_default(false);
}
};

Expand All @@ -68,14 +73,30 @@ class UnrollLoopConfig : public Attrs {
TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

class VarLocalAccessMarker : public ExprVisitor {
public:
explicit VarLocalAccessMarker(
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local)
: var_touched_local_(var_touched_local) {}

void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef<Var>(op)); }

private:
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>* var_touched_local_;
};

// The Visitor is used to check whether var is used as write index in a local memory
// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent,
bool explicit_unroll)
bool explicit_unroll, bool unroll_local_access)
: auto_max_step_(auto_max_step),
auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent),
explicit_unroll_(explicit_unroll) {}
explicit_unroll_(explicit_unroll),
unroll_local_access_(unroll_local_access) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
Expand All @@ -96,6 +117,7 @@ class LoopUnroller : public StmtExprMutator {
}

Stmt VisitStmt_(const ForNode* op) {
// Post order so we can collect more information
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
int value = GetExtent(op);
Expand All @@ -111,6 +133,12 @@ class LoopUnroller : public StmtExprMutator {
auto_unroll = true;
}

// If a loop var is used as indices to a local memory, it must be unrolled so
// the local memory access can be turned into register access.
if (this->var_touched_local_.count(op->loop_var) && value > 0 && unroll_local_access_) {
auto_unroll = true;
}

if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
Expand All @@ -137,8 +165,32 @@ class LoopUnroller : public StmtExprMutator {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return GetRef<PrimExpr>(op);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
++step_count_;
if (unroll_local_access_) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return StmtExprMutator::VisitStmt_(op);
}

Expand All @@ -161,7 +213,7 @@ class LoopUnroller : public StmtExprMutator {
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
return ret;
};
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
return StmtExprMutator::VisitSeqStmt_(op, false, fmutate);
}

Stmt Unroll(const ForNode* op) {
Expand Down Expand Up @@ -202,19 +254,23 @@ class LoopUnroller : public StmtExprMutator {
// this not not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Wether to unroll loops to local access.
bool unroll_local_access_{false};
// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
// set of indices touched during visit local memory
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_touched_local_;
// analyzer
arith::Analyzer analyzer_;
};

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent,
cfg->explicit_unroll)(stmt);
cfg->explicit_unroll, cfg->unroll_local_access)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_unroll_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,49 @@ def main():
tvm.ir.assert_structural_equal(after, expected)


def test_unroll_local_access():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
for i in T.serial(4):
A_local[i] = T.float32(i)

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(B: T.Buffer((64,), "float32")):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(4, thread="threadIdx.x"):
A_local_data = T.allocate([4], dtype="float32", scope="local")
A_local = T.Buffer([4], dtype="float32", data=A_local_data)
A_local[0] = T.float32(0)
A_local[1] = T.float32(1)
A_local[2] = T.float32(2)
A_local[3] = T.float32(3)

with tvm.transform.PassContext(
config={
"tir.UnrollLoop": {
"auto_max_depth": 0,
"auto_max_extent": 1,
"explicit_unroll": True,
"unroll_local_access": True,
}
}
):
after = tvm.tir.transform.UnrollLoop()(Before)
after = tvm.tir.transform.Simplify()(after)

tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
test_unroll_local_access()
test_unroll_loop()
test_unroll_fake_loop()
test_unroll_single_count_loops()
Expand Down