diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 2a6507b62a33..d46387eea822 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -243,7 +243,7 @@ class MatchNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); - v->Visit("clause", &clauses); + v->Visit("clauses", &clauses); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index e888c54c17ac..0b582514f5c3 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -180,17 +180,19 @@ class ModuleNode : public RelayNode { /*! \brief Construct a module from a standalone expression. * - * Allows one to optionally pass a global function map as - * well. + * Allows one to optionally pass a global function map and + * map of type definitions as well. * * \param expr The expression to set as the main function to the module. * \param global_funcs The global function map. + * \param type_definitions Map of global type definitions * * \returns A module with expr set as the main function. */ TVM_DLL static Module FromExpr( const Expr& expr, - const tvm::Map& global_funcs = {}); + const tvm::Map& global_funcs = {}, + const tvm::Map& type_definitions = {}); static constexpr const char* _type_key = "relay.Module"; TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 462dda9488c2..64024d6a8b11 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -74,9 +74,9 @@ class Closure(Value): @register_relay_node class ConstructorValue(Value): - def __init__(self, tag, fields, constructor, types): + def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( - _make.ConstructorValue, tag, fields, constructor, types) + _make.ConstructorValue, tag, fields, constructor) @register_relay_node diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 9ca094158c1c..e814f87083bf 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -183,7 +183,7 @@ def visit_tuple_getitem(self, t): def visit_match(self, m): self.visit(m.data) - for c in m.clause: + for c in m.clauses: self.visit(c.rhs) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 8ac15f743fc4..e0511a257e6d 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -179,5 +179,26 @@ def get_constructor(self, tag): return _module.Module_LookupTag(self, tag) @staticmethod - def from_expr(expr): - return _module.Module_FromExpr(expr) + def from_expr(expr, functions=None, type_defs=None): + """Construct a module from a standalone expression. + + Parameters + ---------- + expr: Expr + The starting expression + global_funcs: Optional[dict] + Map of global vars to function definitions + type_defs: Optional[dict] + Map of global type vars to type definitions + + + Returns + ------- + mod: Module + A module containing the passed definitions, + where expr is set as the entry point + (wrapped in a function if necessary) + """ + funcs = functions if functions is not None else {} + defs = type_defs if type_defs is not None else {} + return _module.Module_FromExpr(expr, funcs, defs) diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index de9e55b369d1..871250e9c9f3 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -35,6 +35,7 @@ from .config import ctx_list from .init import create_workload from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr +from .py_converter import to_python, run_as_python def run_opt_pass(expr, opt_pass): diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index a76a340f113d..eb71120610d3 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -168,8 +168,8 @@ def make_nat_value(prelude, n): constructs a ConstructorValue representing that value as a nat. """ if n == 0: - return ConstructorValue(prelude.z.tag, [], None, []) - return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, []) + return ConstructorValue(prelude.z.tag, [], None) + return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None) def make_nat_expr(prelude, n): diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py new file mode 100644 index 000000000000..c003fe788a11 --- /dev/null +++ b/python/tvm/relay/testing/py_converter.py @@ -0,0 +1,592 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility for converting Relay code into a Python script with equivalent semantics""" +import ast +from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str +import re + +import tvm +from tvm import relay +from tvm.relay.adt import Pattern +from tvm.relay.backend import compile_engine +from tvm.relay.expr import Expr, Function, GlobalVar, Var +from tvm.relay.expr_functor import ExprFunctor + +OUTPUT_VAR_NAME = '_py_out' + +# corresponds to: +# import numpy +# import tvm +# from tvm import relay +# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue +PROLOGUE = [ + ast.Import([alias('numpy', None)]), + ast.Import([alias('tvm', None)]), + ast.ImportFrom('tvm', [alias('relay', None)], 0), + ast.ImportFrom('tvm.relay.backend.interpreter', + [alias('RefValue', None), + alias('TupleValue', None), + alias('TensorValue', None), + alias('ConstructorValue', None)], + 0) +] + +class PythonConverter(ExprFunctor): + """Functor for translating Relay programs into Python ASTs.""" + + def __init__(self, mod, target) -> None: + super().__init__() + self.mod = mod + self.tgt = target + self.engine = compile_engine.get() + self.fun_no = 0 + self.var_no = 0 + self.var_map = {} + + + def convert(self, prog: Expr): + """This method converts the passed Relay expression into a Python + AST object with equivalent semantics. + + The Python AST can be executed using exec(); it can be turned + into text and inspected using astor. + """ + optimized = self.optimize(prog) + + # start with conversion prelude (imports) and convert global defs + body = [] + body += PROLOGUE + body += self.convert_module() + + prog_body, extra_defs = self.visit(optimized) + body += extra_defs + + # we finally must assign the final expression to the output var + # so it can be read after running EXEC + body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body)) + + return ast.fix_missing_locations(ast.Module(body=body)) + + + def optimize(self, prog: Expr): + """Performs optimizations necessary to be able to generate code for prog.""" + # unwrap tuple wrappers (some op calls produce them) + unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog + assert relay.analysis.well_formed(unwrapped) + mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions) + + # necessary pass: SimplifyInference (otherwise we can't generate code for some operators) + # and fusion (to get primitive functions) + opts = relay.transform.Sequential([relay.transform.SimplifyInference(), + relay.transform.FuseOps(fuse_opt_level=0)]) + mod = opts(mod) + optimized = mod['main'] + return optimized if isinstance(unwrapped, Function) else optimized.body + + + def sanitize(self, name: str) -> str: + """Removes any invalid characters (only underscores, numbers, and letters permitted) + from the given name. Since we append a number and underscore to var names anyway, + it doesn't matter if the name is the empty string.""" + return re.sub(r'\W', '', name) + + + def generate_var_name(self, name_hint: str) -> str: + """Generates a unique variable name starting from the hint.""" + name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no) + self.var_no += 1 + return name + + + def generate_function_name(self, name_hint: str) -> str: + """Generates a unique function name starting from the hint.""" + name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no) + self.fun_no += 1 + return name + + + def get_var_name(self, var: Expr) -> str: + """Returns the var name for the given Realy variable.""" + if var in self.var_map: + return self.var_map[var] + name = self.generate_var_name(var.name_hint) + self.var_map[var] = name + return name + + + def include_var(self, var: Expr, assign=False): + """Returns a variable AST node for the given Relay var depending on + whether it must appear in an assignment or not.""" + name = self.get_var_name(var) + return Name(name, Store() if assign else Load()) + + + def parse_name(self, name: str): + """Given the name of a Python method with dots (e.g., 'relay.var'), + returns an appropriate AST object corresponding to that name.""" + attributes = name.split('.') + ret = Name(attributes[0], Load()) + for i in range(len(attributes) - 1): + ret = ast.Attribute(ret, attributes[i+1], Load()) + return ret + + + def parse_numpy_array(self, arr): + """Given a Numpy array, produces an appropriate Python array + or numerical literal representing its contents.""" + parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) + if arr.ndim == 0: + return parse_single(arr.item()) + if arr.ndim == 1: + return ast.List([parse_single(i.item()) for i in arr], Load()) + + elts = [] + for row in arr: + elts.append(self.parse_numpy_array(row)) + return ast.List(elts, Load()) + + + def convert_fields(self, fields: [Expr]): + """Given a list of call args or tuple fields, converts + each and returns their ASTs and their defs lists (in order).""" + bodies = [] + defs = [] + for field in fields: + member_body, member_defs = self.visit(field) + bodies.append(member_body) + defs += member_defs + return (bodies, defs) + + + def convert_to_thunk(self, name_hint: str, expr: Expr): + """Wraps the passed expression in a thunk.""" + body, defs = self.visit(expr) + thunk_name = self.generate_function_name(name_hint) + thunk = self.create_def(thunk_name, [], defs + [Return(body)]) + return (thunk, thunk_name) + + + def convert_func_node(self, func: Function, name_var=None): + """Converts the given Relay function into a Python function, with + special for named functions (locally or globally)""" + if name_var is None: + func_name = self.generate_function_name('_anon_func') + if isinstance(name_var, GlobalVar): + func_name = name_var.name_hint + if isinstance(name_var, Var): + func_name = self.get_var_name(name_var) + + var_names = [self.get_var_name(var) for var in func.params] + body, defs = self.visit(func.body) + ret = self.create_def(func_name, var_names, defs + [Return(body)]) + return (ret, func_name) + + + def convert_module(self): + """Converts all the global functions defined in the module and returns + them as a list of definitions""" + defs = [] + for var, func in self.mod.functions.items(): + # optimize the definition so any operators used are lowered + opt_func = self.optimize(func) + converted_func, _ = self.convert_func_node(opt_func, var) + defs.append(converted_func) + return defs + + + def create_call(self, func_name: str, arguments): + """Creates a simple function call.""" + return ast.Call(self.parse_name(func_name), arguments, []) + + + def create_def(self, func_name: str, arguments: [str], body): + """Wrapper over function definition AST node, whose constructor is inconvenient.""" + return ast.FunctionDef( + func_name, + ast.arguments([ast.arg(argument, None) + for argument in arguments], + None, [], [], None, []), + body, [], None) + + + def create_op_call(self, op: Function, relay_args, py_args): + """Lowers the passed primitive function, registers it in TVM's + global compiler, and produces a call to the lowered function in + the generated Python code.""" + + # compile the function and register globally + cc_key = compile_engine.CCacheKey(op, self.tgt) + func_hash = relay.analysis.structural_hash(op) + op_name = '_lowered_op_{}'.format(func_hash) + if not tvm.get_global_func(op_name, allow_missing=True): + jitted = self.engine.jit(cc_key, self.tgt) + tvm.register_func(op_name, jitted) + + def convert_input(py_input, arg_type): + """Use the types of the function arguments to determine whether we expect + a tensor or tuple (returns list of inputs to the lowered op call)""" + # equivalent: input.data + if isinstance(arg_type, relay.TensorType): + return [ast.Attribute(py_input, 'data', Load())] + assert isinstance(arg_type, relay.TupleType) + # convert each input.fields[i] + ret = [] + for i in range(len(arg_type.fields)): + ret += convert_input( + ast.Subscript( + ast.Attribute(py_input, 'fields', Load()), + ast.Index(Num(i)), Load()), + arg_type.fields[i]) + return ret + + def convert_output(ret_type): + """Use the function return type to produce auxiliary variables to store outputs. + Returns ([assignments of output vars], [extra arguments to pass to op call], + expression collecting output)""" + if isinstance(ret_type, relay.TensorType): + output_var_name = self.generate_var_name('_out') + output_var = Name(output_var_name, Load()) + shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) + # create a new TensorValue of the right shape and dtype + assign_output = Assign( + [Name(output_var_name, Store())], + self.create_call('TensorValue', [ + self.create_call('numpy.empty', [shape, Str(ret_type.dtype)]) + ])) + # we pass the data field as an argument + extra_arg = ast.Attribute(output_var, 'data', Load()) + return ([assign_output], [extra_arg], output_var) + assert isinstance(ret_type, relay.TupleType) + assignments = [] + extra_args = [] + fields = [] + for t in ret_type.fields: + inner_assignments, inner_args, inner_output = convert_output(t) + assignments += inner_assignments + extra_args += inner_args + fields.append(inner_output) + return (assignments, extra_args, self.create_call('TupleValue', fields)) + + # create a function to wrap the call of the lowered op and return + # a call to that function + wrap_name = self.generate_function_name('_{}_wrapper'.format(op_name)) + wrap_args = [self.generate_var_name('_arg_{}'.format(i)) for i in range(len(py_args))] + + inner_call_args = [] + for i in range(len(py_args)): + inner_call_args += convert_input(Name(wrap_args[i], Load()), + relay_args[i].checked_type) + output_assignments, aux_args, output = convert_output(op.checked_type.ret_type) + # equiv: _op = tvm.get_global_func(op_name) + op_var = self.generate_var_name('_op') + op_call = self.create_call('tvm.get_global_func', [Str(op_name)]) + op_assign = Assign([Name(op_var, Store())], op_call) + # equiv: _op(args) + inner_call = self.create_call(op_var, inner_call_args + aux_args) + body = output_assignments + [op_assign, ast.Expr(inner_call), Return(output)] + wrap_def = self.create_def(wrap_name, wrap_args, body) + return wrap_def, self.create_call(wrap_name, py_args) + + + def create_match_check(self, pattern: Pattern, data): + """Given an ADT match pattern and a (Python) expression pointing to + an ADT value, this generates a Python expression that checks if the + ADT value matches the given pattern (returning True or False).""" + + # wildcard or var match everything + if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): + return NameConstant(True) + + # constructor patterns check whether the constructors match + # and also the matches of any nested patterns + + # equiv: (arg.tag == patern_constructor.tag) + conds = [ast.Compare(ast.Attribute(data, 'tag', Load()), + [ast.Eq()], + [ast.Num(pattern.constructor.tag)])] + + # now check for any nested patterns + for i in range(len(pattern.patterns)): + nested_pat = pattern.patterns[i] + # can safely skip var or wildcard patterns: they will + # never cause a check to fail + if not isinstance(nested_pat, relay.PatternConstructor): + continue + + # index into the value corresponding to the subpattern + field_index = ast.Subscript(ast.Attribute(data, 'fields', Load()), + ast.Index(Num(i)), Load()) + conds.append(self.create_match_check(nested_pat, field_index)) + + # if we do not need to check nested pattern, just return the single check + if len(conds) == 1: + return conds[0] + # otherwise AND together any nested checks + return ast.BoolOp(ast.And(), conds) + + + def create_match_clause_body(self, pattern: Pattern, body: Expr): + """Given a match clause pattern and a clause body, + generates a Python function that when called with an ADT + that matches the pattern, returns the result of evaluating + the clause body. This function returns a function definition + and the name of the generated function.""" + + def collect_var_assignments(pat, val): + """This helper function ensures that the pattern is used to + properly assign all subfields of the given AST for use + in the clause body + + E.g., for PatternConstructor(A, PatternVar(v), PatternWildcard(), + PatternConstructor(B, PatternVar(w))) + we would want to have + v = a.fields[0] + w = a.fields[2].fields[0] + """ + if isinstance(pat, relay.PatternWildcard): + return [] + if isinstance(pat, relay.PatternVar): + return [Assign([self.include_var(pat.var, assign=True)], val)] + # constructor pattern: assign each field of the value + # based on subpatterns + assignments = [] + for i in range(len(pat.patterns)): + # we want the assignments for val.fields[i] + field = ast.Subscript(ast.Attribute(val, 'fields', Load()), + ast.Index(Num(i)), Load()) + assignments += collect_var_assignments(pat.patterns[i], field) + return assignments + + func_name = self.generate_function_name('_match_clause_body') + arg_name = self.generate_var_name('_match_clause_body') + + clause_body, defs = self.visit(body) + assignments = collect_var_assignments(pattern, Name(arg_name, Load())) + + func_def = self.create_def(func_name, [arg_name], + defs + assignments + [Return(clause_body)]) + return (func_def, func_name) + + + # Convention for the expr visitor: Each visit function returns a tuple of two members. + # + # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent + # result to the desired Relay expression (and executes all effects in the right order). + # + # The second is a list of function definition *statements* defining thunks and other + # auxiliary functions needed in the translated AST object. The defs in the second object + # will always have unique names and will never perform any effects, so as long as they + # appear in the Python program before the first statement is executed, there should not + # be any problems. + + def visit_var(self, var: Expr): + return (self.include_var(var, assign=False), []) + + + def visit_global_var(self, gvar: Expr): + # we don't need to add numbers to global var names because + # the *names* are checked for uniqueness in the mod + return (Name(gvar.name_hint, Load()), []) + + + def visit_let(self, letexp: Expr): + # To properly account for scoping and ensure that the entire node produces an expression, + # we translate the let binding as a function that we call with the value we intend to bind. + # Yes, this is somewhat ugly. + """ + let var = value in body + ======================= + def let_thunk(var): + return body + let_thunk(value) + """ + bind_body, bind_defs = self.visit(letexp.body) + + func_name = self.generate_function_name('_let_func') + binding_func = self.create_def(func_name, [self.get_var_name(letexp.var)], + bind_defs + [Return(bind_body)]) + + # we call the binding func with the intended value for the bound variable + + # special case: if the value is a function literal, we must ensure it can be + # recursive by naming it after the var + if isinstance(letexp.value, Function): + value_def, value_name = self.convert_func_node(letexp.value, letexp.var) + return (self.create_call(func_name, [Name(value_name, Load())]), + [value_def, binding_func]) + + value_body, value_defs = self.visit(letexp.value) + value_defs.append(binding_func) + binding_call = self.create_call(func_name, [value_body]) + return (binding_call, value_defs) + + + def visit_tuple(self, tup: Expr): + fields, ret_defs = self.convert_fields(tup.fields) + return (self.create_call('TupleValue', fields), ret_defs) + + + def visit_tuple_getitem(self, tgi: Expr): + tup, tup_defs = self.visit(tgi.tuple_value) + ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load()) + return (ret, tup_defs) + + + def visit_if(self, if_block: Expr): + cond_body, cond_defs = self.visit(if_block.cond) + true_body, true_defs = self.visit(if_block.true_branch) + false_body, false_defs = self.visit(if_block.false_branch) + + # need to get the value out of a TensorValue to check the condition + # equvialent to: val.asnumpy() + cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], []) + ret = ast.IfExp(cond_check, true_body, false_body) + return (ret, cond_defs + true_defs + false_defs) + + + def visit_constant(self, constant: Expr): + """Proceeds by converting constant value to a numpy array + and converting it to the appropriate value in the generated + code (whether it be a Python scalar or a Numpy array)""" + value = constant.data.asnumpy() + const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), + [self.parse_numpy_array(value)], + [ast.keyword('dtype', Str(constant.checked_type.dtype))]) + return (self.create_call('TensorValue', [const_expr]), []) + + + def visit_function(self, func: Expr): + # Python's lambdas are very restrictive, so we do "name" inline functions + converted_func, func_name = self.convert_func_node(func) + return (Name(func_name, Load()), [converted_func]) + + + def visit_call(self, call: Expr): + """For calls, we must distinguish between ordinary functions, + operators, and constructor calls.""" + func = call.op + fields, field_defs = self.convert_fields(call.args) + + if isinstance(func, relay.Op): + raise Exception('Operators should have been lowered and eliminated') + + if isinstance(func, relay.Constructor): + # produce a constructor value + return (self.create_call('ConstructorValue', + [ast.Num(func.tag), + ast.List(fields, Load()), + NameConstant(None)]), + field_defs) + + # lowered operator: generate a call to a function that gets the PackedFunc + # from TVM's registry + if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1: + op_call_def, op_call = self.create_op_call(func, call.args, fields) + return (op_call, field_defs + [op_call_def]) + + # ordinary function + converted_func, defs = self.visit(func) + defs += field_defs + return (ast.Call(converted_func, fields, []), defs) + + + def visit_ref_create(self, ref: Expr): + val, defs = self.visit(ref.value) + return (self.create_call('RefValue', [val]), defs) + + + def visit_ref_read(self, read: Expr): + ref, defs = self.visit(read.ref) + return (ast.Attribute(ref, 'value', Load()), defs) + + + def visit_ref_write(self, write: Expr): + """For writing refs, we wrap the update in a thunk + (returning an empty tuple to match Relay's semantics) + that we execute at the right time. This ensures such assignments + can be properly nested, since assignments are statements + in Python but expressions in Relay""" + ref, ref_defs = self.visit(write.ref) + val, val_defs = self.visit(write.value) + thunk_name = self.generate_function_name('_ref_write_thunk') + thunk = self.create_def( + thunk_name, [], + ref_defs + val_defs + [ + Assign([ast.Attribute(ref, 'value', Store())], val), + Return(self.create_call('TupleValue', [])) + ]) + return (self.create_call(thunk_name, []), [thunk]) + + + def visit_match(self, match: Expr): + """For matches, we wrap the entire expression in a thunk + because it is easiest to implement them using if statements. + For each clause, we generate a function that checks if the + pattern matches. If yes, we call a function that assigns + the variables appropriately and invokes the clause body.""" + data, defs = self.visit(match.data) + data_var = self.generate_var_name('_match_data') + + # must ensure the data clause is executed exactly once + thunk_body = [Assign([Name(data_var, Store())], data)] + for clause in match.clauses: + check_expr = self.create_match_check(clause.lhs, Name(data_var, Load())) + body_def, body_name = self.create_match_clause_body(clause.lhs, clause.rhs) + defs.append(body_def) + + # equiv: if check(data): return body(data) + thunk_body.append(ast.If( + check_expr, + [Return(self.create_call(body_name, [Name(data_var, Load())]))], + [] + )) + + # finally if nothing matches we have a failed assert (should never happen) + thunk_body.append(ast.Assert(NameConstant(False), Str('Match was not exhaustive'))) + + thunk_name = self.generate_function_name('_match_thunk') + thunk_def = self.create_def(thunk_name, [], defs + thunk_body) + return (self.create_call(thunk_name, []), [thunk_def]) + + + # these are both handled in the "call" case + def visit_constructor(self, _): + pass + def visit_op(self, _): + pass + + +def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): + """Converts the given Relay expression into a Python script (as a Python AST object). + For easiest debugging, import the astor package and use to_source().""" + mod = mod if mod is not None else relay.Module() + converter = PythonConverter(mod, target) + return converter.convert(expr) + + +def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): + """Converts the given Relay expression into a Python script and + executes it.""" + mod = mod if mod is not None else relay.Module() + py_ast = to_python(expr, mod, target) + code = compile(py_ast, '', 'exec') + var_map = { + OUTPUT_VAR_NAME : None + } + #pylint: disable=exec-used + exec(code, var_map, var_map) + return var_map[OUTPUT_VAR_NAME] diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 0ad0a91efd21..6741f875838e 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -187,8 +187,9 @@ void ModuleNode::Update(const Module& mod) { Module ModuleNode::FromExpr( const Expr& expr, - const tvm::Map& global_funcs) { - auto mod = ModuleNode::make(global_funcs, {}); + const tvm::Map& global_funcs, + const tvm::Map& type_definitions) { + auto mod = ModuleNode::make(global_funcs, type_definitions); auto func_node = expr.as(); Function func; if (func_node) { @@ -266,9 +267,14 @@ TVM_REGISTER_API("relay._module.Module_LookupTag") }); TVM_REGISTER_API("relay._module.Module_FromExpr") -.set_body_typed([](Expr e) { - return ModuleNode::FromExpr(e); -}); +.set_body_typed< + Module(Expr, + tvm::Map, + tvm::Map)>([](Expr e, + tvm::Map funcs, + tvm::Map type_defs) { + return ModuleNode::FromExpr(e, funcs, type_defs); + }); TVM_REGISTER_API("relay._module.Module_Update") .set_body_typed([](Module mod, Module from) { diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 390576f87c18..7be7c75dfe64 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -75,9 +75,9 @@ def count(e): # this is an example of creating the adt value in python side def make_nat(n): if n != 0: - return ConstructorValue(s, [make_nat(n - 1)], []) + return ConstructorValue(s, [make_nat(n - 1)]) else: - return ConstructorValue(z, [], []) + return ConstructorValue(z, []) def make_nat_expr(n): assert n >= 0 diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 0e5e981a5321..c1a19c4d9bb1 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple(): prelude = relay.prelude.Prelude(mod) intrp = create_executor("debug", mod) - nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, []) + nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil) cons_value = ConstructorValue(prelude.cons.tag, [ TensorValue(np.random.rand(1, 10).astype('float32')), nil_value - ], prelude.cons, [relay.TensorType((1, 10), 'float32')]) + ], prelude.cons) ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) tuple_value = TupleValue(*[ diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py new file mode 100644 index 000000000000..49a2219dcd04 --- /dev/null +++ b/tests/python/relay/test_py_converter.py @@ -0,0 +1,555 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relay +from tvm.relay.testing import to_python, run_as_python +from tvm.relay.prelude import Prelude +from tvm.relay.backend.interpreter import TensorValue, TupleValue, RefValue, ConstructorValue + +# helper: uses a dummy let binding to sequence a list +# of expressions: expr1; expr2; expr3, etc. +def seq(*exprs): + ret = exprs[0] + for expr in exprs[1:]: + ret = relay.Let(relay.var('_'), ret, expr) + return ret + + +# creates a dummy ADT for testing +def init_box_adt(mod): + box = relay.GlobalTypeVar('box') + a = relay.TypeVar('a') + box_ctor = relay.Constructor('box', [a], box) + mod[box] = relay.TypeData(box, [a], [box_ctor]) + return (box, box_ctor) + + +# assert that the candidate is a TensorValue with value val +def assert_tensor_value(candidate, val): + assert isinstance(candidate, TensorValue) + assert np.array_equal(candidate.asnumpy(), np.array(val)) + + +# assert that the candidate is a TupleValue with the indicate number of fields +def assert_tuple_value(candidate, fields): + assert isinstance(candidate, TupleValue) + assert len(candidate.fields) == fields + + +# assert that the candidate is a ConstructorValue with the approrpaite constructor +# and number of fields +def assert_constructor_value(candidate, constructor, fields): + assert isinstance(candidate, ConstructorValue) + assert candidate.tag == constructor.tag + assert len(candidate.fields) == fields + + +def test_create_empty_tuple(): + empty = relay.Tuple([]) + tup_val = run_as_python(empty) + assert_tuple_value(tup_val, 0) + + +def test_create_scalar(): + scalar = relay.const(1) + tensor_val = run_as_python(scalar) + assert_tensor_value(tensor_val, 1) + + +def test_create_tensor(): + tensor = relay.const([[1, 1], [2, 2]]) + tensor_val = run_as_python(tensor) + assert_tensor_value(tensor_val, [[1, 1], [2, 2]]) + + +def test_create_nested_tuple(): + relay_tup = relay.Tuple([ + relay.const(1), relay.const(2), + relay.Tuple([ + relay.const(3), + relay.const(4) + ]) + ]) + tup_val = run_as_python(relay_tup) + assert_tuple_value(tup_val, 3) + for i in range(2): + assert_tensor_value(tup_val.fields[i], i + 1) + assert_tuple_value(tup_val.fields[2], 2) + for i in range(2): + assert_tensor_value(tup_val.fields[2].fields[i], i + 3) + + +def test_tuple_get_item(): + relay_tup = relay.Tuple([ + relay.const(1), relay.const(2), + relay.Tuple([ + relay.const(3), + relay.const(4) + ]) + ]) + for i in range(2): + index = relay.TupleGetItem(relay_tup, i) + val = run_as_python(index) + assert_tensor_value(val, i + 1) + # try the inner value too + for i in range(2): + index = relay.TupleGetItem(relay.TupleGetItem(relay_tup, 2), i) + val = run_as_python(index) + assert_tensor_value(val, i + 3) + + +def test_create_let(): + v = relay.Var('v') + let = relay.Let(v, relay.Tuple([]), relay.Tuple([v, v])) + tup_val = run_as_python(let) + assert_tuple_value(tup_val, 2) + assert_tuple_value(tup_val.fields[0], 0) + assert_tuple_value(tup_val.fields[1], 0) + + +def test_create_ref(): + relay_ref = relay.RefCreate(relay.Tuple([])) + ref_val = run_as_python(relay_ref) + assert isinstance(ref_val, RefValue) + assert_tuple_value(ref_val.value, 0) + + +def test_ref_read(): + v = relay.Var('v') + assign = relay.Let(v, relay.RefCreate(relay.Tuple([])), relay.RefRead(v)) + read_val = run_as_python(assign) + assert_tuple_value(read_val, 0) + + +def test_ref_write(): + # check that the result of a ref write is an empty tuple + v = relay.Var('v') + initial_write = relay.Let(v, relay.RefCreate(relay.Tuple([relay.const(1)])), + relay.RefWrite(v, relay.Tuple([relay.const(2)]))) + write_val = run_as_python(initial_write) + assert_tuple_value(write_val, 0) + + # now ensure that the value, once written, can be read back + # (we read the value before and after mutation) + w = relay.Var('w') + read_after_write = relay.Let( + v, relay.RefCreate(relay.Tuple([relay.const(1)])), + relay.Let( + w, relay.RefCreate(relay.RefRead(v)), + seq(relay.RefWrite(v, relay.Tuple([relay.const(2)])), + relay.Tuple([relay.RefRead(w), relay.RefRead(v)])))) + read_val = run_as_python(read_after_write) + assert_tuple_value(read_val, 2) + assert_tuple_value(read_val.fields[0], 1) + assert_tuple_value(read_val.fields[1], 1) + assert_tensor_value(read_val.fields[0].fields[0], 1) + assert_tensor_value(read_val.fields[1].fields[0], 2) + + +def test_if(): + # we will have effects in the blocks to ensure only the intended one is executed + true_cond = relay.const(True) + false_cond = relay.const(False) + + v = relay.Var('v') + true_branch = seq(relay.RefWrite(v, relay.const(1)), relay.RefRead(v)) + false_branch = seq(relay.RefWrite(v, relay.const(2)), relay.RefRead(v)) + + true_expr = relay.Let(v, relay.RefCreate(relay.const(0)), + relay.If(true_cond, true_branch, false_branch)) + false_expr = relay.Let(v, relay.RefCreate(relay.const(0)), + relay.If(false_cond, true_branch, false_branch)) + + true_val = run_as_python(true_expr) + assert_tensor_value(true_val, 1) + + false_val = run_as_python(false_expr) + assert_tensor_value(false_val, 2) + + +def test_local_function(): + v = relay.Var('v') + ident = relay.Function([v], v) + f = relay.Var('f') + call1 = relay.Let(f, ident, f(relay.Tuple([]))) + call2 = relay.Let(f, ident, f(relay.const(2))) + + call_val1 = run_as_python(call1) + assert_tuple_value(call_val1, 0) + + call_val2 = run_as_python(call2) + assert_tensor_value(call_val2, 2) + + +def test_global_function(): + mod = relay.Module() + ident = relay.GlobalVar('ident') + a = relay.TypeVar('a') + v = relay.Var('v', a) + mod[ident] = relay.Function([v], v, a, [a]) + + call1 = ident(relay.const(1)) + call2 = ident(relay.Tuple([relay.const(2), relay.const(2)])) + + call_val1 = run_as_python(call1, mod) + assert_tensor_value(call_val1, 1) + + call_val2 = run_as_python(call2, mod) + assert_tuple_value(call_val2, 2) + assert_tensor_value(call_val2.fields[0], 2) + assert_tensor_value(call_val2.fields[1], 2) + + +def test_constructor(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + + init_box_int = box_ctor(relay.const(1)) + box_val_int = run_as_python(init_box_int, mod) + + assert_constructor_value(box_val_int, box_ctor, 1) + assert_tensor_value(box_val_int.fields[0], 1) + + init_box_tup = box_ctor(relay.Tuple([])) + box_val_tup = run_as_python(init_box_tup, mod) + + assert_constructor_value(box_val_tup, box_ctor, 1) + assert_tuple_value(box_val_tup.fields[0], 0) + + +def test_match_wildcard(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + v = relay.Var('v') + match = relay.Let( + v, box_ctor(relay.Tuple([])), + relay.Match(v, [ + relay.Clause(relay.PatternWildcard(), relay.const(1)) + ])) + + match_val = run_as_python(match, mod) + assert_tensor_value(match_val, 1) + + +def test_match_var(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + v = relay.Var('v') + w = relay.Var('w') + match = relay.Let( + v, box_ctor(relay.const(1)), + relay.Match(v, [ + relay.Clause(relay.PatternVar(w), w) + ])) + + match_val = run_as_python(match, mod) + assert_constructor_value(match_val, box_ctor, 1) + assert_tensor_value(match_val.fields[0], 1) + + +def test_match_pattern(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + v = relay.Var('v') + w = relay.Var('w') + match = relay.Let( + v, box_ctor(relay.const(1)), + relay.Match(v, [ + relay.Clause(relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]), w) + ])) + match_val = run_as_python(match, mod) + assert_tensor_value(match_val, 1) + + +def test_nested_match_pattern(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + v = relay.Var('v') + w = relay.Var('w') + match = relay.Let( + v, box_ctor(box_ctor(relay.const(2))), + relay.Match(v, [ + relay.Clause( + relay.PatternConstructor( + box_ctor, [ + relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]) + ]), + w)])) + match_val = run_as_python(match, mod) + assert_tensor_value(match_val, 2) + +def test_match_order(): + mod = relay.Module() + box, box_ctor = init_box_adt(mod) + v = relay.Var('v') + w = relay.Var('w') + # wildcard pattern goes first + match = relay.Let( + v, box_ctor(box_ctor(relay.const(2))), + relay.Match(v, [ + relay.Clause(relay.PatternWildcard(), relay.const(1)), + relay.Clause( + relay.PatternConstructor( + box_ctor, [ + relay.PatternConstructor(box_ctor, [relay.PatternVar(w)]) + ]), + w)])) + match_val = run_as_python(match, mod) + assert_tensor_value(match_val, 1) + + +def test_local_recursion(): + mod = relay.Module() + p = Prelude(mod) + + v = relay.Var('v') + h = relay.Var('h') + t = relay.Var('t') + f = relay.Var('f') + + # just returns the same list + let = relay.Let(f, relay.Function([v], relay.Match(v, [ + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(h), relay.PatternVar(t)]), + p.cons(h, f(t))), + relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()) + ])), + f(p.cons(relay.const(1), + p.cons(relay.const(2), + p.cons(relay.const(3), p.nil()))))) + + val = run_as_python(let, mod) + assert_constructor_value(val, p.cons, 2) + assert_tensor_value(val.fields[0], 1) + assert_constructor_value(val.fields[1], p.cons, 2) + assert_tensor_value(val.fields[1].fields[0], 2) + assert_constructor_value(val.fields[1].fields[1], p.cons, 2) + assert_tensor_value(val.fields[1].fields[1].fields[0], 3) + assert_constructor_value(val.fields[1].fields[1].fields[1], p.nil, 0) + + +def test_global_recursion(): + mod = relay.Module() + p = Prelude(mod) + copy = relay.GlobalVar('copy') + # same as above: it copies the given list + a = relay.TypeVar('a') + v = relay.Var('v', p.l(a)) + h = relay.Var('h') + t = relay.Var('t') + copy_def = relay.Function([v], relay.Match(v, [ + relay.Clause(relay.PatternConstructor(p.cons, + [relay.PatternVar(h), relay.PatternVar(t)]), + p.cons(h, copy(t))), + relay.Clause(relay.PatternConstructor(p.nil, []), p.nil()) + ]), p.l(a), [a]) + mod[copy] = copy_def + + call1 = copy_def(p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))) + val1 = run_as_python(call1, mod) + assert_constructor_value(val1, p.cons, 2) + assert_tensor_value(val1.fields[0], 1) + assert_constructor_value(val1.fields[1], p.cons, 2) + assert_tensor_value(val1.fields[1].fields[0], 2) + assert_constructor_value(val1.fields[1].fields[1], p.nil, 0) + + call2 = copy_def(p.cons(relay.Tuple([]), p.nil())) + val2 = run_as_python(call2, mod) + assert_constructor_value(val2, p.cons, 2) + assert_tuple_value(val2.fields[0], 0) + assert_constructor_value(val2.fields[1], p.nil, 0) + + +def test_higher_order_call(): + # test with anon func + h = relay.Var('h') + f = relay.Var('f') + x = relay.Var('x') + ho_anon = relay.Let(h, relay.Function([f], f(relay.Tuple([]))), + h(relay.Function([x], relay.const(1)))) + + anon_val = run_as_python(ho_anon) + assert_tensor_value(anon_val, 1) + + # test with named func + g = relay.Var('g') + ho_named = relay.Let(h, relay.Function([f], f(relay.Tuple([]))), + relay.Let(g, relay.Function([x], relay.const(2)), + h(g))) + named_val = run_as_python(ho_named) + assert_tensor_value(named_val, 2) + + +def test_match_effect_exactly_once(): + mod = relay.Module() + p = Prelude(mod) + + # the list should be of length 1! + # Unless we mistakenly execute the data clause more than once + r = relay.Var('r') + data = seq(relay.RefWrite(r, p.cons(relay.Tuple([]), relay.RefRead(r))), relay.RefRead(r)) + match = relay.Let( + r, relay.RefCreate(p.nil()), + relay.Match(data, [ + relay.Clause(relay.PatternConstructor(p.nil, []), relay.const(0)), + relay.Clause( + relay.PatternConstructor( + p.cons, + [relay.PatternWildcard(), relay.PatternConstructor(p.nil, [])]), + relay.const(1)), + relay.Clause(relay.PatternWildcard(), relay.const(2)) + ])) + + match_val = run_as_python(match, mod) + assert_tensor_value(match_val, 1) + + +def test_arbitrary_let_nesting(): + # something that is tricky to do in Python but comes naturally in Relay + mod = relay.Module() + p = Prelude(mod) + x = relay.Var('x') + r = relay.Var('r') + y = relay.Var('y') + z = relay.Var('z') + expr = relay.Tuple([ + relay.Let(x, relay.Tuple([relay.const(1), relay.const(2)]), + relay.TupleGetItem(x, 1)), + relay.Let(r, relay.RefCreate(relay.const(1)), + seq(relay.RefWrite(r, relay.const(3)), relay.RefRead(r))), + relay.Let(y, p.id(relay.Let(z, relay.const(4), z)), y) + ]) + + tup_val = run_as_python(expr, mod) + assert_tuple_value(tup_val, 3) + assert_tensor_value(tup_val.fields[0], 2) + assert_tensor_value(tup_val.fields[1], 3) + assert_tensor_value(tup_val.fields[2], 4) + + +def test_ref_execution_order(): + # we want to have effects execute from left to right + x = relay.Var('x') + y = relay.Var('y') + f = relay.Var('f') + r = relay.Var('r') + + expr = relay.Let(f, relay.Function([x, y], x), + # r = 1 + relay.Let(r, relay.RefCreate(relay.const(1)), + relay.Tuple([ + # should be 1 + relay.RefRead(r), + # set r to 2 and read back + seq(relay.RefWrite(r, relay.const(2)), + relay.RefRead(r)), + # set r to 3 and read back + seq(relay.RefWrite(r, relay.const(3)), + relay.RefRead(r)), + # set r to 4 and read as first arg to f + # set r to 5 and read as second arg to f + # f should evaluate to 4 + f( + seq(relay.RefWrite(r, relay.const(4)), + relay.RefRead(r)), + seq(relay.RefWrite(r, relay.const(5)), + relay.RefRead(r))), + # read back 5 + relay.RefRead(r) + ]))) + + tup_val = run_as_python(expr) + assert_tuple_value(tup_val, 5) + assert_tensor_value(tup_val.fields[0], 1) + assert_tensor_value(tup_val.fields[1], 2) + assert_tensor_value(tup_val.fields[2], 3) + assert_tensor_value(tup_val.fields[3], 4) + assert_tensor_value(tup_val.fields[4], 5) + + +def test_op_add(): + add = relay.add(relay.const(1), relay.const(2)) + add_val = run_as_python(add) + assert_tensor_value(add_val, 3) + + +# test an op with a tuple input +# adapted from test_stack in test_op_level3 +def test_op_stack(): + def verify_stack(dshapes, axis): + x_data = [np.random.normal(size=shape).astype('int32') for shape in dshapes] + ref_res = np.stack(x_data, axis=axis) + + args = [] + for data in x_data: + args.append(relay.const(data)) + call = relay.stack(relay.Tuple(args), axis) + call_val = run_as_python(call) + assert_tensor_value(call_val, ref_res) + + verify_stack([(2,), (2,), (2,)], -1) + verify_stack([(2,), (2,), (2,)], 0) + verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1) + verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) + + +# test an op with a tuple output +# adapted from test_split_infer_type in test_op_level3 +# and test_split in nnvm's test_top_level1 +def test_split(): + def verify_split(shape, indices_or_sections, axis=0): + x = np.random.normal(size=shape).astype('float32') + ref_res = np.split(x, indices_or_sections, axis=axis) + call = relay.split(relay.const(x), indices_or_sections, axis=axis) + call_val = run_as_python(call) + assert_tuple_value(call_val, len(ref_res)) + for i in range(len(ref_res)): + assert_tensor_value(call_val.fields[i], ref_res[i]) + + verify_split((2, 3), 2) + verify_split((5, 3), [3]) + verify_split((5, 9, 3), [3, 4], 1) + verify_split((5, 5, 2, 2), 5, 1) + verify_split((5, 5, 2, 2), 5, 0) + + +# ensure we can generate code for batch_norm, since it requires simplify_inference +# adapted from test_batchnorm in nnvm's test_top_level1 +def test_batch_norm(): + def verify_batch_norm(shapes): + data = [np.absolute(np.random.normal(size=shape).astype('float32')) + for shape in shapes] + relay_args = [relay.const(arg) for arg in data] + + eps = 1e-5 + def reference(x, gamma, beta, moving_mean, moving_var): + return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta + ref_res = reference(*data) + + call = relay.nn.batch_norm(*relay_args, epsilon=eps)[0] + call_val = run_as_python(call) + + # there will be a change in accuracy so we need to check + # approximate equality + assert isinstance(call_val, TensorValue) + tvm.testing.assert_allclose(call_val.asnumpy(), ref_res, atol=eps, rtol=eps) + + verify_batch_norm([(10, 20), (20,), (20,), (20,), (20,)]) + verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)]) + verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)]) + verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])