Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[testing][py_converter] Enhance py_converter to better support entire modules #13769

Merged
merged 8 commits into from
Feb 15, 2023
83 changes: 65 additions & 18 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# import tvm
# from tvm import relay
# from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.runtime import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias("numpy", None)]),
Expand All @@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor):
def __init__(self, mod, target) -> None:
super().__init__()
self.mod = mod
self.tgt = target
self.tgt = target if isinstance(target, tvm.target.Target) else tvm.target.Target(target)
self.tec = te_compiler.get()
self.fun_no = 0
self.var_no = 0
Expand Down Expand Up @@ -98,15 +98,31 @@ def optimize(self, prog: Expr):
# unwrap tuple wrappers (some op calls produce them)
unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
assert relay.analysis.well_formed(unwrapped)
mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions)
# For a lone global var, there is nothing we need to do
if isinstance(unwrapped, relay.GlobalVar):
return unwrapped

# main might be in the mod already and from_expr will not override it if it's there,
# so we need a new name
target_name = self.generate_function_name("target")

wrapped = unwrapped
if not isinstance(unwrapped, relay.Function):
wrapped = relay.Function(relay.analysis.free_vars(unwrapped), unwrapped)

# easiest way to make a deep copy -- note that main will not be overridden if it's present
copy_mod = tvm.IRModule.from_expr(
relay.Tuple([]), self.mod.functions, self.mod.type_definitions
)
copy_mod[target_name] = wrapped

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = tvm.transform.Sequential(
[relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)]
)
mod = opts(mod)
optimized = mod["main"]
copy_mod = opts(copy_mod)
optimized = copy_mod[target_name]
return optimized if isinstance(unwrapped, Function) else optimized.body

def sanitize(self, name: str) -> str:
Expand Down Expand Up @@ -197,7 +213,7 @@ def convert_func_node(self, func: Function, name_var=None):

var_names = [self.get_var_name(var) for var in func.params]
body, defs = self.visit(func.body)
ret = self.create_def(func_name, var_names, defs + [Return(body)])
ret = self.create_def(func_name, var_names, defs + [Return(body)], register_packed=True)
return (ret, func_name)

def convert_module(self):
Expand All @@ -219,10 +235,25 @@ def create_call(self, func_name: str, arguments):
"""Creates a simple function call."""
return ast.Call(self.parse_name(func_name), arguments, [])

def create_def(self, func_name: str, arguments: [str], body):
"""Wrapper over function definition AST node, whose constructor is inconvenient."""
def create_def(self, func_name: str, arguments: [str], body, register_packed: bool = False):
"""
Wrapper over function definition AST node, whose constructor is inconvenient.

register_packed includes a tvm.register_func decorator on the generated function if true.
This option should be used for Relay functions (warning: clobbers registry!)
"""
inner_args = [ast.arg(argument, None) for argument in arguments]

# add a decorator to register as a PackedFunc so the function will be an ObjectRef
# and will allow for putting functions into tuples or refs
decorator_list = [
ast.Call(
self.parse_name("tvm.register_func"),
[ast.Constant(value=func_name)],
[ast.keyword(arg="override", value=ast.Constant(value=True))],
)
]

global __MAJOR__, __MINOR__
if __MAJOR__ == 3 and __MINOR__ >= 8:
arguments = ast.arguments([], inner_args, None, [], [], None, [])
Expand All @@ -233,10 +264,19 @@ def create_def(self, func_name: str, arguments: [str], body):
func_name,
arguments,
body,
[],
decorator_list if register_packed else [],
None,
)

def create_tuple(self, fields):
"""
Given the ASTs for tuple fields, produce an AST that creates a
tuple value with those fields
"""
# Use the FFI API directly so that PackedFuncs will be correctly converted to ObjectRef.
# Using tvm.runtime.container.tuple_object fails to convert PackedFuncs in Python
return self.create_call("_container._ffi_api.Tuple", fields)

def create_op_call(self, op: Function, relay_args, py_args):
"""Lowers the passed primitive function, registers it in TVM's
global compiler, and produces a call to the lowered function in
Expand Down Expand Up @@ -290,8 +330,7 @@ def convert_output(ret_type):
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
fields = [ast.List(fields, Load())]
return (assignments, extra_args, self.create_call("_container.tuple_object", fields))
return (assignments, extra_args, self.create_tuple(fields))

# create a function to wrap the call of the lowered op and return
# a call to that function
Expand Down Expand Up @@ -418,7 +457,9 @@ def visit_var(self, var: Expr):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
return (Name(str(gvar.name_hint), Load()), [])
func_name = str(gvar.name_hint)
# load in the packed func
return (self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]), [])

def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node produces an expression,
Expand Down Expand Up @@ -456,8 +497,7 @@ def let_thunk(var):

def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
fields = [ast.List(fields, Load())]
return (self.create_call("_container.tuple_object", fields), ret_defs)
return (self.create_tuple(fields), ret_defs)

def visit_tuple_getitem(self, tgi: Expr):
tup, tup_defs = self.visit(tgi.tuple_value)
Expand All @@ -471,7 +511,7 @@ def visit_if(self, if_block: Expr):

# need to get the value out of a NDArray to check the condition
# equvialent to: val.numpy()
cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], [])
cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [], [])
ret = ast.IfExp(cond_check, true_body, false_body)
return (ret, cond_defs + true_defs + false_defs)

Expand All @@ -490,7 +530,11 @@ def visit_constant(self, constant: Expr):
def visit_function(self, func: Expr):
# Python's lambdas are very restrictive, so we do "name" inline functions
converted_func, func_name = self.convert_func_node(func)
return (Name(func_name, Load()), [converted_func])
# load in the PackedFunc
return (
self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]),
[converted_func],
)

def visit_call(self, call: Expr):
"""For calls, we must distinguish between ordinary functions,
Expand Down Expand Up @@ -546,7 +590,7 @@ def visit_ref_write(self, write: Expr):
+ val_defs
+ [
Assign([ast.Attribute(ref, "value", Store())], val),
Return(self.create_call("_container.tuple_object", [])),
Return(self.create_tuple([])),
],
)
return (self.create_call(thunk_name, []), [thunk])
Expand Down Expand Up @@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):

def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
"""Converts the given Relay expression into a Python script and
executes it."""
executes it.

Note that closures will be returned as PackedFuncs
"""
mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, "<string>", "exec")
Expand Down
1 change: 1 addition & 0 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,6 @@ TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh
ICHECK_LT(idx, shape.size());
return shape[idx];
});

} // namespace runtime
} // namespace tvm
65 changes: 63 additions & 2 deletions tests/python/relay/test_py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.testing import run_as_python
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
Expand Down Expand Up @@ -70,7 +70,6 @@ def test_create_empty_tuple():
def test_create_scalar():
scalar = relay.const(1)
tensor_val = run_as_python(scalar)
print(type(tensor_val))
assert_tensor_value(tensor_val, 1)


Expand Down Expand Up @@ -611,3 +610,65 @@ def reference(x, gamma, beta, moving_mean, moving_var):
verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])


def test_return_global_var():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
mod = tvm.IRModule()
mod["main"] = identity
main_var = mod.get_global_var("main")
main_func = run_as_python(main_var, mod=mod)

arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = main_func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_tuple():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
tup = relay.Tuple([identity, identity])
index = relay.TupleGetItem(tup, 0)

func = run_as_python(index)
arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_ref():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
gv = relay.GlobalVar("id")

r = relay.Var("r")
seq = relay.Let(
r,
relay.RefCreate(gv),
relay.Call(relay.RefRead(r), [relay.const(np.array([0.0], dtype="float32"))]),
)

mod = tvm.IRModule()
mod[gv] = identity
res = run_as_python(seq, mod=mod)
assert res.numpy() == np.array([0.0], dtype="float32")


def test_compiling_with_main():
unit_type = relay.TupleType([])
unit = relay.Function([], relay.Tuple([]), ret_type=unit_type)

x = relay.Var("x", type_annotation=unit_type)
identity = relay.Function([x], x, ret_type=unit_type)

mod = tvm.IRModule()
mod["unit"] = unit
mod["main"] = identity

res = run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod)
assert isinstance(res, ADT)
assert len(res) == 0