diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 1090b46e59f0..8be1c3604813 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -245,6 +245,8 @@ class RewriteSimplifier { const Expr& new_expr, bool override = false); + std::function EnterConstraint(const Expr& constraint); + private: friend class Analyzer; friend class ConstraintContext; diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 8f832d7313f3..acd964935c25 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -67,8 +67,10 @@ void ConstraintContext::EnterWithScope() { // entering the scope. auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); + auto f2 = analyzer_->rewrite_simplify.EnterConstraint(constraint_); // recovery function. - exit_ = [f0, f1]() { + exit_ = [f0, f1, f2]() { + if (f2 != nullptr) f2(); if (f1 != nullptr) f1(); if (f0 != nullptr) f0(); }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index e3b3e7aed09c..c55385331655 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -220,6 +220,17 @@ Mutate_(const Add* op, const Expr& self) { return ret; } +std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) { + size_t old_literal_size = literal_constraints_.size(); + literal_constraints_.push_back(constraint); + size_t new_literal_size = literal_constraints_.size(); + auto frecover = [old_literal_size, new_literal_size, this]() { + CHECK_EQ(literal_constraints_.size(), new_literal_size); + literal_constraints_.resize(old_literal_size); + }; + return frecover; +} + Expr RewriteSimplifier::Impl:: Mutate_(const Sub* op, const Expr& self) { Expr ret = IRMutator::Mutate_(op, self); @@ -1705,6 +1716,14 @@ Mutate_(const Call* op, const Expr& self) { return op->args[0] & op->args[1]; } } + if (op->is_intrinsic(Call::likely)) { + for (const auto& constraint : literal_constraints_) { + // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } + if (Equal(constraint, op->args[0])) { + return make_const(op->type, true); + } + } + } return ret; } @@ -1761,6 +1780,10 @@ void RewriteSimplifier::Update(const Var& var, impl_->Update(var, info, override); } +std::function RewriteSimplifier::EnterConstraint(const Expr& constraint) { + return impl_->EnterConstraint(constraint); +} + RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) { } diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index 55965ce42d6a..202618599872 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "const_fold.h" #include "pattern_match.h" #include "ir_mutator_with_analyzer.h" @@ -74,6 +75,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { Expr Mutate_(const Cast* op, const Expr& self) override; Expr Mutate_(const Let* op, const Expr& self) override; + std::function EnterConstraint(const Expr& constraint); + protected: /*! \brief internal structure for comparison. */ enum CompareResult { @@ -89,6 +92,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { int recur_depth_{0}; // internal variable map std::unordered_map var_map_; + + std::vector literal_constraints_; + // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index f784514e1302..ae8e35f51d8e 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -51,6 +51,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return Mutate(stmt); } + Stmt Mutate_(const For* op, const Stmt& s) final { + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, op->loop_var < op->min + op->extent); + return IRMutator::Mutate_(op, s); + } + Stmt Mutate_(const LetStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); if (!ir::HasSideEffect(value)) { diff --git a/tests/python/unittest/test_arith_stmt_simplify.py b/tests/python/unittest/test_arith_stmt_simplify.py index 2de4ee5bc735..272893e20c12 100644 --- a/tests/python/unittest/test_arith_stmt_simplify.py +++ b/tests/python/unittest/test_arith_stmt_simplify.py @@ -47,6 +47,62 @@ def test_thread_extent_simplify(): assert isinstance(body.body.body.body, tvm.stmt.Store) +def test_basic_likely_elimination(): + n = tvm.var('n') + X = tvm.placeholder(shape=(n,), name="x") + W = tvm.placeholder(shape=(n + 1,), dtype="int32", name="w") + + def f(i): + start = W[i] + extent = W[i+1] - W[i] + rv = tvm.reduce_axis((0, extent)) + return tvm.sum(X[rv + start], axis=rv) + Y = tvm.compute(X.shape, f, name="y") + s = tvm.create_schedule([Y.op]) + stmt = tvm.lower(s, [X, W, Y], simple_mode=True) + assert('if' not in str(stmt)) + +def test_complex_likely_elimination(): + def cumsum(X): + """ + Y[i] = sum(X[:i]) + """ + (m, ) = X.shape + s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state") + s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32")) + s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1]) + return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum") + + def sparse_lengths_sum(data, indices, lengths): + oshape = list(data.shape) + oshape[0] = lengths.shape[0] + length_offsets = cumsum(lengths) + + def sls(n, d): + gg = tvm.reduce_axis((0, lengths[n])) + indices_idx = length_offsets[n] + gg + data_idx = indices[indices_idx] + data_val = data[data_idx, d] + return tvm.sum(data_val, axis=gg) + + return tvm.compute(oshape, sls) + + m, n, d, i, l = tvm.var('m'), tvm.var('n'), tvm.var('d'), tvm.var('i'), tvm.var('l') + data_ph = tvm.placeholder((m, d * 32), name="data") + indices_ph = tvm.placeholder((i,), name="indices", dtype="int32") + lengths_ph = tvm.placeholder((n,), name="lengths", dtype="int32") + Y = sparse_lengths_sum(data_ph, indices_ph, lengths_ph) + s = tvm.create_schedule([Y.op]) + (n, d) = s[Y].op.axis + (do, di) = s[Y].split(d, factor=32) + (gg,) = s[Y].op.reduce_axis + s[Y].reorder(n, do, gg, di) + s[Y].vectorize(di) + stmt = tvm.lower(s, [data_ph, indices_ph, lengths_ph, Y], simple_mode=True) + assert('if' not in str(stmt)) + if __name__ == "__main__": test_stmt_simplify() test_thread_extent_simplify() + test_basic_likely_elimination() + test_complex_likely_elimination()