Skip to content

Commit

Permalink
[TVM] Rewrite simplification rule to eliminate unnecessary conditiona…
Browse files Browse the repository at this point in the history
…ls. (apache#4076)

The current bounds checking infrastructure inserts checks like:

```
for (i, 0, bounds[n]) {
  if (likely(i < bounds[n]) {
     ...
  }
}
```

into the TVM IR which is currently not removed by simplification infrastructure.
This is a little unclean, as these are trivially true since for a loop var `i`
with a given min and extent, we are guaranteed that `i >= min` and `i < min +
extent`. Thus, we can insert these checks into the IR and use them to eliminate
trivial bounds checks early on.
  • Loading branch information
ajtulloch authored and Animesh Jain committed Oct 17, 2019
1 parent ab02884 commit f44f196
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 1 deletion.
2 changes: 2 additions & 0 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ class RewriteSimplifier {
const Expr& new_expr,
bool override = false);

std::function<void()> EnterConstraint(const Expr& constraint);

private:
friend class Analyzer;
friend class ConstraintContext;
Expand Down
4 changes: 3 additions & 1 deletion src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ void ConstraintContext::EnterWithScope() {
// entering the scope.
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1]() {
exit_ = [f0, f1, f2]() {
if (f2 != nullptr) f2();
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
Expand Down
23 changes: 23 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ Mutate_(const Add* op, const Expr& self) {
return ret;
}

std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) {
size_t old_literal_size = literal_constraints_.size();
literal_constraints_.push_back(constraint);
size_t new_literal_size = literal_constraints_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
CHECK_EQ(literal_constraints_.size(), new_literal_size);
literal_constraints_.resize(old_literal_size);
};
return frecover;
}

Expr RewriteSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
Expand Down Expand Up @@ -1705,6 +1716,14 @@ Mutate_(const Call* op, const Expr& self) {
return op->args[0] & op->args[1];
}
}
if (op->is_intrinsic(Call::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
return make_const(op->type, true);
}
}
}
return ret;
}

Expand Down Expand Up @@ -1761,6 +1780,10 @@ void RewriteSimplifier::Update(const Var& var,
impl_->Update(var, info, override);
}

std::function<void()> RewriteSimplifier::EnterConstraint(const Expr& constraint) {
return impl_->EnterConstraint(constraint);
}

RewriteSimplifier::RewriteSimplifier(Analyzer* parent)
: impl_(new Impl(parent)) {
}
Expand Down
6 changes: 6 additions & 0 deletions src/arithmetic/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/expr_operator.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
#include <vector>
#include "const_fold.h"
#include "pattern_match.h"
#include "ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -74,6 +75,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Expr Mutate_(const Cast* op, const Expr& self) override;
Expr Mutate_(const Let* op, const Expr& self) override;

std::function<void()> EnterConstraint(const Expr& constraint);

protected:
/*! \brief internal structure for comparison. */
enum CompareResult {
Expand All @@ -89,6 +92,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;

std::vector<Expr> literal_constraints_;

// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;

Expand Down
7 changes: 7 additions & 0 deletions src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Mutate(stmt);
}

Stmt Mutate_(const For* op, const Stmt& s) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutator::Mutate_(op, s);
}

Stmt Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
if (!ir::HasSideEffect(value)) {
Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_arith_stmt_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,62 @@ def test_thread_extent_simplify():
assert isinstance(body.body.body.body, tvm.stmt.Store)


def test_basic_likely_elimination():
n = tvm.var('n')
X = tvm.placeholder(shape=(n,), name="x")
W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w")

def f(i):
start = W[i]
extent = W[i+1] - W[i]
rv = tvm.reduce_axis((0, extent))
return tvm.sum(X[rv + start], axis=rv)
Y = tvm.compute(X.shape, f, name="y")
s = tvm.create_schedule([Y.op])
stmt = tvm.lower(s, [X, W, Y], simple_mode=True)
assert('if' not in str(stmt))

def test_complex_likely_elimination():
def cumsum(X):
"""
Y[i] = sum(X[:i])
"""
(m, ) = X.shape
s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state")
s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32"))
s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1])
return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")

def sparse_lengths_sum(data, indices, lengths):
oshape = list(data.shape)
oshape[0] = lengths.shape[0]
length_offsets = cumsum(lengths)

def sls(n, d):
gg = tvm.reduce_axis((0, lengths[n]))
indices_idx = length_offsets[n] + gg
data_idx = indices[indices_idx]
data_val = data[data_idx, d]
return tvm.sum(data_val, axis=gg)

return tvm.compute(oshape, sls)

m, n, d, i, l = tvm.var('m'), tvm.var('n'), tvm.var('d'), tvm.var('i'), tvm.var('l')
data_ph = tvm.placeholder((m, d * 32), name="data")
indices_ph = tvm.placeholder((i,), name="indices", dtype="int32")
lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32")
Y = sparse_lengths_sum(data_ph, indices_ph, lengths_ph)
s = tvm.create_schedule([Y.op])
(n, d) = s[Y].op.axis
(do, di) = s[Y].split(d, factor=32)
(gg,) = s[Y].op.reduce_axis
s[Y].reorder(n, do, gg, di)
s[Y].vectorize(di)
stmt = tvm.lower(s, [data_ph, indices_ph, lengths_ph, Y], simple_mode=True)
assert('if' not in str(stmt))

if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
test_basic_likely_elimination()
test_complex_likely_elimination()

0 comments on commit f44f196

Please sign in to comment.