diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 134676c287e1c..065894871ab9e 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -564,11 +564,14 @@ class StoragePlanRewriter : public IRMutator { Expr combo_size; for (const Allocate* op : e->allocs) { Expr sz = arith::ComputeReduce(op->extents, make_const(Int(32), 1)); + auto nbits = op->type.bits() * op->type.lanes(); if (const auto* imm = sz.as()) { - sz = make_const(Int(64), imm->value); + if (imm->value > std::numeric_limits::max() / nbits) { + sz = make_const(Int(64), imm->value); + } } // transform to bits - auto sz_nbits = sz * (op->type.bits() * op->type.lanes()); + auto sz_nbits = sz * nbits; if (combo_size.defined()) { combo_size = max(combo_size, sz_nbits); } else { @@ -581,7 +584,7 @@ class StoragePlanRewriter : public IRMutator { combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { - combo_size = combo_size + make_const(Int(64), 1); + combo_size = combo_size + make_const(Int(32), 1); } combo_size = ir::Simplify(combo_size); e->new_alloc = Allocate::make(