Skip to content

Commit

Permalink
[TVM][Bugfix] fix storage_rewrite bug when input is big (apache#2580)
Browse files Browse the repository at this point in the history
* fix storage_rewrite bug when input is big

* cast when necessary

* simplification

* simplification

* int64->uint32

* revert uint32->int64
zhiics authored and libing4752 committed Feb 18, 2019
1 parent b0737c0 commit b6099dc
Showing 2 changed files with 41 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/pass/storage_rewrite.cc
Original file line number Diff line number Diff line change
@@ -550,8 +550,10 @@ class StoragePlanRewriter : public IRMutator {
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
make_const(Int(32), 1));
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, e->allocs[0]->extents,
e->alloc_var, alloc_type, {sz},
e->allocs[0]->condition, Evaluate::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
@@ -564,8 +566,19 @@ class StoragePlanRewriter : public IRMutator {
Expr combo_size;
for (const Allocate* op : e->allocs) {
Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
auto nbits = op->type.bits() * op->type.lanes();
if (const auto* imm = sz.as<IntImm>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
<< " * " << nbits
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
sz = make_const(Int(64), imm->value);
}
}
// transform to bits
auto sz_nbits = sz * (op->type.bits() * op->type.lanes());
auto sz_nbits = sz * nbits;
if (combo_size.defined()) {
combo_size = max(combo_size, sz_nbits);
} else {
@@ -578,7 +591,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size = combo_size / type_bits;
// round up for can not divided
if (!divided) {
combo_size = combo_size + make_const(Int(32), 1);
combo_size = combo_size + make_const(Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
25 changes: 25 additions & 0 deletions tests/python/unittest/test_pass_storage_rewrite.py
Original file line number Diff line number Diff line change
@@ -477,6 +477,30 @@ def test_replace_dataflow():
assert isinstance(bounds, tvm.container.Map)


def test_large_input():
@tvm.hybrid.script
def compute(a, b):
n = 16384
c = output_tensor((n, n), 'int32')
for i in range(n):
for j in range(n):
c[i, j] = a[i, j] - b[i, j]
return c

n = 16384
shape = (n, n)
a = tvm.placeholder(shape, name='a', dtype='int32')
b = tvm.placeholder(shape, name='b', dtype='int32')
c = tvm.compute(shape, lambda i, j: compute(a, b)[i, j])
c = tvm.compute(shape, lambda i, j: 1 + c[i, j])
s = tvm.create_schedule(c.op)
stmt = tvm.lower(s, [a, b, c], simple_mode=True)
def verify(n):
if isinstance(n, tvm.stmt.Allocate):
assert n.extents[0].value == 268435456
tvm.ir_pass.PostOrderVisit(stmt, verify)


if __name__ == "__main__":
test_alloc_seq()
test_alloc_different_dtypes()
@@ -492,3 +516,4 @@ def test_replace_dataflow():
test_alloc_seq_type2()
test_reuse_small_buffer()
test_replace_dataflow()
test_large_input()

0 comments on commit b6099dc

Please sign in to comment.