Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PASS] UnrollLoop, isolate arithmetic module. #32

Merged
merged 1 commit into from
Feb 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);

/*!
* \brief inline all calls of f in stmt.
Expand Down Expand Up @@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down Expand Up @@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);


} // namespace ir
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< "but request arg" << i;
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}

Expand Down
1 change: 0 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def build(sch,
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")

if record_codes is not None:
output_ssa = False
Expand Down
1 change: 1 addition & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
- api API functionr registration
- lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure
- runtime Minimum runtime related codes.
10 changes: 10 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>

namespace tvm {
Expand All @@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});

TVM_REGISTER_API(_pass_PostOrderVisit)
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
f(n);
});
});

// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
Expand All @@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
Expand Down
10 changes: 5 additions & 5 deletions src/schedule/compute_expr.h → src/arithmetic/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_

#include <tvm/ir.h>
#include <pass/Interval.h>

namespace tvm {
namespace schedule {
namespace arith {

using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
Expand Down Expand Up @@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}

} // namespace schedule
} // namespace arith
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
96 changes: 12 additions & 84 deletions src/schedule/int_set.cc → src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2016 by Contributors
* \file int_set_impl.cc
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
Expand All @@ -10,7 +10,7 @@
#include "./compute_expr.h"

namespace tvm {
namespace schedule {
namespace arith {

using Halide::Internal::Interval;

Expand Down Expand Up @@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point());
}

Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}

IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
Expand All @@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
}

// Check if a is created from b.
inline bool MatchRange(const IntSet& a,
const Range& b) {
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
Expand Down Expand Up @@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b);
}

// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
MatchRange(outer, dom_map.at(s->outer)) &&
MatchRange(inner, dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());

*parent = Combine<Add>(
Combine<Add>(
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
IntSet::single_point(parent_min));
}

void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));

if (MatchRange(fused, dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}

Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;

const IntervalSet* fused_int = fused.as<IntervalSet>();

if (fused_int && fused_int->i.is_single_point()) {
Expr value = fused_int->i.min;
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}


void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (MatchRange(rebased, dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
}

// Evaluator to evalute the epxression.
class IntSetEvaluator {
public:
Expand Down Expand Up @@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});


} // namespace schedule
} // namespace arith
} // namespace tvm
75 changes: 17 additions & 58 deletions src/schedule/int_set.h → src/arithmetic/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
* \file int_set.h
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_SCHEDULE_INT_SET_H_
#define TVM_SCHEDULE_INT_SET_H_
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>

namespace tvm {
namespace schedule {
namespace arith {

// internal node container of int set.
class IntSetNode;
Expand Down Expand Up @@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*!
Expand Down Expand Up @@ -88,59 +100,6 @@ IntSet EvalSet(Expr e,
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);

/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);

/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
Expand All @@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}

} // namespace schedule
} // namespace arith
} // namespace tvm

#endif // TVM_SCHEDULE_INT_SET_H_
#endif // TVM_ARITHMETIC_INT_SET_H_
22 changes: 18 additions & 4 deletions src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,24 @@ class IRInline : public IRMutator {
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
CHECK_EQ(args_.size(), op->args.size());

bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true;
}

if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
} else {
Map<Var, Expr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
Evaluate::make(expr), vmap).as<Evaluate>()->value;
}
return expr;
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/pass/simple_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std::unordered_map<const Variable*, Expr> smap;
};

Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(stmt);
}
Expand Down
Loading