Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jan 11, 2019
1 parent 547a091 commit bf57ea5
Show file tree
Hide file tree
Showing 10 changed files with 673 additions and 17 deletions.
24 changes: 22 additions & 2 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param mod The global module.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
Expr FuseOps(const Expr& expr, const Module& mod, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
Expand Down Expand Up @@ -188,7 +189,6 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand All @@ -212,6 +212,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);

} // namespace relay
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def optimize(self, expr):
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
fused_expr = ir_pass.fuse_ops(ck_expr, mod=self.mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def build(func,
func = optimize(func, target, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
func = ir_pass.fuse_ops(func, mod=None, opt_level=cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
Expand Down
41 changes: 35 additions & 6 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
Expand All @@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
Expand Down Expand Up @@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
Expand Down Expand Up @@ -305,14 +305,17 @@ def fold_constant(expr):
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, opt_level=1):
def fuse_ops(expr, mod=None, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
opt_level : int
The level of fuse optimization.
Expand All @@ -321,7 +324,7 @@ def fuse_ops(expr, opt_level=1):
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)
return _ir_pass.FuseOps(expr, mod, opt_level)


def combine_parallel_conv2d(expr):
Expand Down Expand Up @@ -357,3 +360,29 @@ def alter_op_layout(expr):
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression
"""
return _ir_pass.to_anf(expr, mod)
2 changes: 1 addition & 1 deletion src/relay/pass/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, 0);
expr = FuseOps(expr, Module(nullptr), 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
Expand Down
28 changes: 25 additions & 3 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
return std::move(groups_);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv);

class FuseMutator : private ExprMutator {
public:
// Run the transform
Expand All @@ -667,8 +669,13 @@ class FuseMutator : private ExprMutator {
return this->Mutate(body);
}

FuseMutator(const Module& mod, int fuse_opt_level, std::set<GlobalVar>* visited) :
mod_(mod), fuse_opt_level_(fuse_opt_level), visited_(visited) { }

private:
Module mod_;
int fuse_opt_level_;
std::set<GlobalVar>* visited_;
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
Expand Down Expand Up @@ -751,6 +758,16 @@ class FuseMutator : private ExprMutator {
return new_tuple;
}

Expr VisitExpr_(const GlobalVarNode* node) {
GlobalVar gv = GetRef<GlobalVar>(node);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv,
Downcast<Function>(FuseOps(mod_->Lookup(gv), mod_, fuse_opt_level_, visited_)));
}
return gv;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
Expand Down Expand Up @@ -790,17 +807,22 @@ class FuseMutator : private ExprMutator {
};


Expr FuseOps(const Expr& expr, int fuse_opt_level) {
Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level) {
std::set<GlobalVar> gv;
return FuseOps(expr, m, fuse_opt_level, &gv);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
return FuseMutator().Transform(expr, fuse_opt_level);
return FuseMutator(m, fuse_opt_level, gv).Transform(expr, fuse_opt_level);
}

TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0], args[1]);
*ret = FuseOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down Expand Up @@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

Expand Down Expand Up @@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};

} // namespace relay
Expand Down
Loading

0 comments on commit bf57ea5

Please sign in to comment.