From 2d03e87f8f5289b60716eab9f7dfce62f1d1939b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 10 Aug 2022 13:14:25 -0700 Subject: [PATCH 1/2] [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite Vtcm allocations were being moved inside loops even if they were originally allocated outside of the loops. Normally PlanAndUpdateBufferAllocationLocation moves allocations as close to use as possible and then StorageRewrite moves them back out as far as possible. However, with Vtcm allocation, PlanAndUpdateBufferAllocationLocation would move the Vtcm allocation close to the compute, then LowerVtcm would convert the allocation to a LetStmt. StorageRewrite would not move this LetStmt as it only handles allocations. Moving LowerVtcmAlloc to after StorageRewrite ensures that the vtcm allocations are in their final spot before converting them to a LetStmt. --- src/driver/driver_api.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 6f4fb618d334..02d6c5e6d1e3 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -204,7 +204,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -223,6 +222,8 @@ Array CreatePassList(bool disable_loop_partition) { if (!disable_storage_rewrite) { pass_list.push_back(tir::transform::StorageRewrite()); } + // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations + pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes From f9abbca4d78356cbcb42bffcfc6aa9a9ddf15d74 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 11 Aug 2022 13:48:55 -0700 Subject: [PATCH 2/2] fix issues with tagging and storage rewrite --- src/tir/transforms/storage_rewrite.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 5a326d9fac8d..d15bed56fd4a 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator { }; // Checks whether the storage_scope is especially tagged for a specific memory. + // Special memory is all combined into a single allocation. bool IsSpecialTaggedMemory(const StorageScope& scope) { - return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace"; + return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" && + scope.tag != ".vtcm"; } // Alllocate entry of node. @@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, e->allocs[0]->condition, Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) {