diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py index 5267731f4f521..d43217ca5dfc7 100644 --- a/python/tvm/hybrid/api.py +++ b/python/tvm/hybrid/api.py @@ -24,17 +24,15 @@ def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types if _is_tvm_arg_types(args): src = _pruned_source(func) - parser = parse_python(src, args) + parser = parse_python(src, func.__globals__, args) input_tensors = [] for i in args: if isinstance(i, Tensor): input_tensors.append(i) - op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, parser.outputs, parser.parsed_body) res = [op.output(i) for i in range(len(parser.outputs))] - return res[0] if len(res) == 1 else res intersect = _enter_hybrid_runtime(func) diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py new file mode 100644 index 0000000000000..730b56f58bd2d --- /dev/null +++ b/python/tvm/hybrid/calls.py @@ -0,0 +1,92 @@ +"""Intrinsics of TVM-Python Hybrid Script for Python compilation time +semantic support.""" + +from .. import api as _api +from .. import expr as _expr +from .. import make as _make +from ..container import Array +from .. import ir_pass +from ..stmt import For +from .util import _internal_assert + +#pylint: disable=redefined-builtin + +LOOP_INTRIN = { + 'range' : For.Serial, + 'unroll' : For.Unrolled, + 'parallel' : For.Parallel, + 'vectorize': For.Vectorized, +} + +def _range(annotation, args): + """Handling TVM loop types""" + n = len(args) + if n == 1: + low, ext = _api.const(0, dtype='int32'), args[0] + else: + _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") + low, ext = args[0], args[1] + if not ir_pass.Equal(low, _api.const(0, dtype='int32')): + ext = ext - low + for_type = LOOP_INTRIN[annotation] + iter_var = None + return iter_var, low, ext, for_type + + +range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name + + +def bind(func_id, args): + """Handling TVM thread binding""" + _internal_assert(func_id == "bind", "This function cannot be directly invoked!") + _internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!") + _internal_assert(isinstance(args[0], str), \ + "A loop bind's first argument should be a string!") + iter_var = _api.thread_axis(args[0]) + low, ext = _api.const(0), args[1] + for_type = None + return iter_var, low, ext, for_type + + +def _math_intrin(func_id, args): + from .. import intrin + return getattr(intrin, func_id)(*args) + +sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name + + +def _min_max(func_id, args): + _internal_assert(len(args) == 2, "Max/Min function should have 2 elements") + return getattr(_make, func_id.title())(args[0], args[1]) + + +min = max = _min_max #pylint: disable=invalid-name + + +def _allocate_tensor(func_id, args): + """Handling TVM tensor allocation. + You may refer hybrid.intrin.allocate for more details.""" + n = len(args) + _internal_assert(isinstance(_api.convert(args[0]), Array), \ + "allocate's first argument should be a tuple of shape!") + shape = args[0] + for i in shape: + _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression") + if n > 1: + _internal_assert(isinstance(args[1], str), + "The data type should be an str") + _internal_assert(args[1].startswith('int') or args[1].startswith('float'), \ + "The data type should be either int or float!") + dtype = args[1] + else: + dtype = 'float32' + if n > 2: + _internal_assert(isinstance(args[2], str), \ + "The data scope should be an string") + _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope") + scope = args[2] + else: + scope = 'global' if func_id != 'output_tensor' else 'output' + return (shape, dtype, scope) + +output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py index 92e259585b7a7..48e92a8bf5acc 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -1,7 +1,6 @@ -"""Intrinsics of TVM-Python Hybrid Script for Python runtime""" +"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime""" import numpy -from ..stmt import For class _range(object): """Base class of the loop ranges in hybrid script""" @@ -102,15 +101,3 @@ def sigmoid(x): 'sigmoid' : sigmoid, 'popcount' : popcount } - - -LOOP_INTRIN = { - 'range' : For.Serial, - 'unroll' : For.Unrolled, - 'parallel' : For.Parallel, - 'vectorize': For.Vectorized, - 'bind' : None -} - - -MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount'] diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index ba10dd8dde3c6..26b0e141d0dbd 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -4,24 +4,24 @@ import operator import logging import sys -from .util import make_nop, halide_imm_types, is_docstring, _internal_assert -from .intrin import LOOP_INTRIN, MATH_INTRIN +from .util import _internal_assert +from . import calls +from . import util from .var_decl import determine_variable_usage -from ..api import thread_axis from ..api import all as _all from ..api import any as _any +from ..tensor import Tensor, Operation from .. import expr as _expr from .. import make as _make -from .. import intrin from .. import api as _api from .. import ir_pass as _ir_pass def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" - lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)] - lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] + lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] + lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] if not lst: - return make_nop() + return util.make_nop() if len(lst) == 1: return lst[0] body = lst[0] @@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor): } - def __init__(self, args, usage, func_name=None): + def __init__(self, args, usage, symbols, func_name=None): """ Parameters ---------- @@ -81,32 +81,49 @@ def __init__(self, args, usage, func_name=None): self.args = list(args) self.usage = usage.copy() self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) - self.alloc_buffers = {} # Buffers formed by allocate instructions + self.alloc_buffers = {} # Buffers formed by explicit allocate instructions self.loops_above = {} # State variable that indicates loop levels above the current node - self.var_consts = {} # Variables that are determined as readonly in previous stage + self.variables = {} # The status of defined variables self.func_name = func_name # The name of the function to be lowered self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body - self.returned = False + self.returned = False # If this function has a valid return + self.symbols = symbols # The global context def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" + pop_buf = [] + pop_var = [] for key, val in self.usage.items(): - if key in self.var_consts.keys(): - continue _, level, _ = val - if level == node: - if key in self._args.keys(): + if level != node: + continue + if key in self._args.keys(): + continue + if key in self.alloc_buffers.keys(): + _buf, _scope = self.alloc_buffers[key] + if _scope == 'output': continue - else: - _buf, _scope = self.alloc_buffers[key] - _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] - _dtype = _buf.dtype - _true = _api.convert(True) - body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) - body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) + pop_buf.append(key) + else: + _internal_assert(key in self.variables.keys(), + "Key should be either in one of args, buffers, and vars") + if not isinstance(self.variables[key], tuple): + continue + _buf, _scope = self.variables[key] + pop_var.append(key) + _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] + _dtype = _buf.dtype + _true = _api.convert(True) + body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) + body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) + + for elem in pop_buf: + self.alloc_buffers.pop(elem) + for elem in pop_var: + self.variables.pop(elem) return body @@ -121,7 +138,6 @@ def _get_buffer_from_id(self, s, for_provide=False): return self.alloc_buffers[s][0] - #pylint: disable=invalid-name, missing-docstring def visit_Module(self, node): _internal_assert(len(node.body) == 1, \ @@ -133,13 +149,13 @@ def visit_FunctionDef(self, node): _internal_assert(len(node.args.args) == len(self.args), \ "The number of arguments passed to the \ function should be the same as it is defined!") + if self.func_name is None: + self.func_name = node.name for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self._args[getattr(arg, _attr)] = self.args[idx] res = list_to_block(self.visit, node.body) res = self.wrap_up_realize(node, res) - if self.func_name is None: - self.func_name = node.name return res @@ -148,23 +164,22 @@ def visit_Expr(self, node): def visit_Name(self, node): - _id = node.id - if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)): - return self._args[_id] - elif _id in self.loops_above.keys(): - return self.loops_above[_id] - _internal_assert(_id not in self._args.keys(), \ - "This id %s should be handled in visit_Subscript!" % _id) - _internal_assert(_id in self.usage.keys(), \ - "This id %s is expected to be a defined variable!" % _id) - # Buffer - if _id in self.alloc_buffers.keys(): - _buf, _ = self.alloc_buffers[_id] - return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) - # Compilation time constant - _internal_assert(_id in self.var_consts.keys(), - "This id %s is expected to a compilation time constant!" % _id) - return self.var_consts[_id] + name = node.id + if name in self.loops_above.keys(): + return self.loops_above[name] + elif name in self.variables.keys(): + res = self.variables[name] + if isinstance(res, tuple): + buf = res[0] + if isinstance(node.ctx, ast.Load): + return _make.Call(buf.dtype, buf.name, [_api.const(0)], \ + _expr.Call.Halide, buf.op, buf.value_index) + return buf, [_api.const(0)] + if isinstance(node.ctx, ast.Load): + return res + return None + buf = self._get_buffer_from_id(name) + return buf def visit_Num(self, node): @@ -172,18 +187,36 @@ def visit_Num(self, node): def visit_AugAssign(self, node): - lhs = self.visit(node.target) + buf = self.visit(node.target) rhs = self.visit(node.value) - rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs) - _internal_assert(isinstance(lhs, _expr.Call), \ - "The LHS of an AugAssign is supposed to be a call!") - return _make.Provide(lhs.func, 0, rhs, lhs.args) + if isinstance(buf, tuple): + _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!") + buf, args = buf + else: + args = [_api.const(0)] + _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") + + read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) + value = HybridParser._binop_maker[type(node.op)](read, rhs) + + return _make.Provide(buf.op, 0, value, args) def visit_Assign(self, node): + rhs = self.visit(node.value) + if isinstance(rhs, Operation): + rmap = {} + _internal_assert(len(node.targets) == rhs.num_outputs, \ + "Unable to detuple the outs to targets") + for i in range(rhs.num_outputs): + _internal_assert(isinstance(node.targets[i], ast.Name), + "You should bind a pure name to the tensors") + self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global') + rmap[rhs.outputs[i].op] = rhs.output(i) + return util.replace_io(rhs.body, rmap) + _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] - rhs = self.visit(node.value) if isinstance(rhs, _expr.Expr): rhs = _ir_pass.Simplify(rhs) if isinstance(lhs, ast.Name): @@ -194,65 +227,63 @@ def visit_Assign(self, node): "Loop variable cannot be overwritten!") decl, _, rw = self.usage[lhs] if decl == lhs_: - _internal_assert(lhs not in self.var_consts.keys(), \ - "A constant cannot be overwritten!") - _internal_assert(lhs not in self.alloc_buffers.keys(), \ + _internal_assert(lhs not in self.variables.keys() and + lhs not in self.alloc_buffers.keys(), \ "This value should not be defined before this point!") if isinstance(rhs, tuple): shape, dtype, scope = rhs ph = _api.placeholder(shape, dtype=dtype, name=lhs) - if scope != 'output': - self.alloc_buffers[lhs] = (ph, scope) - else: - self._args[lhs] = ph + self.alloc_buffers[lhs] = (ph, scope) + if scope == 'output': self.outputs.append(lhs) - return make_nop() - if isinstance(rhs, halide_imm_types) and ast.Store not in rw: - self.var_consts[lhs] = rhs + return util.make_nop() + if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: + self.variables[lhs] = rhs else: ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) - self.alloc_buffers[lhs] = (ph, 'global') - if lhs in self.var_consts.keys(): - return make_nop() - _internal_assert(lhs in self.alloc_buffers.keys(), \ - "This variable should be defined before!") - tgt, _ = self.alloc_buffers[lhs] - return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) + self.variables[lhs] = (ph, 'global') + lhs = self.visit(lhs_) + if lhs is not None: + buf, args = lhs + return _make.Provide(buf.op, 0, rhs, args) + return util.make_nop() else: - lhs = self.visit(lhs) - _internal_assert(isinstance(lhs, _expr.Call), \ + lhs, args = self.visit(lhs) + _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") - #TODO: support slice later - buf = self._get_buffer_from_id(lhs.name, for_provide=True) - return _make.Provide(buf.op, 0, rhs, lhs.args) + res = _make.Provide(lhs.op, lhs.value_index, rhs, args) + return res def visit_Index(self, node): if isinstance(node.value, ast.Tuple): - return [self.visit(i) for i in node.value.elts] + return self.visit(node.value) return [self.visit(node.value)] + def visit_Attribute(self, node): + _internal_assert(isinstance(node.value, ast.Name), \ + "For atrribute access, only both names are supported so far!") + buf = self._get_buffer_from_id(node.value.id) + return getattr(buf, node.attr) + + def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): - array = node.value.id - _buf = self._get_buffer_from_id(array) - return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index) - - _internal_assert(isinstance(node.value, ast.Attribute), \ - "Only variable and attribute's subscript supported so far") - _internal_assert(isinstance(node.value.value, ast.Name), \ - "The root of array access is expect to be a id!") - _internal_assert(node.value.attr == "shape", \ - "Attribute access so far only 'shape' is supported!") + buf = self.visit(node.value) + if isinstance(node.ctx, ast.Load): + return _make.Call(buf.dtype, buf.name, args, \ + _expr.Call.Halide, buf.op, buf.value_index) + return buf, args + + shape = self.visit(node.value) _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") args = args[0] #TODO: maybe support non-constant value later? _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ "So far only constant shape access supported!") - buf = self._get_buffer_from_id(node.value.value.id) - return buf.shape[args.value] + return shape[args.value] def visit_With(self, node): @@ -275,7 +306,7 @@ def visit_If(self, node): if node.orelse: else_body = list_to_block(self.visit, node.orelse) else: - else_body = make_nop() + else_body = util.make_nop() return _make.IfThenElse(cond, if_body, else_body) @@ -305,13 +336,10 @@ def visit_BoolOp(self, node): _internal_assert(isinstance(node.op, ast.Not), \ "Unary is supposed to be not!") return operator.not_(self.visit(node.values[0])) - elif n == 2: - _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \ - "Binary is supposed to be and/or!") - values = [self.visit(i) for i in node.values] - return HybridParser._binop_maker[type(node.op)](*values) - else: - raise ValueError("This Bool Op is not supported yet!") + _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \ + "Binary is supposed to be and/or!") + values = [self.visit(i) for i in node.values] + return HybridParser._binop_maker[type(node.op)](*values) def visit_UnaryOp(self, node): @@ -329,67 +357,17 @@ def visit_Call(self, node): # Yet, no function pointer supported _internal_assert(isinstance(node.func, ast.Name), \ "Only id-function function call is supported so far!") + func_id = node.func.id - n = len(node.args) - if func_id in LOOP_INTRIN.keys() and func_id != 'bind': - if n == 1: - low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) - else: - _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") - low, ext = self.visit(node.args[0]), self.visit(node.args[1]) - if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): - ext = ext - low - for_type = LOOP_INTRIN[func_id] - iter_var = None - return iter_var, low, ext, for_type - elif func_id == 'bind': - _internal_assert(n == 2, "A loop bind should only have 2 arguments!") - _internal_assert(isinstance(node.args[0], ast.Str), \ - "A loop bind's first argument should be a string!") - _vn = node.args[0].s - iter_var = thread_axis(node.args[0].s) - low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) - for_type = None - return iter_var, low, ext, for_type - elif func_id in MATH_INTRIN: - return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) - elif func_id in ['allocate', 'output_tensor']: - _internal_assert(isinstance(node.args[0], ast.Tuple), \ - "allocate's first argument should be a tuple of shape!") - shape = tuple(self.visit(i) for i in node.args[0].elts) - if func_id == 'output_tensor': - _internal_assert(not self.loops_above, \ - "Are you sure to allocate a output buffer multiple times?") - for i in shape: - _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression") - if n > 1: - if isinstance(node.args[1], ast.Str): - dtype = node.args[1].s - else: - _internal_assert(isinstance(node.args[1], ast.Attribute), \ - "Unable to evaluate to get data type") - to_eval = node.args[1] - _internal_assert(isinstance(to_eval.value, ast.Name), \ - "Unable to evaluate the attribute to get data type") - _internal_assert(to_eval.attr == 'dtype', \ - "Only dtype attribute is supported so far") - dtype = self._get_buffer_from_id(to_eval.value.id).dtype - else: - dtype = 'float32' - if n > 2: - _internal_assert(isinstance(node.args[2], ast.Str), \ - "The data scope should be an string") - _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope") - scope = node.args[2].s - else: - scope = 'global' if func_id != 'output_tensor' else 'output' - return (shape, dtype, scope) - elif func_id == 'max' or func_id == 'min': - _internal_assert(n == 2, "Max/Min function should have 2 elements") - a, b = self.visit(node.args[0]), self.visit(node.args[1]) - return getattr(_make, func_id.title())(a, b) - else: - raise ValueError("Function call not supported yet!") + args = [self.visit(i) for i in node.args] + try: + return getattr(calls, func_id)(func_id, args) + except AttributeError: + _internal_assert(func_id in self.symbols.keys(), \ + "The function called is not in the context either!") + outs = self.symbols[func_id](*args) + op = outs.op if isinstance(outs, Tensor) else outs[0].op + return op def visit_For(self, node): @@ -400,7 +378,7 @@ def visit_For(self, node): if iter_var is None: _internal_assert(for_type is not None, "The loop bind function parse error!") offset = iter_var = _api.var(_name) - if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): + if not _ir_pass.Equal(low, _api.const(0)): offset = iter_var + low self.loops_above[_name] = offset else: @@ -411,7 +389,7 @@ def visit_For(self, node): if for_type is None: res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) else: - res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body) + res = _make.For(iter_var, _api.const(0), ext, for_type, 0, _body) self.loops_above.pop(_name) return res @@ -428,14 +406,22 @@ def visit_Return(self, node): _internal_assert(isinstance(i, ast.Name), "What do you return?") ids.append(i.id) _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") - if len(ids) != len(self.outputs): + if len(ids) < len(self.outputs): logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') - self.outputs = [self._args[i] for i in ids] + self.outputs = [self.alloc_buffers[i][0] for i in ids] self.returned = True - return make_nop() + return util.make_nop() + + + def visit_Tuple(self, node): + return tuple(self.visit(i) for i in node.elts) -def parse_python(src, args): + def visit_Str(self, node): + return node.s + + +def parse_python(src, symbols, args): """The helper function of calling the AST visitor Parameters @@ -443,6 +429,9 @@ def parse_python(src, args): src : str The source code of the function to be parsed. + src : str + The symbol list of the global context of the function. + args : list of Tensors or Vars The argument lists to the function. It is NOT encouraged to write a function without arguments. @@ -454,8 +443,8 @@ def parse_python(src, args): The result Halide IR and the parser class instance. """ root = ast.parse(src) - var_usage = determine_variable_usage(root, args) - parser = HybridParser(args, var_usage) + var_usage = determine_variable_usage(root, args, symbols) + parser = HybridParser(args, var_usage, symbols) parser.parsed_body = parser.visit(root) _internal_assert(parser.returned, 'No valid return found in the function body!') return parser diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 78106838f13e6..aa86d55a6fcf1 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -10,6 +10,7 @@ from .. import api as _api from .. import make as _make from .. import expr as _expr +from .. import stmt as _stmt from ..tensor import Tensor @@ -86,3 +87,20 @@ def _restore_runtime(func, intersect): _globals.pop(elem) for k, v in intersect: _globals[k] = v + + +def replace_io(body, rmap): + """Replacing tensors usage according to the dict given""" + from .. import ir_pass + + def replace(op): + if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): + buf = rmap[op.func] + return _make.Provide(buf.op, op.value_index, op.value, op.args) + elif isinstance(op, _expr.Call) and op.func in rmap.keys(): + buf = rmap[op.func] + return _make.Call(buf.dtype, buf.name, op.args, \ + _expr.Call.Halide, buf.op, buf.value_index) + return None + + return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index 27df878743771..eb893a7f22a1d 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor): """The vistor class to determine the declaration, r/w status, and last use of each variable""" #pylint: disable=invalid-name #pylint: disable=missing-docstring - def __init__(self, args): + def __init__(self, args, symbols): self.status = {} self.scope_level = [] self._args = {} self.args = args self.aug_assign_ = False + self.symbols = symbols def visit_FunctionDef(self, node): @@ -43,8 +44,10 @@ def visit_Call(self, node): #No function pointer supported so far _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id") func_id = node.func.id - _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \ - "Function call id not in intrinsics' list") + _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \ + ['range', 'max', 'min'] + \ + list(self.symbols.keys()), \ + "Function call id not in intrinsics' list") for elem in node.args: self.visit(elem) @@ -75,11 +78,13 @@ def visit_Name(self, node): else: decl, loop, usage = self.status[node.id] usage.add(type(node.ctx)) + _internal_assert(loop in self.scope_level, + "%s is used out of the scope it is defined!" % node.id) self.status[node.id] = (decl, loop, usage) -def determine_variable_usage(root, args): +def determine_variable_usage(root, args, symbols): """The helper function for calling the dedicated visitor.""" - visitor = PyVariableUsage(args) + visitor = PyVariableUsage(args, symbols) visitor.visit(root) return visitor.status diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 7efbbe43ee212..f87c75f7929d9 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -270,7 +270,7 @@ def test_bind(): return @script def vec_add(a, b): - c = output_tensor((1000, ), dtype='float32') + c = output_tensor((1000, ), 'float32') for tx in bind('threadIdx.x', 1000): c[tx] = a[tx] + b[tx] return c @@ -506,7 +506,37 @@ def kernel_b(b, a): module(tvm.ndarray.array(np_a), res) tvm.testing.assert_allclose(res.asnumpy(), ref) +def test_func_call(): + @tvm.hybrid.script + def foo(a, b): + for i in range(10): + a[i] = i + 1.0 + for i in range(10): + b[i] = i + 1.0 + c = outer_product(10, 10, a, b) + d = output_tensor(c.shape, c.dtype) + for i in range(10): + for j in range(10): + d[i, j] = c[i, j] + i * j + return d + a = tvm.placeholder((10, ), name='a') + b = tvm.placeholder((10, ), name='b') + run_and_check(foo, [a, b]) + +def test_bool(): + @tvm.hybrid.script + def foo(a): + b = output_tensor(a.shape, a.dtype) + b[0] = 1.2 + for i in range(1, a.shape[0] - 1): + if a[i] * a[i - 1] < a[i] or a[i] * a[i - 1] < a[i - 1] or i * a[i] == a[i]: + b[i] = a[i] + else: + b[i] = 0.0 + return b + a = tvm.placeholder((10, ), name='a') + run_and_check(foo, [a]) if __name__ == "__main__": test_outer_product() @@ -521,7 +551,7 @@ def kernel_b(b, a): test_downstream() test_const_param() test_value_index() + test_func_call() + test_bool() # TODO: # test_inplace() - -