Skip to content

Commit

Permalink
tightening bounding box for IntSet fused in PassUpDomain (#3073)
Browse files Browse the repository at this point in the history
Apply suggestions from code review

Co-Authored-By: Wei Chen <[email protected]>
  • Loading branch information
2 people authored and tqchen committed Jul 18, 2019
1 parent ce363d6 commit 54f903a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,24 @@ void PassUpDomain(const FuseNode* s,
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
Expr fused_extent = (fused.max() - fused.min() + 1);
Expr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::interval(outer_min + fused.min() / inner_extent,
outer_min + fused.max() / inner_extent);
if (is_zero(Simplify(inner_extent % fused_extent)) &&
is_zero(Simplify(fused.min() % fused_extent)) ) {
// fused never spans multiple rows, make a tight bounding box
// there may be other cases when bounding box could be tightened
*inner = IntSet::interval(inner_min + fused.min() % inner_extent,
inner_min + fused.max() % inner_extent);
} else { // fused may span multiple rows, use full row widths
if (!is_zero(Simplify(fused_extent % inner_extent)) ||
!is_zero(Simplify(fused.min() % inner_extent))) {
LOG(WARNING) <<
"fused and original axes are not aligned, this may cause redundant computations";
}
*inner = IntSet::range(dom_map.at(s->inner));
}
return;
}
}
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,55 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)

def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
split1 = tvm.var('s')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.create_schedule(A2.op)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, split1)
s[A1].compute_at(s[A2], xo)

bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0)

expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1)
for i in range(1, 6):
for j in range(1, 6):
for k in range(1, 6):
vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")})
comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value
exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value
assert(comp_ext == exp_ext)

assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0)

def test_bound_fusesplit2():
m = tvm.var("m")
l = tvm.convert(6)
split = tvm.convert(3)
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.create_schedule(A2.op)
fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1])
xo, xi = s[A2].split(fused_axes, split)
s[A1].compute_at(s[A2], xo)

bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
vars = tvm.convert({xo.var: tvm.const(5, "int32")})
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1)
assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3)


def test_bound_warp():
m = tvm.var('m')
Expand Down Expand Up @@ -342,3 +391,5 @@ def _check(B, A=A):
test_bound_warp()
test_bound_tensor_compute_op()
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()

0 comments on commit 54f903a

Please sign in to comment.