diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 539115d8b6f4..ba10dd8dde3c 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -35,20 +35,21 @@ class HybridParser(ast.NodeVisitor): _binop_maker = { - ast.Add : operator.add, - ast.Sub : operator.sub, - ast.Mult : operator.mul, - ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, - ast.Mod : operator.mod, - ast.BitOr : operator.or_, - ast.BitAnd: operator.and_, - ast.BitXor: operator.xor, - ast.Gt : operator.gt, - ast.GtE : operator.ge, - ast.Lt : operator.lt, - ast.LtE : operator.le, - ast.Eq : operator.eq, - ast.NotEq : operator.ne, + ast.Add : operator.add, + ast.Sub : operator.sub, + ast.Mult : operator.mul, + ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, + ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv, + ast.Mod : operator.mod, + ast.BitOr : operator.or_, + ast.BitAnd : operator.and_, + ast.BitXor : operator.xor, + ast.Gt : operator.gt, + ast.GtE : operator.ge, + ast.Lt : operator.lt, + ast.LtE : operator.le, + ast.Eq : operator.eq, + ast.NotEq : operator.ne, ast.And : _all, ast.Or : _any, } @@ -237,7 +238,7 @@ def visit_Subscript(self, node): if isinstance(node.value, ast.Name): array = node.value.id _buf = self._get_buffer_from_id(array) - return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) + return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index) _internal_assert(isinstance(node.value, ast.Attribute), \ "Only variable and attribute's subscript supported so far") diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c718fc66899a..7efbbe43ee21 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -1,4 +1,4 @@ -import tvm, inspect, sys, traceback, numpy, nose +import tvm, inspect, sys, traceback, numpy, nose, types from tvm.hybrid import script from tvm.hybrid.intrin import HYBRID_GLOBALS @@ -11,6 +11,10 @@ def tvm_val_2_py_val(val): return val.value ctx = tvm.context(target, 0) + op = None + + outs = func(*args) + op = outs[0].op if isinstance(outs, list) else outs.op emu_args = [] nd_args = [] @@ -24,8 +28,6 @@ def tvm_val_2_py_val(val): emu_args.append(tvm_val_2_py_val(i)) nd_args.append(emu_args[-1]) - outs = func(*args) - op = outs[0].op if isinstance(outs, list) else outs.op sch = tvm.create_schedule(op) module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target) assert module @@ -425,10 +427,12 @@ def downstream(a): for i in range(20): b[i] = a[i] * i return b + a = tvm.placeholder((20, ), 'float32') b = downstream(a) c = tvm.compute((20, ), lambda x: b[x] + 1.0) + sch = tvm.create_schedule(c.op) module = tvm.build(sch, [a, c]) assert module @@ -469,6 +473,40 @@ def add_something(a, b): tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5) +def test_value_index(): + @tvm.hybrid.script + def kernel_a(a): + b = output_tensor((16, ), 'int32') + c = output_tensor((4, 4), 'int32') + for i in range(16): + b[i] = a[i] + 2 + c[i // 4, i % 4] = a[i] + 1 + return b, c + + @tvm.hybrid.script + def kernel_b(b, a): + c = output_tensor((4, 4), 'int32') + for i in range(4): + for j in range(4): + c[i, j] = a[i * 4 + j] * b[i, j] + return c + + a = tvm.placeholder((16, ), 'int32') + b, c = kernel_a(a) + d = kernel_b(c, b) + sch = tvm.create_schedule(d.op) + module = tvm.build(sch, [a, d]) + assert module + + np_a = numpy.arange(16).astype('int32') + np_b, np_c = kernel_a(np_a) + ref = kernel_b(np_c, np_b) + + res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32')) + module(tvm.ndarray.array(np_a), res) + tvm.testing.assert_allclose(res.asnumpy(), ref) + + if __name__ == "__main__": test_outer_product() @@ -479,9 +517,11 @@ def add_something(a, b): test_math_intrin() test_non_zero() test_allocate() - #test_inplace() test_upstream() test_downstream() test_const_param() + test_value_index() + # TODO: + # test_inplace()