From 0a62b7fe72406048526b2ecfdda240d9fc028e6c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 8 Apr 2023 20:12:20 -0400 Subject: [PATCH] [Cherry-Pick][ARITH] Enhance CanProve to handle symbolic bound (#14523) (#175) This PR enhances CanProve to handle symbolic bound. Such analysis is essential to eliminate predicates in dynamic shape workloads. We also the int set analysis singlepoint check to avoid recursion and improve the overall analysis speed. Added CanProveSinglePoint to serve previous stronger checks. The new CanProve comes with additinal strength argument that can only be used in top-level setting with stronger analysis. Added comment for future implementation efficiency. Testcases are added to cover the cases. Co-authored-by: Tianqi Chen --- include/tvm/arith/analyzer.h | 23 +++++++++- include/tvm/arith/int_set.h | 16 +++++++ python/tvm/arith/__init__.py | 2 +- python/tvm/arith/analyzer.py | 27 ++++++++++++ src/arith/analyzer.cc | 43 +++++++++++++++++-- src/arith/int_set.cc | 24 +++++++++-- src/arith/interval_set.h | 9 ++-- src/arith/rewrite_simplify.cc | 8 ++++ src/arith/rewrite_simplify.h | 1 - .../analysis/block_access_region_detector.cc | 4 +- src/tir/schedule/primitive/compute_at.cc | 7 +-- .../schedule/primitive/loop_transformation.cc | 2 +- tests/python/unittest/test_arith_simplify.py | 31 +++++++++++++ 13 files changed, 178 insertions(+), 19 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 885c23f49186..e64426aca3db 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -59,6 +59,22 @@ enum DivMode { kFloorDiv }; +/*! + * \brief The strength used in top-level condition proves + * \note The higher, the more time consuming it can be. + * + * Do not use level beyond kDefault in internal recursive rewriting in arith + * analysis and only use it at top-level simplification to avoid speed issues. + */ +enum class ProofStrength : int { + /*! \brief default strength, can be used in. */ + kDefault = 0, + /*! + * \brief Prove using symbolic bound analysis + */ + kSymbolicBound = 1 +}; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -656,11 +672,16 @@ class TVM_DLL Analyzer { * \brief Whether can we prove condition. * * \param cond The expression to be proved. + * \param strength the strength of the prove. + * * \return The result. * * \note Analyzer will call into sub-analyzers to get the result. + * Do not use strength beyond default in sub-analyzers and + * only use it in top-level predicate analysis. */ - bool CanProve(const PrimExpr& cond); + bool CanProve(const PrimExpr& cond, ProofStrength strength = ProofStrength::kDefault); + /*! * \brief Simplify expr. * diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 60d7c53d28e8..f09564d050ca 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -85,6 +85,22 @@ class IntSet : public ObjectRef { bool IsEverything() const; /*! \return Whether the set is a single point */ bool IsSinglePoint() const; + /*! + * \brief Check if we can prove it is a single point. + * + * Unlike IsSinglePoint, which only checks ptr equality + * this function will invoke analyzer to do stonger proofs + * but also takes longer time. + * + * Use this function in some of the primitives but do not + * use it in the inner loop of simplification. + * + * \param ana Analyzer used in the proof. + * \return Whether we can prove it is a single point + */ + bool CanProveSinglePoint(Analyzer* ana) const; + // TODO(tvm-team): update all CanProve to explicitly take + // analyzer to encourage more analyzer reuse /*! \return Whether the set is proved to be bigger than 0 */ bool CanProvePositive() const; /*! \return Whether the set is proved to be smaller than 0 */ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 423aafe5d69f..401836aa1968 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -23,7 +23,7 @@ estimate_region_strict_bound, estimate_region_upper_bound, ) -from .analyzer import ModularSet, ConstIntBound, Analyzer +from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 28adbe9d815f..5ea2dfad9dc6 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -15,11 +15,19 @@ # specific language governing permissions and limitations # under the License. """Arithmetic data structure and utility""" +from enum import IntEnum import tvm._ffi from tvm.runtime import Object from . import _ffi_api +class ProofStrength(IntEnum): + """Proof strength of the analysis""" + + DEFAULT = 0 + SYMBOLIC_BOUND = 1 + + @tvm._ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -91,6 +99,7 @@ def __init__(self): self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") self._can_prove_equal = _mod("can_prove_equal") + self._can_prove = _mod("can_prove") def const_int_bound(self, expr): """Find constant integer bound for expr. @@ -190,6 +199,24 @@ def int_set(self, expr, dom_map): """ return self._int_set(expr, dom_map) + def can_prove(self, expr, strength=ProofStrength.DEFAULT): + """Check whether we can prove expr to be true. + + Parameters + ---------- + expr : PrimExpr + The expression. + + strength: ProofStrength + The proof strength + + Returns + ------- + result : Expr + The result. + """ + return self._can_prove(expr, strength) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 4714cf1df59f..89dcb8301a1b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -115,15 +115,47 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { return CanProve(lhs - rhs == 0); } -bool Analyzer::CanProve(const PrimExpr& expr) { +bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // Avoid potentially expensive simplification unless required. if (const auto* ptr = expr.as()) { return ptr->value != 0; } - PrimExpr simplified = Simplify(expr); const int64_t* as_int = tir::as_const_int(simplified); - return as_int && *as_int; + if (as_int && *as_int) return true; + if (strength >= ProofStrength::kSymbolicBound) { + // NOTE: we intentionally only pattern match common bound predicate i < bound + // and put this implementation at the top-level. + // This is to avoid repeatitive calling of this function + // that causes speed issues. + // This strategy can only be called from top-level and not from sub-analyzers. + Optional pos_diff; + int lower_bound = 0; + if (const auto* ptr_lt = expr.as()) { + pos_diff = ptr_lt->b - ptr_lt->a; + lower_bound = 1; + } + if (const auto* ptr_le = expr.as()) { + pos_diff = ptr_le->b - ptr_le->a; + lower_bound = 0; + } + if (const auto* ptr_gt = expr.as()) { + pos_diff = ptr_gt->a - ptr_gt->b; + lower_bound = 1; + } + if (const auto* ptr_ge = expr.as()) { + pos_diff = ptr_ge->a - ptr_ge->b; + lower_bound = 0; + } + if (pos_diff) { + IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + if (iset.HasLowerBound()) { + ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); + if (relaxed_lower_bound->min_value >= lower_bound) return true; + } + } + } + return false; } PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { @@ -189,6 +221,11 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu self->Bind(args[0], args[1].operator PrimExpr()); } }); + } else if (name == "can_prove") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + int strength = args[1]; + *ret = self->CanProve(args[0], static_cast(strength)); + }); } else if (name == "enter_constraint_context") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { // can't use make_shared due to noexcept(false) decl in destructor, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a75d316a7ece..1ad182aa8351 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -492,6 +492,11 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const CastNode* op) final { IntervalSet value_set = this->Eval(op->value); + // short cut for the int set. + if (value_set->min_value.same_as(value_set->max_value)) { + if (value_set->IsEmpty()) return value_set; + return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value)); + } PrimExpr min_value = value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); PrimExpr max_value = @@ -723,6 +728,13 @@ bool IntSet::IsSinglePoint() const { return (s_int && s_int->IsSinglePoint()); } +bool IntSet::CanProveSinglePoint(Analyzer* ana) const { + const IntervalSetNode* s_int = (*this).as(); + if (!s_int) return false; + if (s_int->IsSinglePoint()) return true; + return ana->CanProveEqual(s_int->min_value, s_int->max_value); +} + bool IntSet::CanProvePositive() const { Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); @@ -943,9 +955,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map) { } IntSet IntSet::Vector(PrimExpr x) { - Analyzer ana; - Map dmap; - return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + // short cut: simply get single point + if (x.dtype().lanes() == 1) { + return IntSet::SinglePoint(x); + } else { + // vector case. + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + } } IntSet EvalSet(PrimExpr e, const Map& dom_map) { diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 98fe5bdc2bc6..dc40fa9d4dee 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -60,12 +60,11 @@ class IntervalSetNode : public IntSetNode { bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ bool IsSinglePoint() const { - if (min_value.same_as(max_value)) { - return true; - } - Analyzer analyzer; - return analyzer.CanProveEqual(min_value, max_value); + // NOTE: we are only doing cheap check as this is a frequently called routine, + // do manual prove of min and max for stronger single point check. + return min_value.same_as(max_value); } + /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 0b646ab3205a..a6b71d724178 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -150,6 +150,12 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const Pr // try to prove x equals val CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) { + // NOTE on implementation: this function can be called many times and can be a bottleneck, + // As a result, we keep comparison here lightweight. + // We only do constant int bound analysis here. + // + // For stronger comparison proof that is out of the recursive simplifcation + // consider look at analyzer::CanProveStrong PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -176,6 +182,8 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val if (dbound->max_value <= val) { return CompareResult::kLE; } + + // modular analysis if (val == 0) { ModularSet dmod = analyzer_->modular_set(diff); if (dmod->base != 0) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b8e7fcdd9433..22e7a0b74c40 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -104,7 +104,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // maximum number of recursion allowed during a single pass. static const constexpr int kMaxRecurDepth = 5; - /*! * \brief try to compare x against val. * \param x The expression to be evaluated. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 409356c2b155..057cec475d84 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; + /*!\ brief Internal analyzer. */ + arith::Analyzer ana_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -318,7 +320,7 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - if (range.IsSinglePoint()) { + if (range.CanProveSinglePoint(&ana_)) { PrimExpr min = range.min(); region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); } else { diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 988c73c3f071..75ea308de8a3 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise( arith::IntSet required = required_region[i]; PrimExpr dim_max = max(buffer->shape[i] - 1, 0); - if (provided.IsSinglePoint() && is_const_int(provided.min())) { - ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min())); + if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) { + ICHECK(required.CanProveSinglePoint(analyzer) && + analyzer->CanProveEqual(provided.min(), required.min())); continue; } @@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array& std::unordered_map* iter_doms) { // we only support single point provided region now, which could cover most cases for (const auto& intset : provided_region) { - if (!intset.IsSinglePoint()) return false; + if (!intset.CanProveSinglePoint(analyzer)) return false; } // calculate forward mapping (block vars -> provided region point) Map dom_map; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 992817e87e2d..ac7be0baeb8b 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -430,7 +430,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate)) { + if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index aa9d5179aa3f..754bf36d7ab2 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -34,5 +34,36 @@ def test_simplify_reshape_flattened_index(): ) +def test_simplify_symbolic_comparison(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64") + outer = (n + 31) // 32 + ana.bind(i0, tvm.ir.Range(0, outer)) + ana.bind(i1, tvm.ir.Range(0, 32)) + PS = tvm.arith.ProofStrength + + assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND) + assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND) + assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND) + + +def test_regression_simplify_inf_recursion(): + ana = tvm.arith.Analyzer() + cond = tir.Var("cond", "int32") + + res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype( + "int32" + ) == 0 + # regression in a previous case + # try compare and int set recursive call can cause infinite loop + ana.rewrite_simplify(res) + + if __name__ == "__main__": tvm.testing.main()