From d66ec9308370c08be8a974cde8eea62a713adb38 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 17 Mar 2022 08:33:31 -0500 Subject: [PATCH] [TIR] Bugfix in StorageFlatten, index flattening in PrefetchNode This resolves a bug introduced in https://github.com/apache/tvm/pull/9727, and adds a test to catch this failure mode. This bug occurred because StorageFlatten's visitor for PrefetchNode inserted additional pre-flattened `BufferLoad` nodes after visiting the body of the Prefetch, preventing those `BufferLoad` nodes from being flattened. Moving this visit to after the insertion of the `BufferLoad` nodes allows the usual buffer flattening to apply to the newly inserted nodes. --- src/tir/transforms/storage_flatten.cc | 8 +++----- .../python/unittest/test_tir_transform_storage_flatten.py | 6 ++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index ed36f5828d13..0d57f7928f47 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1489,10 +1489,6 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const PrefetchNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); - const BufferEntry& e = GetBufferEntry(op->buffer); ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope."; @@ -1524,6 +1520,8 @@ class StorageFlattener : public StmtExprMutator { vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); args.push_back(vars.back() + op->bounds[i]->min); } + + Stmt stmt = GetRef(op); for (int i = starts; i >= 0; --i) { if (i < starts) { stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt); @@ -1536,7 +1534,7 @@ class StorageFlattener : public StmtExprMutator { stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt); } } - return stmt; + return this->VisitStmt(stmt); } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 8e430b035606..17afe7881184 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -57,6 +57,12 @@ def test_flatten_prefetch(): assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 + def assert_flat_loads(stmt): + if isinstance(stmt, tvm.tir.BufferLoad): + assert len(stmt.indices) == 1, "All prefetch indices should be flattened" + + tvm.tir.stmt_functor.post_order_visit(stmt, assert_flat_loads) + def test_flatten_storage_align(): m = 8