Skip to content

Commit

Permalink
[testing][py_converter] Enhance py_converter to better support entire…
Browse files Browse the repository at this point in the history
… modules (#13769)

This PR makes a few improvements to py_converter to make it more useful for fuzz testing, especially when running larger modules.

In particular, these changes are to support returning the definition of a global var directly (e.g., if you do run_as_python(main_var, mod=mod), the result will be a function corresponding to mod["main"]) and to correct two bugs in the previous implementation:

Previously, it was not possible to insert a function into a runtime container object like an ADT. This was because the converter was simply compiling Relay functions into Python functions. This change solves this problem by registering the functions into PackedFuncs. However, another fix was also needed: Even though PackedFunc is an ObjectRef in C++, the Python bindings do not recognize PackedFuncs as Objects, so the code now calls the FFI API tuple constructor directly.
The implementation relied on IRModule.from_expr to wrap passed in expressions in a module. However, from_expr will not overwrite the main function if one is passed in via the functions argument. Thus, if the user passed in a module that already had a main function defined, the wrapping would be done incorrectly and result in the main being copied many times. This PR corrects this error by not assuming that the name main will be available and instead constructing a new module with a reserved name for the target.
None of the cases above had been tested before (there are now tests included)
  • Loading branch information
slyubomirsky authored Feb 15, 2023
1 parent f6cebb5 commit e516eaa
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 20 deletions.
83 changes: 65 additions & 18 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# import tvm
# from tvm import relay
# from tvm import nd
# from tvm.runtime import import container as _container
# from tvm.runtime import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias("numpy", None)]),
Expand All @@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor):
def __init__(self, mod, target) -> None:
super().__init__()
self.mod = mod
self.tgt = target
self.tgt = target if isinstance(target, tvm.target.Target) else tvm.target.Target(target)
self.tec = te_compiler.get()
self.fun_no = 0
self.var_no = 0
Expand Down Expand Up @@ -98,15 +98,31 @@ def optimize(self, prog: Expr):
# unwrap tuple wrappers (some op calls produce them)
unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
assert relay.analysis.well_formed(unwrapped)
mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions)
# For a lone global var, there is nothing we need to do
if isinstance(unwrapped, relay.GlobalVar):
return unwrapped

# main might be in the mod already and from_expr will not override it if it's there,
# so we need a new name
target_name = self.generate_function_name("target")

wrapped = unwrapped
if not isinstance(unwrapped, relay.Function):
wrapped = relay.Function(relay.analysis.free_vars(unwrapped), unwrapped)

# easiest way to make a deep copy -- note that main will not be overridden if it's present
copy_mod = tvm.IRModule.from_expr(
relay.Tuple([]), self.mod.functions, self.mod.type_definitions
)
copy_mod[target_name] = wrapped

# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = tvm.transform.Sequential(
[relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)]
)
mod = opts(mod)
optimized = mod["main"]
copy_mod = opts(copy_mod)
optimized = copy_mod[target_name]
return optimized if isinstance(unwrapped, Function) else optimized.body

def sanitize(self, name: str) -> str:
Expand Down Expand Up @@ -197,7 +213,7 @@ def convert_func_node(self, func: Function, name_var=None):

var_names = [self.get_var_name(var) for var in func.params]
body, defs = self.visit(func.body)
ret = self.create_def(func_name, var_names, defs + [Return(body)])
ret = self.create_def(func_name, var_names, defs + [Return(body)], register_packed=True)
return (ret, func_name)

def convert_module(self):
Expand All @@ -219,10 +235,25 @@ def create_call(self, func_name: str, arguments):
"""Creates a simple function call."""
return ast.Call(self.parse_name(func_name), arguments, [])

def create_def(self, func_name: str, arguments: [str], body):
"""Wrapper over function definition AST node, whose constructor is inconvenient."""
def create_def(self, func_name: str, arguments: [str], body, register_packed: bool = False):
"""
Wrapper over function definition AST node, whose constructor is inconvenient.
register_packed includes a tvm.register_func decorator on the generated function if true.
This option should be used for Relay functions (warning: clobbers registry!)
"""
inner_args = [ast.arg(argument, None) for argument in arguments]

# add a decorator to register as a PackedFunc so the function will be an ObjectRef
# and will allow for putting functions into tuples or refs
decorator_list = [
ast.Call(
self.parse_name("tvm.register_func"),
[ast.Constant(value=func_name)],
[ast.keyword(arg="override", value=ast.Constant(value=True))],
)
]

global __MAJOR__, __MINOR__
if __MAJOR__ == 3 and __MINOR__ >= 8:
arguments = ast.arguments([], inner_args, None, [], [], None, [])
Expand All @@ -233,10 +264,19 @@ def create_def(self, func_name: str, arguments: [str], body):
func_name,
arguments,
body,
[],
decorator_list if register_packed else [],
None,
)

def create_tuple(self, fields):
"""
Given the ASTs for tuple fields, produce an AST that creates a
tuple value with those fields
"""
# Use the FFI API directly so that PackedFuncs will be correctly converted to ObjectRef.
# Using tvm.runtime.container.tuple_object fails to convert PackedFuncs in Python
return self.create_call("_container._ffi_api.Tuple", fields)

def create_op_call(self, op: Function, relay_args, py_args):
"""Lowers the passed primitive function, registers it in TVM's
global compiler, and produces a call to the lowered function in
Expand Down Expand Up @@ -290,8 +330,7 @@ def convert_output(ret_type):
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
fields = [ast.List(fields, Load())]
return (assignments, extra_args, self.create_call("_container.tuple_object", fields))
return (assignments, extra_args, self.create_tuple(fields))

# create a function to wrap the call of the lowered op and return
# a call to that function
Expand Down Expand Up @@ -418,7 +457,9 @@ def visit_var(self, var: Expr):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
return (Name(str(gvar.name_hint), Load()), [])
func_name = str(gvar.name_hint)
# load in the packed func
return (self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]), [])

def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node produces an expression,
Expand Down Expand Up @@ -456,8 +497,7 @@ def let_thunk(var):

def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
fields = [ast.List(fields, Load())]
return (self.create_call("_container.tuple_object", fields), ret_defs)
return (self.create_tuple(fields), ret_defs)

def visit_tuple_getitem(self, tgi: Expr):
tup, tup_defs = self.visit(tgi.tuple_value)
Expand All @@ -471,7 +511,7 @@ def visit_if(self, if_block: Expr):

# need to get the value out of a NDArray to check the condition
# equvialent to: val.numpy()
cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], [])
cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [], [])
ret = ast.IfExp(cond_check, true_body, false_body)
return (ret, cond_defs + true_defs + false_defs)

Expand All @@ -490,7 +530,11 @@ def visit_constant(self, constant: Expr):
def visit_function(self, func: Expr):
# Python's lambdas are very restrictive, so we do "name" inline functions
converted_func, func_name = self.convert_func_node(func)
return (Name(func_name, Load()), [converted_func])
# load in the PackedFunc
return (
self.create_call("tvm.get_global_func", [ast.Constant(value=func_name)]),
[converted_func],
)

def visit_call(self, call: Expr):
"""For calls, we must distinguish between ordinary functions,
Expand Down Expand Up @@ -546,7 +590,7 @@ def visit_ref_write(self, write: Expr):
+ val_defs
+ [
Assign([ast.Attribute(ref, "value", Store())], val),
Return(self.create_call("_container.tuple_object", [])),
Return(self.create_tuple([])),
],
)
return (self.create_call(thunk_name, []), [thunk])
Expand Down Expand Up @@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):

def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
"""Converts the given Relay expression into a Python script and
executes it."""
executes it.
Note that closures will be returned as PackedFuncs
"""
mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, "<string>", "exec")
Expand Down
1 change: 1 addition & 0 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,6 @@ TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh
ICHECK_LT(idx, shape.size());
return shape[idx];
});

} // namespace runtime
} // namespace tvm
65 changes: 63 additions & 2 deletions tests/python/relay/test_py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.testing import to_python, run_as_python
from tvm.relay.testing import run_as_python
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
Expand Down Expand Up @@ -70,7 +70,6 @@ def test_create_empty_tuple():
def test_create_scalar():
scalar = relay.const(1)
tensor_val = run_as_python(scalar)
print(type(tensor_val))
assert_tensor_value(tensor_val, 1)


Expand Down Expand Up @@ -611,3 +610,65 @@ def reference(x, gamma, beta, moving_mean, moving_var):
verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])


def test_return_global_var():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
mod = tvm.IRModule()
mod["main"] = identity
main_var = mod.get_global_var("main")
main_func = run_as_python(main_var, mod=mod)

arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = main_func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_tuple():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
tup = relay.Tuple([identity, identity])
index = relay.TupleGetItem(tup, 0)

func = run_as_python(index)
arg = tvm.nd.array(np.array([0.0], dtype="float32"))
res = func(arg)
assert arg.numpy() == res.numpy()


def test_closure_in_ref():
tt = relay.TensorType([1], "float32")
x = relay.Var("x", type_annotation=tt)
identity = relay.Function([x], x, ret_type=tt)
gv = relay.GlobalVar("id")

r = relay.Var("r")
seq = relay.Let(
r,
relay.RefCreate(gv),
relay.Call(relay.RefRead(r), [relay.const(np.array([0.0], dtype="float32"))]),
)

mod = tvm.IRModule()
mod[gv] = identity
res = run_as_python(seq, mod=mod)
assert res.numpy() == np.array([0.0], dtype="float32")


def test_compiling_with_main():
unit_type = relay.TupleType([])
unit = relay.Function([], relay.Tuple([]), ret_type=unit_type)

x = relay.Var("x", type_annotation=unit_type)
identity = relay.Function([x], x, ret_type=unit_type)

mod = tvm.IRModule()
mod["unit"] = unit
mod["main"] = identity

res = run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod)
assert isinstance(res, ADT)
assert len(res) == 0

0 comments on commit e516eaa

Please sign in to comment.