diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py index 7aca007ecd94..11ecbc8f7b60 100644 --- a/python/tvm/hybrid/__init__.py +++ b/python/tvm/hybrid/__init__.py @@ -31,6 +31,8 @@ from __future__ import absolute_import as _abs +import inspect + from .._ffi.base import decorate from .._ffi.function import _init_api from ..build_module import form_body @@ -55,7 +57,9 @@ def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring from .util import _is_tvm_arg_types if _is_tvm_arg_types(args): src = _pruned_source(func) - return source_to_op(src, func.__globals__, args) + closure_vars = inspect.getclosurevars(func).nonlocals + closure_vars.update(inspect.getclosurevars(func).globals) + return source_to_op(src, args, func.__globals__, closure_vars) from .runtime import _enter_hybrid_runtime, _restore_runtime intersect = _enter_hybrid_runtime(func) diff --git a/python/tvm/hybrid/module.py b/python/tvm/hybrid/module.py index 297dd0b9941a..13e45a7516fa 100644 --- a/python/tvm/hybrid/module.py +++ b/python/tvm/hybrid/module.py @@ -62,7 +62,7 @@ def __init__(self, src=None, name=None): def __call__(self, *args): if _is_tvm_arg_types(args): - return source_to_op(self.root_, globals(), args) + return source_to_op(self.root_, args, globals(), {}) return self.func_(*args) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1c1525e11be8..40ea1714fc35 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -25,7 +25,7 @@ from enum import Enum -from .util import _internal_assert +from .util import _internal_assert, _apply_indices from . import calls from . import util from .preprocessor import determine_variable_usage @@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor): } - def __init__(self, args, usage, symbols, func_name=None): + def __init__(self, args, usage, symbols, closure_vars, func_name=None): """ Parameters ---------- @@ -122,6 +122,12 @@ def __init__(self, args, usage, symbols, func_name=None): usage: A dict of variables used in last in this function Provided by last lower pass, which collects this information + symbols : list of str + The symbol list of the global context of the function. + + closure_vars: dict + A dict of external name reference captured by this function. + Returns ------- func_name: str @@ -136,6 +142,8 @@ def __init__(self, args, usage, symbols, func_name=None): if isinstance(v, types.FunctionType): self.add_symbol(k, Symbol.Callable, v) + self.closure_vars = closure_vars + self.binds = {} # Thread binds self.device = 0 # Is it generating device @@ -236,7 +244,11 @@ def visit_Expr(self, node): def visit_Name(self, node): name = node.id if sys.version_info[0] == 2 and name in ['True', 'False']: - return _api.convert(eval(name)) #pylint: disable=eval-used + return _api.convert(ast.literal_eval(name)) + + if name in self.closure_vars: + return _api.convert(self.closure_vars[name]) + ty, entry = self.symbols[name] _internal_assert(name in self.symbols, "Unknown symbol %s!" % name) if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: @@ -356,10 +368,12 @@ def visit_Attribute(self, node): buf = self.visit(node.value) return getattr(buf, node.attr) - def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): + if node.value.id in self.closure_vars: + args = ast.literal_eval(str(args)) + return _api.convert(_apply_indices(self.closure_vars[node.value.id], args)) buf = self.visit(node.value) if isinstance(buf, Array): @@ -576,7 +590,7 @@ def visit_Assert(self, node): return _make.AssertStmt(test, mesg, util.make_nop()) -def parse_python(src, symbols, args): +def parse_python(src, args, symbols, closure_vars): """The helper function of calling the AST visitor Parameters @@ -585,14 +599,17 @@ def parse_python(src, symbols, args): If an ast.node, then directly lower it. If a str, then parse it to ast and lower it. - symbols : str - The symbol list of the global context of the function. - args : list of Tensors or Vars The argument lists to the function. It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function with side effect. + symbols : list of str + The symbol list of the global context of the function. + + closure_vars: dict + A dict of external name reference captured by this function. + Returns ------- root : Stmt @@ -600,14 +617,14 @@ def parse_python(src, symbols, args): """ root = ast.parse(src) if isinstance(src, str) else src _internal_assert(root, ast.AST) - var_usage = determine_variable_usage(root, args, symbols) - parser = HybridParser(args, var_usage, symbols) + var_usage = determine_variable_usage(root, args, symbols, closure_vars) + parser = HybridParser(args, var_usage, symbols, closure_vars) parser.parsed_body = parser.visit(root) _internal_assert(parser.returned, 'No valid return found in the function body!') return parser -def source_to_op(src, symbols, args): +def source_to_op(src, args, symbols, closure_vars): """Another level of wrapper Parameters @@ -616,20 +633,23 @@ def source_to_op(src, symbols, args): If an ast.node, then directly lower it. If a str, then parse it to ast and lower it. - symbols : str - The symbol list of the global context of the function. - args : list of Tensors or Vars The argument lists to the function. It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function with side effect. + symbols : list of str + The symbol list of the global context of the function. + + closure_vars: dict + A dict of external name reference captured by this function. + Returns ------- res : list of output tensors The result of output tensors of the formed OpNode. """ - parser = parse_python(src, symbols, args) + parser = parse_python(src, args, symbols, closure_vars) input_tensors = [] for i in args: diff --git a/python/tvm/hybrid/preprocessor.py b/python/tvm/hybrid/preprocessor.py index 117ebd3091ed..1a9de4e3f801 100644 --- a/python/tvm/hybrid/preprocessor.py +++ b/python/tvm/hybrid/preprocessor.py @@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor): """The vistor class to determine the declaration, r/w status, and last use of each variable""" #pylint: disable=invalid-name #pylint: disable=missing-docstring - def __init__(self, args, symbols): + def __init__(self, args, symbols, closure_vars): self.status = {} self.scope_level = [] self._args = {} self.args = args self.aug_assign_ = False self.symbols = symbols - + self.closure_vars = closure_vars def visit_FunctionDef(self, node): self.scope_level.append(node) @@ -89,6 +89,14 @@ def visit_Name(self, node): "Iter var cannot be overwritten") if node.id not in self.status.keys(): + # It is a captured value in closure + if node.id in self.closure_vars: + try: + ast.literal_eval(str(self.closure_vars[node.id])) + except ValueError: + raise ValueError("Only support capturing constant values in closure") + return + _internal_assert(isinstance(node.ctx, ast.Store), \ 'Undeclared variable %s' % node.id) if self.aug_assign_: @@ -102,8 +110,8 @@ def visit_Name(self, node): self.status[node.id] = (decl, loop, usage) -def determine_variable_usage(root, args, symbols): +def determine_variable_usage(root, args, symbols, closure_vars): """The helper function for calling the dedicated visitor.""" - visitor = PyVariableUsage(args, symbols) + visitor = PyVariableUsage(args, symbols, closure_vars) visitor.visit(root) return visitor.status diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa141329..058c5aa30af7 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -101,3 +101,9 @@ def _is_tvm_arg_types(args): _internal_assert(isinstance(elem, np_arg_types), \ "Expect a numpy type but %s get!" % str(type(elem))) return False + +def _apply_indices(value, indices): + """Apply multidimensional index""" + if indices: + return _apply_indices(value[indices[0]], indices[1:]) + return value diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 254264662fdc..805cff8f5d15 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -768,6 +768,24 @@ def outer_product(a, b): # Test loop binds +def test_capture(): + n = 8 + + constant_tuple = (10, n) + constant_list = [[1, 2], [3, n]] + const_value = 1 + + @tvm.hybrid.script + def add_something(a): + c = output_tensor((constant_tuple[1],), 'int32') + for i in range(constant_tuple[1]): + c[i] = a[i] + constant_list[1][const_value] + return c + + a = tvm.placeholder((n, ), dtype='int32', name='a') + + func, ins, outs = run_and_check(add_something, [a]) + run_and_check(func, ins, outs=outs) if __name__ == "__main__": test_outer_product() @@ -786,5 +804,6 @@ def outer_product(a, b): test_bool() test_const_range() test_schedule() + test_capture() # TODO: # test_inplace()