Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS]LoopPartition #56

Merged
merged 13 commits into from
Mar 4, 2017
8 changes: 4 additions & 4 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class IRMutator {
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& s);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& s);

virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ Stmt InjectVirtualThread(Stmt stmt);
*/
Stmt LiftAllocate(Stmt stmt);

Stmt LoopPartition(Stmt stmt);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always document functions in header


/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
3 changes: 2 additions & 1 deletion src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ TVM_REGISTER_API(_arith_EvalModular)

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>());
});

TVM_REGISTER_API(_IntervalSetGetMin)
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);

} // namespace ir
} // namespace tvm
33 changes: 19 additions & 14 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using Halide::Internal::Interval;
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
explicit VariablePathFinder(Var target) : target_(target) {}
explicit VariablePathFinder(Expr target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
Expand All @@ -37,13 +37,13 @@ class VariablePathFinder: public IRVisitor {

private:
bool found_{false};
Var target_;
Expr target_;
std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
std::vector<const Node*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
Expand All @@ -56,7 +56,7 @@ class BoundDeducer: public IRVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Var target, Expr expr,
BoundDeducer(Expr target, Expr expr,
const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {}

Expand Down Expand Up @@ -137,7 +137,7 @@ class BoundDeducer: public IRVisitor {
bool success{true};

private:
Var target_;
Expr target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
ExprIntSetMap expr_map_;
Expand Down Expand Up @@ -205,15 +205,9 @@ void BoundDeducer::Deduce() {
Visit(expr_);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
BoundDeducer d(v, e, dmap);
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet> dom_map) {
BoundDeducer d(v, e, dom_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
Expand All @@ -225,5 +219,16 @@ IntSet DeduceBound(Var v, Expr e,
return IntSet::interval(min, max);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
return DeduceBound(v, e, dmap);
}

} // namespace arith
} // namespace tvm
35 changes: 30 additions & 5 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,19 @@ inline bool MatchPoint(const IntSet& a,
return i.is_single_point() && i.min.same_as(b);
}

IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
IntSet s = set[i].cover_interval();
IntSet Union(const Array<IntSet>& sets) {
std::vector<IntSet> v_sets;
for (auto s : sets) {
v_sets.push_back(s);
}
return Union(v_sets);
}

IntSet Union(const std::vector<IntSet>& sets) {
if (sets.size() == 1) return sets[0];
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
IntSet s = sets[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
Expand All @@ -179,6 +187,23 @@ IntSet Union(const Array<IntSet>& set) {
return IntervalSet::make(x);
}

IntSet Intersect(const Array<IntSet>& sets) {
std::vector<IntSet> v_sets;
for (auto s : sets) {
v_sets.push_back(s);
}
return Intersect(v_sets);
}

IntSet Intersect(const std::vector<IntSet>& sets) {
Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < sets.size(); ++i) {
Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
x = Interval::make_intersection(x, y);
}
return IntervalSet::make(x);
}

// type traits
template<typename OP>
struct is_logical_op {
Expand Down
9 changes: 8 additions & 1 deletion src/arithmetic/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ ExprIntSetMap EvalSetForEachSubExpr(Expr r,
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);
IntSet Union(const std::vector<IntSet>& sets);

IntSet Intersect(const Array<IntSet>& sets);
IntSet Intersect(const std::vector<IntSet>& sets);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply construct Array in outside and pass it in. Should be same as vector


// implementation
inline const IntSetNode* IntSet::operator->() const {
Expand All @@ -172,8 +176,11 @@ inline const IntSetNode* IntSet::operator->() const {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet DeduceBound(Var v, Expr cond,
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& dom_map);
IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet> dom_map);


} // namespace arith
} // namespace tvm
Expand Down
203 changes: 203 additions & 0 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*!
* Copyright (c) 2016 by Contributors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2017

* \file loop_partition.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "../arithmetic/int_set.h"

namespace tvm {
namespace ir {

using arith::IntSet;
using Halide::Internal::const_true;
using Halide::Internal::const_false;
using Halide::Internal::Interval; // for pos_inf & neg_inf

// a partition means condition is equal to true_value in the interval
struct Partition {
Expr condition;
Expr old_expr;
Expr true_value;
IntSet interval;
};

bool ExprUseVar(Expr expr, const Variable* var) {
bool success = false;
PostOrderVisit(expr, [&var, &success](const NodeRef& node) {
if (node.get() == var) {
success = true;
return;
}
});
return success;
}

inline bool IsConstDomain(Expr min, Expr extent) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply fold it in, since the expression is so simple and fit into oneline

return is_const(min) && is_const(extent);
}

class PartitionFinder : public IRVisitor {
public:
explicit PartitionFinder(VarExpr loop_var,
const std::unordered_map<const Variable*, IntSet>& dom_map,
const std::unordered_set<const Variable*>& variables)
: loop_var_(loop_var), dom_map_(dom_map), variables_(variables) {}

void Visit_(const For* op) {
if (IsConstDomain(op->min, op->extent)) {
dom_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
IRVisitor::Visit_(op);
dom_map_.erase(op->loop_var.get());
} else {
variables_.insert(op->loop_var.get());
IRVisitor::Visit_(op);
variables_.erase(op->loop_var.get());
}
}

void Visit_(const IfThenElse* op) {
if (ExprUseVar(op->condition, loop_var_.get())) {
for (auto var : variables_) {
if (ExprUseVar(op->condition, var)) IRVisitor::Visit_(op);
}

IntSet interval = DeduceBound(loop_var_, op->condition, dom_map_);
if (interval.min().same_as(Interval::neg_inf)) {
IntSet upper_bound = EvalSet(interval.max(), dom_map_);
interval = IntSet::interval(interval.min(), upper_bound.min());
} else if (interval.max().same_as(Interval::pos_inf)) {
IntSet lower_bound = EvalSet(interval.min(), dom_map_);
interval = IntSet::interval(lower_bound.max(), interval.max());
} else {
// Assume the partition is always a infinite set
LOG(WARNING) << "interval wrong";
}
partitions.push_back(Partition{op->condition, op->condition, const_true(), interval});
}
IRVisitor::Visit_(op);
}

std::vector<Partition> partitions;
private:
VarExpr loop_var_;
std::unordered_map<const Variable*, IntSet> dom_map_;
std::unordered_set<const Variable*> variables_;
};

class PartitionReplacer : public IRMutator {
public:
PartitionReplacer(const Partition& p)
: p_(p) {}

Expr Mutate(Expr e) override {
if (e.same_as(p_.old_expr)) {
return Mutate(p_.true_value);
}
return IRMutator::Mutate(e);
}

Stmt Mutate(Stmt s) override { // ? will raise error if no this function
return IRMutator::Mutate(s);
}

private:
const Partition& p_;
};

// LoopPartitioner will try to partition the loop variable in the IR.
// The loop variable can be divided into two categories:
//
// - whose range is fixed, the min and the extent both are constant.
//
// For now, we will not do partition on this kind loop variable, we
// add them into dom_map in order to do deduce for follow-up
// partitions.
//
// - whose range is variable
//
// We will try to do partition on this kind loop variable. If success,
// we will mutate the stmt then return. (only consider the partition
// on the outmost loop yet). If failed, we will mark them as variable
// (add them into variables_), then in the follow-up procedure, we know
// a condition is not able to be deduced if it use this variable.

class LoopPartitioner : public IRMutator {
public:
explicit LoopPartitioner() {}
Expr Mutate(Expr e) override {
return IRMutator::Mutate(e);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to override if there is no change

}
Stmt Mutate(Stmt s) override {
return IRMutator::Mutate(s);
}

Stmt Mutate_(const For* op, const Stmt& stmt) {
if (IsConstDomain(op->min, op->extent)) {
// if the range of loop_var is constant, we will not partition it,
// instead, we will use the fixed domain to deduce.
dom_map_.insert({op->loop_var.get(),
IntSet::interval(op->min, op->min + op->extent - 1)});
Stmt res = IRMutator::Mutate_(op, stmt);
dom_map_.erase(op->loop_var.get());
return res;
}

PartitionFinder finder(op->loop_var, dom_map_, variables_);
finder.Visit(op->body);

if (finder.partitions.empty()) {
variables_.insert(op->loop_var.get());
IRMutator::Mutate_(op, stmt);
variables_.erase(op->loop_var.get());
return stmt;
}

IntSet universe = IntSet::interval(op->min, op->min + op->extent - 1);
std::vector<IntSet> sets{universe};
// merge partitions (take their intersect)
for (auto p : finder.partitions) {
sets.push_back(p.interval);
}
IntSet true_itrv = Intersect(sets);

Stmt simplified_body = op->body;
for (auto p : finder.partitions) {
p.interval = true_itrv;
simplified_body = PartitionReplacer(p).Mutate(simplified_body);
}

Stmt simplified_stmt = For::make(op->loop_var, true_itrv.min(),
true_itrv.max() - true_itrv.min() + 1, op->for_type, op->device_api, simplified_body);
Stmt s = simplified_stmt;

Expr pre_doubt_cond = (true_itrv.min() != universe.min());
IntSet pre_doubt_itrv = IntSet::interval(universe.min(), true_itrv.min());
Stmt pre_stmt = For::make(op->loop_var, pre_doubt_itrv.min(),
pre_doubt_itrv.max() - pre_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body);
s = Block::make(IfThenElse::make(pre_doubt_cond, pre_stmt), s);

Expr post_doubt_cond = (true_itrv.max() != universe.max());
IntSet post_doubt_itrv = IntSet::interval(true_itrv.max(), universe.max());
Stmt post_stmt = For::make(op->loop_var, post_doubt_itrv.min(),
post_doubt_itrv.max() - post_doubt_itrv.min() + 1, op->for_type, op->device_api, op->body);
s = Block::make(s, IfThenElse::make(post_doubt_cond, post_stmt));
return s;
}

private:
std::unordered_map<const Variable*, IntSet> dom_map_;
std::unordered_set<const Variable*> variables_;
};

Stmt LoopPartition(Stmt stmt) {
stmt = LoopPartitioner().Mutate(stmt);
return stmt;
}

} // namespace ir
} // namespace tvm
Loading