Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement
Browse files Browse the repository at this point in the history
fix lint

save

add test

save

fix lint

update comment

fix build for gpu

Update ir_pass.py

save

fix error

fix lint

add test

fix lint

fix test

fix test

reboot pytest

Update to_anf.cc

address review comment

save

fused topo

remove dead code

save

save

do it
MarisaKirisame committed Jan 11, 2019
1 parent 547a091 commit 464fe86
Showing 10 changed files with 684 additions and 17 deletions.
34 changes: 32 additions & 2 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
@@ -154,10 +154,11 @@ Expr FoldConstant(const Expr& expr);
/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param mod The global module.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);
Expr FuseOps(const Expr& expr, const Module& mod, int fuse_opt_level);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
@@ -188,7 +189,6 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
@@ -212,6 +212,36 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);

inline bool IsPrimitiveFunction(const Function& fn) {
NodeRef res = FunctionGetAttr(fn, "Primitive");
const ir::IntImm* pval = res.as<ir::IntImm>();
return pval && (pval->value != 0);
}

inline bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && IsPrimitiveFunction(Downcast<Function>(e));
}

} // namespace relay
} // namespace tvm

2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
@@ -218,7 +218,7 @@ def optimize(self, expr):
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, mod=self.mod)
fused_expr = ir_pass.fuse_ops(ck_expr)
fused_expr = ir_pass.fuse_ops(ck_expr, mod=self.mod)
ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod)
return ck_fused

2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
@@ -234,7 +234,7 @@ def build(func,
func = optimize(func, target, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
func = ir_pass.fuse_ops(func, mod=None, opt_level=cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
41 changes: 35 additions & 6 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
@@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
@@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
@@ -305,14 +305,17 @@ def fold_constant(expr):
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, opt_level=1):
def fuse_ops(expr, mod=None, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
opt_level : int
The level of fuse optimization.
@@ -321,7 +324,7 @@ def fuse_ops(expr, opt_level=1):
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)
return _ir_pass.FuseOps(expr, mod, opt_level)


def combine_parallel_conv2d(expr):
@@ -357,3 +360,29 @@ def alter_op_layout(expr):
Transformed expression with alternated layout.
"""
return _ir_pass.AlterOpLayout(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression
"""
return _ir_pass.to_anf(expr, mod)
2 changes: 1 addition & 1 deletion src/relay/pass/fold_constant.cc
Original file line number Diff line number Diff line change
@@ -128,7 +128,7 @@ class ConstantFolder : public ExprMutator {
// Constant evaluate a expression.
Expr ConstEvaluate(Expr expr) {
expr = InferType(expr, Module(nullptr));
expr = FuseOps(expr, 0);
expr = FuseOps(expr, Module(nullptr), 0);
expr = InferType(expr, Module(nullptr));
return ValueToExpr(executor_(expr));
}
28 changes: 25 additions & 3 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
@@ -650,6 +650,8 @@ GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
return std::move(groups_);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv);

class FuseMutator : private ExprMutator {
public:
// Run the transform
@@ -667,8 +669,13 @@ class FuseMutator : private ExprMutator {
return this->Mutate(body);
}

FuseMutator(const Module& mod, int fuse_opt_level, std::set<GlobalVar>* visited) :
mod_(mod), fuse_opt_level_(fuse_opt_level), visited_(visited) { }

private:
Module mod_;
int fuse_opt_level_;
std::set<GlobalVar>* visited_;
/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
@@ -751,6 +758,16 @@ class FuseMutator : private ExprMutator {
return new_tuple;
}

Expr VisitExpr_(const GlobalVarNode* node) {
GlobalVar gv = GetRef<GlobalVar>(node);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv,
Downcast<Function>(FuseOps(mod_->Lookup(gv), mod_, fuse_opt_level_, visited_)));
}
return gv;
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
@@ -790,17 +807,22 @@ class FuseMutator : private ExprMutator {
};


Expr FuseOps(const Expr& expr, int fuse_opt_level) {
Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level) {
std::set<GlobalVar> gv;
return FuseOps(expr, m, fuse_opt_level, &gv);
}

Expr FuseOps(const Expr& expr, const Module& m, int fuse_opt_level, std::set<GlobalVar>* gv) {
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// new operators.
return FuseMutator().Transform(expr, fuse_opt_level);
return FuseMutator(m, fuse_opt_level, gv).Transform(expr, fuse_opt_level);
}

TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0], args[1]);
*ret = FuseOps(args[0], args[1], args[2]);
});
} // namespace relay
} // namespace tvm
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
@@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

@@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};

} // namespace relay
Loading

0 comments on commit 464fe86

Please sign in to comment.