Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] A Normal Form Canonicalization #2251

Merged
merged 5 commits into from
Jan 24, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
*
* The scope of the root expression is the global scope.

* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
*/
Expr ToANF(const Expr& e, const Module& mod);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let use use ToANormalForm, so everyone can google and find out what is ANF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the python side? It wa to_anf.


} // namespace relay
} // namespace tvm

Expand Down
42 changes: 34 additions & 8 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
Expand All @@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
Expand Down Expand Up @@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
Expand Down Expand Up @@ -480,20 +480,46 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)


def gradient(expr, mod=None):
""".
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
expr : tvm.relay.Expr
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret : tvm.relay.Expr
A function that calculate the original result paired with gradient.
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down Expand Up @@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

Expand Down Expand Up @@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
MarisaKirisame marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace relay
Expand Down
Loading