Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Unifier hotfix #2437

Merged
merged 55 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
862825a
Expand unification in type solver
slyubomirsky Nov 29, 2018
c60e4c8
Only register child type relations for incomplete child types
slyubomirsky Nov 29, 2018
318d85f
Removed redundant registrations
slyubomirsky Nov 29, 2018
5155426
Be sure to copy linked list nodes for child types
slyubomirsky Nov 30, 2018
94c5d97
Add unifier tests
slyubomirsky Nov 30, 2018
2bad1af
Add a negative test case
slyubomirsky Nov 30, 2018
0175c53
Check for recursive equalities when unifying
slyubomirsky Nov 30, 2018
6f5db7c
Minor tweaks to error messages
slyubomirsky Nov 30, 2018
16de70d
Improve recursive unification test
slyubomirsky Nov 30, 2018
8829afe
Do not copy relation list nodes if already resolved
slyubomirsky Nov 30, 2018
5027d76
Add visitor for type resolution, have more complicated unification te…
slyubomirsky Nov 30, 2018
ce59f5b
Avoid catastrophic failure in Resolve() by only recursing in one bran…
slyubomirsky Nov 30, 2018
4d1bde6
Add a null check before resolution
slyubomirsky Nov 30, 2018
c44bb73
Move null check into resolution visitor
slyubomirsky Nov 30, 2018
2b18e21
Rename RecurrenceChecker to OccursChecker
slyubomirsky Dec 2, 2018
beb2bb5
Recursively propagate type relations to child types when a constraint…
slyubomirsky Dec 3, 2018
36ae3f8
Make use of new unification in type inference, add generalization, fi…
slyubomirsky Dec 7, 2018
5f35576
Intermediate progress on fixing the unifier (this breaks stuff)
slyubomirsky Dec 7, 2018
db247ad
Generalize only after all other type unification (still broken)
slyubomirsky Dec 7, 2018
599d4a6
Correct error in type var instantiation
slyubomirsky Dec 8, 2018
600d024
Do not permit free type variables in tests
slyubomirsky Dec 10, 2018
face113
Remove commented-out function inference code
slyubomirsky Dec 10, 2018
aa9c837
Instantiate generalizer once in type inference
slyubomirsky Dec 10, 2018
59eabc7
Fix error in type inference cpp test
slyubomirsky Dec 10, 2018
32240e0
Use alpha equality in cpp test
slyubomirsky Dec 10, 2018
a7e6fa6
Use free var pass for collecting type params in generalization. Fix d…
slyubomirsky Dec 11, 2018
299574d
Handle recursion in let, generalize early in that case
slyubomirsky Dec 11, 2018
0dcd205
Fix FreeVar pass doc
slyubomirsky Dec 11, 2018
ebf9688
Refactoring of free var visitors to ensure fixed order (mostly @jroes…
slyubomirsky Dec 12, 2018
6b15cf7
Add tests for various variable collection passes, add Python handles …
slyubomirsky Dec 12, 2018
6ae595a
Unifier should not keep additional state, ensure that type params *ma…
slyubomirsky Dec 17, 2018
13ecb53
Attempt at keeping type env in inferencer (broken)
slyubomirsky Dec 20, 2018
6cb0ba2
Revert "Attempt at keeping type env in inferencer (broken)"
slyubomirsky Dec 21, 2018
c2088f3
Revert "Unifier should not keep additional state, ensure that type pa…
slyubomirsky Dec 21, 2018
8ac6570
Remove generalization and tests related to it for now; it needs to be…
slyubomirsky Dec 21, 2018
b4dec40
Rename method Bounded() to MarkBounded()
slyubomirsky Dec 22, 2018
b00eb4f
Require all type args to be specified or none (however, type arg test…
slyubomirsky Dec 22, 2018
1eb81e4
Whitespace
slyubomirsky Dec 22, 2018
11bda1c
Remove redundant func type creation in instantiation in TypeSolver
slyubomirsky Dec 22, 2018
02824df
Style change (variable rename)
slyubomirsky Dec 22, 2018
f7c14e8
Instantiate type vars in type_infer instead of unifier, do not expose…
slyubomirsky Dec 22, 2018
19aa3c2
Whitespace fixes and redundant check
slyubomirsky Dec 22, 2018
b697a59
Simplify function literal type inference case
slyubomirsky Dec 22, 2018
fa4f548
Style nitpick
slyubomirsky Dec 24, 2018
af32ff4
Don't drop type params in unifier
slyubomirsky Dec 24, 2018
755611a
Clean up and better document type var instantiation hack
slyubomirsky Dec 24, 2018
a4f5c09
MergeFromTo gathers rel links recursively
slyubomirsky Dec 28, 2018
1ee69cc
Copy links over to avoid circular links
slyubomirsky Dec 28, 2018
8cf9962
Use set for storing rels, propagate after merging typenodes
slyubomirsky Dec 28, 2018
0f9d580
Propagator should be able to propagate multiple relations at once
slyubomirsky Dec 28, 2018
eb2754d
Correct description of AllVars() utility
slyubomirsky Jan 14, 2019
02e1f1c
Ensure null annotations replaced in AD
slyubomirsky Jan 15, 2019
96e4ebb
Unnecessary check
slyubomirsky Jan 15, 2019
7c1a42f
Leave grad ret type to be inferred iannotations gone
slyubomirsky Jan 16, 2019
c02375c
lint
slyubomirsky Jan 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*/
bool WellFormed(const Expr& expr);

/*! \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> BoundVars(const Expr& expr);

/*! \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
Expand All @@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr);
*/
tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> AllVars(const Expr& expr);

/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function
Expand All @@ -130,6 +149,55 @@ tvm::Array<Var> FreeVars(const Expr& expr);
*/
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);

/*! \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param t the type.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> FreeTypeVars(const Type& t);

/*! \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);

/*! \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> BoundTypeVars(const Type& t);

/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);

/*! \brief Get all type variables in type t.
*
* \param t the type.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
tvm::Array<TypeVar> AllTypeVars(const Type& t);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced, and branches that will
Expand Down
68 changes: 66 additions & 2 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ def free_vars(expr):
return _ir_pass.free_vars(expr)


def bound_vars(expr):
"""Get bound vars from expression expr in post-DFS order.

Parameters
----------
expr: tvm.relay.Expr
The input expression

Returns
-------
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _ir_pass.bound_vars(expr)


def all_vars(expr):
"""Get all vars from expression expr in post-DFS order.

Parameters
----------
expr: tvm.relay.Expr
The input expression

Returns
-------
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _ir_pass.all_vars(expr)


def free_type_vars(expr):
"""Get free type variables from expression/type e

Expand All @@ -168,12 +200,44 @@ def free_type_vars(expr):

Returns
-------
free : List[tvm.relay.TypeParam]
The list of free type variables
free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order
"""
return _ir_pass.free_type_vars(expr)


def bound_type_vars(expr):
"""Get bound type variables from expression/type e

Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type

Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)


def all_type_vars(expr):
"""Get all type variables from expression/type e

Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type

Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)


def simplify_inference(expr):
""" Simplify the data-flow graph for inference phase.

Expand Down
19 changes: 15 additions & 4 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
});
return Pair(res.foward, grad);
});

// if type annotations are provided, we will construct a ret type;
// otherwise, leave it to be inferred
Type ret_type = Type();
std::vector<Type> vt;
bool missing = !f->ret_type.defined();
for (const auto& p : f->params) {
if (missing || !p->type_annotation.defined()) {
missing = true;
break;
}
vt.push_back(p->type_annotation);
}
return FunctionNode::make(f->params,
body,
TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}),
{});

if (!missing) {
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
}

return FunctionNode::make(f->params, body, ret_type, {});
}

TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
Expand Down
Loading