diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index d8d784bec958..23f41e1676a6 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -58,18 +58,27 @@ using arith::DeduceBound; using arith::Intersect; using arith::IntSet; -using PartitionKey = std::pair; +using PartitionKey = std::pair; struct PartitionKeyHash { std::size_t operator()(PartitionKey const& k) const noexcept { - std::size_t h1 = std::hash{}(k.first); + std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces) std::size_t h2 = std::hash{}(k.second); return h1 ^ h2; } }; +struct PartitionKeyEqual { + bool operator()(const PartitionKey& k1, const PartitionKey& k2) const { + // NOLINTNEXTLINE(whitespace/braces) + return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first); + } +}; + // Each mapping (cond, cond_value) -> interval represents the fact that // condition cond is proven to have value cond_value (true or false) in interval. -using Partition = std::unordered_map; +using Partition = std::unordered_map; + +using ExpressionSet = std::unordered_set; bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { bool success = false; @@ -101,7 +110,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { - candidates.insert(op); + candidates.insert(GetRef(op)); } record_.erase(var); } else { @@ -119,7 +128,7 @@ class CandidateSelector final : public StmtExprVisitor { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { - candidates.insert(op); + candidates.insert(GetRef(op)); } record_.erase(var.get()); return; @@ -160,7 +169,7 @@ class CandidateSelector final : public StmtExprVisitor { } } - std::unordered_set candidates; + std::unordered_set candidates; private: bool in_likely_{false}; @@ -224,14 +233,14 @@ class PartitionFinder : public StmtExprVisitor { IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.IsNothing()) { // cond is true within interval - partitions[{cond.get(), true}] = interval; + partitions[{cond, true}] = interval; } PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.IsNothing()) { // cond is false within interval - partitions[{cond.get(), false}] = interval; + partitions[{cond, false}] = interval; } } } @@ -276,25 +285,25 @@ class PartitionFinder : public StmtExprVisitor { // Replace the set of conditions given by ps with cond_value (true or false) class ConditionEliminator : public StmtExprMutator { public: - explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) + explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} PrimExpr VisitExpr(const PrimExpr& e) final { - if (ps_.find(e.get()) != ps_.end()) { + if (ps_.find(e) != ps_.end()) { return VisitExpr(cond_value_ ? const_true() : const_false()); } return StmtExprMutator::VisitExpr(e); } private: - std::unordered_set ps_; + ExpressionSet ps_; bool cond_value_; }; // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public StmtMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, PrimExpr cond) + explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -316,7 +325,7 @@ class ThreadPartitionInserter : public StmtMutator { } private: - const std::unordered_set& ps_; + const ExpressionSet& ps_; PrimExpr cond_; bool innermost_thread_scope_; }; @@ -334,9 +343,9 @@ class LoopPartitioner : public StmtMutator { } Stmt VisitStmt_(const ForNode* op) final { - if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, - op->body, false); + auto fs = GetRef(op); + if (selector.candidates.count(fs)) { + Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; } @@ -356,8 +365,9 @@ class LoopPartitioner : public StmtMutator { const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; - if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), var, 0, op->value - 1, op->body, true); + auto as = GetRef(op); + if (selector.candidates.count(as)) { + Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true); if (s.defined()) return s; } @@ -378,11 +388,12 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, - Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, + bool partition_thread_scope); - std::pair> GetIntervalAndCondset( - const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value); + std::pair GetIntervalAndCondset(const Partition& partitions, + const arith::IntervalSet& for_interval, + bool cond_value); inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); @@ -395,10 +406,10 @@ class LoopPartitioner : public StmtMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> LoopPartitioner::GetIntervalAndCondset( +std::pair LoopPartitioner::GetIntervalAndCondset( const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) { Array sets; - std::unordered_set cond_set; + ExpressionSet cond_set; for (const auto& kv : partitions) { if (kv.first.second == cond_value) { @@ -460,8 +471,8 @@ std::pair> LoopPartitioner::GetInterva * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min, - PrimExpr max, Stmt body, bool partition_thread_scope) { +Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, + bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::Interval(min, max)}); @@ -475,7 +486,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; - std::unordered_set cond_set; + ExpressionSet cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = GetIntervalAndCondset(finder.partitions, for_interval, true); @@ -516,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var } if (!partition_thread_scope) { Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(node, body_begin - min, pre_body); + pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); } } } else { @@ -541,7 +552,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var } if (!partition_thread_scope) { Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); + post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body); } } } else { @@ -557,7 +568,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); - mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body); + mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body); // Recurse for each non-empty subrange only if there are at least // two non-empty subranges