Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache object refs in loop partitioner instead of object pointers #6004

Merged
merged 2 commits into from
Jul 8, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 41 additions & 30 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,27 @@ using arith::DeduceBound;
using arith::Intersect;
using arith::IntSet;

using PartitionKey = std::pair<const Object*, bool>;
using PartitionKey = std::pair<PrimExpr, bool>;
struct PartitionKeyHash {
std::size_t operator()(PartitionKey const& k) const noexcept {
std::size_t h1 = std::hash<const Object*>{}(k.first);
std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces)
std::size_t h2 = std::hash<bool>{}(k.second);
return h1 ^ h2;
}
};

struct PartitionKeyEqual {
bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
// NOLINTNEXTLINE(whitespace/braces)
return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
}
};

// Each mapping (cond, cond_value) -> interval represents the fact that
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;

using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;

bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
Expand Down Expand Up @@ -101,7 +110,7 @@ class CandidateSelector final : public StmtExprVisitor {
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
candidates.insert(op);
candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var);
} else {
Expand All @@ -119,7 +128,7 @@ class CandidateSelector final : public StmtExprVisitor {
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
candidates.insert(op);
candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var.get());
return;
Expand Down Expand Up @@ -160,7 +169,7 @@ class CandidateSelector final : public StmtExprVisitor {
}
}

std::unordered_set<const Object*> candidates;
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;

private:
bool in_likely_{false};
Expand Down Expand Up @@ -224,14 +233,14 @@ class PartitionFinder : public StmtExprVisitor {
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond.get(), true}] = interval;
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond.get(), false}] = interval;
partitions[{cond, false}] = interval;
}
}
}
Expand Down Expand Up @@ -276,25 +285,25 @@ class PartitionFinder : public StmtExprVisitor {
// Replace the set of conditions given by ps with cond_value (true or false)
class ConditionEliminator : public StmtExprMutator {
public:
explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}

PrimExpr VisitExpr(const PrimExpr& e) final {
if (ps_.find(e.get()) != ps_.end()) {
if (ps_.find(e) != ps_.end()) {
return VisitExpr(cond_value_ ? const_true() : const_false());
}
return StmtExprMutator::VisitExpr(e);
}

private:
std::unordered_set<const Object*> ps_;
ExpressionSet ps_;
bool cond_value_;
};

// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public StmtMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps, PrimExpr cond)
explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
: ps_(ps), cond_(cond), innermost_thread_scope_(false) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Expand All @@ -316,7 +325,7 @@ class ThreadPartitionInserter : public StmtMutator {
}

private:
const std::unordered_set<const Object*>& ps_;
const ExpressionSet& ps_;
PrimExpr cond_;
bool innermost_thread_scope_;
};
Expand All @@ -334,9 +343,9 @@ class LoopPartitioner : public StmtMutator {
}

Stmt VisitStmt_(const ForNode* op) final {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var, op->min, op->min + op->extent - 1,
op->body, false);
auto fs = GetRef<Stmt>(op);
if (selector.candidates.count(fs)) {
Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
}

Expand All @@ -356,8 +365,9 @@ class LoopPartitioner : public StmtMutator {
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, GetRef<Stmt>(op), var, 0, op->value - 1, op->body, true);
auto as = GetRef<Stmt>(op);
if (selector.candidates.count(as)) {
Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}

Expand All @@ -378,11 +388,12 @@ class LoopPartitioner : public StmtMutator {
}

private:
Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max,
Stmt body, bool partition_thread_scope);
Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope);

std::pair<IntSet, std::unordered_set<const Object*>> GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value);
std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
const arith::IntervalSet& for_interval,
bool cond_value);

inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);

Expand All @@ -395,10 +406,10 @@ class LoopPartitioner : public StmtMutator {

// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetIntervalAndCondset(
std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
Array<IntSet> sets;
std::unordered_set<const Object*> cond_set;
ExpressionSet cond_set;

for (const auto& kv : partitions) {
if (kv.first.second == cond_value) {
Expand Down Expand Up @@ -460,8 +471,8 @@ std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetInterva
* which will eventually be simplified to empty code. And because only one loop was generated
* from loop 2 we stop recursing.
*/
Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min,
PrimExpr max, Stmt body, bool partition_thread_scope) {
Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
Expand All @@ -475,7 +486,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
std::unordered_set<const Object*> cond_set;
ExpressionSet cond_set;
// find an interval in which all conditions on var are true
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, true);
Expand Down Expand Up @@ -516,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
}
}
} else {
Expand All @@ -541,7 +552,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
}
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
}
}
} else {
Expand All @@ -557,7 +568,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);
mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);

// Recurse for each non-empty subrange only if there are at least
// two non-empty subranges
Expand Down