-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PASS]LoopPartition #56
Changes from 4 commits
4a31797
5c95c6f
335a8ad
9afb1e1
266931c
40c0b5f
63d3a7d
d7aba90
e2b39b1
7aa642c
b5786b5
8d0bbd9
d2a90b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,6 +156,10 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r, | |
* \return the set after union | ||
*/ | ||
IntSet Union(const Array<IntSet>& sets); | ||
IntSet Union(const std::vector<IntSet>& sets); | ||
|
||
IntSet Intersect(const Array<IntSet>& sets); | ||
IntSet Intersect(const std::vector<IntSet>& sets); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can simply construct Array in outside and pass it in. Should be same as vector |
||
|
||
// implementation | ||
inline const IntSetNode* IntSet::operator->() const { | ||
|
@@ -172,8 +176,11 @@ inline const IntSetNode* IntSet::operator->() const { | |
* \param dom_map The domain of each variable. | ||
* \return An integer set that can cover all the possible values. | ||
*/ | ||
IntSet DeduceBound(Var v, Expr cond, | ||
IntSet DeduceBound(Expr v, Expr cond, | ||
const Map<Var, IntSet>& dom_map); | ||
IntSet DeduceBound(Expr v, Expr e, | ||
const std::unordered_map<const Variable*, IntSet> dom_map); | ||
|
||
|
||
} // namespace arith | ||
} // namespace tvm | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2017 |
||
* \file loop_partition.cc | ||
*/ | ||
#include <tvm/ir.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <tvm/ir_mutator.h> | ||
#include <tvm/ir_pass.h> | ||
#include <unordered_set> | ||
#include "../arithmetic/int_set.h" | ||
|
||
namespace tvm { | ||
namespace ir { | ||
|
||
using arith::IntSet; | ||
using Halide::Internal::const_true; | ||
using Halide::Internal::const_false; | ||
using Halide::Internal::Interval; // for pos_inf & neg_inf | ||
|
||
// a partition means condition is equal to true_value in the interval | ||
struct Partition { | ||
Expr condition; | ||
Expr old_expr; | ||
Expr true_value; | ||
IntSet interval; | ||
}; | ||
|
||
bool ExprUseVar(Expr expr, const Variable* var) { | ||
bool success = false; | ||
PostOrderVisit(expr, [&var, &success](const NodeRef& node) { | ||
if (node.get() == var) { | ||
success = true; | ||
return; | ||
} | ||
}); | ||
return success; | ||
} | ||
|
||
inline bool IsConstDomain(Expr min, Expr extent) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can simply fold it in, since the expression is so simple and fit into oneline |
||
return is_const(min) && is_const(extent); | ||
} | ||
|
||
class PartitionFinder : public IRVisitor { | ||
public: | ||
explicit PartitionFinder(VarExpr loop_var, | ||
const std::unordered_map<const Variable*, IntSet>& dom_map, | ||
const std::unordered_set<const Variable*>& variables) | ||
: loop_var_(loop_var), dom_map_(dom_map), variables_(variables) {} | ||
|
||
void Visit_(const For* op) { | ||
if (IsConstDomain(op->min, op->extent)) { | ||
dom_map_.insert({op->loop_var.get(), | ||
IntSet::interval(op->min, op->min + op->extent - 1)}); | ||
IRVisitor::Visit_(op); | ||
dom_map_.erase(op->loop_var.get()); | ||
} else { | ||
variables_.insert(op->loop_var.get()); | ||
IRVisitor::Visit_(op); | ||
variables_.erase(op->loop_var.get()); | ||
} | ||
} | ||
|
||
void Visit_(const IfThenElse* op) { | ||
if (ExprUseVar(op->condition, loop_var_.get())) { | ||
for (auto var : variables_) { | ||
if (ExprUseVar(op->condition, var)) IRVisitor::Visit_(op); | ||
} | ||
|
||
IntSet interval = DeduceBound(loop_var_, op->condition, dom_map_); | ||
if (interval.min().same_as(Interval::neg_inf)) { | ||
IntSet upper_bound = EvalSet(interval.max(), dom_map_); | ||
interval = IntSet::interval(interval.min(), upper_bound.min()); | ||
} else if (interval.max().same_as(Interval::pos_inf)) { | ||
IntSet lower_bound = EvalSet(interval.min(), dom_map_); | ||
interval = IntSet::interval(lower_bound.max(), interval.max()); | ||
} else { | ||
// Assume the partition is always a infinite set | ||
LOG(WARNING) << "interval wrong"; | ||
} | ||
partitions.push_back(Partition{op->condition, op->condition, const_true(), interval}); | ||
} | ||
IRVisitor::Visit_(op); | ||
} | ||
|
||
std::vector<Partition> partitions; | ||
private: | ||
VarExpr loop_var_; | ||
std::unordered_map<const Variable*, IntSet> dom_map_; | ||
std::unordered_set<const Variable*> variables_; | ||
}; | ||
|
||
class PartitionReplacer : public IRMutator { | ||
public: | ||
PartitionReplacer(const Partition& p) | ||
: p_(p) {} | ||
|
||
Expr Mutate(Expr e) override { | ||
if (e.same_as(p_.old_expr)) { | ||
return Mutate(p_.true_value); | ||
} | ||
return IRMutator::Mutate(e); | ||
} | ||
|
||
Stmt Mutate(Stmt s) override { // ? will raise error if no this function | ||
return IRMutator::Mutate(s); | ||
} | ||
|
||
private: | ||
const Partition& p_; | ||
}; | ||
|
||
// LoopPartitioner will try to partition the loop variable in the IR. | ||
// The loop variable can be divided into two categories: | ||
// | ||
// - whose range is fixed, the min and the extent both are constant. | ||
// | ||
// For now, we will not do partition on this kind loop variable, we | ||
// add them into dom_map in order to do deduce for follow-up | ||
// partitions. | ||
// | ||
// - whose range is variable | ||
// | ||
// We will try to do partition on this kind loop variable. If success, | ||
// we will mutate the stmt then return. (only consider the partition | ||
// on the outmost loop yet). If failed, we will mark them as variable | ||
// (add them into variables_), then in the follow-up procedure, we know | ||
// a condition is not able to be deduced if it use this variable. | ||
|
||
class LoopPartitioner : public IRMutator { | ||
public: | ||
explicit LoopPartitioner() {} | ||
Expr Mutate(Expr e) override { | ||
return IRMutator::Mutate(e); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to override if there is no change |
||
} | ||
Stmt Mutate(Stmt s) override { | ||
return IRMutator::Mutate(s); | ||
} | ||
|
||
Stmt Mutate_(const For* op, const Stmt& stmt) { | ||
if (IsConstDomain(op->min, op->extent)) { | ||
// if the range of loop_var is constant, we will not partition it, | ||
// instead, we will use the fixed domain to deduce. | ||
dom_map_.insert({op->loop_var.get(), | ||
IntSet::interval(op->min, op->min + op->extent - 1)}); | ||
Stmt res = IRMutator::Mutate_(op, stmt); | ||
dom_map_.erase(op->loop_var.get()); | ||
return res; | ||
} | ||
|
||
PartitionFinder finder(op->loop_var, dom_map_, variables_); | ||
finder.Visit(op->body); | ||
|
||
if (finder.partitions.empty()) { | ||
variables_.insert(op->loop_var.get()); | ||
IRMutator::Mutate_(op, stmt); | ||
variables_.erase(op->loop_var.get()); | ||
return stmt; | ||
} | ||
|
||
IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1); | ||
std::vector<IntSet> sets{universe}; | ||
// merge partitions (take their intersect) | ||
for (auto p : finder.partitions) { | ||
sets.push_back(p.interval); | ||
} | ||
IntSet true_itrv = Intersect(sets); | ||
|
||
Stmt simplified_body = op->body; | ||
for (auto p : finder.partitions) { | ||
p.interval = true_itrv; | ||
simplified_body = PartitionReplacer(p).Mutate(simplified_body); | ||
} | ||
|
||
Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(), | ||
true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body); | ||
Stmt s = simplified_stmt; | ||
|
||
Expr pre_doubt_cond = (true_itrv.min() != universe.min()); | ||
IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min()); | ||
Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(), | ||
pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); | ||
s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s); | ||
|
||
Expr post_doubt_cond = (true_itrv.max() != universe.max()); | ||
IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max()); | ||
Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(), | ||
post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body); | ||
s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt)); | ||
return s; | ||
} | ||
|
||
private: | ||
std::unordered_map<const Variable*, IntSet> dom_map_; | ||
std::unordered_set<const Variable*> variables_; | ||
}; | ||
|
||
Stmt LoopPartition(Stmt stmt) { | ||
stmt = LoopPartitioner().Mutate(stmt); | ||
return stmt; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always document functions in header