Skip to content

Commit

Permalink
[TVM] Fix GatherBound to avoid allocating too much
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed Nov 16, 2018
1 parent 59c70a0 commit 65967a1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
35 changes: 34 additions & 1 deletion src/op/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,40 @@ void ComputeOpNode::GatherBound(
std::unordered_map<IterVar, Range>* out_dom_map) const {
const TensorDom& tdom = tensor_dom.at(self.output(0));
for (size_t i = 0; i < this->axis.size(); ++i) {
Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
// Bounds we get from the declaration of i
Range r_dom = this->axis[i]->dom;
// Bounds we get from the uses of the tensor
Range r_from_uses = arith::Union(tdom.data.at(i)).cover_range(r_dom);
// The result
Range r;

if (can_prove(r_from_uses->extent <= r_dom->extent)) {
// Bounds from the uses are provably tighter, use them
if (can_prove(r_from_uses->extent == r_dom->extent)) {
// If the extents are equal, prefer using r_dom, as it probably has the simpler min
r = r_dom;
} else {
r = r_from_uses;
}
} else if (can_prove(r_dom->extent <= r_from_uses->extent)) {
// The declared bounds are better. This may mean one of the following two things:
// either we have an out-of-bounds error in the input user code, or the simplifier
// did a poor job simplifying call arguments before evaluating ranges.
// Use the declared bounds but issue a warning.
LOG(WARNING) << "GatherBound: the declared bounds " << r_dom
<< " are tighter than the bounds from uses " << r_from_uses
<< " for the variable " << this->axis[i]->var << " of the tensor " << self->name
<< ". Either out-of-bounds or poor simplification.";
r = r_dom;
} else {
// We can prove neither. Issue a warning and use r_from_uses since it was the old behaviour
// and it leads to fewer problems.
LOG(WARNING) << "GatherBound: cannot prove either the declared bounds " << r_dom
<< " or the bounds from uses " << r_from_uses
<< " to be tighter than the other. Will use the bounds from uses.";
r = r_from_uses;
}

CHECK(!out_dom_map->count(this->axis[i]));
(*out_dom_map)[this->axis[i]] = r;
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,26 @@ def test_gemm_bound():
assert(bounds[CC.op.axis[0]].extent.value == 8)
assert(bounds[CC.op.axis[1]].extent.value == 8)

def test_bound_simplification_failure():
# Check that the bounds are not expanded
A = tvm.compute((2,), lambda j: j, "A")

def _check(B, A=A):
s = tvm.create_schedule(B.op)
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.lower(s, [B, A], simple_mode=True)
if not bounds[A.op.axis[0]].extent.value <= 2:
print(stmt)
assert bounds[A.op.axis[0]].extent.value <= 2

# These are hard to simplify, moreover we don't simplify them
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.min(-3*i, -2*i)]))
_check(tvm.compute((10,), lambda i: A[tvm.min(3*i, 4*i) + tvm.max(-3*i, -4*i)]))
_check(tvm.compute((10,), lambda i: A[-2*(i/2) - tvm.min(i, 0-i)]))
_check(tvm.compute((10,), lambda i: A[i + (0 - i)]))
# This would cause out of bounds, but we nevertheless include it
_check(tvm.compute((10,), lambda i: A[i]))

if __name__ == "__main__":
test_bound_nest_thread()
Expand All @@ -273,3 +293,4 @@ def test_gemm_bound():
test_bound2()
test_gemm_bound()
test_bound_warp()
test_bound_simplification_failure()
14 changes: 0 additions & 14 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,6 @@ def _compute(*indice):
stmt = tvm.schedule.ScheduleOps(s, bounds)


def test_schedule_bound_condition():
A = tvm.placeholder((64,), name='A', dtype="float32")
Apad = tvm.compute((66,), lambda i: tvm.select(tvm.all(i>0, i < 65), A[i-1], tvm.const(0.)), name='Apad')
Apad2 = tvm.compute((66,), lambda i: Apad[i]*2, name='Apad2')
s = tvm.create_schedule(Apad2.op)
AL1 = s.cache_read(A,"local",[Apad])
s = s.normalize()
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt = tvm.ir_pass.Simplify(stmt)
assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse))


def intrin_gemv(m, n):
w = tvm.placeholder((m, n), name='w')
x = tvm.placeholder((n,), name='x')
Expand Down Expand Up @@ -420,7 +407,6 @@ def test_schedule_tensor_compute3():
test_schedule1()
test_schedule2()
test_schedule_cache()
test_schedule_bound_condition()
test_schedule_tensor_compute1()
test_schedule_tensor_compute2()
test_schedule_tensor_compute3()

0 comments on commit 65967a1

Please sign in to comment.