diff --git a/include/tvm/expr.h b/include/tvm/expr.h index b7a6a458876f..f446fb4ee591 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -35,6 +35,8 @@ using Halide::Internal::make_const; using Halide::Internal::make_zero; using Halide::Internal::as_const_int; using Halide::Internal::as_const_uint; +using Halide::Internal::const_true; +using Halide::Internal::const_false; inline Type TVMType2Type(TVMType t) { return Type(static_cast(t.code), t.bits, t.lanes); @@ -53,8 +55,8 @@ class Var : public Halide::VarExpr { public: explicit Var(const std::string& name_hint = "v", Type t = Int(32)) : VarExpr(name_hint, t) {} - explicit Var(std::shared_ptr n) : VarExpr(n) {} + explicit Var(VarExpr v) : VarExpr(v) {} /*! \brief type indicate the container type */ using ContainerType = Variable; diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 1a84cb24a1e8..13df25a8c136 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -62,12 +62,12 @@ class IRMutator { virtual Stmt Mutate_(const Allocate* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); - virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); - virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); - virtual Stmt Mutate_(const Provide* op, const Stmt& e); + virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s); + virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s); + virtual Stmt Mutate_(const Provide* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s); - virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); + virtual Stmt Mutate_(const Evaluate* op, const Stmt& s); virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 542ec34424cd..f1ee06188b06 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -137,6 +137,13 @@ Stmt InjectVirtualThread(Stmt stmt); */ Stmt LiftAllocate(Stmt stmt); +/*! + * \brief partition loops in the stmt + * \param stmt The stmt to do loop partition + * \return Transformed stmt. + */ +Stmt LoopPartition(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 1866d9b49970..59ddb47536e8 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -29,7 +29,9 @@ TVM_REGISTER_API(_arith_EvalModular) TVM_REGISTER_API(_arith_DeduceBound) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DeduceBound(args[0], args[1], args[2]); + *ret = DeduceBound(args[0], args[1], + args[2].operator Map(), + args[3].operator Map()); }); TVM_REGISTER_API(_IntervalSetGetMin) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1192dc25dd76..f995f13d1cdc 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI); REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(LiftAllocate); REGISTER_PASS1(InjectVirtualThread); +REGISTER_PASS1(LoopPartition); } // namespace ir } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b83215c4a36a..f264a3294b9a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -21,7 +21,7 @@ using Halide::Internal::Interval; // from a expression. class VariablePathFinder: public IRVisitor { public: - explicit VariablePathFinder(Var target) : target_(target) {} + explicit VariablePathFinder(Expr target) : target_(target) {} void Visit(const NodeRef& node) final { if (visited_.count(node.get()) != 0) return; @@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor { private: bool found_{false}; - Var target_; + Expr target_; std::unordered_set visited_; }; // get the path to the variable, // return empty vector to represent failure -std::vector GetPath(Var target, Expr expr) { +std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); return v.path_; @@ -56,11 +56,11 @@ class BoundDeducer: public IRVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; - BoundDeducer(Var target, Expr expr, - const std::unordered_map& dom_map) - : target_(target), expr_(expr), dom_map_(dom_map) {} + BoundDeducer(Expr target, Expr expr, + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) + : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} - bool Init(); void Deduce(); void Visit(const NodeRef& e) final { @@ -137,9 +137,14 @@ class BoundDeducer: public IRVisitor { bool success{true}; private: - Var target_; + void Init(); + void Transform(); + void Relax(); + + Expr target_; Expr expr_; - const std::unordered_map& dom_map_; + const std::unordered_map& hint_map_; + const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; @@ -163,10 +168,13 @@ class BoundDeduceInputChecker: public IRVisitor { size_t target_count{0}; }; -bool BoundDeducer::Init() { +void BoundDeducer::Init() { BoundDeduceInputChecker checker; if (!checker.Check(this)) success = false; + Transform(); +} +void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { is_greater = false; is_equal = false; @@ -190,30 +198,35 @@ bool BoundDeducer::Init() { } else { success = false; } - return success; } void BoundDeducer::Deduce() { Init(); if (!success) return; + Relax(); // get the path path_ = GetPath(target_, expr_); // get the sign of every subexpr - expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); + expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); } -// assuming e >= 0, deduce the bound of variable from it. -// return empty set to represent deduce failure. -IntSet DeduceBound(Var v, Expr e, - const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first.get()] = kv.second; +void BoundDeducer::Relax() { + if (is_greater) { + expr_ = EvalSet(expr_ , relax_map_).min(); + result = EvalSet(result, relax_map_).max(); + } else { + expr_ = EvalSet(expr_ , relax_map_).max(); + result = EvalSet(result, relax_map_).min(); } - BoundDeducer d(v, e, dmap); +} + +IntSet DeduceBound(Expr v, Expr e, + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { + BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success) return IntSet::nothing(); Expr min = Interval::neg_inf, max = Interval::pos_inf; @@ -225,5 +238,21 @@ IntSet DeduceBound(Var v, Expr e, return IntSet::interval(min, max); } +// assuming e >= 0, deduce the bound of variable from it. +// return empty set to represent deduce failure. +IntSet DeduceBound(Expr v, Expr e, + const Map& hint_map, + const Map& relax_map) { + std::unordered_map hmap; + for (auto kv : hint_map) { + hmap[kv.first.get()] = kv.second; + } + std::unordered_map rmap; + for (auto kv : relax_map) { + rmap[kv.first.get()] = kv.second; + } + return DeduceBound(v, e, hmap, rmap); +} + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 709da26a648f..1a66d060acc5 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -162,11 +162,11 @@ inline bool MatchPoint(const IntSet& a, return i.is_single_point() && i.min.same_as(b); } -IntSet Union(const Array& set) { - if (set.size() == 1) return set[0]; - Interval x = set[0].cover_interval().as()->i; - for (size_t i = 1; i < set.size(); ++i) { - IntSet s = set[i].cover_interval(); +IntSet Union(const Array& sets) { + if (sets.size() == 1) return sets[0]; + Interval x = sets[0].cover_interval().as()->i; + for (size_t i = 1; i < sets.size(); ++i) { + IntSet s = sets[i].cover_interval(); const Interval& y = s.as()->i; if (can_prove(x.max + 1 >= y.min)) { x.max = y.max; @@ -179,6 +179,15 @@ IntSet Union(const Array& set) { return IntervalSet::make(x); } +IntSet Intersect(const Array& sets) { + Interval x = sets[0].cover_interval().as()->i; + for (size_t i = 1; i < sets.size(); ++i) { + Interval y = sets[i].cover_interval().as()->i; + x = Interval::make_intersection(x, y); + } + return IntervalSet::make(x); +} + // type traits template struct is_logical_op { diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 2fc25f55be2d..113c9bd013af 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -8,6 +8,7 @@ #include #include +#include namespace tvm { namespace arith { @@ -157,6 +158,13 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, */ IntSet Union(const Array& sets); +/*! + * \brief Create an union set of all sets + * \param sets The sets to be intersected + * \return the set after intersected + */ +IntSet Intersect(const Array& sets); + // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); @@ -169,11 +177,17 @@ inline const IntSetNode* IntSet::operator->() const { * * \param v The target variable to be deduced. * \param cond The conditional expression. - * \param dom_map The domain of each variable. + * \param hint_map The domain of variable, used to help deduce. + * \param relax The domain of each variable, used to relax the domain. * \return An integer set that can cover all the possible values. */ -IntSet DeduceBound(Var v, Expr cond, - const Map& dom_map); +IntSet DeduceBound(Expr v, Expr cond, + const Map& hint_map, + const Map& relax_map); +IntSet DeduceBound(Expr v, Expr e, + const std::unordered_map& hint_map, + const std::unordered_map& relax_map); + } // namespace arith } // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index bab7471c0561..5fd141ed2f55 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -128,7 +128,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; - if (else_case.defined()) { + if (op->else_case.defined()) { else_case = this->Mutate(op->else_case); } if (condition.same_as(op->condition) && diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc new file mode 100644 index 000000000000..2bd3db2bc56f --- /dev/null +++ b/src/pass/loop_partition.cc @@ -0,0 +1,193 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file loop_partition.cc + */ +#include +#include +#include +#include +#include +#include +#include "../arithmetic/int_set.h" +#include "../arithmetic/int_set_internal.h" + +namespace tvm { +namespace ir { + +using arith::IntSet; + +// a partition means the expr is equal to true in the interval +struct Partition { + Expr expr; + IntSet interval; +}; + +bool ExprUseVars(Expr expr, const std::unordered_set& vars) { + bool success = false; + PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { + if (const Variable* v = node.as()) { + if (vars.count(v)) { + success = true; + return; + } + } + }); + return success; +} + +class PartitionFinder : public IRVisitor { + public: + explicit PartitionFinder(VarExpr loop_var, + const std::unordered_map& dom_map) + : target_var_(loop_var), out_vars_(dom_map.size()), hint_map_(dom_map) { + for (const auto& kv : dom_map) out_vars_.insert(kv.first); + } + + void Visit_(const For* op) { + if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; + + hint_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + relax_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + IRVisitor::Visit_(op); + relax_map_.erase(op->loop_var.get()); + hint_map_.erase(op->loop_var.get()); + } + + void Visit_(const IfThenElse* op) { + if (ExprUseVars(op->condition, std::unordered_set({target_var_.get()}))) { + IntSet interval = DeduceBound(target_var_, op->condition, hint_map_, relax_map_); + partitions[op->condition.get()] = Partition{op->condition, interval}; + } else { + IRVisitor::Visit_(op); + } + } + + std::unordered_map partitions; + + private: + VarExpr target_var_; + std::unordered_set out_vars_; + std::unordered_map hint_map_; + std::unordered_map relax_map_; +}; + +class PartitionReplacer : public IRMutator { + public: + explicit PartitionReplacer(const std::unordered_map& ps) + : ps_(ps) {} + + Expr Mutate(Expr e) override { + if (ps_.count(e.get())) { + return Mutate(const_true()); + } + return IRMutator::Mutate(e); + } + using IRMutator::Mutate; + + private: + const std::unordered_map& ps_; +}; + +class LoopPartitioner : public IRMutator { + public: + LoopPartitioner() {} + + Stmt Mutate_(const For* op, const Stmt& stmt) { + if (!is_const(op->min) || !is_const(op->extent)) { + Stmt s = DoPartition(op, stmt); + if (s.defined()) return s; + } + dom_map_.insert({op->loop_var.get(), + IntSet::interval(op->min, op->min + op->extent - 1)}); + Stmt res = IRMutator::Mutate_(op, stmt); + dom_map_.erase(op->loop_var.get()); + return res; + } + + private: + Stmt DoPartition(const For* op, const Stmt& stmt); + + std::unordered_map dom_map_; +}; + +Stmt LoopPartitioner::DoPartition(const For* op, const Stmt& stmt) { + PartitionFinder finder(op->loop_var, dom_map_); + finder.Visit(op->body); + const auto& partitions = finder.partitions; + + if (partitions.empty()) return Stmt(); + + Expr min = op->min; + Expr max = op->min + op->extent - 1; + Array sets; + // merge partitions (take their intersect) + for (const auto& kv : partitions) { + sets.push_back(kv.second.interval); + } + IntSet true_itrv = Intersect(sets); + + Stmt pre_stmt; + Expr body_begin; + if (true_itrv.as()->i.has_lower_bound()) { + body_begin = true_itrv.min(); + if (!can_prove(body_begin == min)) { + if (!can_prove(body_begin - min >= 0)) { + LOG(WARNING) << "cannot prove: " << (body_begin - min >= 0) + << ", when generating the pre doubt loop"; + body_begin = Max::make(body_begin, min); + } + // [min, body_begin) + Stmt body = Substitute(op->body, + {{Var{op->loop_var}, op->loop_var + min}}); + pre_stmt = For::make(op->loop_var, 0, + body_begin - min, op->for_type, op->device_api, body); + } + } else { + body_begin = min; + } + + Stmt post_stmt; + Expr post_doubt_begin; + if (true_itrv.as()->i.has_upper_bound()) { + post_doubt_begin = true_itrv.max() + 1; + if (!can_prove(true_itrv.max() == max)) { + if (!can_prove(max - post_doubt_begin >= 0)) { + LOG(WARNING) << "Cannot prove: " << (max - post_doubt_begin >= 0) + << ", when generating the post doubt loop"; + post_doubt_begin = Min::make(post_doubt_begin, max); + } + // [post_doubt_begin, max] + Stmt body = Substitute(op->body, + {{Var{op->loop_var}, op->loop_var + post_doubt_begin}}); + post_stmt = For::make(op->loop_var, 0, + max - post_doubt_begin + 1, op->for_type, op->device_api, body); + } + } else { + post_doubt_begin = max + 1; + } + + // [body_begin, post_doubt_begin) + Stmt simplified_body = PartitionReplacer(partitions).Mutate(op->body); + Stmt body = Substitute(simplified_body, {{Var{op->loop_var}, op->loop_var + body_begin}}); + Stmt simplified_stmt = For::make(op->loop_var, 0, + post_doubt_begin - body_begin, op->for_type, op->device_api, body); + Stmt s = simplified_stmt; + if (pre_stmt.defined()) { + s = Block::make(pre_stmt, s); + } + if (post_stmt.defined()) { + s = Block::make(s, post_stmt); + } + + return Simplify(ConvertSSA(s)); +} + +Stmt LoopPartition(Stmt stmt) { + stmt = LoopPartitioner().Mutate(stmt); + return stmt; +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b677ea6ec6fa..fa2ba7235dfd 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,20 +16,25 @@ def test_deduce(): d_s = tvm.arith.intset_interval(-3, -1) e0 = (-b)*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = (d-c)/(-b)+(-1) assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) e1 = (a*4+b < c) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = (c-b)/4+(-2) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) e2 = (tvm.max(5, a * 4) < 0) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max()) == "neg_inf" assert str(res2.min()) == "pos_inf" + e3 = (-b)+a*c-d + res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + ans3 = 2/c+1 + assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) + def test_check(): a = tvm.Var('a') b = tvm.Var('b') @@ -41,15 +46,15 @@ def test_check(): d_s = tvm.arith.intset_interval(-3, -1) # no compare operator - res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}) + res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) assert res1.is_nothing() # multiple compare operators - res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}) + res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}, {}) assert res1.is_nothing() # multiple target variable - res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}) + res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) assert res1.is_nothing() if __name__ == "__main__": diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py new file mode 100644 index 000000000000..fd0662c8d906 --- /dev/null +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -0,0 +1,60 @@ +import tvm + +def test_basic(): + n = tvm.Var('n') + A = tvm.placeholder((n, ), name='A') + B = tvm.placeholder((n, ), name='B') + + T = tvm.compute((n, ), lambda i: A[i]+B[i]) + s = tvm.Schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=4) + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt.body.body.body.first)) + print(stmt) + +def test_multi_loop(): + i = tvm.Var('i') + j = tvm.Var('j') + k = tvm.Var('k') + m = tvm.Var('m') + n = tvm.Var('n') + stmt = tvm.make.For( + i, 0, 4, 0, 0, + tvm.make.For( + j, 0, n, 0, 0, + tvm.make.For( + k, 0, m, 0, 0, + tvm.make.IfThenElse( + (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt.body.first)) + print(stmt) + +def test_multi_if(): + i = tvm.Var('i') + j = tvm.Var('j') + k = tvm.Var('k') + m = tvm.Var('m') + n = tvm.Var('n') + stmt = tvm.make.For( + i, 0, 4, 0, 0, + tvm.make.For( + j, 0, n, 0, 0, + tvm.make.For( + k, 0, m, 0, 0, + tvm.make.Block( + tvm.make.IfThenElse((i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)), + tvm.make.IfThenElse((i*m+j-k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)) + )))) + stmt = tvm.ir_pass.LoopPartition(stmt) + assert('if' not in str(stmt.body.first)) + print(stmt) + + +if __name__ == "__main__": + test_basic() + test_multi_loop() + test_multi_if()