diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 1ec85faea619..44489aa9cf7d 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -38,7 +38,7 @@ # import tvm # from tvm import relay # from tvm import nd -# from tvm.runtime import import container as _container +# from tvm.runtime import container as _container # from tvm.relay.backend.interpreter import RefValue, ConstructorValue PROLOGUE = [ ast.Import([alias("numpy", None)]), @@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor): def __init__(self, mod, target) -> None: super().__init__() self.mod = mod - self.tgt = target + self.tgt = target if isinstance(target, tvm.target.Target) else tvm.target.Target(target) self.tec = te_compiler.get() self.fun_no = 0 self.var_no = 0 @@ -98,15 +98,31 @@ def optimize(self, prog: Expr): # unwrap tuple wrappers (some op calls produce them) unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog assert relay.analysis.well_formed(unwrapped) - mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions) + # For a lone global var, there is nothing we need to do + if isinstance(unwrapped, relay.GlobalVar): + return unwrapped + + # main might be in the mod already and from_expr will not override it if it's there, + # so we need a new name + target_name = self.generate_function_name("target") + + wrapped = unwrapped + if not isinstance(unwrapped, relay.Function): + wrapped = relay.Function(relay.analysis.free_vars(unwrapped), unwrapped) + + # easiest way to make a deep copy -- note that main will not be overridden if it's present + copy_mod = tvm.IRModule.from_expr( + relay.Tuple([]), self.mod.functions, self.mod.type_definitions + ) + copy_mod[target_name] = wrapped # necessary pass: SimplifyInference (otherwise we can't generate code for some operators) # and fusion (to get primitive functions) opts = tvm.transform.Sequential( [relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)] ) - mod = opts(mod) - optimized = mod["main"] + copy_mod = opts(copy_mod) + optimized = copy_mod[target_name] return optimized if isinstance(unwrapped, Function) else optimized.body def sanitize(self, name: str) -> str: @@ -197,7 +213,7 @@ def convert_func_node(self, func: Function, name_var=None): var_names = [self.get_var_name(var) for var in func.params] body, defs = self.visit(func.body) - ret = self.create_def(func_name, var_names, defs + [Return(body)]) + ret = self.create_def(func_name, var_names, defs + [Return(body)], register_packed=True) return (ret, func_name) def convert_module(self): @@ -219,10 +235,25 @@ def create_call(self, func_name: str, arguments): """Creates a simple function call.""" return ast.Call(self.parse_name(func_name), arguments, []) - def create_def(self, func_name: str, arguments: [str], body): - """Wrapper over function definition AST node, whose constructor is inconvenient.""" + def create_def(self, func_name: str, arguments: [str], body, register_packed: bool = False): + """ + Wrapper over function definition AST node, whose constructor is inconvenient. + + register_packed includes a tvm.register_func decorator on the generated function if true. + This option should be used for Relay functions (warning: clobbers registry!) + """ inner_args = [ast.arg(argument, None) for argument in arguments] + # add a decorator to register as a PackedFunc so the function will be an ObjectRef + # and will allow for putting functions into tuples or refs + decorator_list = [ + ast.Call( + self.parse_name("tvm.register_func"), + [ast.Constant(value=func_name)], + [ast.keyword(arg="override", value=ast.Constant(value=True))], + ) + ] + global __MAJOR__, __MINOR__ if __MAJOR__ == 3 and __MINOR__ >= 8: arguments = ast.arguments([], inner_args, None, [], [], None, []) @@ -233,10 +264,19 @@ def create_def(self, func_name: str, arguments: [str], body): func_name, arguments, body, - [], + decorator_list if register_packed else [], None, ) + def create_tuple(self, fields): + """ + Given the ASTs for tuple fields, produce an AST that creates a + tuple value with those fields + """ + # Use the FFI API directly so that PackedFuncs will be correctly converted to ObjectRef. + # Using tvm.runtime.container.tuple_object fails to convert PackedFuncs in Python + return self.create_call("_container._ffi_api.Tuple", fields) + def create_op_call(self, op: Function, relay_args, py_args): """Lowers the passed primitive function, registers it in TVM's global compiler, and produces a call to the lowered function in @@ -290,8 +330,7 @@ def convert_output(ret_type): assignments += inner_assignments extra_args += inner_args fields.append(inner_output) - fields = [ast.List(fields, Load())] - return (assignments, extra_args, self.create_call("_container.tuple_object", fields)) + return (assignments, extra_args, self.create_tuple(fields)) # create a function to wrap the call of the lowered op and return # a call to that function @@ -418,7 +457,9 @@ def visit_var(self, var: Expr): def visit_global_var(self, gvar: Expr): # we don't need to add numbers to global var names because # the *names* are checked for uniqueness in the mod - return (Name(str(gvar.name_hint), Load()), []) + func_name = str(gvar.name_hint) + # load in the packed func + return (self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]), []) def visit_let(self, letexp: Expr): # To properly account for scoping and ensure that the entire node produces an expression, @@ -456,8 +497,7 @@ def let_thunk(var): def visit_tuple(self, tup: Expr): fields, ret_defs = self.convert_fields(tup.fields) - fields = [ast.List(fields, Load())] - return (self.create_call("_container.tuple_object", fields), ret_defs) + return (self.create_tuple(fields), ret_defs) def visit_tuple_getitem(self, tgi: Expr): tup, tup_defs = self.visit(tgi.tuple_value) @@ -471,7 +511,7 @@ def visit_if(self, if_block: Expr): # need to get the value out of a NDArray to check the condition # equvialent to: val.numpy() - cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], []) + cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [], []) ret = ast.IfExp(cond_check, true_body, false_body) return (ret, cond_defs + true_defs + false_defs) @@ -490,7 +530,11 @@ def visit_constant(self, constant: Expr): def visit_function(self, func: Expr): # Python's lambdas are very restrictive, so we do "name" inline functions converted_func, func_name = self.convert_func_node(func) - return (Name(func_name, Load()), [converted_func]) + # load in the PackedFunc + return ( + self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]), + [converted_func], + ) def visit_call(self, call: Expr): """For calls, we must distinguish between ordinary functions, @@ -546,7 +590,7 @@ def visit_ref_write(self, write: Expr): + val_defs + [ Assign([ast.Attribute(ref, "value", Store())], val), - Return(self.create_call("_container.tuple_object", [])), + Return(self.create_tuple([])), ], ) return (self.create_call(thunk_name, []), [thunk]) @@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): """Converts the given Relay expression into a Python script and - executes it.""" + executes it. + + Note that closures will be returned as PackedFuncs + """ mod = mod if mod is not None else tvm.IRModule() py_ast = to_python(expr, mod, target) code = compile(py_ast, "", "exec") diff --git a/src/runtime/container.cc b/src/runtime/container.cc index adcaecbc64cf..7b5105a3fc94 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -202,5 +202,6 @@ TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh ICHECK_LT(idx, shape.size()); return shape[idx]; }); + } // namespace runtime } // namespace tvm diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index bd5635e8cf09..d43ec5861b10 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -18,7 +18,7 @@ import tvm from tvm import te from tvm import relay -from tvm.relay.testing import to_python, run_as_python +from tvm.relay.testing import run_as_python from tvm.relay.prelude import Prelude from tvm.runtime.container import ADT from tvm.relay.backend.interpreter import RefValue, ConstructorValue @@ -70,7 +70,6 @@ def test_create_empty_tuple(): def test_create_scalar(): scalar = relay.const(1) tensor_val = run_as_python(scalar) - print(type(tensor_val)) assert_tensor_value(tensor_val, 1) @@ -611,3 +610,65 @@ def reference(x, gamma, beta, moving_mean, moving_var): verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)]) verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)]) verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)]) + + +def test_return_global_var(): + tt = relay.TensorType([1], "float32") + x = relay.Var("x", type_annotation=tt) + identity = relay.Function([x], x, ret_type=tt) + mod = tvm.IRModule() + mod["main"] = identity + main_var = mod.get_global_var("main") + main_func = run_as_python(main_var, mod=mod) + + arg = tvm.nd.array(np.array([0.0], dtype="float32")) + res = main_func(arg) + assert arg.numpy() == res.numpy() + + +def test_closure_in_tuple(): + tt = relay.TensorType([1], "float32") + x = relay.Var("x", type_annotation=tt) + identity = relay.Function([x], x, ret_type=tt) + tup = relay.Tuple([identity, identity]) + index = relay.TupleGetItem(tup, 0) + + func = run_as_python(index) + arg = tvm.nd.array(np.array([0.0], dtype="float32")) + res = func(arg) + assert arg.numpy() == res.numpy() + + +def test_closure_in_ref(): + tt = relay.TensorType([1], "float32") + x = relay.Var("x", type_annotation=tt) + identity = relay.Function([x], x, ret_type=tt) + gv = relay.GlobalVar("id") + + r = relay.Var("r") + seq = relay.Let( + r, + relay.RefCreate(gv), + relay.Call(relay.RefRead(r), [relay.const(np.array([0.0], dtype="float32"))]), + ) + + mod = tvm.IRModule() + mod[gv] = identity + res = run_as_python(seq, mod=mod) + assert res.numpy() == np.array([0.0], dtype="float32") + + +def test_compiling_with_main(): + unit_type = relay.TupleType([]) + unit = relay.Function([], relay.Tuple([]), ret_type=unit_type) + + x = relay.Var("x", type_annotation=unit_type) + identity = relay.Function([x], x, ret_type=unit_type) + + mod = tvm.IRModule() + mod["unit"] = unit + mod["main"] = identity + + res = run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod) + assert isinstance(res, ADT) + assert len(res) == 0