diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py index 93517fef4d1d..b3fb64579b60 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -29,7 +29,7 @@ def __init__(self, tag, ext): unroll = vectorize = parallel = _range #pylint: disable=invalid-name -def allocate(shape, dtype='float32'): +def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument """Allocate a buffer with given shape Parameters @@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'): The shape of the tensor to be allocated dtype: string The data type of the tensor + scope: string + The storage scope of the tensor Returns ------- diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index b3f076094a12..1e532367a321 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -3,7 +3,7 @@ import ast import operator import sys -from .util import make_nop, make_const_true, make_range_one, halide_imm_types +from .util import make_nop, halide_imm_types from .intrin import LOOP_INTRIN, MATH_INTRIN from .var_decl import determine_variable_usage from ..api import thread_axis @@ -75,7 +75,8 @@ def __init__(self, args, usage, func_name=None): self.args = args[:] self.usage = usage.copy() self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) - self.buffers = {} + self.var_buffers = {} # Buffers formed by mutatble variables + self.alloc_buffers = {} # Buffers formed by allocate instructions self.loops_above = {} # State variable that indicates loop levels above the current node self.var_consts = {} # Variables that are determined as readonly in previous stage self.func_name = func_name # The name of the function to be lowered @@ -87,19 +88,30 @@ def wrap_up_realize(self, node, body): for key, val in self.usage.items(): if key in self.var_consts.keys(): continue - _, scope, _ = val - if scope == node: - _buf = self.buffers[key] + _, level, _ = val + if level == node: + if key in self.var_buffers.keys(): + _buf = self.var_buffers[key] + _scope = 'global' + else: + _buf, _scope = self.alloc_buffers[key] + _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype - _one = make_range_one() - _true = make_const_true() - body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body) + _true = _api.convert(True) + body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) + body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) return body - def _check_id_a_buffer(self, s): - if s not in self._args.keys(): + def _get_buffer_from_id(self, s): + if s not in self._args.keys() and s not in self.alloc_buffers.keys(): raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) + if s in self._args.keys() and s in self.alloc_buffers.keys(): + raise ValueError("%s, a buffer cannot be both argument and allocated!" % s) + if s in self._args.keys(): + return self._args[s] + return self.alloc_buffers[s][0] + #pylint: disable=invalid-name, missing-docstring @@ -138,8 +150,8 @@ def visit_Name(self, node): if _id not in self.usage.keys(): raise ValueError("This id %s is expected to be a defined variable!" % _id) # Buffer - if _id in self.buffers.keys(): - _buf = self.buffers[_id] + if _id in self.var_buffers.keys(): + _buf = self.var_buffers[_id] return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) # Compilation time constant if _id not in self.var_consts.keys(): @@ -155,7 +167,9 @@ def visit_Assign(self, node): if len(node.targets) != 1: raise ValueError("So far only one-valued assignment is supported!") lhs = node.targets[0] - rhs = _ir_pass.Simplify(self.visit(node.value)) + rhs = self.visit(node.value) + if isinstance(rhs, _expr.Expr): + rhs = _ir_pass.Simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs @@ -166,25 +180,31 @@ def visit_Assign(self, node): if decl == lhs_: if lhs in self.var_consts.keys(): raise ValueError("BUG: A constant cannot be overwritten!") - if lhs in self.buffers.keys(): + if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys(): raise ValueError("BUG: This value should not be defined before this point!") + if isinstance(rhs, tuple): + shape, dtype, scope = rhs + ph = _api.placeholder(shape, dtype=dtype, name=lhs) + self.alloc_buffers[lhs] = (ph, scope) + return make_nop() if isinstance(rhs, halide_imm_types) and ast.Store not in rw: self.var_consts[lhs] = rhs else: - self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) + self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) if lhs in self.var_consts.keys(): return make_nop() else: - if lhs not in self.buffers.keys(): - raise ValueError("BUG: This value should be defined before!") - return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) + if lhs not in self.var_buffers.keys(): + raise ValueError("BUG: This variable should be defined before!") + tgt = self.var_buffers[lhs] + return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) else: lhs = self.visit(lhs) if not isinstance(lhs, _expr.Call): raise ValueError("An array access's LHS is expected to be a expr.Call!") #TODO: support slice later - self._check_id_a_buffer(lhs.name) - return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) + buf = self._get_buffer_from_id(lhs.name) + return _make.Provide(buf.op, 0, rhs, lhs.args) def visit_Index(self, node): @@ -197,8 +217,7 @@ def visit_Subscript(self, node): args = self.visit(node.slice) if isinstance(node.value, ast.Name): array = node.value.id - self._check_id_a_buffer(array) - _buf = self._args[array] + _buf = self._get_buffer_from_id(array) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) elif isinstance(node.value, ast.Attribute): if not isinstance(node.value.value, ast.Name): @@ -211,8 +230,8 @@ def visit_Subscript(self, node): #TODO: maybe support non-constant value later? if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): raise ValueError("So far only constant shape access supported!") - self._check_id_a_buffer(node.value.value.id) - return self._args[node.value.value.id].shape[args.value] + buf = self._get_buffer_from_id(node.value.value.id) + return buf.shape[args.value] else: raise ValueError("Not supported yet!") @@ -303,8 +322,30 @@ def visit_Call(self, node): elif func_id in MATH_INTRIN: return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) elif func_id == 'allocate': - #TODO: Support it later! - return make_nop() + if not isinstance(node.args[0], ast.Tuple): + raise ValueError("allocate's first argument should be a tuple of shape!") + shape = tuple(self.visit(i) for i in node.args[0].elts) + for i in shape: + if not isinstance(i, _expr.Expr): + raise ValueError("The shape should be an expression") + if n > 1: + if not isinstance(node.args[1], ast.Str): + raise ValueError("The data type should be an string") + dtype = node.args[1].s + else: + dtype = 'float32' + if n > 2: + if not isinstance(node.args[2], ast.Str): + raise ValueError("The data type should be an string") + scope = node.args[2].s + else: + scope = 'global' + return (shape, dtype, scope) + elif func_id == 'max' or func_id == 'min': + if n != 2: + raise ValueError("Max/Min function should have 2 elements") + a, b = self.visit(node.args[0]), self.visit(node.args[1]) + return getattr(_make, func_id.title())(a, b) else: raise ValueError("Function call not supported yet!") @@ -317,8 +358,10 @@ def visit_For(self, node): if iter_var is None: if for_type is None: raise ValueError("The loop bind function parse error!") - iter_var = _api.var(_name) - self.loops_above[_name] = iter_var + offset = iter_var = _api.var(_name) + if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): + offset = iter_var + low + self.loops_above[_name] = offset else: if for_type is not None: raise ValueError("The loop iterating function parse error!") @@ -328,7 +371,7 @@ def visit_For(self, node): if for_type is None: res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) else: - res = _make.For(iter_var, low, ext, for_type, 0, _body) + res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body) self.loops_above.pop(_name) return res diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 8a5f4a62768d..43d26e859560 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -22,16 +22,6 @@ def make_nop(): return _make.Evaluate(_api.const(0, dtype='int32')) -def make_range_one(): - """Returns a [0, 1] range node in HalideIR.""" - return _make.range_by_min_extent(0, 1) - - -def make_const_true(): - """Returns a constant True node in HalideIR.""" - return _api.convert(True) - - def _pruned_source(func): """Prune source code's extra leading spaces""" lines = inspect.getsource(func).split('\n') diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index 940b8c088df3..df38bac1acba 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -41,7 +41,8 @@ def visit_Call(self, node): #No function pointer supported so far if not isinstance(node.func, ast.Name): raise ValueError("Function call should be an id") - if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range': + func_id = node.func.id + if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']: raise ValueError("Function call id not in intrinsics' list") for elem in node.args: self.visit(elem) @@ -64,7 +65,6 @@ def visit_Name(self, node): self.status[node.id] = (node, self.scope_level[-1], set()) else: decl, loop, usage = self.status[node.id] - loop = self.scope_level[-1] usage.add(type(node.ctx)) self.status[node.id] = (decl, loop, usage) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index fda4f52c1f19..06f24f2adaf5 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -2,6 +2,7 @@ from tvm.hybrid import script from tvm.hybrid.intrin import HYBRID_GLOBALS + @script def outer_product(n, m, a, b, c): for i in range(n): @@ -56,6 +57,7 @@ def test_outer_product(): tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32')) func(_n, _m, tvm_a, tvm_b, tvm_c) numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5) + for key, _ in HYBRID_GLOBALS.items(): assert key not in globals().keys() assert key not in outer_product.__globals__.keys() @@ -74,8 +76,8 @@ def fanout(n, a, b): b[i] = sigma n = tvm.var('n') - a = tvm.placeholder((n, ), name='a') - b = tvm.placeholder((n-3, ), name='b') + a = tvm.placeholder((n, ), 'float32', name='a') + b = tvm.placeholder((n-3, ), 'float32', name='b') ir = fanout(n, a, b) #Check for i in (0, n-3) @@ -85,12 +87,14 @@ def fanout(n, a, b): assert tvm.ir_pass.Equal(ir.extent, n - 3) #Check loopbody ibody = ir.body - assert isinstance(ibody, tvm.stmt.Realize) - assert ibody.bounds[0].min.value == 0 - assert ibody.bounds[0].extent.value == 1 - assert ibody.func.name == 'sigma' + assert isinstance(ibody, tvm.stmt.AttrStmt) + abody = ibody.body + assert isinstance(abody, tvm.stmt.Realize) + assert abody.bounds[0].min.value == 0 + assert abody.bounds[0].extent.value == 1 + assert abody.func.name == 'sigma' #Check i loop body - rbody = ibody.body + rbody = abody.body assert isinstance(rbody.first, tvm.stmt.Provide) assert rbody.first.func.name == 'sigma' assert len(rbody.first.args) == 1 @@ -131,6 +135,21 @@ def fanout(n, a, b): assert len(write.value.args) == 1 assert write.value.args[0].value == 0 + func = tvm.build(tvm.lower(ir, [n, a, b])) + assert func + + np_a = numpy.random.randn(10).astype('float32') + np_b = numpy.zeros(7).astype('float32') + + nd_a = tvm.ndarray.array(np_a) + nd_b = tvm.ndarray.array(np_b) + + fanout(10, np_a, np_b) + func(10, nd_a, nd_b) + + numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, rtol=1e-5, atol=1e-5) + + @script def failure(): for i in range(1, 100): @@ -148,15 +167,18 @@ def test_failure(): def test_looptype(): @script - def looptype(a): - for i in parallel(6): + def looptype(a, b, c): + for i in parallel(8): a[i] = i - for j in vectorize(6): - a[j] = j - for k in unroll(6): - a[k] = k - a = tvm.placeholder((6, ), name='a') - ir = looptype(a) + for j in vectorize(8): + b[j] = j + for k in unroll(8): + c[k] = k + + a = tvm.placeholder((8, ), name='a', dtype='int32') + b = tvm.placeholder((8, ), name='b', dtype='int32') + c = tvm.placeholder((8, ), name='c', dtype='int32') + ir = looptype(a, b, c) iloop = ir.first jloop = ir.rest.first kloop = ir.rest.rest @@ -164,6 +186,24 @@ def looptype(a): assert jloop.for_type == tvm.stmt.For.Vectorized assert kloop.for_type == tvm.stmt.For.Unrolled + func = tvm.build(tvm.lower(ir, [a, b, c])) + + np_a = numpy.zeros((8, )).astype('int32') + np_b = numpy.zeros((8, )).astype('int32') + np_c = numpy.zeros((8, )).astype('int32') + + nd_a = tvm.ndarray.array(np_a) + nd_b = tvm.ndarray.array(np_b) + nd_c = tvm.ndarray.array(np_c) + + looptype(np_a, np_b, np_c) + func(nd_a, nd_b, nd_c) + + numpy.testing.assert_allclose(np_a, nd_a.asnumpy()) + numpy.testing.assert_allclose(np_b, nd_b.asnumpy()) + numpy.testing.assert_allclose(np_c, nd_c.asnumpy()) + + def test_if(): @script def if_then_else(a, b): @@ -234,12 +274,14 @@ def intrin_real(a): a[3] = sigmoid(a[3]) a[4] = power(a[4], a[5]) a[5] = tanh(a[5]) + a[6] = min(a[4], a[5]) + a[7] = max(a[5], a[6]) - a6 = tvm.placeholder((6, ), dtype='float32', name='a') + a6 = tvm.placeholder((8, ), dtype='float32', name='a') ir = intrin_real(a6) func = tvm.build(tvm.lower(ir, [a6])) assert func - a = numpy.arange(2, 8).astype('float32') + a = numpy.arange(2, 10).astype('float32') tvm_a = tvm.ndarray.array(a) func(tvm_a) intrin_real(a) @@ -259,22 +301,87 @@ def intrin_int(a): func(tvm_a) assert tvm_a.asnumpy()[0] == a[0] -def test_allocate_buffer(): - def blur(a): - for i in serail(32): - h_blur = allocate((4, 36)) - for j in serail(4): - for k in serail(36): - s = allocate((1, ), 'float32') - for dj in serail(4): - s[0] = s[0] + a[i, j + dj] - h_blur[j, k] = s[0] / 4. - for j in serail(32): - s = 0. - for di in serail(4): - s = s + h_blur[di, j] - h_blur[i, j] = s / 4. - +def test_non_zero(): + @tvm.hybrid.script + def blur(a, b): + for i in range(2, 32): + for j in range(2, 32): + s = 0.0 + for di in range(3): + for dj in range(3): + s = s + a[i-di, j-dj] + b[i-2, j-2] = s / 9.0 + try: + np_a = numpy.random.randn(32, 32).astype('float32') + np_b = numpy.zeros((30, 30), dtype='float32') + blur(np_a, np_b) + + ph_a = tvm.placeholder((32, 32), 'float32', 'a') + ph_b = tvm.placeholder((30, 30), 'float32', 'b') + ir = tvm.hybrid.parse(blur, [ph_a, ph_b]) + func = tvm.lower(ir, [ph_a, ph_b]) + func = tvm.build(func) + + nd_a = tvm.ndarray.array(np_a) + nd_b = tvm.ndarray.array(np_b) + func(nd_a, nd_b) + + numpy.testing.assert_allclose(np_b, nd_b.asnumpy(), atol=1e-5, rtol=1e-5) + except IOError: + print('[Warning] Non-zero first test skipped by Python2') + + @tvm.hybrid.script + def triangle(a, b, c): + for i in range(10): + for j in range(i, 10): + c[i, j] = a[i] * b[j] + + a = tvm.placeholder((10, ), dtype='float32', name='a') + b = tvm.placeholder((10, ), dtype='float32', name='b') + c = tvm.placeholder((10, 10), dtype='float32', name='c') + + np_a = numpy.random.randn(10).astype('float32') + np_b = numpy.random.randn(10).astype('float32') + np_c = numpy.zeros((10, 10)).astype('float32') + + nd_a = tvm.ndarray.array(np_a) + nd_b = tvm.ndarray.array(np_b) + nd_c = tvm.ndarray.array(np_c) + + triangle(np_a, np_b, np_c) + + func = tvm.build(tvm.lower(triangle(a, b, c), [a, b, c])) + assert func + func(nd_a, nd_b, nd_c) + numpy.testing.assert_allclose(nd_c.asnumpy(), np_c) + +def test_allocate(): + @tvm.hybrid.script + def blur2d(a, b): + for i in range(30): + ha = allocate((3, 30), 'float32') + for j in range(3): + for k in range(30): + ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2] + for j in range(30): + b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0 + + a = tvm.placeholder((32, 32), 'float32', 'a') + b = tvm.placeholder((30, 30), 'float32', 'b') + + func = tvm.build(tvm.lower(blur2d(a, b), [a, b])) + assert func + + np_a = numpy.random.randn(32, 32).astype('float32') + np_b = numpy.zeros((30, 30)).astype('float32') + + nd_a = tvm.ndarray.array(np_a) + nd_b = tvm.ndarray.array(np_b) + + func(nd_a, nd_b) + blur2d(np_a, np_b) + + numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, atol=1e-5, rtol=1e-5) if __name__ == "__main__": test_outer_product() @@ -284,4 +391,6 @@ def blur(a): test_if() test_bind() test_math_intrin() + test_non_zero() + test_allocate()