Skip to content

Commit

Permalink
Set split node's range to minimum of ext and split factor or split np…
Browse files Browse the repository at this point in the history
…arts, but only when PassDownDomain is called with allow_missing == false, i.e. by InferBound. Add a helper PassUpThreadBinding() to get a map telling whether an IterVar has at least one leaf IterVar deriving from it binding to a thread. Add two unit tests. (apache#5044)
  • Loading branch information
yongfeng-nv authored and zhiics committed Apr 17, 2020
1 parent 49e95e5 commit 9b6b0e3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 3 deletions.
76 changes: 73 additions & 3 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,66 @@ void Update(std::unordered_map<IterVar, Range>* p_state,
}
}

/*!
* \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to
* a thread.
*
* \param stage The stage to operate on.
* \param p_state The propagation result of each IterVar.
*/
void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>* p_state) {
auto bound_to_thread = [&stage](const IterVar& iv) {
bool bound = false;
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end()) {
bound = (*it).second->bind_thread.defined();
}
return bound;
};

auto& state = *p_state;
// Fill p_state with leaf itervars
for (const IterVar& iv : stage->leaf_iter_vars) {
state[iv] = bound_to_thread(iv);
}
// Traverse the graph bottom-up to propagate thread binding information
for (size_t i = stage->relations.size(); i != 0; --i) {
IterVarRelation rel = stage->relations[i - 1];
if (const SplitNode* s = rel.as<SplitNode>()) {
state[s->parent] = state[s->inner] || state[s->outer];
} else if (const FuseNode* s = rel.as<FuseNode>()) {
state[s->inner] = state[s->fused];
state[s->outer] = state[s->fused];
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
state[s->parent] = state[s->rebased];
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
}
}
}

void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
return actx->Simplify(indexdiv(a + (b - 1), b));
};

auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) {
if (actx->CanProve(a < b)) {
return actx->Simplify(a);
}
return actx->Simplify(b);
};

std::unordered_map<IterVar, bool> dominating_thread;
PassUpThreadBinding(stage, &dominating_thread);

auto& state = *p_state;
// forwar iteration on relations
for (IterVarRelation rel : stage->relations) {
Expand All @@ -72,14 +121,35 @@ void PassDownDomain(const Stage& stage,
}
CHECK(!state.count(r->inner));
const Range& range_parent = state.at(r->parent);
// Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the
// following conditions are met:
// 1. No leaf IterVar derived from iv binds to any thread. People may use split
// to force an IterVar extent to match the number of allocated threads to fuse stages
// that require different number of threads. We don't want to change these extents.
// 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound,
// rather than by an early compiler phase, such as rfactor(). We don't want to tighten an
// IterVar in an early phase allowing missing IterVars, because it may bind to a thread later.
// 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one
// zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent
// IterVar. We don't touch it.
auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) {
return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent)
? factor_or_nparts
: minimum_or_later(range_parent->extent, factor_or_nparts);
};
if (r->factor.defined()) {
Update(p_state, r->inner,
Range::make_by_min_extent(0, r->factor), actx);
Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->inner, r->factor)),
actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->factor)), actx);
} else {
Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx);
Update(p_state, r->outer,
Range::make_by_min_extent(
0, resolve_min_extent_for_split(r->outer, r->nparts)),
actx);
Update(p_state, r->inner,
Range::make_by_min_extent(
0, ceil_div(range_parent->extent, r->nparts)), actx);
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)

def test_bound_split_ext_less_than_factor():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], factor = 32)
s[EF].compute_at(s[E], xo)

bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xi].extent.value == m

def test_bound_split_ext_less_than_naprts():
m = 8
I = te.placeholder((m,), name='I')
EF = te.compute((m,), lambda i: I[i] * 2, name = "EF")
E = te.compute((m,), lambda i: EF[i] * 2, name = "E")
s = te.create_schedule([E.op])
xo, xi = s[E].split(s[E].op.axis[0], nparts = 32)
s[EF].compute_at(s[E], xo)

bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent.value == m

def test_bound_split_divisible():
m = te.var('m')
l = te.var('l')
Expand Down

0 comments on commit 9b6b0e3

Please sign in to comment.