From 6bcad2ec0a799153643823d209042b519d8a4569 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 6 Jul 2021 08:35:35 +0800 Subject: [PATCH] fix storage rewrite index remap (#8338) --- src/tir/transforms/storage_rewrite.cc | 2 +- .../test_tir_transform_storage_rewrite.py | 42 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 36eeddb17d89..c755576e2b88 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -506,7 +506,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Remap the index PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { if (e->bits_offset == 0) return index; - uint64_t elem_bits = dtype.bits() * dtype.lanes(); + uint64_t elem_bits = dtype.bits(); ICHECK_EQ(e->bits_offset % elem_bits, 0U); return make_const(index.dtype(), e->bits_offset / elem_bits) + index; } diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index dbe7e04700d9..70e77ff69fea 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -228,6 +228,47 @@ def verify(n): assert num_alloc[0] == 1 +def test_storage_combine_with_vectorization(): + n = 1024 + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute((n,), lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + AA = s.cache_read(A, "global:tag", readers=[C]) + BB = s.cache_read(B, "global:tag", readers=[C]) + CC = s.cache_write(C, "global:tag") + s[CC].vectorize(s[CC].op.axis[0]) + bounds = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + mod = tvm.tir.transform.Simplify()(mod) + stmt = mod["main"].body + num_alloc = [0] + + def verify(v): + # find add op + if ( + isinstance(v, tvm.tir.Add) + and isinstance(v.a, tvm.tir.Load) + and isinstance(v.b, tvm.tir.Load) + ): + lhs_ramp = v.a.index + rhs_ramp = v.b.index + # these two ramp load should not overlap + assert lhs_ramp.lanes == n + assert rhs_ramp.lanes == n + assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n + elif isinstance(v, tvm.tir.Allocate): + num_alloc[0] += 1 + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + def test_storage_share_gpu(): m = te.var("m") A = [te.placeholder((m), name="A")] @@ -648,6 +689,7 @@ def verify(n): test_parallel_alloc() test_while_alloc() test_storage_combine() + test_storage_combine_with_vectorization() test_storage_share_gpu() test_inplace_rule2()