From 1ec2c369128c9d57bb09087ab16cb3a2527dd9de Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 24 Aug 2022 17:44:22 +0800 Subject: [PATCH] [TIR][CompactBufferAllocation] Improve upperbound estimation of buffer compaction (#12527) Hi, this change wants to add some minor updation to region estimator used by buffer compaction: - Add and clearify among `EstimateRegionStrictBound`, `EstimateRegionLowerBound` and `EstimateRegionUpperBound` Originally we have `EstimateRegionLowerBound`, actually it implements strict bound estimation IMO. Now add `upper` and `strict` version for where we actually want them. - When estimating upperbounds (eg. in buffer compaction), try estimate each dimension independently when they are dependent accesses where `EstimateRegionLowerBound` is expected to fail. Eg, `A[i, i], 3 < i < 16` fails via `EstimateRegionLowerBound` who check indices be independent. But we can still try best to invoke strict bound analysis on each dimension individually. - If range->extent == 1 for `EvalSet(range, dom)`, invoke `EvalSet(range->min, dom)` instead. Eg, `EvalSet([k*k, k*k+1), dom_k)` results to [-inf, +inf] due to current algorithm limitation but `EvalSet(k*k, dom_k)` results to a range which makes more sense. --- include/tvm/arith/int_set.h | 39 +- python/tvm/arith/__init__.py | 8 +- python/tvm/arith/int_set.py | 48 +++ src/arith/int_set.cc | 131 +++++-- src/tir/schedule/primitive/compute_at.cc | 2 +- src/tir/schedule/state.cc | 14 +- src/tir/schedule/utils.h | 18 - src/tir/transforms/compact_buffer_region.cc | 2 +- tests/python/unittest/test_arith_intset.py | 354 ++++++++++-------- ...est_tir_transform_compact_buffer_region.py | 100 +++++ 10 files changed, 496 insertions(+), 220 deletions(-) diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 7cc4efe6b012..5ef7108d9797 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -261,7 +261,29 @@ Array UnionRegionLowerBound(const Array>& nd_int_sets); IntSet Intersect(const Array& sets); /*! - * \brief Analyze the region with affine map, given the domain of variables and their predicate + * \brief Converts the Ranges to IntSets + * \param var_dom The ranges of variables + * \return The integer sets of the variables + */ +Map AsIntSet(const Map& var_dom); + +/*! + * \brief Analyze the region with affine map, given the domain of variables and their predicate. + * The result should be strict, i.e. no region is discarded or relaxed. + * \param region The region to be analyzed + * \param var_dom The ranges of the variables + * \param predicate The predicate for the affine map + * \param analyzer The analyzer used + * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis + */ +TVM_DLL Optional> EstimateRegionStrictBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer); + +/*! + * \brief Analyze the region with affine map, given the domain of variables and their predicate. + * Some subregion may be discarded during the lower-bound analysis. * \param region The region to be analyzed * \param var_dom The ranges of the variables * \param predicate The predicate for the affine map @@ -273,6 +295,21 @@ TVM_DLL Optional> EstimateRegionLowerBound(const Array& reg const PrimExpr& predicate, arith::Analyzer* analyzer); +/*! + * \brief Analyze the region with affine map, given the domain of variables and their predicate + * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added + * to the result. + * \param region The region to be analyzed + * \param var_dom The ranges of the variables + * \param predicate The predicate for the affine map + * \param analyzer The analyzer used + * \return an array of arith::IntSet as the result of analysis + */ +TVM_DLL Array EstimateRegionUpperBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_SET_H_ diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index f5a0478dc008..03c0769850c9 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -16,7 +16,13 @@ # under the License. """Integer bound analysis, simplification and pattern detection.""" -from .int_set import IntSet, IntervalSet, estimate_region_lower_bound +from .int_set import ( + IntSet, + IntervalSet, + estimate_region_lower_bound, + estimate_region_strict_bound, + estimate_region_upper_bound, +) from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index b5f2100b7c7d..151461bcaf9f 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -83,6 +83,7 @@ def __init__(self, min_value, max_value): def estimate_region_lower_bound(region, var_dom, predicate): """Analyze the region with affine map, given the domain of variables and their predicate + Some subregion may be discarded during the lower-bound analysis. Parameters ---------- @@ -103,6 +104,53 @@ def estimate_region_lower_bound(region, var_dom, predicate): return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate) +def estimate_region_strict_bound(region, var_dom, predicate): + """Analyze the region with affine map, given the domain of variables and their predicate + The result should be strict, i.e. no region is discarded or relaxed. + + Parameters + ---------- + region : List[Range] + The region to be analyzed. + + var_dom : Dict[Var, Range] + The ranges of the variables + + predicate : PrimExpr + The predicate for the affine map + + Returns + ---------- + region_int_set : Optional[List[IntSet]] + None if the detection fails, or an array of IntSets as the result of analysis + """ + return _ffi_api.EstimateRegionStrictBound(region, var_dom, predicate) + + +def estimate_region_upper_bound(region, var_dom, predicate): + """Analyze the region with affine map, given the domain of variables and their predicate + Relaxation of the region may be used in upper-bound analysis, + i.e. some extra region may be added to the result. + + Parameters + ---------- + region : List[Range] + The region to be analyzed. + + var_dom : Dict[Var, Range] + The ranges of the variables + + predicate : PrimExpr + The predicate for the affine map + + Returns + ---------- + region_int_set : List[IntSet] + an array of IntSets as the result of analysis + """ + return _ffi_api.EstimateRegionUpperBound(region, var_dom, predicate) + + def pos_inf(): """Returns the symbolic positive infinity diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 584bbe8f04ea..e8e223ceca09 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -975,6 +975,9 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; + if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { + return EvalSet(r->min, dom_map); + } IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; @@ -1035,15 +1038,57 @@ IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -Optional> EstimateRegionLowerBound(const Array& region, - const Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { +Map AsIntSet(const Map& var_dom) { + Map result; + for (auto kv : var_dom) { + const Var& var = kv.first; + const Range& range = kv.second; + result.Set(var, arith::IntSet::FromRange(range)); + } + return result; +} + +/*! \brief Helper function to convert IterSumExpr to the actual touched range. */ +static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, + Analyzer* analyzer) { + if (iter_min->args.empty()) { + return IntSet::FromMinExtent(iter_min->base, extent); + } + ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr"; + const IterSplitExpr& split = iter_min->args[0]; + if (!analyzer->CanProve(extent >= split->scale)) { + return NullOpt; + } + + const PrimExpr& base = iter_min->base; + // IterSplitExpr: (source // lower_factor) % extent * scale + // where `(source // lower_factor) % extent` is within [0, extent - 1] + if (analyzer->CanProve(split->scale < 0)) { + // If scale is negative, the var dom is [(extent - 1) * scale, 0] + // The total base is `base + (extent - 1) * scale`, + // while total extent is `dom_extent + (extent - 1) * (-scale)` + const PrimExpr& var_extent = (split->extent - 1) * split->scale; + return IntSet::FromMinExtent(base + var_extent, extent - var_extent); + } else { + // If scale is positive, the var dom is [0, (extent - 1) * scale] + // The total dom is [base, dom_extent + (extent - 1) * scale] + return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale); + } +} + +Optional> EstimateRegionStrictBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, Analyzer* analyzer) { int ndim = region.size(); Array iter_sum_exprs{nullptr}; { Array affine_indices; affine_indices.reserve(ndim); for (const Range& range : region) { + if (!is_const_number(range->extent)) { + // dynamic extent is not supported yet. + return NullOpt; + } affine_indices.push_back(range->min); } auto res = DetectIterMap( @@ -1060,31 +1105,57 @@ Optional> EstimateRegionLowerBound(const Array& region, for (int i = 0; i < ndim; ++i) { const IterSumExpr& sum_expr = iter_sum_exprs[i]; const Range& range = region[i]; - if (sum_expr->args.empty()) { - result.push_back(IntSet::FromMinExtent(sum_expr->base, range->extent)); - continue; - } - ICHECK_EQ(sum_expr->args.size(), 1); - const IterSplitExpr& split = sum_expr->args[0]; - if (!analyzer->CanProve(range->extent >= split->scale)) { + Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer); + if (int_set.defined()) { + result.push_back(int_set.value()); + } else { return NullOpt; } + } + return result; +} - const PrimExpr& base = sum_expr->base; - // IterSplitExpr: (source // lower_factor) % extent * scale - // where `(source // lower_factor) % extent` is within [0, extent - 1] - if (analyzer->CanProve(split->scale < 0)) { - // If scale is negative, the var dom is [(extent - 1) * scale, 0] - // The total base is `base + (extent - 1) * scale`, - // while total extent is `dom_extent + (extent - 1) * (-scale)` - const PrimExpr& var_extent = (split->extent - 1) * split->scale; - result.push_back(IntSet::FromMinExtent(base + var_extent, range->extent - var_extent)); - } else { - // If scale is positive, the var dom is [0, (extent - 1) * scale] - // The total dom is [base, dom_extent + (extent - 1) * scale] - result.push_back( - IntSet::FromMinExtent(base, range->extent + (split->extent - 1) * split->scale)); +Optional> EstimateRegionLowerBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer) { + return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); +} + +Array EstimateRegionUpperBound(const Array& region, const Map& var_dom, + const PrimExpr& predicate, Analyzer* analyzer) { + if (Optional> result = EstimateRegionStrictBound( + /*region=*/region, + /*var_dom=*/var_dom, + /*predicate=*/predicate, /*analyzer=*/analyzer)) { + return result.value(); + } + Array result; + result.reserve(region.size()); + // try estimate each dimension independently + for (const Range& range : region) { + auto res = DetectIterMap( + /*indices=*/{range->min}, /*input_iters=*/var_dom, + /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); + if (!res->indices.empty()) { + ICHECK_EQ(res->indices.size(), 1U); + IterSumExpr sum_expr = res->indices[0]; + + // dynamic extent is not supported yet. + PrimExpr extent = range->extent; + if (!is_const_number(extent)) { + IntSet relaxed = EvalSet(extent, AsIntSet(var_dom)); + ICHECK(relaxed.HasUpperBound()); + extent = relaxed.max(); + } + + if (Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { + result.push_back(int_set.value()); + continue; + } } + // fallback to coarse grained evalset + result.push_back(EvalSet(range, AsIntSet(var_dom))); } return result; } @@ -1118,6 +1189,18 @@ TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") Analyzer analyzer; return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); }); +TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") + .set_body_typed([](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); + }); +TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") + .set_body_typed([](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); + }); TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 7b0d749f03dc..98a6b2400ee3 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -356,7 +356,7 @@ void RelaxBufferRegions(const Map& binding, runtime::StorageRank rank = scope.rank; if (rank != previous_rank || !var_dom.defined()) { previous_rank = rank; - var_dom = AsIntSet(LoopDomainOfSRefTreePath( + var_dom = arith::AsIntSet(LoopDomainOfSRefTreePath( /*low_inclusive=*/relax_path_low_inclusive, /*high_exclusive=*/relax_path_high_exclusive, /*extra_relax_scope=*/scope)); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index dadabba48540..07481ddb19e3 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -#include "./utils.h" +#include +#include "./utils.h" namespace tvm { namespace tir { @@ -44,13 +45,10 @@ Array AnalyzeRegionUpperBound(const BufferRegion& region, /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); - if (Optional> result = EstimateRegionLowerBound( - /*region=*/region->region, - /*var_dom=*/var_dom, - /*predicate=*/predicate, /*analyzer=*/analyzer)) { - return result.value(); - } - return arith::EvalSet(region->region, AsIntSet(var_dom)); + return EstimateRegionUpperBound( + /*region=*/region->region, + /*var_dom=*/var_dom, + /*predicate=*/predicate, /*analyzer=*/analyzer); } /*! diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 53cafa798b54..3db80989ae10 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -249,24 +249,6 @@ inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { return thread_scope.rank == 1 && thread_scope.dim_index >= 0; } -/******** Integer set ********/ - -/*! - * \brief Converts the Ranges to IntSets - * \param var_dom The ranges of variables - * \return The integer sets of the variables - */ -inline Map AsIntSet(const Map& var_dom) { - std::unordered_map result; - result.reserve(var_dom.size()); - for (auto kv : var_dom) { - Var& var = kv.first; - Range& range = kv.second; - result.emplace(std::move(var), arith::IntSet::FromRange(std::move(range))); - } - return {result.begin(), result.end()}; -} - /**************** Loop extents ****************/ /*! diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 2844f1b35e9e..249b8cca77b0 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -88,7 +88,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate, var_dom[GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); } Optional> eval_res = - arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer); + arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer); if (eval_res.defined()) { return NDIntSet(eval_res.value().begin(), eval_res.value().end()); } diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 2302d0ed54f2..24228fb52703 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te from tvm import tir -from tvm.ir.base import structural_equal +from tvm.arith.analyzer import Analyzer class IntSetChecker: @@ -128,66 +129,139 @@ def test_select(): ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) -def test_region_lower_bound_not_independent(): +def check_region_bound(expect_region, var_dom, mode, predicate=None): + """Helper to check region bound estimation. + + Parameters + ---------- + expect_region: dict + The keys are of form (begin, end) or PrimExpr as a single point. The values are + expected estimated region or region dict on different bindings. + + var_dom: dict + Map var to iteration domain range. + + mode: str + Specify "lowerbound", "upperbound" or else use strict bound estimation. + + predicate: PrimExpr + Extra predicate, defaults to True. + """ + if predicate is None: + predicate = tvm.tir.IntImm("bool", 1) + region = [] + expect = [] + for k, v in expect_region.items(): + if not isinstance(k, (tuple, list)): + k = (k, k + 1) + region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0]))) + expect.append(v) + if mode == "lowerbound": + result = tvm.arith.estimate_region_lower_bound( + region=region, var_dom=var_dom, predicate=predicate + ) + elif mode == "upperbound": + result = tvm.arith.estimate_region_upper_bound( + region=region, var_dom=var_dom, predicate=predicate + ) + else: + result = tvm.arith.estimate_region_strict_bound( + region=region, var_dom=var_dom, predicate=predicate + ) + if result is None: + assert all([_ is None for _ in expect]) + return + assert len(result) == len(expect) + for intset, expect_desc in zip(result, expect): + if isinstance(expect_desc, dict): + # check range on different free var bindings + for binding in expect_desc: + analyzer = Analyzer() + for k, v in binding: + analyzer.bind(k, v) + expect_begin, expect_end = expect_desc[binding] + result_begin = analyzer.simplify(intset.min_value, 3) + result_end = analyzer.simplify(intset.max_value + 1, 3) + print(result_end) + assert analyzer.can_prove_equal( + result_begin - expect_begin, 0 + ), f"{result_begin} vs {expect_begin}" + assert analyzer.can_prove_equal( + result_end - expect_end, 0 + ), f"{result_end} vs {expect_end}" + else: + # check range + expect_begin, expect_end = expect_desc + analyzer = Analyzer() + assert analyzer.can_prove_equal( + intset.min_value - expect_begin, 0 + ), f"{intset.min_value} vs {expect_begin}" + assert analyzer.can_prove_equal( + intset.max_value - expect_end + 1, 0 + ), f"{intset.max_value} vs {expect_end - 1}" + + +def test_region_bound_not_independent(): + # (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available i = tvm.tir.Var("i", "int32") - result = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range(begin=i, end=i + 2), - tvm.ir.Range(begin=i + 1, end=i + 4), - ], - var_dom={ - i: tvm.ir.Range(begin=0, end=64), - }, - predicate=tvm.tir.IntImm("bool", 1), + var_dom = { + i: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({(i, i + 2): None, (i + 2, i + 4): None}, var_dom, mode="lowerbound") + check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound") + + # when only a subset of access indices are affine + i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32") + var_dom = { + i: tvm.ir.Range(begin=0, end=16), + j: tvm.ir.Range(begin=0, end=16), + k: tvm.ir.Range(begin=0, end=16), + } + check_region_bound( + {i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None}, + var_dom, + predicate=j * 4 + i % 4 > 3, + mode="lowerbound", + ) + check_region_bound( + {i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)}, + var_dom, + predicate=j * 4 + i % 4 > 3, + mode="upperbound", ) - assert result is None -def test_region_lower_bound_stride_too_wide(): +def test_region_bound_stride_too_wide(): i = tvm.tir.Var("i", "int32") - result = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range(begin=i * 4, end=i * 4 + 2), - ], - var_dom={ - i: tvm.ir.Range(begin=0, end=64), - }, - predicate=tvm.tir.IntImm("bool", 1), - ) - assert result is None + var_dom = {i: tvm.ir.Range(begin=0, end=64)} + check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound") + check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound") -def test_region_lower_bound_small_stride(): +def test_region_bound_small_stride(): i = tvm.tir.Var("i", "int32") - (result,) = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range.from_min_extent(min_value=i * 4, extent=8), - ], - var_dom={ - i: tvm.ir.Range(begin=0, end=64), - }, - predicate=tvm.tir.IntImm("bool", 1), - ) - assert result.min_value.value == 0 - assert result.max_value.value == 259 + var_dom = { + i: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({(i * 4, i * 4 + 8): (0, 260)}, var_dom, mode="lowerbound") def test_region_lower_bound_split_predicate(): x_o = tvm.tir.Var("xo", "int32") x_i = tvm.tir.Var("xi", "int32") x = x_o * 4 + x_i - (result,) = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range.from_min_extent(min_value=x * 4, extent=8), - ], - var_dom={ - x_o: tvm.ir.Range(begin=0, end=16), - x_i: tvm.ir.Range(begin=0, end=4), - }, + var_dom = { + x_o: tvm.ir.Range(begin=0, end=16), + x_i: tvm.ir.Range(begin=0, end=4), + } + check_region_bound({(x * 4, x * 4 + 8): (0, 256)}, var_dom, predicate=x < 63, mode="lowerbound") + + check_region_bound( + {(x * 4, x * 4 + 8): (0, 256), (x * 3, x * 3 + 5): (0, 191)}, + var_dom, predicate=x < 63, + mode="upperbound", ) - assert result.min_value.value == 0 - assert result.max_value.value == 255 def test_region_lower_bound_multiple_variables(): @@ -198,127 +272,94 @@ def test_region_lower_bound_multiple_variables(): i = div(x, 16) j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4 k = wid % 32 - (i_int_set, j_int_set, k_int_set) = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range.from_min_extent(min_value=i, extent=1), - tvm.ir.Range.from_min_extent(min_value=j, extent=1), - tvm.ir.Range.from_min_extent(min_value=k, extent=1), - ], - var_dom={ - x: tvm.ir.Range(begin=0, end=32), - wid: tvm.ir.Range(begin=0, end=64), - }, - predicate=tvm.tir.IntImm("bool", 1), - ) - assert i_int_set.min_value.value == 0 - assert i_int_set.max_value.value == 1 - assert j_int_set.min_value.value == 0 - assert j_int_set.max_value.value == 31 - assert k_int_set.min_value.value == 0 - assert k_int_set.max_value.value == 31 + var_dom = { + x: tvm.ir.Range(begin=0, end=32), + wid: tvm.ir.Range(begin=0, end=64), + } + check_region_bound({i: (0, 2), j: (0, 32), k: (0, 32)}, var_dom, mode="lowerbound") def test_region_lower_bound_negative_scale(): i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") - int_set_0, int_set_1 = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range.from_min_extent(min_value=1 - i, extent=4), - tvm.ir.Range.from_min_extent(min_value=20 - j * 4, extent=16), - ], - var_dom={ - i: tvm.ir.Range(begin=0, end=4), - j: tvm.ir.Range(begin=0, end=4), - }, - predicate=tvm.tir.IntImm("bool", 1), + var_dom = { + i: tvm.ir.Range(begin=0, end=4), + j: tvm.ir.Range(begin=0, end=4), + } + check_region_bound( + {(1 - i, 5 - i): (-2, 5), (20 - j * 4, 36 - j * 4): (8, 36)}, var_dom, mode="lowerbound" ) - assert int_set_0.min_value.value == -2 - assert int_set_0.max_value.value == 4 - assert int_set_1.min_value.value == 8 - assert int_set_1.max_value.value == 35 def test_region_lower_bound_for_non_perfect_tile(): h1 = tvm.tir.Var("h1", "int32") h2 = tvm.tir.Var("h2", "int32") h3 = tvm.tir.Var("h3", "int32") - analyzer = tvm.arith.Analyzer() - - def do_test_point_access(point, predicates, var_dom, expect): - regions = tvm.arith.estimate_region_lower_bound( - region=[ - tvm.ir.Range.from_min_extent(min_value=point, extent=1), - ], - var_dom=var_dom, - predicate=tvm.tir.all(*predicates), - ) - if expect is None: # expect a failure - assert regions is None - else: - assert len(regions) == 1 - for binding, expect_min, expect_max in expect: - min_diff = expect_min - regions[0].min_value - assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0 - max_diff = expect_max - regions[0].max_value - assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0 # non-uniform tiling, single inner variable - # h3 == 0: region is [1, 9] - # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9] - # h3 > 26: region is [h3 * 8, 223] - do_test_point_access( - point=h3 * 8 + h2, - predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224], - var_dom={ - h2: tvm.ir.Range(begin=0, end=10), + var_dom = { + h2: tvm.ir.Range(begin=0, end=10), + } + check_region_bound( + { + h3 * 8 + + h2: { + (): ( + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 224, + ), + ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) + ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) + ((h3, 27),): (h3 * 8, 224), # h3 > 26: region is [h3 * 8, 224) + } }, - expect=[ - ( - {}, - tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 223, - ), - ({h3: 0}, 1, 9), - ({h3: 10}, h3 * 8, h3 * 8 + 9), - ({h3: 27}, h3 * 8, 223), - ], + var_dom, + predicate=tvm.tir.all(1 <= h3 * 8 + h2, h3 * 8 + h2 < 224), + mode="lowerbound", ) # non-uniform tiling, two inner variables - do_test_point_access( - point=h3 * 8 + h2 * 5 + h1, - predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224], - var_dom={ - h2: tvm.ir.Range(begin=0, end=2), - h1: tvm.ir.Range(begin=0, end=5), + var_dom = { + h1: tvm.ir.Range(begin=0, end=5), + h2: tvm.ir.Range(begin=0, end=2), + } + check_region_bound( + { + h3 * 8 + + h2 * 5 + + h1: { + (): ( + tvm.tir.max(h3 * 8, 1), + tvm.tir.max(h3 * 8, 1) + - tvm.tir.max(h3 * 8, 214) + - tvm.tir.max(1 - h3 * 8, 0) + + 224, + ), + ((h3, 0),): (1, 10), + ((h3, 10),): (h3 * 8, h3 * 8 + 10), + ((h3, 27),): (h3 * 8, 224), + } }, - expect=[ - ( - {}, - tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 223, - ), - ({h3: 0}, 1, 9), - ({h3: 10}, h3 * 8, h3 * 8 + 9), - ({h3: 27}, h3 * 8, 223), - ], + var_dom, + predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224), + mode="lowerbound", ) - # should fail on incompatible predicates - do_test_point_access( - point=h3 * 8 + h2 * 5 + h1, - predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224], - var_dom={ - h2: tvm.ir.Range(begin=0, end=2), - h1: tvm.ir.Range(begin=0, end=5), - }, - expect=None, + # lowerbound should fail on incompatible predicates + check_region_bound( + {h3 * 8 + h2 * 5 + h1: None}, + var_dom, + predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), + mode="lowerbound", + ) + check_region_bound( + {h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)}, + var_dom, + predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), + mode="upperbound", ) @@ -328,12 +369,7 @@ def test_region_lower_bound_unfusable(): tvm.tir.Var("j", "int32"): tvm.ir.Range(4), } i, j = var_dom - region = [ - tvm.ir.Range.from_min_extent((i + j) // 2, 1), - ] - result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True) - assert result[0].min_value == 0 - assert result[0].max_value == 5 + check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound") def test_union_lower_bound(): @@ -347,18 +383,4 @@ def test_union_lower_bound(): if __name__ == "__main__": - test_basic() - test_vector() - test_add_sub() - test_mul_div() - test_max_min() - test_select() - test_mod() - test_region_lower_bound_not_independent() - test_region_lower_bound_stride_too_wide() - test_region_lower_bound_small_stride() - test_region_lower_bound_split_predicate() - test_region_lower_bound_multiple_variables() - test_region_lower_bound_negative_scale() - test_region_lower_bound_for_non_perfect_tile() - test_union_lower_bound() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 31bb9b8b7cdb..049de0bed4f9 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -909,5 +909,105 @@ def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), _check(func, compacted_func) +def test_compact_dependent_buffer_indices(): + """Check the upper bound on different indices could be independently estimated.""" + + @T.prim_func + def diagonal_access(): + for i in range(8): + with T.block(): + A = T.alloc_buffer((256, 256), "float32") + for j, k in T.grid(8, 8): + with T.block(): + T.where(j * 8 + k < 60) + A[i * 64 + j * 8 + k, i * 64 + j * 8 + k] = 1.0 + + @T.prim_func + def diagonal_access_compacted() -> None: + for i in T.serial(8): + with T.block(): + A = T.alloc_buffer([60, 60], dtype="float32") + for j, k in T.grid(8, 8): + with T.block(): + T.where(j * 8 + k < 60) + A[j * 8 + k, j * 8 + k] = 1.0 + + _check(diagonal_access, diagonal_access_compacted) + + +def test_compact_dependent_buffer_indices_of_packed_matmul(): + """Check the outer dimension of the packed M-dim should be compacted to 1 wrt split condition.""" + + @T.prim_func + def nonuniform_packed_matmul_write_cache( + A: T.Buffer[(1020, 64), "float32"], + B: T.Buffer[(1000, 64), "float32"], + C: T.Buffer[(1020, 1000), "float32"], + ): + for i0, i1 in T.grid(4, 1): + with T.block(): + C_local2 = T.alloc_buffer([4, 1, 16, 1000, 16], dtype="float32", scope="local") + C_local1 = T.alloc_buffer([1020, 1000], dtype="float32", scope="local") + for ax0, ax1, ax2 in T.grid(255, 1000, 64): + with T.block("matmul"): + if ax2 == 0: + C_local1[i0 * 255 + ax0, ax1] = 0 + C_local1[i0 * 255 + ax0, ax1] = ( + C_local1[i0 * 255 + ax0, ax1] + A[i0 * 255 + ax0, ax2] * B[ax1, ax2] + ) + for ax0, ax1 in T.grid(255, 1000): + with T.block("st1"): + C_local2[ + (i0 * 255 + ax0) // 255, + 0, + (i0 * 255 + ax0) % 255 // 16, + ax1, + (i0 * 255 + ax0) % 255 % 16, + ] = C_local1[i0 * 255 + ax0, ax1] + for ax0, ax1, ax2 in T.grid(16, 16, 1000): + with T.block("st2"): + T.where(ax0 * 16 + ax1 < 255) + C[i0 * 255 + (ax0 * 16 + ax1), i1 * 1000 + ax2] = C_local2[ + (i0 * 255 + ax0 * 16 + ax1) // 255, + 0, + (i0 * 255 + ax0 * 16 + ax1) % 255 // 16, + i1 * 1000 + ax2, + (i0 * 255 + ax0 * 16 + ax1) % 255 % 16, + ] + + @T.prim_func + def nonuniform_packed_matmul_write_cache_compacted( + A: T.Buffer[(1020, 64), "float32"], + B: T.Buffer[(1000, 64), "float32"], + C: T.Buffer[(1020, 1000), "float32"], + ) -> None: + for i0, i1 in T.grid(4, 1): + with T.block(): + C_local2 = T.alloc_buffer([1, 1, 15, 1000, 16], dtype="float32", scope="local") + C_local1 = T.alloc_buffer([255, 1000], dtype="float32", scope="local") + for ax0, ax1, ax2 in T.grid(255, 1000, 64): + with T.block("matmul"): + if ax2 == 0: + C_local1[ax0, ax1] = 0 + C_local1[ax0, ax1] = ( + C_local1[ax0, ax1] + A[i0 * 255 + ax0, ax2] * B[ax1, ax2] + ) + for ax0, ax1 in T.grid(255, 1000): + with T.block("st1"): + C_local2[0, 0, ax0 // 16, ax1, ax0 % 16] = C_local1[ax0, ax1] + for ax0, ax1, ax2 in T.grid(16, 16, 1000): + with T.block("st2"): + T.where(ax0 * 16 + ax1 < 255) + C[i0 * 255 + ax0 * 16 + ax1, ax2] = C_local2[ + (ax0 * 16 + ax1) // 255, + 0, + (ax0 * 16 + ax1) % 255 // 16, + ax2, + (ax0 * 16 + ax1) % 255 % 16, + ] + + _check(nonuniform_packed_matmul_write_cache, nonuniform_packed_matmul_write_cache_compacted) + + if __name__ == "__main__": tvm.testing.main()