diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 5c972595ff00..8141462478ab 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -214,7 +214,40 @@ void ComputeOpNode::GatherBound( std::unordered_map* out_dom_map) const { const TensorDom& tdom = tensor_dom.at(self.output(0)); for (size_t i = 0; i < this->axis.size(); ++i) { - Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); + // Bounds we get from the declaration of i + Range r_dom = this->axis[i]->dom; + // Bounds we get from the uses of the tensor + Range r_from_uses = arith::Union(tdom.data.at(i)).cover_range(r_dom); + // The result + Range r; + + if (can_prove(r_from_uses->extent <= r_dom->extent)) { + // Bounds from the uses are provably tighter, use them + if (can_prove(r_from_uses->extent == r_dom->extent)) { + // If the extents are equal, prefer using r_dom, as it probably has the simpler min + r = r_dom; + } else { + r = r_from_uses; + } + } else if (can_prove(r_dom->extent <= r_from_uses->extent)) { + // The declared bounds are better. This may mean one of the following two things: + // either we have an out-of-bounds error in the input user code, or the simplifier + // did a poor job simplifying call arguments before evaluating ranges. + // Use the declared bounds but issue a warning. + LOG(WARNING) << "GatherBound: the declared bounds " << r_dom + << " are tighter than the bounds from uses " << r_from_uses + << " for the variable " << this->axis[i]->var << " of the tensor " << self->name + << ". Either out-of-bounds or poor simplification."; + r = r_dom; + } else { + // We can prove neither. Issue a warning and use r_from_uses since it was the old behaviour + // and it leads to fewer problems. + LOG(WARNING) << "GatherBound: cannot prove either the declared bounds " << r_dom + << " or the bounds from uses " << r_from_uses + << " to be tighter than the other. Will use the bounds from uses."; + r = r_from_uses; + } + CHECK(!out_dom_map->count(this->axis[i])); (*out_dom_map)[this->axis[i]] = r; } diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 30be3783bbb3..218f5c73b122 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -259,6 +259,26 @@ def test_gemm_bound(): assert(bounds[CC.op.axis[0]].extent.value == 8) assert(bounds[CC.op.axis[1]].extent.value == 8) +def test_bound_simplification_failure(): + # Check that the bounds are not expanded + A = tvm.compute((2,), lambda j: j, "A") + + def _check(B, A=A): + s = tvm.create_schedule(B.op) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.lower(s, [B, A], simple_mode=True) + if not bounds[A.op.axis[0]].extent.value <= 2: + print(stmt) + assert bounds[A.op.axis[0]].extent.value <= 2 + + # These are hard to simplify, moreover we don't simplify them + _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)])) + _check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)])) + _check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)])) + _check(tvm.compute((10,), lambda i: A[i + (0 - i)])) + # This would cause out of bounds, but we nevertheless include it + _check(tvm.compute((10,), lambda i: A[i])) if __name__ == "__main__": test_bound_nest_thread() @@ -273,3 +293,4 @@ def test_gemm_bound(): test_bound2() test_gemm_bound() test_bound_warp() + test_bound_simplification_failure() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8774514cfa17..bd1194c72a86 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -264,19 +264,6 @@ def _compute(*indice): stmt = tvm.schedule.ScheduleOps(s, bounds) -def test_schedule_bound_condition(): - A = tvm.placeholder((64,), name='A', dtype="float32") - Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad') - Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2') - s = tvm.create_schedule(Apad2.op) - AL1 = s.cache_read(A,"local",[Apad]) - s = s.normalize() - bounds = tvm.schedule.InferBound(s) - stmt = tvm.schedule.ScheduleOps(s, bounds) - stmt = tvm.ir_pass.Simplify(stmt) - assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) - - def intrin_gemv(m, n): w = tvm.placeholder((m, n), name='w') x = tvm.placeholder((n,), name='x') @@ -420,7 +407,6 @@ def test_schedule_tensor_compute3(): test_schedule1() test_schedule2() test_schedule_cache() - test_schedule_bound_condition() test_schedule_tensor_compute1() test_schedule_tensor_compute2() test_schedule_tensor_compute3()