Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite #12364

Merged
merged 2 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand All @@ -223,6 +222,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
if (!disable_storage_rewrite) {
pass_list.push_back(tir::transform::StorageRewrite());
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator {
};

// Checks whether the storage_scope is especially tagged for a specific memory.
// Special memory is all combined into a single allocation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've verified that with the special tagged memory helper change we are getting correct behavior for aligned access in the cases that I mentioned to fail previously.

bool IsSpecialTaggedMemory(const StorageScope& scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace";
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
}

// Alllocate entry of node.
Expand Down Expand Up @@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator {

if (e->allocs.size() == 1) {
// simply use the original allocation.
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
Expand Down