Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Partial Eval now Support Interprocedural optimization and has termination check. #3033

Merged
merged 1 commit into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e);
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down Expand Up @@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);

/*!
* \brief Bind the free variables to a Relay expression.
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param inline_once whether or not to inline binding used one.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down
77 changes: 42 additions & 35 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def well_formed(expr):

Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression

Returns
Expand Down Expand Up @@ -175,7 +175,7 @@ def free_vars(expr):

Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression

Returns
Expand All @@ -197,7 +197,7 @@ def bound_vars(expr):

Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression

Returns
Expand All @@ -213,7 +213,7 @@ def all_vars(expr):

Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression

Returns
Expand All @@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):

Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional

mod : Optional[tvm.relay.Module]
The global module

Returns
Expand All @@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):

Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional

mod : Optional[tvm.relay.Module]
The global module

Returns
Expand All @@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):

Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module

Returns
Expand All @@ -286,12 +288,12 @@ def simplify_inference(expr):

Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression

Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
Expand All @@ -304,48 +306,50 @@ def canonicalize_ops(expr):

Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression

Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr):
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).

Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression

inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
return _ir_pass.dead_code_elimination(expr, inline_once)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).

Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.

rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.

Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
Expand All @@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):

Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.

rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.

Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
Expand All @@ -378,12 +382,12 @@ def structural_hash(value):

Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.

Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
Expand Down Expand Up @@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.

mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.

Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
Expand All @@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
Expand Down Expand Up @@ -612,7 +616,7 @@ def get_total_mac_number(expr):

Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
Expand All @@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.

fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.

Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)

def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.

Expand All @@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.

mod : Optional[tvm.relay.Module]
The global module

Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)

def unmatched_cases(match, mod=None):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use the command in https://github.com/dmlc/tvm/blob/master/.clang-format to do the formatting so we maintain a consistent code style. The indentation style is 2 space I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not work - the result is completely different from all other files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably due to we use different rules for macro.

<< node->attrs << ", " << node->type_args << ")";
});

Let LetNode::make(Var var, Expr value, Expr body) {
Expand Down Expand Up @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});

} // namespace relay
Expand Down
28 changes: 18 additions & 10 deletions src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}

Expand Down Expand Up @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;

bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}

Expr VisitExpr_(const VarNode* op) final {
Expand All @@ -144,19 +152,19 @@ class CalcDep : private ExprVisitor {
};
};

Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}

TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);

namespace transform {

Pass DeadCodeElimination() {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
Expand Down
Loading