Skip to content

Commit

Permalink
[Relay][Testing] Relay-to-Python compilation (#3156)
Browse files Browse the repository at this point in the history
* First pass at Relay-to-Python converter testing utility

* Indicate astor as a dependency

* Add astor dep to host as well

* Typos and small bugs

* Handle ADTs and matching in Python conversion

* Remove any dependency on ast.parse

* Eliminate unnecessary type var field in Python version of ConstructorValue (already gone on C++ side)

* Update constructor value, fix syntax errors

* Don't forget keywords arg on Call nodes

* Fix some incorrect calls to ast nodes

* Fix more calls, a little more cleaning up

* Missing cases in attr conversion

* Lower op calls instead of running them through interpreter, as in @MarisaKirisame's AoT compiler

* We do still need the module

* Remove changes to op attrs: Will PR separately

* Smoke test and corrections

* More tests and fixes

* Ensure imports are properly global in generated Python code

* Add unit tests for refs

* Add unit test for tuple indexing

* Add unit test for if expression

* Remove astor dependency

* Remove astor from meta.yaml too

* Fix if test and add basic local function test

* Add global function test, refactor earlier tests

* Correct 'clause' field in ADT so Python and C++ field names match

* More fixes and tests for matching and constructors

* Dramatically simplify matching: no need for a thunk

* Improve ref writing test

* Ensure local recursion works

* cleanup

* Add test for global recursion

* Add test for higher-order calls

* Get ops working, add basic tests

* Remove accidentally duplicated test

* More docstrings to appease pylint

* Forgot to fix a test using constructor values

* Reduce optimization level in fusion and fix tuple input to operators

* Test op with tuple output, fix tuple output code

* Add unit test for batch norm

* Add a couple more tricky test cases

* Correct nat constructor to drop unnecessary field

* Fix the op attrs file (accidentally reduced it)

* Address review comments

* Adapt to new ConstructorValue representation (no more runtime dep on module)

* Use pass manager and updated interfaces. Extend module.from_expr to accommodate necessary demands

* Use sequential return value

* Lift out nested conditionals

* Replace triple single quotes with triple double quotes

* Use main variable instead of entry_func
  • Loading branch information
slyubomirsky authored and jroesch committed Jul 10, 2019
1 parent 93d1c06 commit db841c2
Show file tree
Hide file tree
Showing 12 changed files with 1,197 additions and 20 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class MatchNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("clauses", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,19 @@ class ModuleNode : public RelayNode {

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});
const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});

static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class Closure(Value):

@register_relay_node
class ConstructorValue(Value):
def __init__(self, tag, fields, constructor, types):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor, types)
_make.ConstructorValue, tag, fields, constructor)


@register_relay_node
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def visit_tuple_getitem(self, t):

def visit_match(self, m):
self.visit(m.data)
for c in m.clause:
for c in m.clauses:
self.visit(c.rhs)


Expand Down
25 changes: 23 additions & 2 deletions python/tvm/relay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,26 @@ def get_constructor(self, tag):
return _module.Module_LookupTag(self, tag)

@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
def from_expr(expr, functions=None, type_defs=None):
"""Construct a module from a standalone expression.
Parameters
----------
expr: Expr
The starting expression
global_funcs: Optional[dict]
Map of global vars to function definitions
type_defs: Optional[dict]
Map of global type vars to type definitions
Returns
-------
mod: Module
A module containing the passed definitions,
where expr is set as the entry point
(wrapped in a function if necessary)
"""
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
from .py_converter import to_python, run_as_python


def run_opt_pass(expr, opt_pass):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def make_nat_value(prelude, n):
constructs a ConstructorValue representing that value as a nat.
"""
if n == 0:
return ConstructorValue(prelude.z.tag, [], None, [])
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
return ConstructorValue(prelude.z.tag, [], None)
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None)


def make_nat_expr(prelude, n):
Expand Down
Loading

0 comments on commit db841c2

Please sign in to comment.