Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
save

do it
  • Loading branch information
MarisaKirisame committed Jan 18, 2019
1 parent b374192 commit 8d4698e
Show file tree
Hide file tree
Showing 6 changed files with 577 additions and 11 deletions.
20 changes: 20 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,26 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

/*! \brief turn a dataflow graph into A Normal Form.
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A Normal Form).
*
* The scope of the root expression is the global scope.
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);

} // namespace relay
} // namespace tvm

Expand Down
42 changes: 34 additions & 8 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit):
----------
expr : tvm.relay.Expr
The input expression.
fvisit : function
The visitor function to be applied.
"""
Expand All @@ -35,7 +36,6 @@ def infer_type(expr, mod=None):
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
checked_expr : tvm.relay.Expr
Expand Down Expand Up @@ -112,11 +112,11 @@ def check_kind(t, mod=None):
Parameters
----------
t: tvm.relay.Type
t : tvm.relay.Type
The type to check
mod: tvm.relay.Module, optional
The global module
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
Expand Down Expand Up @@ -480,20 +480,46 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)


def to_anf(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)


def gradient(expr, mod=None):
""".
"""
Transform a function to return original result paired with gradient of input.
Parameters
----------
expr : tvm.relay.Expr
The input expression, which is a Function or a GlobalVar.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
ret : tvm.relay.Expr
A function that calculate the original result paired with gradient.
expr : tvm.relay.Expr
The output expression.
"""
return _ir_pass.first_order_gradient(expr, mod)
6 changes: 5 additions & 1 deletion src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down Expand Up @@ -71,11 +72,13 @@ class LetList {
*
* \return the wrapped expr.
*/
Expr Get(const Expr& body) const {
Expr Get(const Expr& body) {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
}

Expand Down Expand Up @@ -108,6 +111,7 @@ class LetList {

private:
std::vector<std::pair<Var, Expr> > lets_;
bool used_ = false;
};

} // namespace relay
Expand Down
Loading

0 comments on commit 8d4698e

Please sign in to comment.