From a7eab357807e55c5df0865fa7c78bb48a8d7cd35 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 6 Apr 2023 19:25:11 -0400 Subject: [PATCH] [ARITH] Enhance CanProve to handle symbolic bound This PR enhances CanProve to handle symbolic bound. Such analysis is essential to eliminate predicates in dynamic shape workloads. We also the int set analysis singlepoint check to avoid recursion and improve the overall analysis speed. Added CanProveSinglePoint to serve previous stronger checks. The new CanProve comes with additinal strength argument that can only be used in top-level setting with stronger analysis. Added comment for future implementation efficiency. Testcases are added to cover the cases. --- include/tvm/arith/analyzer.h | 23 +++++++++- include/tvm/arith/int_set.h | 16 +++++++ python/tvm/arith/__init__.py | 2 +- python/tvm/arith/analyzer.py | 27 ++++++++++++ src/arith/analyzer.cc | 43 +++++++++++++++++-- src/arith/int_set.cc | 24 +++++++++-- src/arith/interval_set.h | 9 ++-- src/arith/rewrite_simplify.cc | 8 ++++ src/arith/rewrite_simplify.h | 5 ++- .../analysis/block_access_region_detector.cc | 4 +- src/tir/schedule/primitive/compute_at.cc | 7 +-- .../schedule/primitive/loop_transformation.cc | 2 +- tests/python/unittest/test_arith_simplify.py | 31 +++++++++++++ 13 files changed, 182 insertions(+), 19 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 885c23f491862..e64426aca3db1 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -59,6 +59,22 @@ enum DivMode { kFloorDiv }; +/*! + * \brief The strength used in top-level condition proves + * \note The higher, the more time consuming it can be. + * + * Do not use level beyond kDefault in internal recursive rewriting in arith + * analysis and only use it at top-level simplification to avoid speed issues. + */ +enum class ProofStrength : int { + /*! \brief default strength, can be used in. */ + kDefault = 0, + /*! + * \brief Prove using symbolic bound analysis + */ + kSymbolicBound = 1 +}; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -656,11 +672,16 @@ class TVM_DLL Analyzer { * \brief Whether can we prove condition. * * \param cond The expression to be proved. + * \param strength the strength of the prove. + * * \return The result. * * \note Analyzer will call into sub-analyzers to get the result. + * Do not use strength beyond default in sub-analyzers and + * only use it in top-level predicate analysis. */ - bool CanProve(const PrimExpr& cond); + bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault); + /*! * \brief Simplify expr. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 60d7c53d28e84..f09564d050caa 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -85,6 +85,22 @@ class IntSet : public ObjectRef { bool IsEverything() const; /*! \return Whether the set is a single point */ bool IsSinglePoint() const; + /*! + * \brief Check if we can prove it is a single point. + * + * Unlike IsSinglePoint, which only checks ptr equality + * this function will invoke analyzer to do stonger proofs + * but also takes longer time. + * + * Use this function in some of the primitives but do not + * use it in the inner loop of simplification. + * + * \param ana Analyzer used in the proof. + * \return Whether we can prove it is a single point + */ + bool CanProveSinglePoint(Analyzer* ana) const; + // TODO(tvm-team): update all CanProve to explicitly take + // analyzer to encourage more analyzer reuse /*! \return Whether the set is proved to be bigger than 0 */ bool CanProvePositive() const; /*! \return Whether the set is proved to be smaller than 0 */ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 423aafe5d69f4..401836aa19683 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -23,7 +23,7 @@ estimate_region_strict_bound, estimate_region_upper_bound, ) -from .analyzer import ModularSet, ConstIntBound, Analyzer +from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 28adbe9d815f5..5ea2dfad9dc65 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -15,11 +15,19 @@ # specific language governing permissions and limitations # under the License. """Arithmetic data structure and utility""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from . import _ffi_api +class ProofStrength(IntEnum): + """Proof strength of the analysis""" + + DEFAULT = 0 + SYMBOLIC_BOUND = 1 + + @tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -91,6 +99,7 @@ def __init__(self): self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") self._can_prove_equal = _mod("can_prove_equal") + self._can_prove = _mod("can_prove") def const_int_bound(self, expr): """Find constant integer bound for expr. @@ -190,6 +199,24 @@ def int_set(self, expr, dom_map): """ return self._int_set(expr, dom_map) + def can_prove(self, expr, strength=ProofStrength.DEFAULT): + """Check whether we can prove expr to be true. + + Parameters + ---------- + expr : PrimExpr + The expression. + + strength: ProofStrength + The proof strength + + Returns + ------- + result : Expr + The result. + """ + return self._can_prove(expr, strength) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 4714cf1df59fe..89dcb8301a1b7 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { return CanProve(lhs - rhs == 0); } -bool Analyzer::CanProve(const PrimExpr& expr) { +bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // Avoid potentially expensive simplification unless required. if (const auto* ptr = expr.as()) { return ptr->value != 0; } - PrimExpr simplified = Simplify(expr); const int64_t* as_int = tir::as_const_int(simplified); - return as_int && *as_int; + if (as_int && *as_int) return true; + if (strength >= ProofStrength::kSymbolicBound) { + // NOTE: we intentionally only pattern match common bound predicate i < bound + // and put this implementation at the top-level. + // This is to avoid repeatitive calling of this function + // that causes speed issues. + // This strategy can only be called from top-level and not from sub-analyzers. + Optional pos_diff; + int lower_bound = 0; + if (const auto* ptr_lt = expr.as()) { + pos_diff = ptr_lt->b - ptr_lt->a; + lower_bound = 1; + } + if (const auto* ptr_le = expr.as()) { + pos_diff = ptr_le->b - ptr_le->a; + lower_bound = 0; + } + if (const auto* ptr_gt = expr.as()) { + pos_diff = ptr_gt->a - ptr_gt->b; + lower_bound = 1; + } + if (const auto* ptr_ge = expr.as()) { + pos_diff = ptr_ge->a - ptr_ge->b; + lower_bound = 0; + } + if (pos_diff) { + IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + if (iset.HasLowerBound()) { + ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); + if (relaxed_lower_bound->min_value >= lower_bound) return true; + } + } + } + return false; } PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { @@ -189,6 +221,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu self->Bind(args[0], args[1].operator PrimExpr()); } }); + } else if (name == "can_prove") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + int strength = args[1]; + *ret = self->CanProve(args[0], static_cast(strength)); + }); } else if (name == "enter_constraint_context") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { // can't use make_shared due to noexcept(false) decl in destructor, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a75d316a7ece4..b9b829c1e5c5a 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -492,6 +492,11 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const CastNode* op) final { IntervalSet value_set = this->Eval(op->value); + // short cut for the int set. + if (value_set->min_value.same_as(value_set->max_value)) { + if (value_set->IsEmpty()) return value_set; + return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value)); + } PrimExpr min_value = value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); PrimExpr max_value = @@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const { return (s_int && s_int->IsSinglePoint()); } +bool IntSet::CanProveSinglePoint(Analyzer* ana) const { + const IntervalSetNode* s_int = (*this).as(); + if (!s_int) return false; + if (s_int->IsSinglePoint()) return true; + return ana->CanProveEqual(s_int->min_value, s_int->max_value); +} + bool IntSet::CanProvePositive() const { Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); @@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map) { } IntSet IntSet::Vector(PrimExpr x) { - Analyzer ana; - Map dmap; - return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + // short cutm simply get single point + if (x.dtype().lanes() == 1) { + return IntSet::SinglePoint(x); + } else { + // vector case. + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + } } IntSet EvalSet(PrimExpr e, const Map& dom_map) { diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 98fe5bdc2bc6f..dc40fa9d4deee 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode { bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ bool IsSinglePoint() const { - if (min_value.same_as(max_value)) { - return true; - } - Analyzer analyzer; - return analyzer.CanProveEqual(min_value, max_value); + // NOTE: we are only doing cheap check as this is a frequently called routine, + // do manual prove of min and max for stronger single point check. + return min_value.same_as(max_value); } + /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40a5977ec54ca..c9acc8f751a68 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -150,6 +150,12 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr // try to prove x equals val CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) { + // NOTE on implementation: this function can be called many times and can be a bottleneck, + // As a result, we keep comparison here lightweight. + // We only do constant int bound analysis here. + // + // For stronger comparison proof that is out of the recursive simplifcation + // consider look at analyzer::CanProveStrong PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -176,6 +182,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val if (dbound->max_value <= val) { return CompareResult::kLE; } + + // modular analysis if (val == 0) { ModularSet dmod = analyzer_->modular_set(diff); if (dmod->base != 0) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b8e7fcdd94337..1933c0d32a12b 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -90,6 +90,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { protected: // counter to record recursive rewrite depth. int recur_depth_{0}; + // counter to record recursive comparison depth that invokes set analysis + int symbolic_set_eval_depth_{0}; // internal variable map std::unordered_map var_map_; @@ -104,7 +106,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; - + // maximum number of set eval recursion allowed during a single pass. + static const constexpr int kMaxSymbolicSetEvalDepth = 1; /*! * \brief try to compare x against val. * \param x The expression to be evaluated. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 409356c2b1553..057cec475d84f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; + /*!\ brief Internal analyzer. */ + arith::Analyzer ana_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -318,7 +320,7 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - if (range.IsSinglePoint()) { + if (range.CanProveSinglePoint(&ana_)) { PrimExpr min = range.min(); region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); } else { diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 988c73c3f0711..75ea308de8a34 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise( arith::IntSet required = required_region[i]; PrimExpr dim_max = max(buffer->shape[i] - 1, 0); - if (provided.IsSinglePoint() && is_const_int(provided.min())) { - ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min())); + if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) { + ICHECK(required.CanProveSinglePoint(analyzer) && + analyzer->CanProveEqual(provided.min(), required.min())); continue; } @@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& std::unordered_map* iter_doms) { // we only support single point provided region now, which could cover most cases for (const auto& intset : provided_region) { - if (!intset.IsSinglePoint()) return false; + if (!intset.CanProveSinglePoint(analyzer)) return false; } // calculate forward mapping (block vars -> provided region point) Map dom_map; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index d9c58a0381035..a26843b7bd05b 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -430,7 +430,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate)) { + if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index aa9d5179aa3f8..754bf36d7ab28 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -34,5 +34,36 @@ def test_simplify_reshape_flattened_index(): ) +def test_simplify_symbolic_comparison(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64") + outer = (n + 31) // 32 + ana.bind(i0, tvm.ir.Range(0, outer)) + ana.bind(i1, tvm.ir.Range(0, 32)) + PS = tvm.arith.ProofStrength + + assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND) + + +def test_regression_simplify_inf_recursion(): + ana = tvm.arith.Analyzer() + cond = tir.Var("cond", "int32") + + res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype( + "int32" + ) == 0 + # regression in a previous case + # try compare and int set recursive call can cause infinite loop + ana.rewrite_simplify(res) + + if __name__ == "__main__": tvm.testing.main()