From 1e65ab8016f566452a84d2dec59a3d5ba3b10f9f Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Fri, 8 Jun 2018 13:21:46 -0700 Subject: [PATCH 01/31] cleanup a branch I messed up --- docs/dev/hybrid_script.md | 107 +++++++ python/tvm/build_module.py | 16 +- python/tvm/hybrid/__init__.py | 10 + python/tvm/hybrid/_internal.py | 62 +++++ python/tvm/hybrid/_intrin.py | 69 +++++ python/tvm/hybrid/api.py | 43 +++ python/tvm/hybrid/parser.py | 294 ++++++++++++++++++++ python/tvm/hybrid/var_decl.py | 68 +++++ tests/python/unittest/test_hybrid_script.py | 288 +++++++++++++++++++ 9 files changed, 952 insertions(+), 5 deletions(-) create mode 100644 docs/dev/hybrid_script.md create mode 100644 python/tvm/hybrid/__init__.py create mode 100644 python/tvm/hybrid/_internal.py create mode 100644 python/tvm/hybrid/_intrin.py create mode 100644 python/tvm/hybrid/api.py create mode 100644 python/tvm/hybrid/parser.py create mode 100644 python/tvm/hybrid/var_decl.py create mode 100644 tests/python/unittest/test_hybrid_script.py diff --git a/docs/dev/hybrid_script.md b/docs/dev/hybrid_script.md new file mode 100644 index 000000000000..05212b748196 --- /dev/null +++ b/docs/dev/hybrid_script.md @@ -0,0 +1,107 @@ +# Hybrid Frontend Developer Guide + +This hybrid frontend is aimed at: +1. Building IR in a more intuitive way +2. Writing preliminary versions of some idioms that yet have not been supported by + +## Features + +### Software emulation + +This feature supports both software emulation and compilation of the code. + +To define a function, you need to use `tvm.hybrid.script` decorator to indicate this is a hybrid function: +````Python +@tvm.hybrid.script +def outer_product(a, b, c): + for i in range(a.shape[0]): + for j in range(b.shape[0]): + c[i, j] = a[i] * b[j] +a = numpy.random.rand(100) +b = numpy.random.rand(99) +c = numpy.zeros((100, 99)) +outer_product(a, b) +```` +This decorator will help you to import [key words](#keywords) required spontaneously when software emulation. +Every element in the argument list is either a python variable or `numpy` tensor. + +### Backend Compilation + +The current parse interface looks like: +````Python +a = tvm.placeholder((100, ), name='a') +b = tvm.placeholder((99, ), name='b') +c = tvm.placeholder((100, 99), name='c') +tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function +```` +**TODO**: If we pass these tvm tensors to this function, it returns a op node: +````Python +a = tvm.placeholder((100, ), name='a') +b = tvm.placeholder((99, ), name='b') +c = tvm.placeholder((100, 99), name='c') +op = outer_product(a, b, c) # return the corresponding op node +```` +#### Scheduling + +**Under construction, not truly supported yet.** + +Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR: +````Python +sch = tvm.create_schedule(op) +jo, ji = sch.split(j, 4) +sch.vectorize(ji) +```` +`split`, `reorder`, and loop_annotation will be supported! + +### Attributes +So far, ONLY tensors' `shape` attribute is supported! + +### Loops + +In HalideIR, loops have in total 4 types: `serail`, `unrolled`, `parallel`, and `vectorized`. + +Here we use `range`, `serial`, `unroll`, `parallel`, and `vectorize`, these **5** keywords to annotate the types of for loops. + +**NOTE**: In HalideIR those are enums, they are in passive form. Here we use active form to annotate loops, because they are ready to run. + +**NOTE**: Unlike what that is in HalideIR, in `loop_type(a, b)`, `a` is the starting point and `b` is the trip count of iterations. Here `loop_type(a, b)` indicates `[a, b)`. + +### Variables + +Because there is no variables in `HalideIR`, all the mutatable variables will be lowered to an array with size 1. +It takes the first store of a variable as its declaration. +**NOTE**: Unlike conventional Python, the declared array can only be used in the scope level it is declared. +````Python +for i in range(5): + sum = 0 + for j in range(5): + sum += a[i, j] #do something with sum + b[i] = sum #you can still use sum in this level +#you can NEVER use some here, even though it is allowed in conventional Python +a[0] = sum +```` +### Conditional Statement and Expression + +````Python +if condition: + # do something +a = b if condition else c +```` +However, NO `True` and `False` keyword supported yet. + +### Math intrinsics +So far, these math intrinsics, `log`, `exp`, `sigmoid`, `tanh`, `power`, and `popcount`, are supported. No import is required, just use it! +### Array allocation +**TODO**: Use a function call `allocation(shape, type, share/local)` to declare an array buffer. The basic usage is roughly the same as variables +### Thread bind +You can also do loop-thread bind by writing code like this: +````Python +for tx in bind("threadIdx.x", 100): + a[tx] = b[tx] +```` +## Appendix + +### Keywords +- Statement keywords: `for`, `in`, `if`, `else` +- For keywords: `serial`, `range`, `unroll`, `parallel`, `vectorize`, `bind` +- Math keywords: `log`, `exp`, `sigmoid`, `tanh`, `power`, `popcount` diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 72b89af020d7..a8613f548fae 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -332,12 +332,18 @@ def lower(sch, lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - # normalize schedule first - sch = sch.normalize() + # Phase 0 - bounds = schedule.InferBound(sch) - stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) + if isinstance(sch, schedule.Schedule): + # normalize schedule first + sch = sch.normalize() + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + stmt = ir_pass.InjectPrefetch(stmt) + else: + #So far there is no op for hybrid script, so a plain ir body is given + stmt = sch + for f in lower_phase0: stmt = f(stmt) # Phase 1 diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py new file mode 100644 index 000000000000..e0a39c562f0f --- /dev/null +++ b/python/tvm/hybrid/__init__.py @@ -0,0 +1,10 @@ +"""Hybrid Programming APIs of TVM Python Package. + +This package maps a subset of python to HalideIR so that: +1. Users can write some preliminary versions of the computation patterns +have not been supported yet and verify it across the real execution and +python semantic emulation. +2. Developers can build HalideIR by writing Python code. +""" + +from .api import script, parse diff --git a/python/tvm/hybrid/_internal.py b/python/tvm/hybrid/_internal.py new file mode 100644 index 000000000000..d2c0b56b4bc9 --- /dev/null +++ b/python/tvm/hybrid/_internal.py @@ -0,0 +1,62 @@ +"""Internal utilities for parsing Python subset to HalideIR""" + +import sys +import inspect +import numpy +from ._intrin import HYBRID_GLOBALS +from .. import api as _api +from .. import make as _make +from .. import expr as _expr +from ..tensor import Tensor + +# Useful constants +NOP = _make.Evaluate(_api.const(0, dtype='int32')) +RANGE_ONE = _make.range_by_min_extent(0, 1) +TRUE = _api.convert(True) +ZERO = _api.const(0) + +# Node types represent constants in HalideIR +HALIDE_IMM = (_expr.FloatImm, _expr.IntImm, _expr.UIntImm) + +def _pruned_source(func): + """Prune source code's extra leading spaces""" + lines = inspect.getsource(func).split('\n') + leading_space = len(lines[0]) - len(lines[0].lstrip(' ')) + lines = [line[leading_space:] for line in lines] + return '\n'.join(lines) + +TVM_ARG_TYPES = (_expr.Var, Tensor) +if sys.version_info[0] == 3: + NUMPY_ARG_TYPES = (float, int, numpy.float32, numpy.int32, numpy.ndarray) +else: + NUMPY_ARG_TYPES = (float, int, long, numpy.float32, numpy.int32, numpy.ndarray) + +def _is_tvm_arg_types(args): + """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. + If neither is true, raise a assertion error.""" + if isinstance(args[0], TVM_ARG_TYPES): + for elem in args[1:]: + assert isinstance(elem, TVM_ARG_TYPES) + return True + assert isinstance(args[0], NUMPY_ARG_TYPES) + for elem in args[1:]: + assert isinstance(elem, NUMPY_ARG_TYPES) + return False + +def _enter_hybrid_runtime(func): + """Put hybrid runtime variables into the global scope""" + _globals = func.__globals__ + intersect = [] + for elem in list(HYBRID_GLOBALS.keys()): + if elem in _globals.keys(): + intersect.append((elem, _globals[elem])) + _globals[elem] = HYBRID_GLOBALS[elem] + return intersect + +def _restore_runtime(func, intersect): + """Rollback the modification caused by hybrid runtime""" + _globals = func.__globals__ + for elem in list(HYBRID_GLOBALS.keys()): + _globals.pop(elem) + for k, v in intersect: + _globals[k] = v diff --git a/python/tvm/hybrid/_intrin.py b/python/tvm/hybrid/_intrin.py new file mode 100644 index 000000000000..d039ed8b9b5b --- /dev/null +++ b/python/tvm/hybrid/_intrin.py @@ -0,0 +1,69 @@ +"""Intrinsics of Python-Halide DSL for Python runtime""" + +import numpy +from ..stmt import For + +class _range(object): + """Base class of the loop ranges in hybrid script""" + def __init__(self, a, b=None): + if b is None: + self.low = 0 + self.ext = a + else: + self.low = a + self.ext = b + + def __iter__(self): + i = 0 + while i < self.ext: + yield i + self.low + i += 1 + +class bind(_range): #pylint: disable=invalid-name + def __init__(self, tag, ext): + super(bind, self).__init__(ext) + self.tag = tag + +serial = unroll = vectorize = parallel = _range #pylint: disable=invalid-name + +def allocate(shape, dtype=None): + """Allocate a buffer with given shape""" + dtype = 'float32' if dtype is None else dtype + return numpy.zeros(shape).astype(dtype) + +def popcount(x): + cnt = 0 + while x: + x -= x & -x + cnt += 1 + return cnt + +def sigmoid(x): + return 1 / (1 + numpy.exp(-x)) + +HYBRID_GLOBALS = { + 'serial' : serial, + 'unroll' : unroll, + 'vectorize' : vectorize, + 'parallel' : parallel, + 'allocate' : allocate, + 'bind' : bind, + 'sqrt' : numpy.sqrt, + 'log' : numpy.log, + 'tanh' : numpy.tanh, + 'power' : numpy.power, + 'exp' : numpy.exp, + 'sigmoid' : sigmoid, + 'popcount' : popcount +} + +LOOP_INTRIN = { + 'range' : For.Serial, + 'serial' : For.Serial, + 'unroll' : For.Unrolled, + 'parallel' : For.Parallel, + 'vectorize': For.Vectorized, + 'bind' : None +} + +MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount'] diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py new file mode 100644 index 000000000000..7859f7ed194a --- /dev/null +++ b/python/tvm/hybrid/api.py @@ -0,0 +1,43 @@ +"""APIs of lowering the Python subset to HalideIR""" +import types +import decorator +from .parser import parse_python +from ._internal import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types, _pruned_source + +@decorator.decorator +def script(func, *args): + """If the arguments are tvm types, compile it to HalideIR. + O.W. return the python emulated result""" + if _is_tvm_arg_types(args): + return parse(func, args) + else: + intersect = _enter_hybrid_runtime(func) + func(*args) + _restore_runtime(func, intersect) + return func + +def parse(func, args): + """Parse a subset of Python to HalideIR + + Parameters + ---------- + func : str or types.FunctionType + If it is a string, parse the source code + If it is a function, parse the function + + args : list of Buffer or Tensor or Var + The argument lists to the function. + Leave it None if no buffer is related to the function to be parsed + + Returns + ------- + (halide_ir, parser) : (Stmt, PyAST2HalideIR) + The result Halide IR and the parser class instance. + TODO: Later we deprecate this return value, use a dedicated OP node type instead + """ + if isinstance(func, str): + src = func + else: + assert isinstance(func, types.FunctionType) + src = _pruned_source(func) + return parse_python(src, args) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py new file mode 100644 index 000000000000..734c8d7acf59 --- /dev/null +++ b/python/tvm/hybrid/parser.py @@ -0,0 +1,294 @@ +"""Compiling a subset of Python to HalideIR""" +#pylint: disable=no-else-return +import ast +import operator +import sys +from ._internal import NOP, TRUE, RANGE_ONE, HALIDE_IMM, ZERO +from ._intrin import LOOP_INTRIN, MATH_INTRIN +from .var_decl import determine_variable_usage +from ..api import thread_axis +from .. import expr as _expr +from .. import stmt as _stmt +from .. import make as _make +from .. import intrin +from .. import api as _api +from .. import ir_pass as _ir_pass + +def list_to_block(visit, lst): + """Convert a list of Python IR nodes to HalideIR Block""" + lst = list(map(visit, lst)) + lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, NOP)] + if not lst: + return NOP + if len(lst) == 1: + return lst[0] + body = lst[0] + for i in lst[1:]: + body = _make.Block(body, i) + return body + +class PyAST2HalideIR(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to HalideIR""" + + _binop_maker = { + ast.Add : operator.add, + ast.Sub : operator.sub, + ast.Mult : operator.mul, + ast.Div : _make.Div, + ast.Mod : operator.mod, + ast.BitOr : operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt : operator.gt, + ast.GtE : operator.ge, + ast.Lt : operator.lt, + ast.LtE : operator.le, + ast.Eq : operator.eq, + ast.NotEq : operator.ne, + } + + _unaryop_maker = { + ast.USub : operator.neg, + ast.Invert : operator.invert, + ast.Not : operator.not_ + } + + def __init__(self, args, usage, func_name=None): + """ + Parameters + ---------- + args: A list of tvm.placeholder or tvm.var + Provided by the user, the argument list of the function to be lowered. + + usage: A dict of variables used in last in this function + Provided by last lower pass, which collects this information + + Returns + ------- + func_name: str + The name of the function to be lowered; if not provided, + the compiler will use the name in the AST + """ + self.args = args[:] + self.usage = usage.copy() + self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) + self.buffers = {} + self.loops_above = {} # State variable that indicates loop levels above the current node + self.var_consts = {} # Variables that are determined as readonly in previous stage + self.func_name = func_name # The name of the function to be lowered + self.iter_axis = [] + + #pylint: disable=missing-docstring, invalid-name + #pylint: disable=consider-merging-isinstance, no-else-return + #pylint: disable=inconsistent-return-statements + + def wrap_up_realize(self, node, body): + """Wrap up all the variables which will no longer be used""" + for key, val in self.usage.items(): + if key in self.var_consts.keys(): + continue + _, scope, _ = val + if scope == node: + _buf = self.buffers[key] + body = _make.Realize(_buf.op, 0, _buf.dtype, [RANGE_ONE], TRUE, body) + return body + + def visit_Module(self, node): + assert len(node.body) == 1 + return self.visit(node.body[0]) + + def visit_FunctionDef(self, node): + assert len(node.args.args) == len(self.args) + for idx, arg in enumerate(node.args.args): + _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible + self._args[getattr(arg, _attr)] = self.args[idx] + res = list_to_block(self.visit, node.body) + res = self.wrap_up_realize(node, res) + if self.func_name is None: + self.func_name = node.name + return res + + def visit_Expr(self, node): + return self.visit(node.value) + + def visit_Name(self, node): + _id = node.id + if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var): + return self._args[_id] + elif _id in self.loops_above.keys(): + return self.loops_above[_id] + # This id cannot be a buffer; buffer will be handled in subscript + assert _id not in self._args.keys() + assert _id in self.usage.keys() + # Buffer + if _id in self.buffers.keys(): + _buf = self.buffers[_id] + return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) + # Compilation time constant + assert _id in self.var_consts.keys() + return self.var_consts[_id] + + def visit_Num(self, node): + return _api.const(node.n) + + def visit_Assign(self, node): + assert len(node.targets) == 1 + lhs = node.targets[0] + rhs = _ir_pass.Simplify(self.visit(node.value)) + if isinstance(lhs, ast.Name): + #TODO: support defined intermediate buffer later + lhs_ = lhs + lhs = lhs.id + assert lhs not in self.loops_above.keys() + decl, _, rw = self.usage[lhs] + if decl == lhs_: + assert lhs not in self.var_consts.keys() + assert lhs not in self.buffers.keys() + if isinstance(rhs, HALIDE_IMM) and ast.Store not in rw: + self.var_consts[lhs] = rhs + else: + self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) + if lhs in self.var_consts.keys(): + return NOP + else: + assert lhs in self.buffers.keys() + return _make.Provide(self.buffers[lhs].op, 0, rhs, [ZERO]) + else: + lhs = self.visit(lhs) + assert isinstance(lhs, _expr.Call) + #TODO: support slice later + assert lhs.name in self._args.keys() + return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) + + def visit_Index(self, node): + if isinstance(node.value, ast.Tuple): + return [self.visit(i) for i in node.value.elts] + return [self.visit(node.value)] + + def visit_Subscript(self, node): + #assert isinstance(node.value, ast.Name) or isinstance(node.value, ast.Attribute) + args = self.visit(node.slice) + if isinstance(node.value, ast.Name): + array = node.value.id + assert array in self._args.keys() + _buf = self._args[array] + return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) + elif isinstance(node.value, ast.Attribute): + assert isinstance(node.value.value, ast.Name) + assert node.value.attr == "shape" + assert len(args) == 1 + args = args[0] + #TODO: maybe support non-constant value later? + assert isinstance(args, (_expr.IntImm, _expr.UIntImm)) + assert node.value.value.id in self._args.keys() + return self._args[node.value.value.id].shape[args.value] + else: + assert False + + def visit_With(self, node): + if sys.version_info[0] < 3: + context = node.context_expr + option = node.optional_vars + else: + assert len(node.items) == 1 + context = node.items[0].context_expr + option = node.items[0].optional_vars + assert isinstance(context, ast.Call) + assert isinstance(option, ast.Name) + self.annotation[option.id] = context.func.id + return list_to_block(self.visit, node.body) + + def visit_If(self, node): + cond = self.visit(node.test) + if_body = list_to_block(self.visit, node.body) + if node.orelse: + else_body = list_to_block(self.visit, node.orelse) + else: + else_body = NOP + return _make.IfThenElse(cond, if_body, else_body) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if_body = self.visit(node.body) + else_body = self.visit(node.orelse) + return _make.Select(cond, if_body, else_body) + + def visit_Compare(self, node): + lhs = self.visit(node.left) + assert len(node.ops) == 1 + assert len(node.comparators) == 1 + rhs = self.visit(node.comparators[0]) + return PyAST2HalideIR._binop_maker[type(node.ops[0])](lhs, rhs) + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + return PyAST2HalideIR._unaryop_maker[type(node.op)](operand) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + return PyAST2HalideIR._binop_maker[type(node.op)](lhs, rhs) + + def visit_Call(self, node): + # Yet, no function pointer supported + assert isinstance(node.func, ast.Name) + func_id = node.func.id + n = len(node.args) + if func_id in LOOP_INTRIN.keys() and func_id != 'bind': + if n == 1: + low, ext = ZERO, self.visit(node.args[0]) + else: + assert n == 2 + low, ext = self.visit(node.args[0]), self.visit(node.args[1]) + if not _ir_pass.Equal(low, ZERO): + ext = ext - low + for_type = LOOP_INTRIN[func_id] + iter_var = None + return iter_var, low, ext, for_type + elif func_id == 'bind': + assert n == 2 + assert isinstance(node.args[0], ast.Str) + _vn = node.args[0].s + iter_var = thread_axis(node.args[0].s) + low, ext = ZERO, self.visit(node.args[1]) + for_type = None + return iter_var, low, ext, for_type + elif func_id in MATH_INTRIN: + return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) + elif func_id == 'allocate': + #TODO: Support it later! + if n == 1: + pass + else: + assert n == 2 + pass + else: + assert False and "Not supported yet!" + + def visit_For(self, node): + iter_var, low, ext, for_type = self.visit(node.iter) + assert isinstance(node.target, ast.Name) + _name = node.target.id + if iter_var is None: + assert for_type is not None + iter_var = _api.var(_name) + self.loops_above[_name] = iter_var + else: + self.loops_above[_name] = iter_var.var + assert for_type is None + _body = list_to_block(self.visit, node.body) + _body = self.wrap_up_realize(node, _body) + if for_type is None: + res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) + else: + res = _make.For(iter_var, low, ext, for_type, 0, _body) + self.loops_above.pop(_name) + return res + +def parse_python(src, args): + """ The helper function of calling the AST visitor""" + root = ast.parse(src) + var_usage = determine_variable_usage(root, args) + parser = PyAST2HalideIR(args, var_usage) + halide_ir = parser.visit(root) + return (halide_ir, parser) diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py new file mode 100644 index 000000000000..745a4f614fbd --- /dev/null +++ b/python/tvm/hybrid/var_decl.py @@ -0,0 +1,68 @@ +"""Determines the declaration, r/w status, and last use of each variable""" +import ast +import sys +from ._intrin import HYBRID_GLOBALS + +class PyVariableUsage(ast.NodeVisitor): + """The vistor class to determine the declaration, r/w status, and last use of each variable""" + #pylint: disable=invalid-name + #pylint: disable=missing-docstring + def __init__(self, args): + self.status = {} + self.scope_level = [] + self._args = {} + self.args = args + + def visit_FunctionDef(self, node): + self.scope_level.append(node) + assert len(node.args.args) == len(self.args) + for idx, arg in enumerate(node.args.args): + _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible + self._args[getattr(arg, _attr)] = self.args[idx] + for i in node.body: + self.visit(i) + + def visit_For(self, node): + assert isinstance(node.target, ast.Name) + + self.visit(node.iter) + self.scope_level.append(node) + + for i in node.body: + self.visit(i) + + self.scope_level.pop() + + def visit_Call(self, node): + #No function pointer supported so far + assert isinstance(node.func, ast.Name) + assert node.func.id in HYBRID_GLOBALS.keys() or node.func.id == 'range' + for elem in node.args: + self.visit(elem) + + def visit_Name(self, node): + # If it is from the argument list or loop variable, we do not worry about it! + if node.id in self._args.keys(): + return + fors = [loop.target.id for loop in self.scope_level if isinstance(loop, ast.For)] + if node.id in fors: + return + # The loop variable cannot be overwritten when iteration + if isinstance(node.ctx, ast.Store): + assert node.id not in fors + + if node.id not in self.status.keys(): + # In Python, "first store" indicates "declaration" + assert isinstance(node.ctx, ast.Store) + self.status[node.id] = (node, self.scope_level[-1], set()) + else: + decl, loop, usage = self.status[node.id] + loop = self.scope_level[-1] + usage.add(type(node.ctx)) + self.status[node.id] = (decl, loop, usage) + +def determine_variable_usage(root, args): + """The helper function for calling the dedicated visitor.""" + visitor = PyVariableUsage(args) + visitor.visit(root) + return visitor.status diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py new file mode 100644 index 000000000000..261bd3d1bbfe --- /dev/null +++ b/tests/python/unittest/test_hybrid_script.py @@ -0,0 +1,288 @@ +import tvm, inspect, sys, traceback, numpy +from tvm.hybrid import script +from tvm.hybrid._intrin import HYBRID_GLOBALS + +@script +def outer_product(n, m, a, b, c): + for i in serial(n): + for j in range(m): + c[i, j] = a[i] * b[j] + +#Test global function +#Test bridge between frontend and backend +def test_outer_product(): + n = tvm.var('n') + m = tvm.var('m') + a = tvm.placeholder((n, ), name='a') + b = tvm.placeholder((m, ), name='b') + c = tvm.placeholder((n, m), name='c') + ir, _ = outer_product(n, m, a, b, c) + #Check for i in (0, n) + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i' + assert ir.min.value == 0 + assert ir.extent.name == 'n' + ibody = ir.body + assert isinstance(ibody, tvm.stmt.For) + #Check for j in (0, m) + assert ibody.loop_var.name == 'j' + assert ibody.min.value == 0 + assert ibody.extent.name == 'm' + #Check loop body + jbody = ibody.body + assert isinstance(jbody, tvm.stmt.Provide) + assert jbody.func.name == 'c' + assert len(jbody.args) == 2 + assert jbody.args[0].name == 'i' + assert jbody.args[1].name == 'j' + assert isinstance(jbody.value, tvm.expr.Mul) + mul = jbody.value + assert isinstance(mul.a, tvm.expr.Call) + assert mul.a.name == 'a' + assert mul.b.name == 'b' + + func = tvm.lower(ir, [n, m, a, b, c]) + func = tvm.build(func) + + _n = 999 + _m = 1001 + _a = numpy.random.rand(_n).astype('float32') + _b = numpy.random.rand(_m).astype('float32') + c_python = numpy.zeros((_n, _m), dtype='float32') + outer_product(_n, _m, _a, _b, c_python) + + tvm_a = tvm.ndarray.array(_a) + tvm_b = tvm.ndarray.array(_b) + tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32')) + func(_n, _m, tvm_a, tvm_b, tvm_c) + numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5) + for key, _ in HYBRID_GLOBALS.items(): + assert key not in globals().keys() + assert key not in outer_product.__globals__.keys() + +#Test local function +#Test allocation of local variable +def test_fanout(): + @script + def fanout(n, a, b): + three = 3.0 + for i in serial(a.shape[0] - 3): + sigma = 0.0 + for j in serial(3): + sigma = sigma + a[i + j] + sigma = sigma / three + b[i] = sigma + + n = tvm.var('n') + a = tvm.placeholder((n, ), name='a') + b = tvm.placeholder((n-3, ), name='b') + ir, _ = fanout(n, a, b) + + #Check for i in (0, n-3) + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i' + assert ir.min.value == 0 + assert tvm.ir_pass.Equal(ir.extent, n - 3) + #Check loopbody + ibody = ir.body + assert isinstance(ibody, tvm.stmt.Realize) + assert ibody.bounds[0].min.value == 0 + assert ibody.bounds[0].extent.value == 1 + assert ibody.func.name == 'sigma' + #Check i loop body + rbody = ibody.body + assert isinstance(rbody.first, tvm.stmt.Provide) + assert rbody.first.func.name == 'sigma' + assert len(rbody.first.args) == 1 + assert rbody.first.args[0].value == 0 + #Check fanout loop + jloop = rbody.rest.first + assert jloop.loop_var.name == 'j' + assert jloop.min.value == 0 + assert jloop.extent.value == 3 + jbody = jloop.body + assert isinstance(jbody, tvm.stmt.Provide) + assert len(jbody.args) == 1 + assert jbody.args[0].value == 0 + assert jbody.func.name == 'sigma' + assert isinstance(jbody.value, tvm.expr.Add) + value = jbody.value + assert isinstance(value.a, tvm.expr.Call) + assert value.a.name == 'sigma' + assert len(value.a.args) == 1 + assert value.a.args[0].value == 0 + assert value.b.name == 'a' + assert len(value.b.args) == 1 + assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) + divide= rbody.rest.rest.first + assert isinstance(divide, tvm.stmt.Provide) + assert len(divide.args) == 1 + assert divide.args[0].value == 0 + value = divide.value + assert isinstance(value, tvm.expr.Mul) + assert value.a.name == 'sigma' + assert len(value.a.args) == 1 + assert value.a.args[0].value == 0 + assert abs(value.b.value - (1 / 3.0)) < 1e-5 + write = rbody.rest.rest.rest + assert isinstance(write, tvm.stmt.Provide) + assert write.func.name == 'b' + assert write.value.name == 'sigma' + assert len(write.value.args) == 1 + assert write.value.args[0].value == 0 + +@script +def failure(): + for i in serial(1, 100): + i = 0 + +def test_failure(): + try: + tvm.hybrid.parse(failure, []) + except IOError: + assert sys.version_info[0] == 2 + lineno = inspect.currentframe().f_back.f_lineno + print('[Warning] Python2 cannot do the failure case @line #%d' % lineno) + except AssertionError: + _, _, tb = sys.exc_info() + _, _, func, text = traceback.extract_tb(tb)[-1] + assert func == 'visit_Assign' + assert text == 'assert lhs not in self.loops_above.keys()' + + +def test_looptype(): + @script + def looptype(a): + for i in parallel(6): + a[i] = i + for j in vectorize(6): + a[j] = j + for k in unroll(6): + a[k] = k + a = tvm.placeholder((6, ), name='a') + ir, _ = looptype(a) + iloop = ir.first + jloop = ir.rest.first + kloop = ir.rest.rest + assert iloop.for_type == tvm.stmt.For.Parallel + assert jloop.for_type == tvm.stmt.For.Vectorized + assert kloop.for_type == tvm.stmt.For.Unrolled + +def test_if(): + @script + def if_then_else(a, b): + for i in serial(10): + if i % 2 == 0: + a[i] = -1 + else: + a[i] = 1 + for i in unroll(10): + b[i] = -1 if i % 2 == 0 else 1 + + a = tvm.placeholder((10, ), dtype='int32', name='a') + b = tvm.placeholder((10, ), dtype='int32', name='b') + ir, _ = if_then_else(a, b) + func = tvm.lower(ir, [a, b]) + func = tvm.build(func) + assert func + + _a = numpy.zeros((10, ), dtype = 'int32') + _b = numpy.zeros((10, ), dtype = 'int32') + if_then_else(_a, _b) + + tvm_a = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32')) + tvm_b = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32')) + func(tvm_a, tvm_b) + + numpy.testing.assert_allclose(tvm_a.asnumpy(), _a, rtol=1e-5) + numpy.testing.assert_allclose(tvm_b.asnumpy(), _b, rtol=1e-5) + numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5) + +def test_bind(): + @script + def vec_add(a, b, c): + for tx in bind('threadIdx.x', 1000): + c[tx] = b[tx] + c[tx] + + a = tvm.placeholder((1000, ), dtype='float32', name='a') + b = tvm.placeholder((1000, ), dtype='float32', name='b') + c = tvm.placeholder((1000, ), dtype='float32', name='c') + ir, _ = vec_add(a, b, c) + #print(tvm.lower(ir, [a, b, c], simple_mode=True)) + func = tvm.lower(ir, [a, b, c]) + func = tvm.build(func, target = 'cuda') + + _a = numpy.random.rand(1000).astype('float32') + _b = numpy.random.rand(1000).astype('float32') + _c = numpy.zeros((1000, ), dtype = 'float32') + + + tvm_a = tvm.ndarray.array(_a, tvm.gpu(0)) + tvm_b = tvm.ndarray.array(_b, tvm.gpu(0)) + tvm_c = tvm.ndarray.array(_c, tvm.gpu(0)) + + func(tvm_a, tvm_b, tvm_c) + vec_add(_a, _b, _c) + + numpy.testing.assert_allclose(_c, tvm_c.asnumpy(), rtol=1e-5) + +def test_math_intrin(): + @script + def intrin_real(a): + a[0] = sqrt(a[0]) + a[1] = log(a[1]) + a[2] = exp(a[2]) + a[3] = sigmoid(a[3]) + a[4] = power(a[4], a[5]) + a[5] = tanh(a[5]) + + a6 = tvm.placeholder((6, ), dtype='float32', name='a') + ir, _ = intrin_real(a6) + func = tvm.build(tvm.lower(ir, [a6])) + assert func + a = numpy.arange(2, 8).astype('float32') + tvm_a = tvm.ndarray.array(a) + func(tvm_a) + intrin_real(a) + numpy.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5) + + @script + def intrin_int(a): + a[0] = popcount(a[0]) + + a1 = tvm.placeholder((1, ), dtype='int32') + ir, _ = intrin_int(a1) + func = tvm.build(tvm.lower(ir, [a1])) + assert func + a = numpy.array([1234567890]).astype('int32') + tvm_a = tvm.ndarray.array(a) + intrin_int(a) + func(tvm_a) + assert tvm_a.asnumpy()[0] == a[0] + +def test_allocate_buffer(): + def blur(a): + for i in serail(32): + h_blur = allocate((4, 36)) + for j in serail(4): + for k in serail(36): + s = allocate((1, ), 'float32') + for dj in serail(4): + s[0] = s[0] + a[i, j + dj] + h_blur[j, k] = s[0] / 4. + for j in serail(32): + s = 0. + for di in serail(4): + s = s + h_blur[di, j] + h_blur[i, j] = s / 4. + + +if __name__ == "__main__": + test_outer_product() + test_fanout() + test_failure() + test_looptype() + test_if() + test_bind() + test_math_intrin() + From 9578c2aeaa6fd28a0e1e5fc84dd9f8ff82b74c88 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Fri, 8 Jun 2018 13:25:46 -0700 Subject: [PATCH 02/31] cleanup unused import --- python/tvm/hybrid/parser.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 734c8d7acf59..ee546da403ab 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -8,7 +8,6 @@ from .var_decl import determine_variable_usage from ..api import thread_axis from .. import expr as _expr -from .. import stmt as _stmt from .. import make as _make from .. import intrin from .. import api as _api @@ -261,7 +260,6 @@ def visit_Call(self, node): pass else: assert n == 2 - pass else: assert False and "Not supported yet!" From c37eabc4995b970f6b912104807e17d3acd20d24 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Fri, 8 Jun 2018 13:38:09 -0700 Subject: [PATCH 03/31] gpu skip? --- tests/python/unittest/test_hybrid_script.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 261bd3d1bbfe..96c3bce3e2f2 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -199,6 +199,9 @@ def if_then_else(a, b): numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5) def test_bind(): + if not tvm.gpu(0).exist: + print('No GPU found! Skip this test!') + return @script def vec_add(a, b, c): for tx in bind('threadIdx.x', 1000): @@ -208,7 +211,7 @@ def vec_add(a, b, c): b = tvm.placeholder((1000, ), dtype='float32', name='b') c = tvm.placeholder((1000, ), dtype='float32', name='c') ir, _ = vec_add(a, b, c) - #print(tvm.lower(ir, [a, b, c], simple_mode=True)) + func = tvm.lower(ir, [a, b, c]) func = tvm.build(func, target = 'cuda') From 60ba428e05420d17a34760fb9c787864bd213688 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Sun, 10 Jun 2018 22:10:14 -0700 Subject: [PATCH 04/31] Fix typo in developer tutorial --- docs/dev/hybrid_script.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/hybrid_script.md b/docs/dev/hybrid_script.md index 05212b748196..513b0b713d5d 100644 --- a/docs/dev/hybrid_script.md +++ b/docs/dev/hybrid_script.md @@ -20,7 +20,7 @@ def outer_product(a, b, c): a = numpy.random.rand(100) b = numpy.random.rand(99) c = numpy.zeros((100, 99)) -outer_product(a, b) +outer_product(a, b, c) ```` This decorator will help you to import [key words](#keywords) required spontaneously when software emulation. Every element in the argument list is either a python variable or `numpy` tensor. From 9b70244466351d66a89e7efe68141ec550e53b74 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Wed, 13 Jun 2018 10:14:49 -0700 Subject: [PATCH 05/31] Erase TODO, typo fixed. --- docs/dev/hybrid_script.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/docs/dev/hybrid_script.md b/docs/dev/hybrid_script.md index 513b0b713d5d..437db6756e5e 100644 --- a/docs/dev/hybrid_script.md +++ b/docs/dev/hybrid_script.md @@ -34,13 +34,15 @@ b = tvm.placeholder((99, ), name='b') c = tvm.placeholder((100, 99), name='c') tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function ```` -**TODO**: If we pass these tvm tensors to this function, it returns a op node: +If we pass these tvm tensors to this function, it returns a op node: ````Python a = tvm.placeholder((100, ), name='a') b = tvm.placeholder((99, ), name='b') c = tvm.placeholder((100, 99), name='c') op = outer_product(a, b, c) # return the corresponding op node ```` +**This function is still under construction** + #### Scheduling **Under construction, not truly supported yet.** @@ -58,7 +60,7 @@ So far, ONLY tensors' `shape` attribute is supported! ### Loops -In HalideIR, loops have in total 4 types: `serail`, `unrolled`, `parallel`, and `vectorized`. +In HalideIR, loops have in total 4 types: `serial`, `unrolled`, `parallel`, and `vectorized`. Here we use `range`, `serial`, `unroll`, `parallel`, and `vectorize`, these **5** keywords to annotate the types of for loops. @@ -73,12 +75,12 @@ It takes the first store of a variable as its declaration. **NOTE**: Unlike conventional Python, the declared array can only be used in the scope level it is declared. ````Python for i in range(5): - sum = 0 + s = 0 for j in range(5): - sum += a[i, j] #do something with sum + s += a[i, j] #do something with sum b[i] = sum #you can still use sum in this level #you can NEVER use some here, even though it is allowed in conventional Python -a[0] = sum +a[0] = s ```` ### Conditional Statement and Expression @@ -92,7 +94,9 @@ However, NO `True` and `False` keyword supported yet. ### Math intrinsics So far, these math intrinsics, `log`, `exp`, `sigmoid`, `tanh`, `power`, and `popcount`, are supported. No import is required, just use it! ### Array allocation -**TODO**: Use a function call `allocation(shape, type, share/local)` to declare an array buffer. The basic usage is roughly the same as variables +Use a function call `allocation(shape, type, share/local)` to declare an array buffer. The basic usage is roughly the same as variables. + +**This function is still under construction.** ### Thread bind You can also do loop-thread bind by writing code like this: ````Python From 6227166a941de74a13f16ef5863cc070f42ae59c Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Sun, 17 Jun 2018 20:29:37 -0700 Subject: [PATCH 06/31] move np_type to _ffi.base; get rid of constant ir nodes --- python/tvm/_ffi/base.py | 1 + python/tvm/build_module.py | 2 ++ python/tvm/hybrid/_internal.py | 37 ++++++++++++++++++---------------- python/tvm/hybrid/parser.py | 22 +++++++++++--------- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 49348f3110ad..584899ab6310 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -23,6 +23,7 @@ numeric_types = (float, int, long, np.float32, np.int32) py_str = lambda x: x +np_arg_types = (*numeric_types, np.ndarray) class TVMError(Exception): """Error thrown by TVM function""" diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index a8613f548fae..777654af6619 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -342,6 +342,8 @@ def lower(sch, stmt = ir_pass.InjectPrefetch(stmt) else: #So far there is no op for hybrid script, so a plain ir body is given + if not isinstance(sch, _stmt.Stmt): + raise ValueError("sch should be either a Schedule or a Stmt") stmt = sch for f in lower_phase0: diff --git a/python/tvm/hybrid/_internal.py b/python/tvm/hybrid/_internal.py index d2c0b56b4bc9..85aeee081d48 100644 --- a/python/tvm/hybrid/_internal.py +++ b/python/tvm/hybrid/_internal.py @@ -4,19 +4,25 @@ import inspect import numpy from ._intrin import HYBRID_GLOBALS +from .._ffi.base import np_arg_types from .. import api as _api from .. import make as _make from .. import expr as _expr from ..tensor import Tensor +# If it is a +tvm_arg_types = (Tensor, _expr.Var) +halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) + # Useful constants -NOP = _make.Evaluate(_api.const(0, dtype='int32')) -RANGE_ONE = _make.range_by_min_extent(0, 1) -TRUE = _api.convert(True) -ZERO = _api.const(0) +def make_nop(): + return _make.Evaluate(_api.const(0, dtype='int32')) + +def make_range_one(): + return _make.range_by_min_extent(0, 1) -# Node types represent constants in HalideIR -HALIDE_IMM = (_expr.FloatImm, _expr.IntImm, _expr.UIntImm) +def make_const_true(): + return _api.convert(True) def _pruned_source(func): """Prune source code's extra leading spaces""" @@ -25,22 +31,19 @@ def _pruned_source(func): lines = [line[leading_space:] for line in lines] return '\n'.join(lines) -TVM_ARG_TYPES = (_expr.Var, Tensor) -if sys.version_info[0] == 3: - NUMPY_ARG_TYPES = (float, int, numpy.float32, numpy.int32, numpy.ndarray) -else: - NUMPY_ARG_TYPES = (float, int, long, numpy.float32, numpy.int32, numpy.ndarray) - def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. - If neither is true, raise a assertion error.""" - if isinstance(args[0], TVM_ARG_TYPES): + If neither is true, raise a value error.""" + if isinstance(args[0], tvm_arg_types): for elem in args[1:]: - assert isinstance(elem, TVM_ARG_TYPES) + if not isinstance(elem, tvm_arg_types): + raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem))) return True - assert isinstance(args[0], NUMPY_ARG_TYPES) + if not isinstance(args[0], np_arg_types): + raise ValueError("Expect a numpy type but % get!" % str(type(elem))) for elem in args[1:]: - assert isinstance(elem, NUMPY_ARG_TYPES) + if not isinstance(elem, np_arg_types): + raise ValueError("Expect a numpy type but % get!" % str(type(elem))) return False def _enter_hybrid_runtime(func): diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index ee546da403ab..88cf2718fe39 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -1,9 +1,9 @@ -"""Compiling a subset of Python to HalideIR""" +"""Compiling a TVM Hybrid Script Python to HalideIR""" #pylint: disable=no-else-return import ast import operator import sys -from ._internal import NOP, TRUE, RANGE_ONE, HALIDE_IMM, ZERO +from ._internal import make_nop, make_const_true, make_range_one, halide_imm_types from ._intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -16,9 +16,9 @@ def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" lst = list(map(visit, lst)) - lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, NOP)] + lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] if not lst: - return NOP + return make_nop() if len(lst) == 1: return lst[0] body = lst[0] @@ -89,7 +89,10 @@ def wrap_up_realize(self, node, body): _, scope, _ = val if scope == node: _buf = self.buffers[key] - body = _make.Realize(_buf.op, 0, _buf.dtype, [RANGE_ONE], TRUE, body) + _dtype = _buf.dtype + _one = make_range_one() + _true = make_const_true() + body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body) return body def visit_Module(self, node): @@ -143,15 +146,15 @@ def visit_Assign(self, node): if decl == lhs_: assert lhs not in self.var_consts.keys() assert lhs not in self.buffers.keys() - if isinstance(rhs, HALIDE_IMM) and ast.Store not in rw: + if isinstance(rhs, halide_imm_types) and ast.Store not in rw: self.var_consts[lhs] = rhs else: self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) if lhs in self.var_consts.keys(): - return NOP + return make_nop() else: assert lhs in self.buffers.keys() - return _make.Provide(self.buffers[lhs].op, 0, rhs, [ZERO]) + return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) else: lhs = self.visit(lhs) assert isinstance(lhs, _expr.Call) @@ -203,7 +206,7 @@ def visit_If(self, node): if node.orelse: else_body = list_to_block(self.visit, node.orelse) else: - else_body = NOP + else_body = make_nop() return _make.IfThenElse(cond, if_body, else_body) def visit_IfExp(self, node): @@ -234,6 +237,7 @@ def visit_Call(self, node): func_id = node.func.id n = len(node.args) if func_id in LOOP_INTRIN.keys() and func_id != 'bind': + ZERO = _api.const(0, dtype='int32') if n == 1: low, ext = ZERO, self.visit(node.args[0]) else: From ab523d7c2f8e1c935b27b1bc37dbb5d7db8d94d6 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Sun, 17 Jun 2018 21:35:24 -0700 Subject: [PATCH 07/31] move all the assertions to ValueError; adjust the test case! --- python/tvm/_ffi/base.py | 1 - python/tvm/hybrid/_intrin.py | 12 +- python/tvm/hybrid/{_internal.py => _util.py} | 11 +- python/tvm/hybrid/api.py | 5 +- python/tvm/hybrid/parser.py | 128 ++++++++++++------- python/tvm/hybrid/var_decl.py | 29 +++-- tests/python/unittest/test_hybrid_script.py | 26 ++-- 7 files changed, 138 insertions(+), 74 deletions(-) rename python/tvm/hybrid/{_internal.py => _util.py} (95%) diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 584899ab6310..49348f3110ad 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -23,7 +23,6 @@ numeric_types = (float, int, long, np.float32, np.int32) py_str = lambda x: x -np_arg_types = (*numeric_types, np.ndarray) class TVMError(Exception): """Error thrown by TVM function""" diff --git a/python/tvm/hybrid/_intrin.py b/python/tvm/hybrid/_intrin.py index d039ed8b9b5b..c386994e98fa 100644 --- a/python/tvm/hybrid/_intrin.py +++ b/python/tvm/hybrid/_intrin.py @@ -1,4 +1,4 @@ -"""Intrinsics of Python-Halide DSL for Python runtime""" +"""Intrinsics of TVM-Python Hybrid Script for Python runtime""" import numpy from ..stmt import For @@ -19,28 +19,36 @@ def __iter__(self): yield i + self.low i += 1 + class bind(_range): #pylint: disable=invalid-name def __init__(self, tag, ext): super(bind, self).__init__(ext) self.tag = tag + serial = unroll = vectorize = parallel = _range #pylint: disable=invalid-name + def allocate(shape, dtype=None): """Allocate a buffer with given shape""" dtype = 'float32' if dtype is None else dtype return numpy.zeros(shape).astype(dtype) + def popcount(x): + """Software emulated popcount function which counts 1's in a number's binary representation.""" cnt = 0 while x: x -= x & -x cnt += 1 return cnt + def sigmoid(x): + """Software emulated sigmoid function, which returns 1/(1+exp(-x)).""" return 1 / (1 + numpy.exp(-x)) + HYBRID_GLOBALS = { 'serial' : serial, 'unroll' : unroll, @@ -57,6 +65,7 @@ def sigmoid(x): 'popcount' : popcount } + LOOP_INTRIN = { 'range' : For.Serial, 'serial' : For.Serial, @@ -66,4 +75,5 @@ def sigmoid(x): 'bind' : None } + MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount'] diff --git a/python/tvm/hybrid/_internal.py b/python/tvm/hybrid/_util.py similarity index 95% rename from python/tvm/hybrid/_internal.py rename to python/tvm/hybrid/_util.py index 85aeee081d48..f96f9685eca8 100644 --- a/python/tvm/hybrid/_internal.py +++ b/python/tvm/hybrid/_util.py @@ -4,26 +4,32 @@ import inspect import numpy from ._intrin import HYBRID_GLOBALS -from .._ffi.base import np_arg_types +from .._ffi.base import numeric_types from .. import api as _api from .. import make as _make from .. import expr as _expr from ..tensor import Tensor + # If it is a +np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, _expr.Var) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) + # Useful constants def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) + def make_range_one(): return _make.range_by_min_extent(0, 1) + def make_const_true(): return _api.convert(True) + def _pruned_source(func): """Prune source code's extra leading spaces""" lines = inspect.getsource(func).split('\n') @@ -31,6 +37,7 @@ def _pruned_source(func): lines = [line[leading_space:] for line in lines] return '\n'.join(lines) + def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" @@ -46,6 +53,7 @@ def _is_tvm_arg_types(args): raise ValueError("Expect a numpy type but % get!" % str(type(elem))) return False + def _enter_hybrid_runtime(func): """Put hybrid runtime variables into the global scope""" _globals = func.__globals__ @@ -56,6 +64,7 @@ def _enter_hybrid_runtime(func): _globals[elem] = HYBRID_GLOBALS[elem] return intersect + def _restore_runtime(func, intersect): """Rollback the modification caused by hybrid runtime""" _globals = func.__globals__ diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py index 7859f7ed194a..cc0a324b0785 100644 --- a/python/tvm/hybrid/api.py +++ b/python/tvm/hybrid/api.py @@ -1,8 +1,10 @@ """APIs of lowering the Python subset to HalideIR""" +from __future__ import absolute_import as _abs + import types import decorator from .parser import parse_python -from ._internal import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types, _pruned_source +from ._util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types, _pruned_source @decorator.decorator def script(func, *args): @@ -16,6 +18,7 @@ def script(func, *args): _restore_runtime(func, intersect) return func + def parse(func, args): """Parse a subset of Python to HalideIR diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 88cf2718fe39..c6d34a568793 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,7 +3,7 @@ import ast import operator import sys -from ._internal import make_nop, make_const_true, make_range_one, halide_imm_types +from ._util import make_nop, make_const_true, make_range_one, halide_imm_types from ._intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -26,9 +26,11 @@ def list_to_block(visit, lst): body = _make.Block(body, i) return body -class PyAST2HalideIR(ast.NodeVisitor): + +class HybridParser(ast.NodeVisitor): """Python AST visitor pass which finally lowers it to HalideIR""" + _binop_maker = { ast.Add : operator.add, ast.Sub : operator.sub, @@ -46,12 +48,14 @@ class PyAST2HalideIR(ast.NodeVisitor): ast.NotEq : operator.ne, } + _unaryop_maker = { ast.USub : operator.neg, ast.Invert : operator.invert, ast.Not : operator.not_ } + def __init__(self, args, usage, func_name=None): """ Parameters @@ -77,10 +81,10 @@ def __init__(self, args, usage, func_name=None): self.func_name = func_name # The name of the function to be lowered self.iter_axis = [] - #pylint: disable=missing-docstring, invalid-name + + #pylint: disable=invalid-name #pylint: disable=consider-merging-isinstance, no-else-return #pylint: disable=inconsistent-return-statements - def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" for key, val in self.usage.items(): @@ -95,12 +99,22 @@ def wrap_up_realize(self, node, body): body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body) return body + + def _check_id_a_buffer(self, s): + if s not in self._args.keys(): + raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) + + def visit_Module(self, node): - assert len(node.body) == 1 + if len(node.body) != 1: + raise ValueError("Only one-function source code can be fed to this parser!") return self.visit(node.body[0]) + def visit_FunctionDef(self, node): - assert len(node.args.args) == len(self.args) + if len(node.args.args) != len(self.args): + raise ValueError("The number of arguments passed to the function\ + should be the same as it is defined!") for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self._args[getattr(arg, _attr)] = self.args[idx] @@ -110,42 +124,50 @@ def visit_FunctionDef(self, node): self.func_name = node.name return res + def visit_Expr(self, node): return self.visit(node.value) + def visit_Name(self, node): _id = node.id if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var): return self._args[_id] elif _id in self.loops_above.keys(): return self.loops_above[_id] - # This id cannot be a buffer; buffer will be handled in subscript - assert _id not in self._args.keys() - assert _id in self.usage.keys() + if _id in self._args.keys(): + raise ValueError("This id %s should be handled in visit_Subscript!" % _id) + if _id not in self.usage.keys(): + raise ValueError("This id %s is expected to be a defined variable!" % _id) # Buffer if _id in self.buffers.keys(): _buf = self.buffers[_id] return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) # Compilation time constant - assert _id in self.var_consts.keys() + if _id not in self.var_consts.keys(): + raise ValueError("This id %s is expected to a compilation time constant!" % _id) return self.var_consts[_id] def visit_Num(self, node): return _api.const(node.n) def visit_Assign(self, node): - assert len(node.targets) == 1 + if len(node.targets) != 1: + raise ValueError("So far only one-valued assignment is supported!") lhs = node.targets[0] rhs = _ir_pass.Simplify(self.visit(node.value)) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs lhs = lhs.id - assert lhs not in self.loops_above.keys() + if lhs in self.loops_above.keys(): + raise ValueError("You CAN NEVER overwrite a loop variable!") decl, _, rw = self.usage[lhs] if decl == lhs_: - assert lhs not in self.var_consts.keys() - assert lhs not in self.buffers.keys() + if lhs in self.var_consts.keys(): + raise ValueError("BUG: A constant cannot be overwritten!") + if lhs in self.buffers.keys(): + raise ValueError("BUG: This value should not be defined before this point!") if isinstance(rhs, halide_imm_types) and ast.Store not in rw: self.var_consts[lhs] = rhs else: @@ -153,13 +175,15 @@ def visit_Assign(self, node): if lhs in self.var_consts.keys(): return make_nop() else: - assert lhs in self.buffers.keys() + if lhs not in self.buffers.keys(): + raise ValueError("BUG: This value should be defined before!") return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) else: lhs = self.visit(lhs) - assert isinstance(lhs, _expr.Call) + if not isinstance(lhs, _expr.Call): + raise ValueError("An array access's LHS is expected to be a expr.Call!") #TODO: support slice later - assert lhs.name in self._args.keys() + self._check_id_a_buffer(lhs.name) return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) def visit_Index(self, node): @@ -168,35 +192,41 @@ def visit_Index(self, node): return [self.visit(node.value)] def visit_Subscript(self, node): - #assert isinstance(node.value, ast.Name) or isinstance(node.value, ast.Attribute) args = self.visit(node.slice) if isinstance(node.value, ast.Name): array = node.value.id - assert array in self._args.keys() + self._check_id_a_buffer(array) _buf = self._args[array] return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) elif isinstance(node.value, ast.Attribute): - assert isinstance(node.value.value, ast.Name) - assert node.value.attr == "shape" - assert len(args) == 1 + if not isinstance(node.value.value, ast.Name): + raise ValueError("The root of array access is expect to be a id!") + if node.value.attr != "shape": + raise ValueError("Attribute access so far only 'shape' is supported!") + if len(args) != 1: + raise ValueError("For 'shape' access the argument should be only one!") args = args[0] #TODO: maybe support non-constant value later? - assert isinstance(args, (_expr.IntImm, _expr.UIntImm)) - assert node.value.value.id in self._args.keys() + if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): + raise ValueError("So far only constant shape access supported!") + self._check_id_a_buffer(node.value.value.id) return self._args[node.value.value.id].shape[args.value] else: - assert False + raise ValueError("Not supported yet!") def visit_With(self, node): if sys.version_info[0] < 3: context = node.context_expr option = node.optional_vars else: - assert len(node.items) == 1 + if len(node.items) != 1: + raise ValueError("Only one with element is supported so far!") context = node.items[0].context_expr option = node.items[0].optional_vars - assert isinstance(context, ast.Call) - assert isinstance(option, ast.Name) + if not isinstance(context, ast.Call): + raise ValueError("The object must be a Python function call!") + if not isinstance(option, ast.Name): + raise ValueError("The object after 'as' must be an id!") self.annotation[option.id] = context.func.id return list_to_block(self.visit, node.body) @@ -217,23 +247,26 @@ def visit_IfExp(self, node): def visit_Compare(self, node): lhs = self.visit(node.left) - assert len(node.ops) == 1 - assert len(node.comparators) == 1 + if len(node.ops) != 1: + raise ValueError("Only one compare op is supported!") + if len(node.comparators) != 1: + raise ValueError("Only one comparator is supported!") rhs = self.visit(node.comparators[0]) - return PyAST2HalideIR._binop_maker[type(node.ops[0])](lhs, rhs) + return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs) def visit_UnaryOp(self, node): operand = self.visit(node.operand) - return PyAST2HalideIR._unaryop_maker[type(node.op)](operand) + return HybridParser._unaryop_maker[type(node.op)](operand) def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) - return PyAST2HalideIR._binop_maker[type(node.op)](lhs, rhs) + return HybridParser._binop_maker[type(node.op)](lhs, rhs) def visit_Call(self, node): # Yet, no function pointer supported - assert isinstance(node.func, ast.Name) + if not isinstance(node.func, ast.Name): + raise ValueError("Only id-function function call is supported so far!") func_id = node.func.id n = len(node.args) if func_id in LOOP_INTRIN.keys() and func_id != 'bind': @@ -241,7 +274,8 @@ def visit_Call(self, node): if n == 1: low, ext = ZERO, self.visit(node.args[0]) else: - assert n == 2 + if n != 2: + raise ValueError("A loop intrinsic should only have 1 or 2 arguments!") low, ext = self.visit(node.args[0]), self.visit(node.args[1]) if not _ir_pass.Equal(low, ZERO): ext = ext - low @@ -249,8 +283,10 @@ def visit_Call(self, node): iter_var = None return iter_var, low, ext, for_type elif func_id == 'bind': - assert n == 2 - assert isinstance(node.args[0], ast.Str) + if n != 2: + raise ValueError("A loop bind should only have 2 arguments!") + if not isinstance(node.args[0], ast.Str): + raise ValueError("A loop bind's first argument should be a string!") _vn = node.args[0].s iter_var = thread_axis(node.args[0].s) low, ext = ZERO, self.visit(node.args[1]) @@ -265,19 +301,22 @@ def visit_Call(self, node): else: assert n == 2 else: - assert False and "Not supported yet!" + raise ValueError("Function call not supported yet!") def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) - assert isinstance(node.target, ast.Name) + if not isinstance(node.target, ast.Name): + raise ValueError("The loop iterator should be a variable!") _name = node.target.id if iter_var is None: - assert for_type is not None + if for_type is None: + raise ValueError("The loop bind function parse error!") iter_var = _api.var(_name) self.loops_above[_name] = iter_var else: + if for_type is not None: + raise ValueError("The loop iterating function parse error!") self.loops_above[_name] = iter_var.var - assert for_type is None _body = list_to_block(self.visit, node.body) _body = self.wrap_up_realize(node, _body) if for_type is None: @@ -287,10 +326,11 @@ def visit_For(self, node): self.loops_above.pop(_name) return res + def parse_python(src, args): - """ The helper function of calling the AST visitor""" + """The helper function of calling the AST visitor""" root = ast.parse(src) var_usage = determine_variable_usage(root, args) - parser = PyAST2HalideIR(args, var_usage) + parser = HybridParser(args, var_usage) halide_ir = parser.visit(root) - return (halide_ir, parser) + return halide_ir diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index 745a4f614fbd..fa7daa5d4b45 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -3,6 +3,7 @@ import sys from ._intrin import HYBRID_GLOBALS + class PyVariableUsage(ast.NodeVisitor): """The vistor class to determine the declaration, r/w status, and last use of each variable""" #pylint: disable=invalid-name @@ -13,33 +14,38 @@ def __init__(self, args): self._args = {} self.args = args + def visit_FunctionDef(self, node): self.scope_level.append(node) - assert len(node.args.args) == len(self.args) + if len(node.args.args) != len(self.args): + raise ValueError('#arguments passed should be the same as #arguments defined') for idx, arg in enumerate(node.args.args): _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible self._args[getattr(arg, _attr)] = self.args[idx] for i in node.body: self.visit(i) - def visit_For(self, node): - assert isinstance(node.target, ast.Name) + def visit_For(self, node): + if not isinstance(node.target, ast.Name): + raise ValueError("For's iterator should be an id") self.visit(node.iter) self.scope_level.append(node) - for i in node.body: self.visit(i) - self.scope_level.pop() + def visit_Call(self, node): #No function pointer supported so far - assert isinstance(node.func, ast.Name) - assert node.func.id in HYBRID_GLOBALS.keys() or node.func.id == 'range' + if not isinstance(node.func, ast.Name): + raise ValueError("Function call should be an id") + if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range': + raise ValueError("Function call id not in intrinsics' list") for elem in node.args: self.visit(elem) + def visit_Name(self, node): # If it is from the argument list or loop variable, we do not worry about it! if node.id in self._args.keys(): @@ -48,12 +54,12 @@ def visit_Name(self, node): if node.id in fors: return # The loop variable cannot be overwritten when iteration - if isinstance(node.ctx, ast.Store): - assert node.id not in fors + if isinstance(node.ctx, ast.Store) and node.id in fors: + raise ValueError("Iter var cannot be overwritten") if node.id not in self.status.keys(): - # In Python, "first store" indicates "declaration" - assert isinstance(node.ctx, ast.Store) + if not isinstance(node.ctx, ast.Store): + raise ValueError('In Python, "first store" indicates "declaration"') self.status[node.id] = (node, self.scope_level[-1], set()) else: decl, loop, usage = self.status[node.id] @@ -61,6 +67,7 @@ def visit_Name(self, node): usage.add(type(node.ctx)) self.status[node.id] = (decl, loop, usage) + def determine_variable_usage(root, args): """The helper function for calling the dedicated visitor.""" visitor = PyVariableUsage(args) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 96c3bce3e2f2..f1b267cd1ece 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -16,7 +16,7 @@ def test_outer_product(): a = tvm.placeholder((n, ), name='a') b = tvm.placeholder((m, ), name='b') c = tvm.placeholder((n, m), name='c') - ir, _ = outer_product(n, m, a, b, c) + ir = outer_product(n, m, a, b, c) #Check for i in (0, n) assert isinstance(ir, tvm.stmt.For) assert ir.loop_var.name == 'i' @@ -76,7 +76,7 @@ def fanout(n, a, b): n = tvm.var('n') a = tvm.placeholder((n, ), name='a') b = tvm.placeholder((n-3, ), name='b') - ir, _ = fanout(n, a, b) + ir = fanout(n, a, b) #Check for i in (0, n-3) assert isinstance(ir, tvm.stmt.For) @@ -139,15 +139,11 @@ def failure(): def test_failure(): try: tvm.hybrid.parse(failure, []) - except IOError: + except IOError as err: assert sys.version_info[0] == 2 - lineno = inspect.currentframe().f_back.f_lineno - print('[Warning] Python2 cannot do the failure case @line #%d' % lineno) - except AssertionError: - _, _, tb = sys.exc_info() - _, _, func, text = traceback.extract_tb(tb)[-1] - assert func == 'visit_Assign' - assert text == 'assert lhs not in self.loops_above.keys()' + print('[Warning] Python2 cannot do the failure case because "%s"' % str(err)) + except Exception as err: + assert str(err) == 'You CAN NEVER overwrite a loop variable!' def test_looptype(): @@ -160,7 +156,7 @@ def looptype(a): for k in unroll(6): a[k] = k a = tvm.placeholder((6, ), name='a') - ir, _ = looptype(a) + ir = looptype(a) iloop = ir.first jloop = ir.rest.first kloop = ir.rest.rest @@ -181,7 +177,7 @@ def if_then_else(a, b): a = tvm.placeholder((10, ), dtype='int32', name='a') b = tvm.placeholder((10, ), dtype='int32', name='b') - ir, _ = if_then_else(a, b) + ir = if_then_else(a, b) func = tvm.lower(ir, [a, b]) func = tvm.build(func) assert func @@ -210,7 +206,7 @@ def vec_add(a, b, c): a = tvm.placeholder((1000, ), dtype='float32', name='a') b = tvm.placeholder((1000, ), dtype='float32', name='b') c = tvm.placeholder((1000, ), dtype='float32', name='c') - ir, _ = vec_add(a, b, c) + ir = vec_add(a, b, c) func = tvm.lower(ir, [a, b, c]) func = tvm.build(func, target = 'cuda') @@ -240,7 +236,7 @@ def intrin_real(a): a[5] = tanh(a[5]) a6 = tvm.placeholder((6, ), dtype='float32', name='a') - ir, _ = intrin_real(a6) + ir = intrin_real(a6) func = tvm.build(tvm.lower(ir, [a6])) assert func a = numpy.arange(2, 8).astype('float32') @@ -254,7 +250,7 @@ def intrin_int(a): a[0] = popcount(a[0]) a1 = tvm.placeholder((1, ), dtype='int32') - ir, _ = intrin_int(a1) + ir = intrin_int(a1) func = tvm.build(tvm.lower(ir, [a1])) assert func a = numpy.array([1234567890]).astype('int32') From 955c0be69efed1f60da139e60d8ee723ca951271 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Sun, 17 Jun 2018 21:43:13 -0700 Subject: [PATCH 08/31] fix lint --- python/tvm/hybrid/_util.py | 7 +++---- python/tvm/hybrid/parser.py | 27 +++++++++++++++++---------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/tvm/hybrid/_util.py b/python/tvm/hybrid/_util.py index f96f9685eca8..0e0803827223 100644 --- a/python/tvm/hybrid/_util.py +++ b/python/tvm/hybrid/_util.py @@ -1,6 +1,5 @@ """Internal utilities for parsing Python subset to HalideIR""" -import sys import inspect import numpy from ._intrin import HYBRID_GLOBALS @@ -11,13 +10,13 @@ from ..tensor import Tensor -# If it is a +#pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, _expr.Var) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) -# Useful constants +# Useful constants. In avoid of runtime dependences, we use function calls to return them. def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) @@ -47,7 +46,7 @@ def _is_tvm_arg_types(args): raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem))) return True if not isinstance(args[0], np_arg_types): - raise ValueError("Expect a numpy type but % get!" % str(type(elem))) + raise ValueError("Expect a numpy type but % get!" % str(type(args[0]))) for elem in args[1:]: if not isinstance(elem, np_arg_types): raise ValueError("Expect a numpy type but % get!" % str(type(elem))) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index c6d34a568793..bbfcba8cee9d 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -1,5 +1,5 @@ """Compiling a TVM Hybrid Script Python to HalideIR""" -#pylint: disable=no-else-return + import ast import operator import sys @@ -69,8 +69,8 @@ def __init__(self, args, usage, func_name=None): Returns ------- func_name: str - The name of the function to be lowered; if not provided, - the compiler will use the name in the AST + The name of the function to be lowered; if not provided, + the compiler will use the name in the AST """ self.args = args[:] self.usage = usage.copy() @@ -82,9 +82,6 @@ def __init__(self, args, usage, func_name=None): self.iter_axis = [] - #pylint: disable=invalid-name - #pylint: disable=consider-merging-isinstance, no-else-return - #pylint: disable=inconsistent-return-statements def wrap_up_realize(self, node, body): """Wrap up all the variables which will no longer be used""" for key, val in self.usage.items(): @@ -105,6 +102,7 @@ def _check_id_a_buffer(self, s): raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) + #pylint: disable=invalid-name, missing-docstring def visit_Module(self, node): if len(node.body) != 1: raise ValueError("Only one-function source code can be fed to this parser!") @@ -148,9 +146,11 @@ def visit_Name(self, node): raise ValueError("This id %s is expected to a compilation time constant!" % _id) return self.var_consts[_id] + def visit_Num(self, node): return _api.const(node.n) + def visit_Assign(self, node): if len(node.targets) != 1: raise ValueError("So far only one-valued assignment is supported!") @@ -186,11 +186,13 @@ def visit_Assign(self, node): self._check_id_a_buffer(lhs.name) return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) + def visit_Index(self, node): if isinstance(node.value, ast.Tuple): return [self.visit(i) for i in node.value.elts] return [self.visit(node.value)] + def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): @@ -214,6 +216,7 @@ def visit_Subscript(self, node): else: raise ValueError("Not supported yet!") + def visit_With(self, node): if sys.version_info[0] < 3: context = node.context_expr @@ -230,6 +233,7 @@ def visit_With(self, node): self.annotation[option.id] = context.func.id return list_to_block(self.visit, node.body) + def visit_If(self, node): cond = self.visit(node.test) if_body = list_to_block(self.visit, node.body) @@ -239,12 +243,14 @@ def visit_If(self, node): else_body = make_nop() return _make.IfThenElse(cond, if_body, else_body) + def visit_IfExp(self, node): cond = self.visit(node.test) if_body = self.visit(node.body) else_body = self.visit(node.orelse) return _make.Select(cond, if_body, else_body) + def visit_Compare(self, node): lhs = self.visit(node.left) if len(node.ops) != 1: @@ -254,15 +260,18 @@ def visit_Compare(self, node): rhs = self.visit(node.comparators[0]) return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs) + def visit_UnaryOp(self, node): operand = self.visit(node.operand) return HybridParser._unaryop_maker[type(node.op)](operand) + def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) return HybridParser._binop_maker[type(node.op)](lhs, rhs) + def visit_Call(self, node): # Yet, no function pointer supported if not isinstance(node.func, ast.Name): @@ -296,13 +305,11 @@ def visit_Call(self, node): return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) elif func_id == 'allocate': #TODO: Support it later! - if n == 1: - pass - else: - assert n == 2 + return make_nop() else: raise ValueError("Function call not supported yet!") + def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) if not isinstance(node.target, ast.Name): From c3f4db384f2efa6926ebc8121d401328517c055b Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 12:36:02 -0700 Subject: [PATCH 09/31] commit before rebase --- docs/dev/hybrid_script.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/dev/hybrid_script.md b/docs/dev/hybrid_script.md index 437db6756e5e..d50b54f02fd1 100644 --- a/docs/dev/hybrid_script.md +++ b/docs/dev/hybrid_script.md @@ -1,14 +1,13 @@ # Hybrid Frontend Developer Guide -This hybrid frontend is aimed at: -1. Building IR in a more intuitive way -2. Writing preliminary versions of some idioms that yet have not been supported by +This hybrid frontend is not only aimed at writing preliminary versions of some idioms that yet have +been supported for users. Developers can also use this feature to build IR rapidly. ## Features ### Software emulation -This feature supports both software emulation and compilation of the code. +Both software emulation and compilation are supported. To define a function, you need to use `tvm.hybrid.script` decorator to indicate this is a hybrid function: ````Python @@ -22,7 +21,7 @@ b = numpy.random.rand(99) c = numpy.zeros((100, 99)) outer_product(a, b, c) ```` -This decorator will help you to import [key words](#keywords) required spontaneously when software emulation. +This decorator will import [key words](#keywords) required spontaneously when software emulation. Every element in the argument list is either a python variable or `numpy` tensor. ### Backend Compilation From c1d48d24dec76d392499fbe7f316f1b066af3641 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 13:48:59 -0700 Subject: [PATCH 10/31] preview of rst --- docs/langref/hybrid_script.rst | 142 +++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 docs/langref/hybrid_script.rst diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst new file mode 100644 index 000000000000..11504db80e4c --- /dev/null +++ b/docs/langref/hybrid_script.rst @@ -0,0 +1,142 @@ +Hybrid Frontend Language Reference +---------------------------------- + +Overview +======== + +This hybrid frontend allows users to write preliminary versions of some idioms that yet have +been supported by TVM officially. + +Features +======== + +#. Software emulation + +Both software emulation and compilation are supported. +To define a function, you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: + + +.. code-block:: python + @tvm.hybrid.script + def outer_product(a, b, c): + for i in range(a.shape[0]): + for j in range(b.shape[0]): + c[i, j] = a[i] * b[j] + a = numpy.random.rand(100) + b = numpy.random.rand(99) + c = numpy.zeros((100, 99)) + outer_product(a, b, c) + +This decorator will import [key words](#keywords) required spontaneously when software emulation. +After software emulation is done, the imported keywords will be cleaned up. Users do not need +worry about keyword conflict and pollution. + +Every element passed for software emulation in the argument list is either a python variable +or ``numpy`` numeric type. + +#. Backend Compilation + +The current parse interface looks like: +.. code-block:: python + a = tvm.placeholder((100, ), name='a') + b = tvm.placeholder((99, ), name='b') + c = tvm.placeholder((100, 99), name='c') + tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function + +If we pass these tvm tensors to this function, it returns a op node: +**Under construction, we are still deciding what kind of node should be returned.** +.. code-block:: python + a = tvm.placeholder((100, ), name='a') + b = tvm.placeholder((99, ), name='b') + c = tvm.placeholder((100, 99), name='c') + op = outer_product(a, b, c) # return the corresponding op node + +#. Tuning + +**Under construction, not truly supported yet.** +Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR: +.. code-block:: python + sch = tvm.create_schedule(op) + jo, ji = sch.split(j, 4) + sch.vectorize(ji) + +``split``, ``reorder``, and loop_annotation will be supported! + +#. Loops + +In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. + +Here we use ``range``, ``serial``, ``unroll``, ``parallel``, and ``vectorize``, +these **5** keywords to annotate the types of for loops. The the usage is roughly +the same as Python standard ``range``. + +**NOTE**: In HalideIR those are enums, they are in passive form. + Here we use active form to annotate loops, because they are ready to run. + +**NOTE**: Unlike what that is in HalideIR, in ``loop_type(a, b)``, + ``a`` is the starting point and ``b`` is the trip count of iterations. + Here ``loop_type(a, b)`` indicates ``[a, b)``. + +#. Variables + +All the mutatable variables will be lowered to an array with size 1. +It regards the first store of a variable as its declaration. + +**NOTE**: Unlike conventional Python, in hybrid script, the declared variable + can only be used in the scope level it is declared. + +**NOTE**: Currently, you can ONLY use basic-typed variables, i.e. the type of the + variable should be either ``float32``, or ``int32``. + +.. code-block:: python + for i in range(5): + s = 0 # declaration + for j in range(5): + s += a[i, j] # do something with sum + b[i] = sum # you can still use sum in this level + a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python + b = (1, 2) # this has NOT been supported yet! + +#. Attributes + +So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a +tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. +.. code-block:: python + x = a.shape[2] # OK! + for i in range(3): + for j in a.shape[i]: # BAD! i is not a constant! + # do something + + +#. Conditional Statement and Expression + + +.. code-block:: python + if condition: + # do something + a = b if condition else c + +However, NO ``True`` and ``False`` keyword supported yet. + +#. Math intrinsics + +So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, +``tanh``, ``power``, and ``popcount``, are supported. +No import is required, just as it is mentioned in 1., just use it! + +#. Array allocation +**Under construction, this function will be supported later!** +Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer. +The basic usage is roughly the same as a normal array. + + +#. Thread bind +You can also do loop-thread bind by writing code like this: +.. code-block:: python + for tx in bind("threadIdx.x", 100): + a[tx] = b[tx] + +#. Keywords + - Statement keywords: ``for``, ``in``, ``if``, ``else`` + - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` + - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` From 504968357e053e335e9e7a20be1943867c0fab71 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 13:50:53 -0700 Subject: [PATCH 11/31] preview code block --- docs/langref/hybrid_script.rst | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 11504db80e4c..1aa0a18a8214 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -10,13 +10,11 @@ been supported by TVM officially. Features ======== -#. Software emulation - -Both software emulation and compilation are supported. -To define a function, you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: - +1. Software emulation: Both software emulation and compilation are supported. To define a function, +you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: .. code-block:: python + @tvm.hybrid.script def outer_product(a, b, c): for i in range(a.shape[0]): From 382c751e865c4f77daf04330bdb70e80c2cf4fe9 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 13:53:07 -0700 Subject: [PATCH 12/31] dont be too stingy to put blank lines! --- docs/langref/hybrid_script.rst | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 1aa0a18a8214..a00eb9ac39d4 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -10,7 +10,10 @@ been supported by TVM officially. Features ======== -1. Software emulation: Both software emulation and compilation are supported. To define a function, +Software emulation +================== + +Both software emulation and compilation are supported. To define a function, you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: .. code-block:: python @@ -32,18 +35,23 @@ worry about keyword conflict and pollution. Every element passed for software emulation in the argument list is either a python variable or ``numpy`` numeric type. -#. Backend Compilation +2. Backend Compilation The current parse interface looks like: + .. code-block:: python + a = tvm.placeholder((100, ), name='a') b = tvm.placeholder((99, ), name='b') c = tvm.placeholder((100, 99), name='c') tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function If we pass these tvm tensors to this function, it returns a op node: + **Under construction, we are still deciding what kind of node should be returned.** + .. code-block:: python + a = tvm.placeholder((100, ), name='a') b = tvm.placeholder((99, ), name='b') c = tvm.placeholder((100, 99), name='c') @@ -52,8 +60,11 @@ If we pass these tvm tensors to this function, it returns a op node: #. Tuning **Under construction, not truly supported yet.** + Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR: + .. code-block:: python + sch = tvm.create_schedule(op) jo, ji = sch.split(j, 4) sch.vectorize(ji) @@ -87,6 +98,7 @@ It regards the first store of a variable as its declaration. variable should be either ``float32``, or ``int32``. .. code-block:: python + for i in range(5): s = 0 # declaration for j in range(5): @@ -99,7 +111,9 @@ It regards the first store of a variable as its declaration. So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. + .. code-block:: python + x = a.shape[2] # OK! for i in range(3): for j in a.shape[i]: # BAD! i is not a constant! @@ -110,6 +124,7 @@ tuple, so you MUST access it as an array. Also, currently, only constant-indexed .. code-block:: python + if condition: # do something a = b if condition else c @@ -123,14 +138,18 @@ So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, No import is required, just as it is mentioned in 1., just use it! #. Array allocation + **Under construction, this function will be supported later!** + Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer. The basic usage is roughly the same as a normal array. #. Thread bind You can also do loop-thread bind by writing code like this: + .. code-block:: python + for tx in bind("threadIdx.x", 100): a[tx] = b[tx] From bfadeb3a02f282c8e81c28118df5d47b232cffd9 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 13:54:51 -0700 Subject: [PATCH 13/31] preview subsubsections --- docs/langref/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index a00eb9ac39d4..3bf6f9e87651 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -11,7 +11,7 @@ Features ======== Software emulation -================== +^^^^^^^^^^^^^^^^^^ Both software emulation and compilation are supported. To define a function, you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: From 7d6bb8d4c94b8419427fcc6611f8e02233264beb Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 13:59:21 -0700 Subject: [PATCH 14/31] keywords label? --- docs/langref/hybrid_script.rst | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 3bf6f9e87651..4417a26cece6 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -28,14 +28,15 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun c = numpy.zeros((100, 99)) outer_product(a, b, c) -This decorator will import [key words](#keywords) required spontaneously when software emulation. +This decorator will import :ref:`keywords-label` required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need worry about keyword conflict and pollution. Every element passed for software emulation in the argument list is either a python variable or ``numpy`` numeric type. -2. Backend Compilation +Backend Compilation +^^^^^^^^^^^^^^^^^^^ The current parse interface looks like: @@ -57,11 +58,12 @@ If we pass these tvm tensors to this function, it returns a op node: c = tvm.placeholder((100, 99), name='c') op = outer_product(a, b, c) # return the corresponding op node -#. Tuning +Tuning +^^^^^^ **Under construction, not truly supported yet.** -Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR: +Follow up the example above, you can use some tvm like interfaces to tune the code: .. code-block:: python @@ -71,7 +73,8 @@ Follow up the example above, you can use some tvm like interfaces to manipulate ``split``, ``reorder``, and loop_annotation will be supported! -#. Loops +Loops +^^^^^ In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. @@ -86,7 +89,8 @@ the same as Python standard ``range``. ``a`` is the starting point and ``b`` is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. -#. Variables +Variables +^^^^^^^^^ All the mutatable variables will be lowered to an array with size 1. It regards the first store of a variable as its declaration. @@ -100,14 +104,15 @@ It regards the first store of a variable as its declaration. .. code-block:: python for i in range(5): - s = 0 # declaration + s = 0 # declaration, this s will be a 1-array in lowered IR for j in range(5): s += a[i, j] # do something with sum b[i] = sum # you can still use sum in this level a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python b = (1, 2) # this has NOT been supported yet! -#. Attributes +Attributes +^^^^^^^^^^ So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. @@ -120,8 +125,8 @@ tuple, so you MUST access it as an array. Also, currently, only constant-indexed # do something -#. Conditional Statement and Expression - +Conditional Statement and Expression +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code-block:: python @@ -131,13 +136,16 @@ tuple, so you MUST access it as an array. Also, currently, only constant-indexed However, NO ``True`` and ``False`` keyword supported yet. -#. Math intrinsics + +Math Intrinsics +^^^^^^^^^^^^^^^ So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. No import is required, just as it is mentioned in 1., just use it! -#. Array allocation +Array Allocation +^^^^^^^^^^^^^^^^ **Under construction, this function will be supported later!** @@ -145,7 +153,8 @@ Use a function call ``allocation(shape, type, share/local)`` to declare an array The basic usage is roughly the same as a normal array. -#. Thread bind +Thread Bind +^^^^^^^^^^^ You can also do loop-thread bind by writing code like this: .. code-block:: python @@ -153,7 +162,9 @@ You can also do loop-thread bind by writing code like this: for tx in bind("threadIdx.x", 100): a[tx] = b[tx] -#. Keywords +.. _keywords-label: +Keywords +-------- - Statement keywords: ``for``, ``in``, ``if``, ``else`` - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` From d2a01ff023ad1289ce24dcbcc6b79b4afab0e40c Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 14:01:03 -0700 Subject: [PATCH 15/31] keywords label?? --- docs/langref/hybrid_script.rst | 4 +++- docs/langref/index.rst | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 4417a26cece6..c3f81b805a94 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -28,7 +28,7 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun c = numpy.zeros((100, 99)) outer_product(a, b, c) -This decorator will import :ref:`keywords-label` required spontaneously when software emulation. +This decorator will import :ref:`keywords ` required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need worry about keyword conflict and pollution. @@ -162,7 +162,9 @@ You can also do loop-thread bind by writing code like this: for tx in bind("threadIdx.x", 100): a[tx] = b[tx] + .. _keywords-label: + Keywords -------- - Statement keywords: ``for``, ``in``, ``if``, ``else`` diff --git a/docs/langref/index.rst b/docs/langref/index.rst index dc51c3172c57..65f78d1d278b 100644 --- a/docs/langref/index.rst +++ b/docs/langref/index.rst @@ -2,3 +2,8 @@ Language Reference ================== This document provide references to embedded languages in TVM stack. + +.. toctree:: + :maxdepth: 2 + + hybrid_script From 7b76c87632ed65e2728cd0ca8527684a82bad872 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 14:03:07 -0700 Subject: [PATCH 16/31] anchor ref! --- docs/langref/hybrid_script.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index c3f81b805a94..51e13f1acbc7 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -28,7 +28,7 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun c = numpy.zeros((100, 99)) outer_product(a, b, c) -This decorator will import :ref:`keywords ` required spontaneously when software emulation. +This decorator will import `Keywords`_ required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need worry about keyword conflict and pollution. @@ -167,6 +167,5 @@ You can also do loop-thread bind by writing code like this: Keywords -------- - - Statement keywords: ``for``, ``in``, ``if``, ``else`` - - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` - - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` +- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` +- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` From f72eb4287bfd186c89afc8c87f2ac25a06955c4a Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 14:04:32 -0700 Subject: [PATCH 17/31] no label required? --- docs/langref/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 51e13f1acbc7..b500bca7571f 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -142,7 +142,7 @@ Math Intrinsics So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. -No import is required, just as it is mentioned in 1., just use it! +No import is required, just as it is mentioned in `Software Emulation`_, just use it! Array Allocation ^^^^^^^^^^^^^^^^ From 7bb75a92a9f2fc0856f0c2a3d99104b11deac87b Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Mon, 18 Jun 2018 14:05:02 -0700 Subject: [PATCH 18/31] no anchor! --- docs/langref/hybrid_script.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index b500bca7571f..3b8617a3615a 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -163,8 +163,6 @@ You can also do loop-thread bind by writing code like this: a[tx] = b[tx] -.. _keywords-label: - Keywords -------- - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` From a8b85286f7d4fa1a31d853d77c9d6e52f7cc082b Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Tue, 19 Jun 2018 11:33:30 -0700 Subject: [PATCH 19/31] inline ZERO to fit online tests? --- python/tvm/hybrid/parser.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index bbfcba8cee9d..74170f73e32c 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -279,14 +279,13 @@ def visit_Call(self, node): func_id = node.func.id n = len(node.args) if func_id in LOOP_INTRIN.keys() and func_id != 'bind': - ZERO = _api.const(0, dtype='int32') if n == 1: - low, ext = ZERO, self.visit(node.args[0]) + low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) else: if n != 2: raise ValueError("A loop intrinsic should only have 1 or 2 arguments!") low, ext = self.visit(node.args[0]), self.visit(node.args[1]) - if not _ir_pass.Equal(low, ZERO): + if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): ext = ext - low for_type = LOOP_INTRIN[func_id] iter_var = None @@ -298,7 +297,7 @@ def visit_Call(self, node): raise ValueError("A loop bind's first argument should be a string!") _vn = node.args[0].s iter_var = thread_axis(node.args[0].s) - low, ext = ZERO, self.visit(node.args[1]) + low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) for_type = None return iter_var, low, ext, for_type elif func_id in MATH_INTRIN: From 09dd99481b643c497f644f3065f7ac11920880dc Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Tue, 19 Jun 2018 11:53:45 -0700 Subject: [PATCH 20/31] hybrid script dev guide add! --- docs/dev/hybrid_script.md | 110 ---------------------------------- docs/dev/hybrid_script.rst | 62 +++++++++++++++++++ docs/dev/index.rst | 1 + python/tvm/hybrid/var_decl.py | 1 + 4 files changed, 64 insertions(+), 110 deletions(-) delete mode 100644 docs/dev/hybrid_script.md create mode 100644 docs/dev/hybrid_script.rst diff --git a/docs/dev/hybrid_script.md b/docs/dev/hybrid_script.md deleted file mode 100644 index d50b54f02fd1..000000000000 --- a/docs/dev/hybrid_script.md +++ /dev/null @@ -1,110 +0,0 @@ -# Hybrid Frontend Developer Guide - -This hybrid frontend is not only aimed at writing preliminary versions of some idioms that yet have -been supported for users. Developers can also use this feature to build IR rapidly. - -## Features - -### Software emulation - -Both software emulation and compilation are supported. - -To define a function, you need to use `tvm.hybrid.script` decorator to indicate this is a hybrid function: -````Python -@tvm.hybrid.script -def outer_product(a, b, c): - for i in range(a.shape[0]): - for j in range(b.shape[0]): - c[i, j] = a[i] * b[j] -a = numpy.random.rand(100) -b = numpy.random.rand(99) -c = numpy.zeros((100, 99)) -outer_product(a, b, c) -```` -This decorator will import [key words](#keywords) required spontaneously when software emulation. -Every element in the argument list is either a python variable or `numpy` tensor. - -### Backend Compilation - -The current parse interface looks like: -````Python -a = tvm.placeholder((100, ), name='a') -b = tvm.placeholder((99, ), name='b') -c = tvm.placeholder((100, 99), name='c') -tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function -```` -If we pass these tvm tensors to this function, it returns a op node: -````Python -a = tvm.placeholder((100, ), name='a') -b = tvm.placeholder((99, ), name='b') -c = tvm.placeholder((100, 99), name='c') -op = outer_product(a, b, c) # return the corresponding op node -```` -**This function is still under construction** - -#### Scheduling - -**Under construction, not truly supported yet.** - -Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR: -````Python -sch = tvm.create_schedule(op) -jo, ji = sch.split(j, 4) -sch.vectorize(ji) -```` -`split`, `reorder`, and loop_annotation will be supported! - -### Attributes -So far, ONLY tensors' `shape` attribute is supported! - -### Loops - -In HalideIR, loops have in total 4 types: `serial`, `unrolled`, `parallel`, and `vectorized`. - -Here we use `range`, `serial`, `unroll`, `parallel`, and `vectorize`, these **5** keywords to annotate the types of for loops. - -**NOTE**: In HalideIR those are enums, they are in passive form. Here we use active form to annotate loops, because they are ready to run. - -**NOTE**: Unlike what that is in HalideIR, in `loop_type(a, b)`, `a` is the starting point and `b` is the trip count of iterations. Here `loop_type(a, b)` indicates `[a, b)`. - -### Variables - -Because there is no variables in `HalideIR`, all the mutatable variables will be lowered to an array with size 1. -It takes the first store of a variable as its declaration. -**NOTE**: Unlike conventional Python, the declared array can only be used in the scope level it is declared. -````Python -for i in range(5): - s = 0 - for j in range(5): - s += a[i, j] #do something with sum - b[i] = sum #you can still use sum in this level -#you can NEVER use some here, even though it is allowed in conventional Python -a[0] = s -```` -### Conditional Statement and Expression - -````Python -if condition: - # do something -a = b if condition else c -```` -However, NO `True` and `False` keyword supported yet. - -### Math intrinsics -So far, these math intrinsics, `log`, `exp`, `sigmoid`, `tanh`, `power`, and `popcount`, are supported. No import is required, just use it! -### Array allocation -Use a function call `allocation(shape, type, share/local)` to declare an array buffer. The basic usage is roughly the same as variables. - -**This function is still under construction.** -### Thread bind -You can also do loop-thread bind by writing code like this: -````Python -for tx in bind("threadIdx.x", 100): - a[tx] = b[tx] -```` -## Appendix - -### Keywords -- Statement keywords: `for`, `in`, `if`, `else` -- For keywords: `serial`, `range`, `unroll`, `parallel`, `vectorize`, `bind` -- Math keywords: `log`, `exp`, `sigmoid`, `tanh`, `power`, `popcount` diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst new file mode 100644 index 000000000000..d78f70834ee8 --- /dev/null +++ b/docs/dev/hybrid_script.rst @@ -0,0 +1,62 @@ +Hybrid Frontend Developer Guide +------------------------------- + +If you are a developer: +1. who is trying writing some preliminary patterns that have not been supported by TVM yet, + maybe ``lang_ref/hybrid_script.rst`` is a better place for you. +2. who wants to know the implementing details of this module, you are right here! + +Features +-------- + +Software emulation +^^^^^^^^^^^^^^^^^^ + +In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``. +This decorator helps 2 things: +1. Importing runtime variables +2. Overload the function according to the arguments passed + +Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no +choice. What I did is add those names into python dict ``func.__global__`` and after +the call to ``func`` is done, those names will be cleaned up. + +Overload is simple: the decorator checks the arguments' types and determines which function +should be actually called. + + +Backend Compilation +^^^^^^^^^^^^^^^^^^^ + +Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py`` and +``python/tvm/hybrid/parser.py`` for more details. The first stage determines the +usage, or more accurately the declaration of each variable and the second stage does +the actual IR generation. + +Attributes +^^^^^^^^^^ + +So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript`` +in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution, I just +check the attributes when subscript. + +Loops +^^^^^ + +In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. + +**NOTE**: Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` + is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it +to HalideIR, we need to do ``start, extent = a, b - a`` + +Variables +^^^^^^^^^ + +Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1. +It takes the first store of a variable as its declaration. + + +### Math intrinsics +So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. +Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation +except ``popcount`` and ``sigmoid``. I implemented them manually. diff --git a/docs/dev/index.rst b/docs/dev/index.rst index 3fb052938689..f3ab322bfe53 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -10,3 +10,4 @@ In this part of documentation, we share the rationale for the specific choices m runtime nnvm_json_spec nnvm_overview + hybrid_script diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index fa7daa5d4b45..d4717544d652 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -1,4 +1,5 @@ """Determines the declaration, r/w status, and last use of each variable""" + import ast import sys from ._intrin import HYBRID_GLOBALS From 9c3c545ed340c9fef8a3c5225f50b6550f8bb8ae Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Tue, 19 Jun 2018 11:55:27 -0700 Subject: [PATCH 21/31] I think I need a rst IDE --- docs/dev/hybrid_script.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index d78f70834ee8..f17e77c0250f 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -2,8 +2,10 @@ Hybrid Frontend Developer Guide ------------------------------- If you are a developer: + 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, - maybe ``lang_ref/hybrid_script.rst`` is a better place for you. +maybe ``lang_ref/hybrid_script.rst`` is a better place for you. + 2. who wants to know the implementing details of this module, you are right here! Features @@ -14,7 +16,9 @@ Software emulation In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``. This decorator helps 2 things: + 1. Importing runtime variables + 2. Overload the function according to the arguments passed Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no @@ -46,7 +50,7 @@ Loops In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. **NOTE**: Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` - is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it +is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it to HalideIR, we need to do ``start, extent = a, b - a`` Variables @@ -55,8 +59,8 @@ Variables Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1. It takes the first store of a variable as its declaration. - -### Math intrinsics +Math intrinsics +^^^^^^^^^^^^^^^ So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation except ``popcount`` and ``sigmoid``. I implemented them manually. From 742ffd33f8b4afc1254cbeaab62592a2f355005f Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:15:53 -0700 Subject: [PATCH 22/31] update the docs --- docs/dev/hybrid_script.rst | 28 ++++++++++++++--------- docs/langref/hybrid_script.rst | 42 +++++++++++++++------------------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index f17e77c0250f..7324224194a7 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -1,10 +1,10 @@ Hybrid Frontend Developer Guide -------------------------------- +=============================== If you are a developer: 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe ``lang_ref/hybrid_script.rst`` is a better place for you. +maybe `../lang_ref/hybrid_script.rst`_ is a better place for you. 2. who wants to know the implementing details of this module, you are right here! @@ -12,7 +12,7 @@ Features -------- Software emulation -^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~ In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``. This decorator helps 2 things: @@ -30,7 +30,7 @@ should be actually called. Backend Compilation -^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~ Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py`` and ``python/tvm/hybrid/parser.py`` for more details. The first stage determines the @@ -38,29 +38,35 @@ usage, or more accurately the declaration of each variable and the second stage the actual IR generation. Attributes -^^^^^^^^^^ +~~~~~~~~~~ So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript`` in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution, I just check the attributes when subscript. Loops -^^^^^ +~~~~~ In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. -**NOTE**: Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` -is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it -to HalideIR, we need to do ``start, extent = a, b - a`` + + Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` + is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it + to HalideIR, we need to do ``start, extent = a, b - a`` + + + In HalideIR those are enums, they are in passive form. + Here we use active form to annotate loops, because they are ready to run. + Variables -^^^^^^^^^ +~~~~~~~~~ Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1. It takes the first store of a variable as its declaration. Math intrinsics -^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~ So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation except ``popcount`` and ``sigmoid``. I implemented them manually. diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 3b8617a3615a..2488d8734616 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -10,8 +10,8 @@ been supported by TVM officially. Features ======== -Software emulation -^^^^^^^^^^^^^^^^^^ +Software Emulation +~~~~~~~~~~~~~~~~~~ Both software emulation and compilation are supported. To define a function, you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function: @@ -28,7 +28,7 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun c = numpy.zeros((100, 99)) outer_product(a, b, c) -This decorator will import `Keywords`_ required spontaneously when software emulation. +This decorator will import `keywords `_ required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need worry about keyword conflict and pollution. @@ -36,7 +36,7 @@ Every element passed for software emulation in the argument list is either a pyt or ``numpy`` numeric type. Backend Compilation -^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~ The current parse interface looks like: @@ -59,7 +59,7 @@ If we pass these tvm tensors to this function, it returns a op node: op = outer_product(a, b, c) # return the corresponding op node Tuning -^^^^^^ +~~~~~~ **Under construction, not truly supported yet.** @@ -74,23 +74,16 @@ Follow up the example above, you can use some tvm like interfaces to tune the co ``split``, ``reorder``, and loop_annotation will be supported! Loops -^^^^^ +~~~~~ In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. -Here we use ``range``, ``serial``, ``unroll``, ``parallel``, and ``vectorize``, -these **5** keywords to annotate the types of for loops. The the usage is roughly -the same as Python standard ``range``. - -**NOTE**: In HalideIR those are enums, they are in passive form. - Here we use active form to annotate loops, because they are ready to run. - -**NOTE**: Unlike what that is in HalideIR, in ``loop_type(a, b)``, - ``a`` is the starting point and ``b`` is the trip count of iterations. - Here ``loop_type(a, b)`` indicates ``[a, b)``. +Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize``, +these **4** keywords to annotate the corresponding types of for loops. +The the usage is roughly the same as Python standard ``range``. Variables -^^^^^^^^^ +~~~~~~~~~ All the mutatable variables will be lowered to an array with size 1. It regards the first store of a variable as its declaration. @@ -111,8 +104,9 @@ It regards the first store of a variable as its declaration. a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python b = (1, 2) # this has NOT been supported yet! + Attributes -^^^^^^^^^^ +~~~~~~~~~~ So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. @@ -126,7 +120,7 @@ tuple, so you MUST access it as an array. Also, currently, only constant-indexed Conditional Statement and Expression -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code-block:: python @@ -138,14 +132,14 @@ However, NO ``True`` and ``False`` keyword supported yet. Math Intrinsics -^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~ So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported. No import is required, just as it is mentioned in `Software Emulation`_, just use it! Array Allocation -^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~ **Under construction, this function will be supported later!** @@ -154,7 +148,9 @@ The basic usage is roughly the same as a normal array. Thread Bind -^^^^^^^^^^^ +~~~~~~~~~~~ + + You can also do loop-thread bind by writing code like this: .. code-block:: python @@ -164,6 +160,6 @@ You can also do loop-thread bind by writing code like this: Keywords --------- +~~~~~~~~ - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` From f40cf575a45fa5dad66329398a50a15c73af06d5 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:18:52 -0700 Subject: [PATCH 23/31] note! --- docs/dev/hybrid_script.rst | 4 ++++ docs/langref/hybrid_script.rst | 13 +++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index 7324224194a7..5eaff04ea405 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -50,11 +50,15 @@ Loops In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``. +.. note:: + Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b`` is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it to HalideIR, we need to do ``start, extent = a, b - a`` +.. note:: + In HalideIR those are enums, they are in passive form. Here we use active form to annotate loops, because they are ready to run. diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 2488d8734616..8e7ba0bdde6e 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -88,11 +88,16 @@ Variables All the mutatable variables will be lowered to an array with size 1. It regards the first store of a variable as its declaration. -**NOTE**: Unlike conventional Python, in hybrid script, the declared variable - can only be used in the scope level it is declared. +.. note:: -**NOTE**: Currently, you can ONLY use basic-typed variables, i.e. the type of the - variable should be either ``float32``, or ``int32``. + Unlike conventional Python, in hybrid script, the declared variable + can only be used in the scope level it is declared. + + +.. note:: + + Currently, you can ONLY use basic-typed variables, i.e. the type of the + variable should be either ``float32``, or ``int32``. .. code-block:: python From 799f1d34f2b8a1589e1d51a4a6ee8bdd727764ce Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:22:41 -0700 Subject: [PATCH 24/31] link? --- docs/dev/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index 5eaff04ea405..28aec8bf4a0b 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -4,7 +4,7 @@ Hybrid Frontend Developer Guide If you are a developer: 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe `../lang_ref/hybrid_script.rst`_ is a better place for you. +maybe `language ref <../lang_ref/hybrid_script>`_ is a better place for you. 2. who wants to know the implementing details of this module, you are right here! From 91969477b6c40c15ced2fa7dd5e707f6c714346a Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:24:09 -0700 Subject: [PATCH 25/31] link! --- docs/dev/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index 28aec8bf4a0b..b8514e8cbb9b 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -4,7 +4,7 @@ Hybrid Frontend Developer Guide If you are a developer: 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe `language ref <../lang_ref/hybrid_script>`_ is a better place for you. +maybe `language ref <../lang_ref/hybrid_script.rst>`_ is a better place for you. 2. who wants to know the implementing details of this module, you are right here! From 42d5615662c677b0b807e411ddfdcad04bc592fe Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:25:03 -0700 Subject: [PATCH 26/31] no underscore! --- docs/dev/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index b8514e8cbb9b..370a61bcab2b 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -4,7 +4,7 @@ Hybrid Frontend Developer Guide If you are a developer: 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe `language ref <../lang_ref/hybrid_script.rst>`_ is a better place for you. +maybe `language ref <../langref/hybrid_script.rst>`_ is a better place for you. 2. who wants to know the implementing details of this module, you are right here! From d66818802617dd7a4d91be47c2a44a0939609b5b Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 12:27:01 -0700 Subject: [PATCH 27/31] inner anchor --- docs/langref/hybrid_script.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 8e7ba0bdde6e..2d3d1d2be48a 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -28,7 +28,7 @@ you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid fun c = numpy.zeros((100, 99)) outer_product(a, b, c) -This decorator will import `keywords `_ required spontaneously when software emulation. +This decorator will import `Keywords`_ required spontaneously when software emulation. After software emulation is done, the imported keywords will be cleaned up. Users do not need worry about keyword conflict and pollution. From 41cc88c4a598185330505cfb365e2536df4e542d Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 14:10:08 -0700 Subject: [PATCH 28/31] follow up tianqis suggestion --- docs/api/python/hybrid.rst | 15 +++++++ docs/api/python/index.rst | 1 + docs/langref/hybrid_script.rst | 6 +-- python/tvm/hybrid/_intrin.py | 47 ++++++++++++++++++--- python/tvm/hybrid/_util.py | 3 ++ python/tvm/hybrid/api.py | 3 +- python/tvm/hybrid/parser.py | 2 +- tests/python/unittest/test_hybrid_script.py | 10 ++--- 8 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 docs/api/python/hybrid.rst diff --git a/docs/api/python/hybrid.rst b/docs/api/python/hybrid.rst new file mode 100644 index 000000000000..3b4c598d82dd --- /dev/null +++ b/docs/api/python/hybrid.rst @@ -0,0 +1,15 @@ +tvm.hybrid +---------- +.. automodule:: tvm.hybrid + +.. autosummary:: + + tvm.hybrid.parse + tvm.hybrid.script + tvm.hybrid.popcount + tvm.hybrid.sigmoid + +.. autofunction:: tvm.hybrid.parse +.. autofunction:: tvm.hybrid.script +.. autofunction:: tvm.hybrid.popcount +.. autofunction:: tvm.hybrid.sigmoid diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index a6bed557dd3b..bab29b82f473 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -21,3 +21,4 @@ Python API dev topi nnvm/index + hybrid diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 2d3d1d2be48a..fc2e3c23e45b 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -1,14 +1,14 @@ Hybrid Frontend Language Reference ----------------------------------- +================================== Overview -======== +-------- This hybrid frontend allows users to write preliminary versions of some idioms that yet have been supported by TVM officially. Features -======== +-------- Software Emulation ~~~~~~~~~~~~~~~~~~ diff --git a/python/tvm/hybrid/_intrin.py b/python/tvm/hybrid/_intrin.py index c386994e98fa..4fb40a9e3f92 100644 --- a/python/tvm/hybrid/_intrin.py +++ b/python/tvm/hybrid/_intrin.py @@ -29,14 +29,38 @@ def __init__(self, tag, ext): serial = unroll = vectorize = parallel = _range #pylint: disable=invalid-name -def allocate(shape, dtype=None): - """Allocate a buffer with given shape""" - dtype = 'float32' if dtype is None else dtype +def allocate(shape, dtype='float32'): + """Allocate a buffer with given shape + + Parameters + ---------- + shape: Tuple + The shape of the tensor to be allocated + dtype: string + The data type of the tensor + + Returns + ------- + tensor: numpy.array + The tensor allocated + """ return numpy.zeros(shape).astype(dtype) def popcount(x): - """Software emulated popcount function which counts 1's in a number's binary representation.""" + """ + Count ones in the binary representation of number x + + Parameters + ---------- + x: Integer + The number to be counted + + Returns + ------- + cnt: Integer + The number of ones in the binary representation of number x + """ cnt = 0 while x: x -= x & -x @@ -45,12 +69,22 @@ def popcount(x): def sigmoid(x): - """Software emulated sigmoid function, which returns 1/(1+exp(-x)).""" + """ + Sigmoid function of x, aka 1/(1+exp(-x)). + + Parameters + ---------- + x: a real number + + Returns + ------- + res: a real number + The result of sigmoid function + """ return 1 / (1 + numpy.exp(-x)) HYBRID_GLOBALS = { - 'serial' : serial, 'unroll' : unroll, 'vectorize' : vectorize, 'parallel' : parallel, @@ -68,7 +102,6 @@ def sigmoid(x): LOOP_INTRIN = { 'range' : For.Serial, - 'serial' : For.Serial, 'unroll' : For.Unrolled, 'parallel' : For.Parallel, 'vectorize': For.Vectorized, diff --git a/python/tvm/hybrid/_util.py b/python/tvm/hybrid/_util.py index 0e0803827223..3aa315205212 100644 --- a/python/tvm/hybrid/_util.py +++ b/python/tvm/hybrid/_util.py @@ -18,14 +18,17 @@ # Useful constants. In avoid of runtime dependences, we use function calls to return them. def make_nop(): + """Returns a 'no operation' node in HalideIR.""" return _make.Evaluate(_api.const(0, dtype='int32')) def make_range_one(): + """Returns a [0, 1] range node in HalideIR.""" return _make.range_by_min_extent(0, 1) def make_const_true(): + """Returns a constant True node in HalideIR.""" return _api.convert(True) diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py index cc0a324b0785..4d8774bc339e 100644 --- a/python/tvm/hybrid/api.py +++ b/python/tvm/hybrid/api.py @@ -34,9 +34,8 @@ def parse(func, args): Returns ------- - (halide_ir, parser) : (Stmt, PyAST2HalideIR) + root : Stmt The result Halide IR and the parser class instance. - TODO: Later we deprecate this return value, use a dedicated OP node type instead """ if isinstance(func, str): src = func diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 74170f73e32c..dbfe798c680d 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -1,4 +1,4 @@ -"""Compiling a TVM Hybrid Script Python to HalideIR""" +"""Hybrid Script Parser""" import ast import operator diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index f1b267cd1ece..93667f2752ea 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -4,7 +4,7 @@ @script def outer_product(n, m, a, b, c): - for i in serial(n): + for i in range(n): for j in range(m): c[i, j] = a[i] * b[j] @@ -66,9 +66,9 @@ def test_fanout(): @script def fanout(n, a, b): three = 3.0 - for i in serial(a.shape[0] - 3): + for i in range(a.shape[0] - 3): sigma = 0.0 - for j in serial(3): + for j in range(3): sigma = sigma + a[i + j] sigma = sigma / three b[i] = sigma @@ -133,7 +133,7 @@ def fanout(n, a, b): @script def failure(): - for i in serial(1, 100): + for i in range(1, 100): i = 0 def test_failure(): @@ -167,7 +167,7 @@ def looptype(a): def test_if(): @script def if_then_else(a, b): - for i in serial(10): + for i in range(10): if i % 2 == 0: a[i] = -1 else: From 81b0b325e53cebee13cf688d01439c44b3c5bcd5 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 17:16:24 -0700 Subject: [PATCH 29/31] erase serail --- python/tvm/hybrid/_intrin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/hybrid/_intrin.py b/python/tvm/hybrid/_intrin.py index 4fb40a9e3f92..a748e265f61a 100644 --- a/python/tvm/hybrid/_intrin.py +++ b/python/tvm/hybrid/_intrin.py @@ -26,7 +26,7 @@ def __init__(self, tag, ext): self.tag = tag -serial = unroll = vectorize = parallel = _range #pylint: disable=invalid-name +unroll = vectorize = parallel = _range #pylint: disable=invalid-name def allocate(shape, dtype='float32'): From 3e6e39c496538a565cb6d75ac5543e96904e68f4 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 17:20:55 -0700 Subject: [PATCH 30/31] use :ref: --- docs/dev/hybrid_script.rst | 2 +- docs/langref/hybrid_script.rst | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst index 370a61bcab2b..0af02a56e72c 100644 --- a/docs/dev/hybrid_script.rst +++ b/docs/dev/hybrid_script.rst @@ -4,7 +4,7 @@ Hybrid Frontend Developer Guide If you are a developer: 1. who is trying writing some preliminary patterns that have not been supported by TVM yet, -maybe `language ref <../langref/hybrid_script.rst>`_ is a better place for you. +maybe :ref:`hybrid-langref-label` is a better place for you. 2. who wants to know the implementing details of this module, you are right here! diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index fc2e3c23e45b..fdaed2b5be40 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -1,3 +1,5 @@ +.. _hybrid-langref-label: + Hybrid Frontend Language Reference ================================== From 45a2d4cc2e74a5ceb02e46c8c681da04e14060ef Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 21 Jun 2018 22:08:02 -0700 Subject: [PATCH 31/31] move _xx to xx; fix lint --- python/tvm/hybrid/api.py | 3 ++- python/tvm/hybrid/{_intrin.py => intrin.py} | 2 +- python/tvm/hybrid/parser.py | 4 ++-- python/tvm/hybrid/{_util.py => util.py} | 2 +- python/tvm/hybrid/var_decl.py | 2 +- tests/python/unittest/test_hybrid_script.py | 2 +- 6 files changed, 8 insertions(+), 7 deletions(-) rename python/tvm/hybrid/{_intrin.py => intrin.py} (99%) rename python/tvm/hybrid/{_util.py => util.py} (98%) diff --git a/python/tvm/hybrid/api.py b/python/tvm/hybrid/api.py index 4d8774bc339e..bc5376509522 100644 --- a/python/tvm/hybrid/api.py +++ b/python/tvm/hybrid/api.py @@ -4,12 +4,12 @@ import types import decorator from .parser import parse_python -from ._util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types, _pruned_source @decorator.decorator def script(func, *args): """If the arguments are tvm types, compile it to HalideIR. O.W. return the python emulated result""" + from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types if _is_tvm_arg_types(args): return parse(func, args) else: @@ -37,6 +37,7 @@ def parse(func, args): root : Stmt The result Halide IR and the parser class instance. """ + from .util import _pruned_source if isinstance(func, str): src = func else: diff --git a/python/tvm/hybrid/_intrin.py b/python/tvm/hybrid/intrin.py similarity index 99% rename from python/tvm/hybrid/_intrin.py rename to python/tvm/hybrid/intrin.py index a748e265f61a..93517fef4d1d 100644 --- a/python/tvm/hybrid/_intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -71,7 +71,7 @@ def popcount(x): def sigmoid(x): """ Sigmoid function of x, aka 1/(1+exp(-x)). - + Parameters ---------- x: a real number diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index dbfe798c680d..7d4c40e8c7e9 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,8 +3,8 @@ import ast import operator import sys -from ._util import make_nop, make_const_true, make_range_one, halide_imm_types -from ._intrin import LOOP_INTRIN, MATH_INTRIN +from .util import make_nop, make_const_true, make_range_one, halide_imm_types +from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis from .. import expr as _expr diff --git a/python/tvm/hybrid/_util.py b/python/tvm/hybrid/util.py similarity index 98% rename from python/tvm/hybrid/_util.py rename to python/tvm/hybrid/util.py index 3aa315205212..8a5f4a62768d 100644 --- a/python/tvm/hybrid/_util.py +++ b/python/tvm/hybrid/util.py @@ -2,7 +2,7 @@ import inspect import numpy -from ._intrin import HYBRID_GLOBALS +from .intrin import HYBRID_GLOBALS from .._ffi.base import numeric_types from .. import api as _api from .. import make as _make diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index d4717544d652..940b8c088df3 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -2,7 +2,7 @@ import ast import sys -from ._intrin import HYBRID_GLOBALS +from .intrin import HYBRID_GLOBALS class PyVariableUsage(ast.NodeVisitor): diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 93667f2752ea..fda4f52c1f19 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -1,6 +1,6 @@ import tvm, inspect, sys, traceback, numpy from tvm.hybrid import script -from tvm.hybrid._intrin import HYBRID_GLOBALS +from tvm.hybrid.intrin import HYBRID_GLOBALS @script def outer_product(n, m, a, b, c):