diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 12c5db703ec58..34e756d369ef1 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -270,10 +270,24 @@ void PassUpDomain(const FuseNode* s, *outer = IntSet::single_point(v_outer); *inner = IntSet::single_point(v_inner); } else { - LOG(WARNING) << "use fallback inference rule in fuse"; - // simply use the entire set, this rule can be enhanced. - *outer = IntSet::range(dom_map.at(s->outer)); - *inner = IntSet::range(dom_map.at(s->inner)); + Expr fused_extent = (fused.max() - fused.min() + 1); + Expr inner_extent = dom_map.at(s->inner)->extent; + *outer = IntSet::interval(outer_min + fused.min() / inner_extent, + outer_min + fused.max() / inner_extent); + if ( is_zero(Simplify(inner_extent % fused_extent)) && + is_zero(Simplify(fused.min() % fused_extent)) ) { + // fused never spans multiple rows, make a tight bounding box + // there may be other cases when bounding box could be tightened + *inner = IntSet::interval(inner_min + fused.min() % inner_extent, + inner_min + fused.max() % inner_extent); + } else { // fused may span multiple rows, use full row widths + if (!is_zero(Simplify(fused_extent % inner_extent)) || + !is_zero(Simplify(fused.min() % inner_extent))) { + LOG(WARNING) << + "fused and original axes are not aligned, this may cause redundant computations"; + } + *inner = IntSet::range(dom_map.at(s->inner)); + } return; } } diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index f16305779a43a..740527116a6e5 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -69,6 +69,55 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_fusesplit1(): + m = tvm.var('m') + l = tvm.var('l') + split1 = tvm.var('s') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.create_schedule(A2.op) + fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1]) + xo, xi = s[A2].split(fused_axes, split1) + s[A1].compute_at(s[A2], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[0]].min - (xo * split1) / l ).value == 0) + + expected_extent = (((xo + 1) * split1 - 1) / l - (xo * split1) / l + 1) + for i in range(1, 6): + for j in range(1, 6): + for k in range(1, 6): + vars = tvm.convert({split1: tvm.const(i, "int32"), l: tvm.const(j, "int32"), xo.var: tvm.const(k, "int32")}) + comp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value + exp_ext = tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(expected_extent, vars)).value + assert( comp_ext == exp_ext ) + + assert(tvm.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + +def test_bound_fusesplit2(): + m = tvm.var("m") + l = tvm.convert(6) + split = tvm.convert(3) + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.create_schedule(A2.op) + fused_axes = s[A2].fuse(A2.op.axis[0], A2.op.axis[1]) + xo, xi = s[A2].split(fused_axes, split) + s[A1].compute_at(s[A2], xo) + + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + vars = tvm.convert({xo.var: tvm.const(5, "int32")}) + assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2) + assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3) + assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1) + assert(tvm.ir_pass.Simplify(tvm.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) + def test_bound_warp(): m = tvm.var('m') @@ -320,3 +369,5 @@ def _body(): test_gemm_bound() test_bound_warp() test_bound_tensor_compute_op() + test_bound_fusesplit1() + test_bound_fusesplit2()