Skip to content

Commit

Permalink
fix storage rewrite index remap (apache#8338)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored Jul 6, 2021
1 parent ec47129 commit 6bcad2e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
if (e->bits_offset == 0) return index;
uint64_t elem_bits = dtype.bits() * dtype.lanes();
uint64_t elem_bits = dtype.bits();
ICHECK_EQ(e->bits_offset % elem_bits, 0U);
return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
}
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,47 @@ def verify(n):
assert num_alloc[0] == 1


def test_storage_combine_with_vectorization():
n = 1024
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
C = te.compute((n,), lambda i: A[i] + B[i], name="C")
s = te.create_schedule(C.op)
AA = s.cache_read(A, "global:tag", readers=[C])
BB = s.cache_read(B, "global:tag", readers=[C])
CC = s.cache_write(C, "global:tag")
s[CC].vectorize(s[CC].op.axis[0])
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.VectorizeLoop()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
mod = tvm.tir.transform.Simplify()(mod)
stmt = mod["main"].body
num_alloc = [0]

def verify(v):
# find add op
if (
isinstance(v, tvm.tir.Add)
and isinstance(v.a, tvm.tir.Load)
and isinstance(v.b, tvm.tir.Load)
):
lhs_ramp = v.a.index
rhs_ramp = v.b.index
# these two ramp load should not overlap
assert lhs_ramp.lanes == n
assert rhs_ramp.lanes == n
assert lhs_ramp.base >= rhs_ramp.base + n or rhs_ramp.base >= lhs_ramp.base + n
elif isinstance(v, tvm.tir.Allocate):
num_alloc[0] += 1

tvm.tir.stmt_functor.post_order_visit(stmt, verify)
assert num_alloc[0] == 1


def test_storage_share_gpu():
m = te.var("m")
A = [te.placeholder((m), name="A")]
Expand Down Expand Up @@ -648,6 +689,7 @@ def verify(n):
test_parallel_alloc()
test_while_alloc()
test_storage_combine()
test_storage_combine_with_vectorization()
test_storage_share_gpu()
test_inplace_rule2()

Expand Down

0 comments on commit 6bcad2e

Please sign in to comment.