From dc81767f9e59fe8ed832fd9b52ed39ac61e869e6 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Sat, 27 Mar 2021 20:26:05 +0800 Subject: [PATCH] [ARITH] detect iter affine map with predicate (#7752) --- include/tvm/arith/analyzer.h | 10 + include/tvm/arith/iter_affine_map.h | 5 +- include/tvm/tir/analysis.h | 6 + python/tvm/arith/iter_affine_map.py | 14 +- src/arith/analyzer.cc | 7 + src/arith/expr_complexity.cc | 53 ++ src/arith/iter_affine_map.cc | 453 ++++++++++++++---- src/arith/solve_linear_inequality.cc | 51 +- .../unittest/test_arith_iter_affine_map.py | 104 +++- 9 files changed, 547 insertions(+), 156 deletions(-) create mode 100644 src/arith/expr_complexity.cc diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index cd20bdcf4d1a..adb037bfd050 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -458,6 +458,16 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ bool CanProveLess(const PrimExpr& expr, int64_t upper_bound); + /*! + * \brief Whether can we prove lhs == rhs. + * + * \param lhs The input lhs. + * \param rhs The input rhs. + * \return Whether we can prove lhs == rhs. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs); /*! * \brief Whether can we prove condition. * diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index e2e081d2be89..f786c013443c 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -136,6 +136,7 @@ class IterMark : public ObjectRef { TVM_DLL IterMark(PrimExpr source, PrimExpr extent); TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode); }; /*! @@ -259,7 +260,6 @@ class IterSumExpr : public IterMapExpr { /*! * \brief Detect if indices can be written as - * * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] * * Here y = some-quasi-affine-iter-map(input_iters) @@ -272,12 +272,15 @@ class IterSumExpr : public IterMapExpr { * * \param indices The indices to detect pattern for. * \param input_iters Map from variable to iterator's range. + * \param predicate The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); } // namespace arith diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 1692a8cdacf3..250a84e782a2 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -172,6 +172,12 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constrain Array> GetBlockAccessRegion(const Block& block, const Map& buffer_var_map); +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + * \param expr The expr to be calculated. + */ +TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr); + // Pass variants of verification analysis // directly throws RuntimeError when verification fails. namespace transform { diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 123d9b85480a..4033d797dff8 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -88,21 +88,27 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) -def detect_iter_map(indices, input_iters): - """Detect if indices can be written mapped iters from input_iters. +def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False): + """Detect if indices can be written as mapped iters from input iters Parameters ---------- indices : List[PrimExpr] - The input indices. + The input indices input_iters : Map[Var, Range] The domain of each input iterators. + predicate : PrimExpr + The predicate constraints on the input iterators + + require_bijective : bool + A boolean flag that indicates whether the mapping should be bijective + Returns ------- results : List[IterSumExpr] The iter map matching result. Empty array if no match can be found. """ - return _ffi_api.DetectIterMap(indices, input_iters) + return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 9737b53703fd..08e32f576299 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -100,6 +100,13 @@ bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { return false; } +bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { + const auto* clhs = lhs.as(); + const auto* crhs = rhs.as(); + if (clhs && crhs) return clhs->value == crhs->value; + return CanProve(lhs - rhs == 0); +} + bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; diff --git a/src/arith/expr_complexity.cc b/src/arith/expr_complexity.cc new file mode 100644 index 000000000000..e809668bb624 --- /dev/null +++ b/src/arith/expr_complexity.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/expr_complexity.cc + * \brief Calculate expr complexity. + */ +#include +#include + +namespace tvm { +namespace tir { + +/*! \brief Count the size of the PrimExpr. */ +class PrimExprSizeCounter : public ExprVisitor { + public: + PrimExprSizeCounter() = default; + + static size_t Count(const PrimExpr& expr) { + PrimExprSizeCounter prim_expr_size_counter; + prim_expr_size_counter.VisitExpr(expr); + return prim_expr_size_counter.counter_; + } + + private: + void VisitExpr(const PrimExpr& expr) final { + counter_++; + ExprVisitor::VisitExpr(expr); + } + + size_t counter_{0}; +}; + +size_t CalculateExprComplexity(const PrimExpr& expr) { return PrimExprSizeCounter::Count(expr); } + +} // namespace tir +} // namespace tvm diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7efdd03fa11e..3757b5eb0d51 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -29,6 +29,7 @@ #include "../support/utils.h" #include "const_fold.h" +#include "pattern_match.h" namespace tvm { namespace arith { @@ -123,11 +124,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /*! - * \brief Collector that collects - * the outgoing split reference of each IterMark. + * \brief Collector that collects the outgoing split reference of each IterMark. * - * These out-going splits can then be used to - * check if the iterators are independent. + * These out-going splits can then be used to check if the iterators are independent. */ class IterMarkSplitCollector { public: @@ -161,8 +160,7 @@ class IterMarkSplitCollector { } }; -// Rewriter to rewrite PrimExpr to IterMapExpr -// when possible +/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; @@ -170,16 +168,19 @@ class IterMapRewriter : public ExprMutator { explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) : analyzer_(analyzer) { for (auto kv : input_iters) { - const auto& vrng = kv.second; - if (is_zero(vrng->min)) { - IterMark mark(kv.first, vrng->extent); - var_map_[kv.first] = IterSplitExpr(mark); + const Var& var = kv.first; + const Range& vrng = kv.second; + if (is_one(vrng->extent)) { + var_map_[var] = IterSumExpr({}, vrng->min); + } else if (is_zero(vrng->min)) { + IterMark mark(var, vrng->extent); + var_map_[var] = IterSplitExpr(mark); input_marks_.push_back(mark); } else { - IterMark mark(kv.first - vrng->min, vrng->extent); - auto sum_expr = ToIterSumExpr(IterSplitExpr(mark)); + IterMark mark(var - vrng->min, vrng->extent); + IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark)); sum_expr.CopyOnWrite()->base = vrng->min; - var_map_[kv.first] = sum_expr; + var_map_[var] = sum_expr; input_marks_.push_back(mark); } } @@ -187,33 +188,88 @@ class IterMapRewriter : public ExprMutator { size_t unresolved_count() const { return unresolved_count_; } - IterSumExpr Rewrite(PrimExpr expr) { + IterSumExpr Rewrite(const PrimExpr& expr) { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - bool CheckBijective(const Array& indices) { - // This function checks two conditions: - // - C0: Each iter mark should be fully covered by non-overlapping splits. - // - C1: All of the input iterators are used. - // - // Example: given x in [0, 8) y in [0, 6) - // - indices = [x, x+1, y] won't pass because x and x+1 contribute - // two splits that overlaps with each other. - // - indices = [x / 4, x % 4, y] will pass because x / 4 and x % 4 - // contribute two non-overlapping splits that covers x. - // - indices = [x / 4, x % 4] won't pass because y is not used. - // + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, + const PrimExpr& predicate_induced_extent) { + return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent); + } + + /*! + * \brief If require_bijective is true, this function checks two conditions: + * - C0: Each iter mark should be fully covered by non-overlapping splits. + * - C1: All of the input iterators are used. + * Example: given x in [0, 8) y in [0, 6) + * - bindings = [x, x + 1, y] won't pass because x and x+1 contribute + * two splits that overlaps with each other. + * - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4 + * contribute two non-overlapping splits that covers x. + * - bindings = [x / 4, x % 4] won't pass because y is not used. + * + * If require_bijective is false, this function checks one condition: + * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits. + * Example: given x in [0, 8) y in [0, 6) + * - bindings = [x / 4] will pass because x / 4 can be one split of x + * - bindings = [x / 4, x % 4] will pass because x / 4 and x % 4 + * contribute two non-overlapping splits that covers x. + * - bindings = [x / 3] will not pass because x / 3 can not be one split of x + * \return whether the bindings are valid + */ + bool CheckMapping(const Array& bindings, bool require_bijective) { IterMarkSplitCollector collector; // We can check that for each iter mark: - // All the splits that refers to the itermark covers its extent. + // All the splits that refers to the iter_mark covers its extent. // The splits do not overlap with each other. - collector.Collect(indices); + collector.Collect(bindings); for (const IterMark& mark : collector.visited_) { - if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).empty()) return false; + if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) + return false; } - // all input marks must be visited - for (const auto& mark : input_marks_) { - if (collector.visited_.count(mark) == 0) return false; + if (require_bijective) { + // all input marks must be visited + for (const IterMark& mark : input_marks_) { + if (collector.visited_.count(mark) == 0) return false; + } + } + return true; + } + + /*! + * \brief Check the validity of iterator constraints + * The flattened forms of two different iterator constraints + * either 1) follow inclusion relation or 2) have no intersection + * + * For Example, x = i0*30 + i1*15 + i2*3 + i3, + * 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \intersect {i2, i3} = empty set. + * 2) [i0*2 + i1 < 3, i1*5 + i2 < 5] is not valid, + * since {i0, i1} \intersect {i1, i2} = {i1}, i0 \in {i0, i1}, i0 \notin {i1, i2} + * \return whether the predicates are valid; + */ + bool CheckConstraints() const { + // the constrained_iters_flattened_ are in the order of shorter to longer + // since we visit the predicates in the order of size + for (size_t i = 0; i < constrained_iters_flattened_.size(); ++i) { + for (size_t j = i + 1; j < constrained_iters_flattened_.size(); ++j) { + // state: 0(start), -1(no intersection), 1(inclusion) + int state = 0; + for (const IterSplitExpr& arg1 : constrained_iters_flattened_[i]->args) { + bool found = false; + for (const IterSplitExpr& arg2 : constrained_iters_flattened_[j]->args) { + if (IterSplitEqual(arg1, arg2)) { + found = true; + break; + } + } + // Check either it is inclusion or intersection, but not both + if (state == 0) { + state = found ? 1 : -1; + } else if ((state == -1 && found) || (state == 1 && !found)) { + return false; + } + } + } } return true; } @@ -243,25 +299,30 @@ class IterMapRewriter : public ExprMutator { size_t operator()(const IterSumExpr& value) const { // for now only hash on source index. size_t hash = value->args.size(); - for (const auto& arg : value->args) { + for (const IterSplitExpr& arg : value->args) { hash = support::HashCombine(hash, std::hash()(arg->source.get())); } return hash; } }; + static bool IterSplitEqual(const IterSplitExpr& lhs, const IterSplitExpr& rhs, + bool check_scale = true) { + tir::ExprDeepEqual equal; + if (!lhs->source.same_as(rhs->source)) return false; + if (!equal(lhs->lower_factor, rhs->lower_factor)) return false; + if (check_scale && !equal(lhs->scale, rhs->scale)) return false; + if (!equal(lhs->extent, rhs->extent)) return false; + return true; + } + struct IterSumEqual { bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const { tir::ExprDeepEqual equal; if (lhs->args.size() != rhs->args.size()) return false; if (!equal(lhs->base, rhs->base)) return false; for (size_t i = 0; i < lhs->args.size(); ++i) { - auto lvalue = lhs->args[i]; - auto rvalue = rhs->args[i]; - if (!lvalue->source.same_as(rvalue->source)) return false; - if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false; - if (!equal(lvalue->scale, rvalue->scale)) return false; - if (!equal(lvalue->extent, rvalue->extent)) return false; + if (!IterSplitEqual(lhs->args[i], rhs->args[i])) return false; } return true; } @@ -275,19 +336,64 @@ class IterMapRewriter : public ExprMutator { std::unordered_map var_map_; // input iter marks std::vector input_marks_; - // The canonical map for sum - std::unordered_map sum_fuse_map_; + // The map for sum that maps flattened form to IterMark with normal form and extent + // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // predicate: j*2 + k < 9 + // Then, flattened form = IterSum(IterSplit(i, scale=9), + // IterSplit(j, scale=2), + // IterSplit(k, scale=1)) + // normal form = IterSum(IterSplit(i, scale=9), + // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), + // IterSplit(k, scale=1)), + // extent=9) + // scale=1)) + std::unordered_map sum_fuse_map_; + // The map for sum that maps normal form to flattened form + std::unordered_map flattened_map_; + // The flattened forms of constrained iters + std::vector constrained_iters_flattened_; /*! - * \brief Verify that splits fully covers mark in a non-overlapping fashion. - * If verification passes, return splits from outermost to inner most order. - * If not, return an empty array + * \brief Look for a split in splits that is not used such that its lower_factor is smallest. + * Note that here we use division to compare lower_factor. + * \param splits the split array to search in. + * \param used the input used array. + * \param expected_lower_factor the skipped lower factor. + * \return the index of the expected split, split.size() if not found. + */ + size_t SearchSkipLowerFactor(const std::vector& splits, + const std::vector& used, + const PrimExpr& expected_lower_factor) { + size_t res = splits.size(); + for (size_t i = 0; i < splits.size(); ++i) { + if (used[i]) continue; + if (!used[i] && !CanProveDivisible(splits[i]->lower_factor, expected_lower_factor)) { + // all the remaining unused splits should have their lower factor divisible + return splits.size(); + } + if (res == splits.size() || + CanProveDivisible(splits[res]->lower_factor, splits[i]->lower_factor)) { + // note down the split with smaller lower factor + res = i; + } + } + return res; + } + + /*! + * \brief If bijective is required, verify that splits fully covers mark in a non-overlapping + * fashion, If not, verify that splits are valid and compatible for the mark. + * If verification passes, return splits from outermost to innermost order. + * If not, return an empty array. * \param mark The iterator of interest. * \param splits The splits to be verified. + * \param require_bijective A boolean flag that indicates whether the bindings should be + * bijective. * \return The normalized splits. */ Array TryNormalizeSplits(const IterMark& mark, - const std::vector& splits) { + const std::vector& splits, + bool require_bijective) { std::vector used(splits.size(), false); std::vector iters; PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); @@ -296,31 +402,83 @@ class IterMapRewriter : public ExprMutator { size_t j = 0; for (; j < splits.size(); ++j) { if (used[j]) continue; - if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; + if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) + break; } if (j == splits.size()) { - return Array(); + // we do not allow incomplete split if the bindings should be bijective + if (require_bijective) return Array(); + // look for the next split skipping this lower factor + // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] + // It is valid to only have [y / 6, y % 2] if bijective is not required + // We can skip (y / 2) % 6 + j = SearchSkipLowerFactor(splits, used, expected_lower_factor); + // split not found + if (j == splits.size()) return Array(); } used[j] = true; iters.push_back(splits[j]); - expected_lower_factor *= splits[j]->extent; + expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; + } + // Case 1. bijective is required. + // We check the extent we calculate is consistent with the extent of the mark + // Case 2. bijective is not required. + // We check the extent we calculate is a factor of the extent of the mark + // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. + if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) || + (!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) { + return Array(); } - if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array(); return Array(iters.rbegin(), iters.rend()); } + /*! + * \brief Normalize the left hand side of iter constraint(expr < predicate_induced_extent) + * \param expr The left hand side of iter constraint. + * \param predicate_induced_extent Extent from iter constraint. + * \return The Normalized expression. + */ + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, + const PrimExpr& predicate_induced_extent) { + // We are normalizing the left hand side of iter constraint(iter < predicate_induced_extent) + Optional opt = TryFuseIters(expr); + // scale should be 1 + if (opt.defined() && is_one(opt.value()->scale)) { + IterSumExpr sum = Downcast(opt.value()->source->source); + // get the flattened form + auto it = flattened_map_.find(sum); + ICHECK(it != flattened_map_.end()); + IterSumExpr flattened_form = it->second; + // get the mark + auto it_mark = sum_fuse_map_.find(flattened_form); + ICHECK(it_mark != sum_fuse_map_.end()); + IterMark mark = it_mark->second; + mark.CopyOnWrite()->extent = min(predicate_induced_extent, mark->extent); + // update the bound of the lhs based on predicate_induced_extent + sum_fuse_map_[flattened_form] = mark; + // we need to note down the flattened form of constrained iterators + // to check the validity of constraints, see also CheckConstraints() + constrained_iters_flattened_.push_back(flattened_form); + expr.CopyOnWrite()->args = Array({opt.value()}); + return expr; + } + ++unresolved_count_; + return expr; + } + /*! * \brief Normalize expr to an iterator + offset. * \param expr The input expression. * \return The Normalized expression. */ IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { + // We are normalizing a regular iter if (expr->args.size() <= 1) return expr; PrimExpr base = expr->base; expr.CopyOnWrite()->base = make_zero(expr->dtype); - auto opt = TryFuseIters(expr); + Optional opt = TryFuseIters(expr); expr.CopyOnWrite()->base = base; - if (opt) { + if (opt.defined()) { expr.CopyOnWrite()->args = Array({opt.value()}); return expr; } else { @@ -329,13 +487,6 @@ class IterMapRewriter : public ExprMutator { } } - bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) { - const auto* clhs = lhs.as(); - const auto* crhs = rhs.as(); - if (clhs && crhs) return clhs->value == crhs->value; - return analyzer_->CanProve(lhs - rhs == 0); - } - /*! * \brief Create a IterSumExpr from expr. * \param expr The input expr. @@ -352,22 +503,24 @@ class IterMapRewriter : public ExprMutator { } } - // Try to normalize IterSum into a fused IterMark - // return a corresponding splitexpr if needed. - // IterSum = x1*c1 + x2*c2 + ... + xn*cn - // = (x1*s1 + x2*s2 + ... + xn)*cn - // = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) - // = [IterSplit(IterMark(y), scale=cn)] - // return a corresponding IterSplitExpr if needed. + /*! + * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + * = (x1*s1 + x2*s2 + ... + xn)*cn + * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + * = [IterSplit(IterMark(y), scale=cn)] + * return a corresponding IterSplitExpr if needed. + * Try to normalize IterSum into a fused IterMark + * \param expr The input sum. + * \return The split with the fused IterMark if succeed. + */ Optional TryFuseIters(IterSumExpr expr) { if (!is_zero(expr->base)) return NullOpt; if (expr->args.size() == 1) return expr->args[0]; // select the iterators in order std::vector visited(expr->args.size(), false); - std::vector iters; - iters.reserve(expr->args.size()); - // canonicalize the expression - // find the base scale first + std::vector flattened_iters, grouped_iters; + // canonicalize the expression into two different forms: flattened form and structured form + // step0. check if find the base scale first Optional base_scale = NullOpt; size_t base_index = 0; for (size_t i = 0; i < expr->args.size(); ++i) { @@ -381,35 +534,87 @@ class IterMapRewriter : public ExprMutator { if (!base_scale) return NullOpt; // check if it can be remapped into a fused pattern. PrimExpr expected_scale = base_scale.value(); - for (size_t i = 0; i < expr->args.size(); ++i) { + for (size_t i = 0; i < expr->args.size();) { + // find j such that expr->args[j] has expected scale size_t j = i == 0 ? base_index : 0; for (; j < expr->args.size(); ++j) { - if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break; + if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break; + } + if (j == expr->args.size()) return NullOpt; + // look for the longest constrained iter started from expr->args[j] + // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // predicate: j*2 + k < 9 + // We need to match the predicate in expr and adjust the expected scale, + // otherwise we expect the scale of i to be 2*5=10 + Optional constraint_to_match; + for (const IterSumExpr& iter : constrained_iters_flattened_) { + if (IterSplitEqual(expr->args[j], iter->args.back(), false)) { + // find a predicate started from expr->args[j] + if (!constraint_to_match || + constraint_to_match.value()->args.size() < iter->args.size()) { + constraint_to_match = iter; + } + } } - if (j == expr->args.size()) { - return NullOpt; + if (constraint_to_match) { + // match the predicate and mark the iterators in the constraint_to_match as visited + // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) + // predicate = j*2 + k < 9 + // then j*2 + k matches the lower two splits of expr + for (auto it = constraint_to_match.value()->args.rbegin(); + it != constraint_to_match.value()->args.rend(); ++it) { + size_t k = 0; + for (; k < expr->args.size(); ++k) { + if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) { + if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale)) + break; + } + } + if (k == expr->args.size()) return NullOpt; + visited[k] = true; + flattened_iters.push_back(expr->args[k]); + } + auto iter = sum_fuse_map_.find(constraint_to_match.value()); + ICHECK(iter != sum_fuse_map_.end()); + IterMark iter_matched = iter->second; + grouped_iters.emplace_back(iter_matched, expected_scale); + expected_scale *= iter_matched->extent; + // move forward + i += constraint_to_match.value()->args.size(); + } else { + // constraint_to_match not found, skip this iterator + visited[j] = true; + flattened_iters.push_back(expr->args[j]); + grouped_iters.push_back(expr->args[j]); + expected_scale *= expr->args[j]->extent; + ++i; } - visited[j] = true; - auto arg = expr->args[j]; - arg.CopyOnWrite()->scale = div(expr->args[j]->scale, base_scale.value()); - iters.push_back(arg); - expected_scale *= expr->args[j]->extent; - } - // update the iterator to use the canonicalized form - expr.CopyOnWrite()->args = Array(iters.rbegin(), iters.rend()); - auto it = sum_fuse_map_.find(expr); - if (it != sum_fuse_map_.end()) return it->second; - auto mark = IterMark(expr, div(expected_scale, base_scale.value())); - IterSplitExpr split(mark, base_scale.value()); - sum_fuse_map_[expr] = split; - return split; + } + // Get the flattened form and structured form + // both forms have splits from outermost to innermost + IterSumExpr structured_form = expr, flattened_form = expr; + flattened_form.CopyOnWrite()->args = + Array(flattened_iters.rbegin(), flattened_iters.rend()); + structured_form.CopyOnWrite()->args = + Array(grouped_iters.rbegin(), grouped_iters.rend()); + auto it = sum_fuse_map_.find(flattened_form); + if (it != sum_fuse_map_.end()) { + // old iter + return IterSplitExpr(it->second, base_scale.value()); + } else { + // new iter, form a new mark + IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value())); + sum_fuse_map_[flattened_form] = mark; + flattened_map_[structured_form] = flattened_form; + return IterSplitExpr(mark, base_scale.value()); + } } bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value % crhs->value == 0; - return analyzer_->CanProve(floormod(lhs, rhs) == 0); + return analyzer_->CanProveEqual(lhs, rhs) || analyzer_->CanProve(floormod(lhs, rhs) == 0); } PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig); @@ -459,27 +664,87 @@ class IterMapRewriter : public ExprMutator { } }; +/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */ +struct IterConstraint { + // The expr of the iter + PrimExpr iter; + // The expr of the upper_bound + PrimExpr upper_bound; + // The size of the iter, which is the number of nodes + size_t expr_size = 0; + + IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size) + : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {} +}; + +/*! + * \brief Split the predicate into `(a < b) && (c < d) && ...` + * \param pred The predicate to be split. + * \return A list of pairs, each element of which are lhs and rhs of the '<' sign, + * empty if the split failed. + */ +std::vector MatchUpperBoundConstraints(PrimExpr pred) { + std::vector result; + arith::PVar lhs, rhs, rest; + for (;;) { + if ((rest && (lhs < rhs)).Match(pred)) { + result.emplace_back(lhs.Eval(), rhs.Eval(), 0); + pred = rest.Eval(); + } else if ((lhs < rhs).Match(pred)) { + result.emplace_back(lhs.Eval(), rhs.Eval(), 0); + break; + } else { + return std::vector(); + } + } + return result; +} + Array DetectIterMap(const Array& indices, const Map& input_iters, + const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. + + std::vector constraints = MatchUpperBoundConstraints(predicate); + if (!is_one(predicate) && constraints.empty()) return Array(); + + // We have to make sure when we visit an iterator, all the constraints related with its successors + // in the iter var graph has been visited, where the expression of this iterator will contain the + // expression of its successor, so we sort them by their sizes. + for (IterConstraint& constraint : constraints) { + constraint.expr_size = CalculateExprComplexity(constraint.iter); + } + + std::sort( + constraints.begin(), constraints.end(), + [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); + IterMapRewriter rewriter(analyzer, input_iters); + // Step0.0: rewrite constraints in the order from size-small ones to size-big ones + for (const IterConstraint& constraint : constraints) { + PrimExpr res = rewriter.RewriteIterConstraint(constraint.iter, constraint.upper_bound); + if (rewriter.unresolved_count() != 0) return Array(); + } + if (!rewriter.CheckConstraints()) return Array(); + // Step0.1: rewrite indices Array results; - for (PrimExpr value : indices) { results.push_back(rewriter.Rewrite(value)); if (rewriter.unresolved_count() != 0) return Array(); } - if (!rewriter.CheckBijective(results)) return Array(); + // Step1: IterIndependenceChecker checks if the iterator are independent. + if (!rewriter.CheckMapping(results, require_bijective)) return Array(); return results; } TVM_REGISTER_GLOBAL("arith.DetectIterMap") - .set_body_typed([](const Array& indices, const Map& input_iters) { + .set_body_typed([](const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool is_bijective) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, &ana); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -675,7 +940,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (auto opt = TryFuseIters(ret)) { + if (Optional opt = TryFuseIters(ret)) { return SplitFloorDivConst(opt.value(), b, GetRef(op)); } else { ++unresolved_count_; @@ -750,7 +1015,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { if (a->IsInstance()) { IterSumExpr ret = Downcast(a); - if (auto opt = TryFuseIters(ret)) { + if (Optional opt = TryFuseIters(ret)) { return SplitFloorModConst(opt.value(), b, GetRef(op)); } else { ++unresolved_count_; diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index dd9044833546..6aad5b7b0a25 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -39,58 +39,9 @@ namespace arith { using namespace tvm::runtime; using namespace tvm::tir; -#define PLUS_ONE(OP) \ - void VisitExpr_(const OP* op) final { num_symbols_++; } - -#define PLUS_ONE_BINARY(OP) \ - void VisitExpr_(const OP* op) final { \ - num_symbols_++; \ - VisitExpr(op->a); \ - VisitExpr(op->b); \ - } - -/*! - * \brief Calculate the expresion complexity based on number of symbols it contains. - */ -class ExprComplexity : public ExprVisitor { - public: - size_t Eval(const PrimExpr& expr) { - VisitExpr(expr); - return num_symbols_; - } - - PLUS_ONE_BINARY(AddNode) - PLUS_ONE_BINARY(SubNode) - PLUS_ONE_BINARY(MulNode) - PLUS_ONE_BINARY(DivNode) - PLUS_ONE_BINARY(ModNode) - PLUS_ONE_BINARY(FloorDivNode) - PLUS_ONE_BINARY(FloorModNode) - PLUS_ONE_BINARY(MinNode) - PLUS_ONE_BINARY(MaxNode) - PLUS_ONE_BINARY(EQNode) - PLUS_ONE_BINARY(NENode) - PLUS_ONE_BINARY(LTNode) - PLUS_ONE_BINARY(LENode) - PLUS_ONE_BINARY(GTNode) - PLUS_ONE_BINARY(GENode) - PLUS_ONE_BINARY(AndNode) - PLUS_ONE_BINARY(OrNode) - PLUS_ONE(VarNode) - PLUS_ONE(FloatImmNode) - PLUS_ONE(IntImmNode) - void VisitExpr_(const NotNode* op) final { - num_symbols_++; - VisitExpr(op->a); - } - - private: - size_t num_symbols_{0}; -}; - struct ExprLess { bool operator()(const PrimExpr& l, const PrimExpr& r) const { - return ExprComplexity().Eval(l) < ExprComplexity().Eval(r); + return CalculateExprComplexity(l) < CalculateExprComplexity(r); } }; diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 6ab61fdd9592..ac05809449bd 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -19,13 +19,13 @@ from tvm import te -def ifuse(inputs): +def ifuse(inputs, pred_extent=None): """Fuse iterators""" value, extent = 0, 1 for i, ext in inputs: value = value * ext + i extent = extent * ext - return (value, extent) + return value, extent if pred_extent is None else pred_extent def isplit(axis, factor): @@ -67,7 +67,9 @@ def test_trivial(): assert_iter_sum_pattern(res[2], 1, 3) res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) - assert len(res) == 0 + assert len(res) == 2 + assert_iter_sum_pattern(res[0], 3, 0) + assert_iter_sum_pattern(res[1], 1, 3) # not independent res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) @@ -79,8 +81,6 @@ def test_fuse(): y = tvm.tir.Var("y", "int32") c = tvm.tir.SizeVar("c", "int32") c0 = tvm.tir.SizeVar("c0", "int32") - c1 = tvm.tir.SizeVar("c1", "int32") - c2 = tvm.tir.SizeVar("c1", "int32") res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 4)])) assert len(res) == 1 @@ -121,10 +121,8 @@ def test_fuse(): def test_split(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("y", "int32") c0 = tvm.tir.SizeVar("c0", "int32") c1 = tvm.tir.SizeVar("c1", "int32") - c2 = tvm.tir.SizeVar("c1", "int32") fld = tvm.tir.floordiv flm = tvm.tir.floormod @@ -196,8 +194,100 @@ def test_compound(): tvm.ir.assert_structural_equal(sz, res[0]) +def test_predicate(): + x = tvm.tir.Var("x", "int32"), 13 + y = tvm.tir.Var("y", "int32"), 10 + + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128) + + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + + # duplicate constraint + res = tvm.arith.detect_iter_map( + [x[0] * 10 + y[0]], + var_dom([x, y]), + tvm.tir.all(x[0] * 10 + y[0] < 128, x[0] * 10 + y[0] < 64), + ) + + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 64, 0) + + # useless constraint + res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 140) + + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 130, 0) + + i1 = tvm.tir.Var("i1", "int32"), 7 + i2 = tvm.tir.Var("i2", "int32"), 2 + i3 = tvm.tir.Var("i3", "int32"), 4 + i4 = tvm.tir.Var("i4", "int32"), 3 + res = tvm.arith.detect_iter_map( + [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], + var_dom([i1, i2, i3, i4]), + ( + tvm.tir.all( + i1[0] * 2 + i2[0] < 13, + i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, + i3[0] * 3 + i4[0] < 10, + ) + ), + ) + assert len(res) == 1 + assert_iter_sum_pattern(res[0], 128, 0) + + i1 = tvm.tir.Var("i1", "int32"), 7 + i2 = tvm.tir.Var("i2", "int32"), 2 + i3 = tvm.tir.Var("i3", "int32"), 4 + i4 = tvm.tir.Var("i4", "int32"), 3 + + # wrong constraint + res = tvm.arith.detect_iter_map( + [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], + var_dom([i1, i2, i3, i4]), + ( + tvm.tir.all( + i1[0] * 2 + i2[0] < 13, + i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, + i3[0] * 3 + i4[0] < 7, + ) + ), + ) + assert len(res) == 0 + + # incompatible constraint + res = tvm.arith.detect_iter_map( + [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], + var_dom([i1, i2, i3, i4]), + ( + tvm.tir.all( + i1[0] * 2 + i2[0] < 13, + i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, + i3[0] * 3 + i4[0] < 10, + i1[0] * 4 + i3[0] < 20, + ) + ), + ) + assert len(res) == 0 + + res = tvm.arith.detect_iter_map( + [i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0]], + var_dom([i1, i2, i3, i4]), + ( + tvm.tir.all( + i1[0] * 2 + i2[0] < 13, + i1[0] * 20 + i2[0] * 10 + i3[0] * 3 + i4[0] < 128, + i1[0] * 4 + i3[0] < 20, + ) + ), + ) + assert len(res) == 0 + + if __name__ == "__main__": test_split() test_trivial() test_fuse() test_compound() + test_predicate()