diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 331b60a865ed..9ba9dcde63c9 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -550,8 +550,10 @@ class StoragePlanRewriter : public IRMutator { } if (e->allocs.size() == 1) { // simply use the original allocation. + Expr sz = arith::ComputeReduce(e->allocs[0]->extents, + make_const(Int(32), 1)); e->new_alloc = Allocate::make( - e->alloc_var, alloc_type, e->allocs[0]->extents, + e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); @@ -564,8 +566,19 @@ class StoragePlanRewriter : public IRMutator { Expr combo_size; for (const Allocate* op : e->allocs) { Expr sz = arith::ComputeReduce(op->extents, make_const(Int(32), 1)); + auto nbits = op->type.bits() * op->type.lanes(); + if (const auto* imm = sz.as()) { + if (imm->value > std::numeric_limits::max() / nbits) { + LOG(WARNING) << "The allocation requires : " << imm->value + << " * " << nbits + << " bits, which is greater than the maximum of" + " int32. The size is cast to int64." + << "\n"; + sz = make_const(Int(64), imm->value); + } + } // transform to bits - auto sz_nbits = sz * (op->type.bits() * op->type.lanes()); + auto sz_nbits = sz * nbits; if (combo_size.defined()) { combo_size = max(combo_size, sz_nbits); } else { @@ -578,7 +591,7 @@ class StoragePlanRewriter : public IRMutator { combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { - combo_size = combo_size + make_const(Int(32), 1); + combo_size = combo_size + make_const(Int(32), 1); } combo_size = ir::Simplify(combo_size); e->new_alloc = Allocate::make( diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index faf70204c29e..52851d4afe95 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -477,6 +477,30 @@ def test_replace_dataflow(): assert isinstance(bounds, tvm.container.Map) +def test_large_input(): + @tvm.hybrid.script + def compute(a, b): + n = 16384 + c = output_tensor((n, n), 'int32') + for i in range(n): + for j in range(n): + c[i, j] = a[i, j] - b[i, j] + return c + + n = 16384 + shape = (n, n) + a = tvm.placeholder(shape, name='a', dtype='int32') + b = tvm.placeholder(shape, name='b', dtype='int32') + c = tvm.compute(shape, lambda i, j: compute(a, b)[i, j]) + c = tvm.compute(shape, lambda i, j: 1 + c[i, j]) + s = tvm.create_schedule(c.op) + stmt = tvm.lower(s, [a, b, c], simple_mode=True) + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + assert n.extents[0].value == 268435456 + tvm.ir_pass.PostOrderVisit(stmt, verify) + + if __name__ == "__main__": test_alloc_seq() test_alloc_different_dtypes() @@ -492,3 +516,4 @@ def test_replace_dataflow(): test_alloc_seq_type2() test_reuse_small_buffer() test_replace_dataflow() + test_large_input()