diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 9327c0865689..7d9cc1a918d1 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -33,36 +33,40 @@ from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .runtime import ndarray as nd +# tvm.error +from . import error + # tvm.ir from .ir import IRModule from .ir import transform from .ir import container from . import ir +# tvm.tir +from . import tir + +# tvm.target +from . import target + # others from . import tensor from . import arith -from . import expr -from . import stmt from . import make -from . import ir_pass from . import schedule - -from . import ir_builder -from . import target -from . import generic from . import hybrid from . import testing -from . import error - from .api import * -from .intrin import * from .tensor_intrin import decl_tensor_intrin from .schedule import create_schedule from .build_module import build, lower, build_config from .tag import tag_scope +# backward compact for topi, to be removed later +from .tir import expr, stmt, ir_builder, ir_pass, generic +from .tir.op import * +from . import intrin + # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel diff --git a/python/tvm/api.py b/python/tvm/api.py index e7778d6cc5df..3a8eedcf8e2e 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -23,13 +23,17 @@ from tvm.runtime import convert, const, DataType from tvm.ir import container as _container +from tvm.tir import expr as _expr +from tvm.tir import stmt as _stmt +from tvm.tir import decl_buffer, layout, bijective_layout +from tvm.tir import min_value, max_value, indexdiv, indexmod +import tvm.tir._ffi_api from ._ffi.base import string_types, TVMError from ._ffi.registry import register_func, get_global_func, extract_ext_funcs from . import _api_internal from . import make as _make -from . import expr as _expr from . import tensor as _tensor from . import schedule as _schedule from . import tag as _tag @@ -40,37 +44,6 @@ handle = "handle" -def min_value(dtype): - """minimum value of dtype - - Parameters - ---------- - dtype : str - The data type. - - Returns - ------- - value : tvm.Expr - The minimum value of dtype. - """ - return _api_internal._min_value(dtype) - - -def max_value(dtype): - """maximum value of dtype - - Parameters - ---------- - dtype : str - The data type. - - Returns - ------- - value : tvm.Expr - The maximum value of dtype. - """ - return _api_internal._max_value(dtype) - def var(name="tindex", dtype=int32): """Create a new variable with specified name and dtype @@ -87,7 +60,7 @@ def var(name="tindex", dtype=int32): var : Var The result symbolic variable. """ - return _api_internal._Var(name, dtype) + return _expr.Var(name, dtype) def size_var(name="size", dtype=int32): @@ -106,7 +79,7 @@ def size_var(name="size", dtype=int32): var : SizeVar The result symbolic shape variable. """ - return _api_internal._SizeVar(name, dtype) + return _expr.SizeVar(name, dtype) def any(*args): @@ -126,9 +99,9 @@ def any(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _make._OpOr(args[0], args[1]) + ret = tvm.tir._ffi_api._OpOr(args[0], args[1]) for i in range(2, len(args)): - ret = _make._OpOr(ret, args[i]) + ret = tvm.tir._ffi_api._OpOr(ret, args[i]) return ret @@ -150,9 +123,9 @@ def all(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _make._OpAnd(args[0], args[1]) + ret = tvm.tir._ffi_api._OpAnd(args[0], args[1]) for i in range(2, len(args)): - ret = _make._OpAnd(ret, args[i]) + ret = tvm.tir._ffi_api._OpAnd(ret, args[i]) return ret @@ -438,7 +411,7 @@ def extern(shape, output_placeholders.append(decl_buffer(shp, dt, name)) body = fcompute(input_placeholders, output_placeholders) if isinstance(body, _expr.PrimExpr): - body = _make.Evaluate(body) + body = _stmt.Evaluate(body) op = _api_internal._ExternOp(name, tag, attrs, inputs, input_placeholders, @@ -447,159 +420,6 @@ def extern(shape, return res[0] if len(res) == 1 else res -def decl_buffer(shape, - dtype=None, - name="buffer", - data=None, - strides=None, - elem_offset=None, - scope="", - data_alignment=-1, - offset_factor=0, - buffer_type=""): - """Declare a new symbolic buffer. - - Normally buffer is created automatically during lower and build. - This is only needed if user want to specify their own buffer layout. - - See the note below for detailed discussion on usage of buffer. - - Parameters - ---------- - shape : tuple of Expr - The shape of the buffer. - - dtype : str, optional - The data type of the buffer. - - name : str, optional - The name of the buffer. - - data : Var, optional - The data pointer in the buffer. - - strides: array of Expr - The stride of the buffer. - - elem_offset: Expr, optional - The beginning offset of the array to data. - In terms of number of elements of dtype. - - scope: str, optional - The storage scope of the buffer, if not global. - If scope equals empty string, it means it is global memory. - - data_alignment: int, optional - The alignment of data pointer in bytes. - If -1 is passed, the alignment will be set to TVM's internal default. - - offset_factor: int, optional - The factor of elem_offset field, when set, - elem_offset is required to be multiple of offset_factor. - If 0 is pssed, the alignment will be set to 1. - if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. - - buffer_type: str, optional, {"", "auto_broadcast"} - auto_broadcast buffer allows one to implement broadcast computation - without considering whether dimension size equals to one. - TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. - - Returns - ------- - buffer : Buffer - The created buffer - - Example - ------- - Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, - - .. code-block:: python - - m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") - n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") - o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") - A = tvm.placeholder((m0, m1, m2), name='A') - B = tvm.placeholder((n0, n1, n2), name='B') - C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') - Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") - Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") - s = tvm.create_schedule(C.op) - fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) - ctx = tvm.cpu(0) - a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) - fadd(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - - Note - ---- - Buffer data structure reflects the DLTensor structure in dlpack. - While DLTensor data structure is very general, it is usually helpful - to create function that only handles specific case of data structure - and make compiled function benefit from it. - - If user pass strides and elem_offset is passed as None - when constructing the function, then the function will be specialized - for the DLTensor that is compact and aligned. - If user pass a fully generic symbolic array to the strides, - then the resulting function becomes fully generic. - """ - shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape - dtype = float32 if dtype is None else dtype - strides = () if strides is None else strides - if offset_factor != 0 and elem_offset is None: - shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" - elem_offset = var('%s_elem_offset' % name, shape_dtype) - if data is None: - data = var(name, "handle") - return _api_internal._Buffer( - data, dtype, shape, strides, elem_offset, name, scope, - data_alignment, offset_factor, buffer_type) - -def layout(layout_str): - """Create a layout node from a string. - - Parameters - ---------- - layout_str : str - A layout representation is composed of upper cases, lower cases and numbers, - where upper case indicates a primal axis and - the corresponding lower case with factor size indicates the subordinate axis. - For example, NCHW16c can describe a 5-D tensor of - [batch_size, channel, height, width, channel_block]. - Here subordinate axis channel_block=16 is the factor size of - the primal axis C (channel). - - Returns - ------- - layout : Layout - The created layout - """ - return _api_internal._Layout(layout_str) - -def bijective_layout(src_layout, dst_layout): - """Create a bijective layout mapping. - - Parameters - ---------- - src_layout : str or Layout - source layout. - - dst_layout : str or Layout - destination layout. - - Returns - ------- - bijective_layout : BijectiveLayout - The created bijective layout - """ - if isinstance(src_layout, str): - src_layout = layout(src_layout) - if isinstance(dst_layout, str): - dst_layout = layout(dst_layout) - return _api_internal._BijectiveLayout(src_layout, dst_layout) - def _IterVar(dom, name, iter_type, thread_tag=''): """Internal function to create IterVar @@ -758,7 +578,7 @@ def _make_reduce(expr, axis, where=None): expr = convert([expr]) result = convert(result) id_elem = convert(id_elem) - combiner = _make.CommReducer(lhs, rhs, result, id_elem) + combiner = _expr.CommReducer(lhs, rhs, result, id_elem) axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) if where is None: where = convert(True) @@ -810,164 +630,9 @@ def reducer(expr, axis, where=None, *args): reducer.__doc__ = doc_str.format(name) return reducer -def div(a, b): - """Compute a / b as in C/C++ semantics. - - Parameters - ---------- - a : Expr - The left hand operand, known to be non-negative. - - b : Expr - The right hand operand, known to be non-negative. - - Returns - ------- - res : Expr - The result expression. - Note - ---- - When operands are integers, returns truncdiv(a, b). - """ - return _make._OpDiv(a, b) - - -def indexdiv(a, b): - """Compute floor(a / b) where a and b are non-negative. - - Parameters - ---------- - a : Expr - The left hand operand, known to be non-negative. - - b : Expr - The right hand operand, known to be non-negative. - - Returns - ------- - res : Expr - The result expression. - - Note - ---- - Use this function to split non-negative indices. - This function may take advantage of operands' - non-negativeness. - """ - return _make._OpIndexDiv(a, b) - - -def indexmod(a, b): - """Compute the remainder of indexdiv. a and b are non-negative. - - Parameters - ---------- - a : Expr - The left hand operand, known to be non-negative. - - b : Expr - The right hand operand, known to be non-negative. - - Returns - ------- - res : Expr - The result expression. - - Note - ---- - Use this function to split non-negative indices. - This function may take advantage of operands' - non-negativeness. - """ - return _make._OpIndexMod(a, b) - - -def truncdiv(a, b): - """Compute the truncdiv of two expressions. - - Parameters - ---------- - a : Expr - The left hand operand - - b : Expr - The right hand operand - - Returns - ------- - res : Expr - The result expression. - - Note - ---- - This is the default integer division behavior in C. - """ - return _make._OpTruncDiv(a, b) - - -def truncmod(a, b): - """Compute the truncmod of two expressions. - - Parameters - ---------- - a : Expr - The left hand operand - - b : Expr - The right hand operand - - Returns - ------- - res : Expr - The result expression. - - Note - ---- - This is the default integer division behavior in C. - """ - return _make._OpTruncMod(a, b) - - -def floordiv(a, b): - """Compute the floordiv of two expressions. - - Parameters - ---------- - a : Expr - The left hand operand - - b : Expr - The right hand operand - - Returns - ------- - res : Expr - The result expression. - """ - return _make._OpFloorDiv(a, b) - - -def floormod(a, b): - """Compute the floormod of two expressions. - - Parameters - ---------- - a : Expr - The left hand operand - - b : Expr - The right hand operand - - Returns - ------- - res : Expr - The result expression. - """ - return _make._OpFloorMod(a, b) - -#pylint: disable=unnecessary-lambda +# pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") -min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') -max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max') +min = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMin(x, y), max_value, name='min') +max = comm_reducer(lambda x, y: tvm.tir._ffi_api._OpMax(x, y), min_value, name='max') tvm._ffi._init_api("tvm.api") diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 5067277d32a8..9ff8b24fcb5d 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -227,7 +227,7 @@ def args_to_workload(x, topi_compute_func=None): workload = 0 else: raise RuntimeError('Do not support type "%s" in argument. Consider to use' - 'primitive types or tvm.expr.Var only' % type(x)) + 'primitive types or tvm.tir.Var only' % type(x)) return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload def template(func): diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index c2993ac27819..de78a3ee2700 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -26,17 +26,19 @@ from tvm.runtime import Object, ndarray from tvm.ir import container from tvm.target import codegen +from tvm.tir import expr +from tvm.tir import ir_pass +from tvm.tir import Stmt +from tvm.tir.stmt import LoweredFunc + +from . import target as _target from . import api from . import _api_internal from . import tensor from . import schedule -from . import expr -from . import ir_pass -from . import stmt as _stmt -from . import target as _target from . import make -from .stmt import LoweredFunc + class DumpIR(object): @@ -61,7 +63,7 @@ def decorate(self, func): def dump(*args, **kwargs): """dump function""" retv = func(*args, **kwargs) - if not isinstance(retv, (_stmt.Stmt, LoweredFunc, container.Array)): + if not isinstance(retv, (Stmt, LoweredFunc, container.Array)): return retv fname = func.func_name if hasattr(func, 'func_name') else func.__name__ pname = str(self._pass_id) + "_" + fname + "_ir.cc" diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 7c024b792867..cdd4ce22c82d 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. """External function interface to BLAS libraries.""" -from __future__ import absolute_import as _abs - -from .. import api as _api, intrin as _intrin +import tvm +from .. import api as _api def matmul(lhs, rhs, transa=False, transb=False, **kwargs): @@ -46,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): return _api.extern( (n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb ), name="C", @@ -78,7 +77,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs return _api.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cblas.batch_matmul" if not iterative else "tvm.contrib.cblas.batch_matmul_iterative", diff --git a/python/tvm/contrib/cublas.py b/python/tvm/contrib/cublas.py index cf4e1f5e49bc..75290a8f6402 100644 --- a/python/tvm/contrib/cublas.py +++ b/python/tvm/contrib/cublas.py @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. """External function interface to cuBLAS libraries.""" -from __future__ import absolute_import as _abs - +import tvm from .. import api as _api -from .. import intrin as _intrin def matmul(lhs, rhs, transa=False, transb=False, dtype=None): """Create an extern op that compute matrix mult of A and rhs with cuBLAS @@ -44,7 +42,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None): dtype = dtype if dtype is not None else lhs.dtype return _api.extern( (n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cublas.matmul", ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") @@ -73,6 +71,6 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): dtype = dtype if dtype is not None else lhs.dtype return _api.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cublas.batch_matmul", ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") diff --git a/python/tvm/contrib/cublaslt.py b/python/tvm/contrib/cublaslt.py index 5470fd0b4c18..1000ede1379d 100644 --- a/python/tvm/contrib/cublaslt.py +++ b/python/tvm/contrib/cublaslt.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. """External function interface to cuBLASlt libraries.""" -from __future__ import absolute_import as _abs - +import tvm from .. import api as _api -from .. import intrin as _intrin + def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): """Create an extern op that compute matrix mult of A and rhs with cuBLAS @@ -46,6 +45,6 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): dtype = dtype if dtype is not None else lhs.dtype return _api.extern( (n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cublaslt.matmul", ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C") diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 1b5caca699e5..20b42d79d27e 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -18,8 +18,8 @@ # pylint: disable-msg=C0103 import ctypes import numpy as np +import tvm from .. import api as _api -from .. import intrin as _intrin from .. import get_global_func as _get_global_func # algos can be read from cudnn.h @@ -365,7 +365,7 @@ def conv_forward(x, if dims == 4: return _api.extern( oshape, [x, w], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv2d.forward", conv_mode, tensor_format, @@ -383,7 +383,7 @@ def conv_forward(x, return _api.extern( oshape, [x, w], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cudnn.conv3d.forward", conv_mode, tensor_format, diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index e062ac1e735e..7f024f70b21a 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -18,8 +18,8 @@ # pylint: disable-msg=C0103 import ctypes import numpy as np +import tvm from .. import api as _api -from .. import intrin as _intrin from .. import get_global_func as _get_global_func @@ -113,7 +113,7 @@ def conv2d_forward(x, return _api.extern( list(oshape), [x, w], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.miopen.conv2d.forward", conv_mode, data_type, diff --git a/python/tvm/contrib/mps.py b/python/tvm/contrib/mps.py index d9cab2471c6f..5d84e892ec74 100644 --- a/python/tvm/contrib/mps.py +++ b/python/tvm/contrib/mps.py @@ -15,9 +15,8 @@ # specific language governing permissions and limitations # under the License. """External function interface to MPS libraries.""" -from __future__ import absolute_import as _abs +import tvm from .. import api as _api -from .. import intrin as _intrin # pylint: disable=C0103,W0612 @@ -50,7 +49,7 @@ def matmul(lhs, rhs, transa=False, transb=False): n = c return _api.extern( (m, n), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.mps.matmul", ins[0], ins[1], outs[0], transa, transb), name="C") @@ -82,6 +81,6 @@ def conv2d(data, weight, pad='SAME', stride=1): return _api.extern( (n, ho, wo, co), [data, weight], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.mps.conv2d", ins[0], ins[1], outs[0], padding, stride), name="C") diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 3e2132eb5067..a55a344b6410 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. """External function interface to NNPACK libraries.""" +import tvm import tvm._ffi - from .. import api as _api -from .. import intrin as _intrin def is_available(): @@ -46,7 +45,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1): m = rhs.shape[0] return _api.extern( (m, ), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.nnpack.fully_connected_inference", ins[0], ins[1], outs[0], nthreads), name="C") @@ -110,7 +109,7 @@ def convolution_inference( return _api.extern( (batch, output_channels, output_height, output_width), [data, kernel, bias] if bias is not None else [data, kernel], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], @@ -163,7 +162,7 @@ def convolution_inference_without_weight_transform( return _api.extern( (batch, output_channels, output_height, output_width), [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.nnpack.convolution_inference_without_weight_transform", ins[0], ins[1], @@ -198,7 +197,7 @@ def convolution_inference_weight_transform( return _api.extern( (output_channels, input_channels, transform_tile_size, transform_tile_size), [kernel], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.nnpack.convolution_inference_weight_transform", ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype) diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index 059bf2344e6b..bcc9b1703386 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. """External function interface to random library.""" +import tvm import tvm._ffi - from .. import api as _api -from .. import intrin as _intrin def randint(low, high, size, dtype='int32'): @@ -39,7 +38,7 @@ def randint(low, high, size, dtype='int32'): A tensor with specified size and dtype """ assert 'int' in dtype, "the type of randint output must be int or uint" - return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( + return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.random.randint", int(low), int(high), outs[0]), dtype=dtype) @@ -67,7 +66,7 @@ def uniform(low, high, size): out : Tensor A tensor with specified size and dtype. """ - return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( + return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.random.uniform", float(low), float(high), outs[0]), dtype='float32') @@ -91,7 +90,7 @@ def normal(loc, scale, size): out : Tensor A tensor with specified size and dtype """ - return _api.extern(size, [], lambda ins, outs: _intrin.call_packed( + return _api.extern(size, [], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.random.normal", float(loc), float(scale), outs[0]), dtype='float32') diff --git a/python/tvm/contrib/rocblas.py b/python/tvm/contrib/rocblas.py index bdd6146e2ecf..e11be5a1d973 100644 --- a/python/tvm/contrib/rocblas.py +++ b/python/tvm/contrib/rocblas.py @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. """External function interface to rocBLAS libraries.""" -from __future__ import absolute_import as _abs - +import tvm from .. import api as _api -from .. import intrin as _intrin def matmul(lhs, rhs, transa=False, transb=False): """Create an extern op that compute matrix mult of A and rhs with rocBLAS @@ -43,6 +41,6 @@ def matmul(lhs, rhs, transa=False, transb=False): m = rhs.shape[0] if transb else rhs.shape[1] return _api.extern( (n, m), [lhs, rhs], - lambda ins, outs: _intrin.call_packed( + lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.rocblas.matmul", ins[0], ins[1], outs[0], transa, transb), name="C") diff --git a/python/tvm/generic.py b/python/tvm/generic.py index b7bea7fa4236..7c46312c2ea5 100644 --- a/python/tvm/generic.py +++ b/python/tvm/generic.py @@ -14,117 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Generic opertors in TVM. -We follow the numpy naming convention for this interface -(e.g., tvm.generic.multitply ~ numpy.multiply). -The default implementation is used by tvm.ExprOp. -""" -# pylint: disable=unused-argument -from . import make as _make - -#Operator precedence used when overloading. -__op_priority__ = 0 - - -def add(lhs, rhs): - """Generic add operator. - - Parameters - ---------- - lhs : object - The left operand. - rhs : object - The right operand. - - Returns - ------- - op : tvm.Expr - The result Expr of add operaton. - """ - return _make._OpAdd(lhs, rhs) - - -def subtract(lhs, rhs): - """Generic subtract operator. - - Parameters - ---------- - lhs : object - The left operand. - rhs : object - The right operand. - - Returns - ------- - op : tvm.Expr - The result Expr of subtract operaton. - """ - return _make._OpSub(lhs, rhs) - - -def multiply(lhs, rhs): - """Generic multiply operator. - - Parameters - ---------- - lhs : object - The left operand. - rhs : object - The right operand. - - Returns - ------- - op : tvm.Expr - The result Expr of multiply operaton. - """ - return _make._OpMul(lhs, rhs) - -def divide(lhs, rhs): - """Generic divide operator. - - Parameters - ---------- - lhs : object - The left operand. - rhs : object - The right operand. - - Returns - ------- - op : tvm.Expr - The result Expr of divide operaton. - """ - return _make._OpDiv(lhs, rhs) - -def floordiv(lhs, rhs): - """Generic floordiv operator. - - Parameters - ---------- - lhs : object - The left operand. - rhs : object - The right operand. - - Returns - ------- - op : tvm.Expr - The result Expr of divide operaton. - """ - return _make._OpFloorDiv(lhs, rhs) - - -def cast(src, dtype): - """Generic cast operator. - - Parameters - ---------- - src : object - The source operand. - - Returns - ------- - op : tvm.Expr - The result Expr of divide operaton. - """ - return _make._cast(dtype, src) +"""Generic operators.""" +# pylint:disable=unused-wildcard-import, wildcard-import +from .tir.generic import * diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 78ce2e20e1fd..0933628a9943 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -17,16 +17,17 @@ """Intrinsics of TVM-Python Hybrid Script for Python compilation time semantic support.""" from tvm.ir.container import Array +from tvm import target as _tgt +from tvm.tir import expr as _expr +from tvm.tir import ir_pass +from tvm.tir import call_pure_intrin +from tvm.tir.stmt import For + from .. import api as _api -from .. import expr as _expr -from .. import make as _make -from .. import target as _tgt -from .. import ir_pass -from ..stmt import For + from .util import _internal_assert -from ..intrin import call_pure_intrin -#pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin LOOP_INTRIN = { 'range' : For.Serial, @@ -69,15 +70,15 @@ def bind(func_id, args): def _math_intrin(func_id, args): # pylint: disable=import-outside-toplevel - from .. import intrin - return getattr(intrin, func_id)(*args) + import tvm.tir.op + return getattr(tvm.tir.op, func_id)(*args) sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name def _min_max(func_id, args): _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements") - return getattr(_make, func_id.title())(args[0], args[1]) + return getattr(_expr, func_id.title())(args[0], args[1]) min = max = _min_max #pylint: disable=invalid-name @@ -127,7 +128,7 @@ def len(func_id, args): def _cast(func_id, args): _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \ "Only one expression can be cast") - return _make.Cast(func_id, args[0]) + return _expr.Cast(func_id, args[0]) float16 = float32 = float64 = _cast #pylint: disable=invalid-name int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index a0b2dfea6062..6be0006a851f 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -24,7 +24,11 @@ import numbers from enum import Enum -from tvm.ir.container import Array +from tvm.ir import Array, Range +import tvm.tir +from tvm.tir import expr as _expr +from tvm.tir import stmt as _stmt +from tvm.tir import ir_pass as _ir_pass from .util import _internal_assert from . import calls @@ -35,12 +39,7 @@ from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal -from .. import expr as _expr -from .. import make as _make -from .. import stmt as _stmt - from .. import api as _api -from .. import ir_pass as _ir_pass def concat_list_to_block(lst): @@ -79,13 +78,13 @@ class Symbol(Enum): def _floordiv(x, y): if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): - return _api.floordiv(x, y) + return tvm.tir.floordiv(x, y) return operator.floordiv(x, y) def _floormod(x, y): if isinstance(x, _expr.ExprOp) or isinstance(y, _expr.ExprOp): - return _api.floormod(x, y) + return tvm.tir.floormod(x, y) return operator.mod(x, y) @@ -208,11 +207,11 @@ def wrap_up_realize(self, node, body): if _scope == 'global': body = self.wrap_up_binds(body) - _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] + _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = _api.convert(True) - body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) - body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) + body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body) + body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) for elem in to_pop: self.symbols.pop(elem) @@ -223,7 +222,7 @@ def wrap_up_realize(self, node, body): def wrap_up_binds(self, body): for _, iter_var in self.binds.items(): ext = iter_var.dom.extent - body = _make.AttrStmt(iter_var, 'thread_extent', ext, body) + body = tvm.tir.AttrStmt(iter_var, 'thread_extent', ext, body) self.binds = {} return body @@ -271,7 +270,7 @@ def visit_Name(self, node): return entry if isinstance(node.ctx, ast.Load) else None if ty is Symbol.BufferVar: if isinstance(node.ctx, ast.Load): - return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ + return tvm.tir.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \ _expr.Call.Halide, entry.op, entry.value_index) return entry, [_api.const(0, 'int32')] # Do I need any assertion here? @@ -304,10 +303,10 @@ def visit_AugAssign(self, node): args = [_api.const(0, 'int32')] _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") - read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) + read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) value = HybridParser._binop_maker[type(node.op)](read, rhs) - return _make.Provide(buf.op, 0, value, args) + return tvm.tir.Provide(buf.op, 0, value, args) def visit_Assign(self, node): @@ -358,13 +357,13 @@ def visit_Assign(self, node): lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs - return _make.Provide(buf.op, 0, rhs, args) + return tvm.tir.Provide(buf.op, 0, rhs, args) return util.make_nop() lhs, args = self.visit(lhs) _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") - res = _make.Provide(lhs.op, lhs.value_index, rhs, args) + res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args) return res @@ -391,8 +390,8 @@ def visit_Subscript(self, node): arr = arr[i.value] return arr if isinstance(node.ctx, ast.Load): - return _make.Call(arr.dtype, arr.name, args, - _expr.Call.Halide, arr.op, arr.value_index) + return tvm.tir.Call(arr.dtype, arr.name, args, + _expr.Call.Halide, arr.op, arr.value_index) return arr, args def visit_With(self, node): @@ -426,14 +425,14 @@ def visit_If(self, node): else_body = visit_list_to_block(self.visit, node.orelse) else: else_body = None - return _make.IfThenElse(cond, if_body, else_body) + return tvm.tir.IfThenElse(cond, if_body, else_body) def visit_IfExp(self, node): cond = self.visit(node.test) if_body = self.visit(node.body) else_body = self.visit(node.orelse) - return _make.Select(cond, if_body, else_body) + return tvm.tir.Select(cond, if_body, else_body) def visit_Compare(self, node): @@ -543,7 +542,7 @@ def visit_For(self, node): else: _internal_assert(not isinstance(for_type, tuple), \ "Micro expansion should be handled before!") - res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) + res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) self.symbols.pop(_name) return res @@ -580,7 +579,7 @@ def visit_Str(self, node): def visit_Assert(self, node): test = self.visit(node.test) mesg = _api.convert(self.visit(node.msg)) - return _make.AssertStmt(test, mesg, util.make_nop()) + return tvm.tir.AssertStmt(test, mesg, util.make_nop()) def parse_python(src, args, symbols, closure_vars): diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 8ef200a02b2c..2f449dc8f69e 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -22,12 +22,13 @@ import sys import numpy +from tvm._ffi.base import numeric_types from tvm.ir.container import Array + +from tvm.tir import expr as _expr +from tvm.tir import stmt as _stmt + from .. import api as _api -from .. import make as _make -from .. import expr as _expr -from .. import stmt as _stmt -from .._ffi.base import numeric_types from ..tensor import Tensor @@ -46,7 +47,7 @@ def _internal_assert(cond, err): # Useful constants. In avoid of runtime dependences, we use function calls to return them. def make_nop(): """Returns a 'no operation' node in HalideIR.""" - return _make.Evaluate(_api.const(0, dtype='int32')) + return _stmt.Evaluate(_api.const(0, dtype='int32')) def is_docstring(node): @@ -77,10 +78,10 @@ def replace_io(body, rmap): def replace(op): if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): buf = rmap[op.func] - return _make.Provide(buf.op, op.value_index, op.value, op.args) + return _stmt.Provide(buf.op, op.value_index, op.value, op.args) if isinstance(op, _expr.Call) and op.func in rmap.keys(): buf = rmap[op.func] - return _make.Call(buf.dtype, buf.name, op.args, \ + return _expr.Call(buf.dtype, buf.name, op.args, \ _expr.Call.Halide, buf.op, buf.value_index) return None diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 04cbf9ee86c7..93e8fcb3f140 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -14,678 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Expression Intrinsics and math functions in TVM.""" -# pylint: disable=redefined-builtin -import tvm._ffi -import tvm.target.codegen - -from . import make as _make -from .api import convert, const -from .expr import Call as _Call -from .schedule import Buffer as _Buffer - -def _pack_buffer(buf): - """Build intrinsics that packs the buffer. - """ - assert buf.shape - shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape, - _Call.Intrinsic, None, 0) - strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides, - _Call.Intrinsic, None, 0) if buf.strides else 0 - pack_args = [buf.data, - shape, - strides, - len(buf.shape), - const(0, dtype=buf.dtype), - buf.elem_offset] - return _make.Call("handle", "tvm_stack_make_array", - pack_args, _Call.Intrinsic, None, 0) - -def call_packed(*args): - """Build expression by call an external packed function. - - The argument to packed function can be Expr or Buffer. - The argument is the corresponding POD type when Expr is presented. - - When the argument is Buffer, the corresponding PackedFunc - will recieve an TVMArrayHandle whose content is valid during the callback period. - If the PackedFunc is a python callback, then the corresponding argument is NDArray. - - Parameters - ---------- - args : list of Expr or Buffer. - Positional arguments. - - Returns - ------- - call : Expr - The call expression. - - See Also - -------- - tvm.extern : Create tensor with extern function call. - """ - call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args] - return _make.Call( - "int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0) - - -def call_pure_intrin(dtype, func_name, *args): - """Build expression by calling a pure intrinsic function. - - Intrinsics can be overloaded with multiple data types via - the intrinsic translation rule. - - Parameters - ---------- - dtype : str - The data type of the result. - - func_name: str - The intrinsic function name. - - args : list - Positional arguments. - - Returns - ------- - call : Expr - The call expression. - """ - args = convert(args) - return _make.Call( - dtype, func_name, convert(args), _Call.PureIntrinsic, None, 0) - - -def call_intrin(dtype, func_name, *args): - """Build expression by calling an intrinsic function. - - Intrinsics can be overloaded with multiple data types via - the intrinsic translation rule. - - Parameters - ---------- - dtype : str - The data type of the result. - - func_name: str - The intrinsic function name. - - args : list - Positional arguments. - - Returns - ------- - call : Expr - The call expression. - """ - args = convert(args) - return _make.Call( - dtype, func_name, convert(args), _Call.Intrinsic, None, 0) - - -def call_pure_extern(dtype, func_name, *args): - """Build expression by calling a pure extern function. - - Parameters - ---------- - dtype : str - The data type of the result. - - func_name: str - The extern function name. - - args : list - Positional arguments. - - Returns - ------- - call : Expr - The call expression. - """ - return _make.Call( - dtype, func_name, convert(args), _Call.PureExtern, None, 0) - - -def call_extern(dtype, func_name, *args): - """Build expression by calling a extern function. - - Parameters - ---------- - dtype : str - The data type of the result. - - func_name: str - The extern function name. - - args : list - Positional arguments. - - Returns - ------- - call : Expr - The call expression. - """ - return _make.Call( - dtype, func_name, convert(args), _Call.Extern, None, 0) - - -def call_llvm_intrin(dtype, name, *args): - """Build expression by calling an llvm intrinsic function - - Parameters - ---------- - dtype : str - The data type of the result. - - name : str - The name of the llvm intrinsic function. - - args : list - Poistional arguments. - - Returns - ------- - call : Expr - The call expression. - """ - llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(name) - assert llvm_id != 0, "%s is not an LLVM intrinsic" % name - return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) - - -def exp(x): - """Take exponetial of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "exp", x) - - -def erf(x): - """Take gauss error function of the input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "erf", x) - - -def tanh(x): - """Take hyperbolic tanh of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "tanh", x) - - -def sigmoid(x): - """Quick function to get sigmoid - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "sigmoid", x) - - -def log(x): - """Take log of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "log", x) - -def cos(x): - """Take cos of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "cos", x) - -def sin(x): - """Take sin of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "sin", x) - -def atan(x): - """Take atan of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "atan", x) - -def sqrt(x): - """Take square root of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "sqrt", x) - - -def rsqrt(x): - """Take reciprocal of square root of input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "rsqrt", x) - - -def floor(x): - """Take floor of float input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.floor(x) - - -def ceil(x): - """Take ceil of float input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.ceil(x) - - -def trunc(x): - """Get truncated value of the input. - - The truncated value of the scalar x is the - nearest integer i which is closer to zero than x is. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.trunc(x) - - -def abs(x): - """Get absolute value of the input element-wise. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.abs(x) - - -def round(x): - """Round elements of the array to the nearest integer. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.round(x) - - -def nearbyint(x): - """Round elements of the array to the nearest integer. - This intrinsic uses llvm.nearbyint instead of llvm.round - which is faster but will results different from tvm.round. - Notably nearbyint rounds according to the rounding mode, - whereas tvm.round (llvm.round) ignores that. - For differences between the two see: - https://en.cppreference.com/w/cpp/numeric/math/round - https://en.cppreference.com/w/cpp/numeric/math/nearbyint - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.nearbyint(x) - - -def isnan(x): - """Check if input value is Nan. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return _make.isnan(x) - - -def power(x, y): - """x power y - - Parameters - ---------- - x : Expr - Input argument. - - y : Expr - The exponent - - Returns - ------- - z : Expr - The result. - """ - return _make._OpPow(convert(x), convert(y)) - - -def popcount(x): - """Count the number of set bits in input x. - - Parameters - ---------- - x : Expr - Input argument. - - Returns - ------- - y : Expr - The result. - """ - return call_pure_intrin(x.dtype, "popcount", x) - -def fmod(x, y): - """Return the remainder of x divided by y with the same sign as x. - - Parameters - ---------- - x : Expr - Input argument. - y : Expr - Input argument. - - Returns - ------- - z : Expr - The result. - """ - return call_pure_intrin(x.dtype, "fmod", x, y) - - -def if_then_else(cond, t, f): - """Conditional selection expression. - - Parameters - ---------- - cond : Expr - The condition - - t : Expr - The result expression if cond is true. - - f : Expr - The result expression if cond is false. - - Returns - ------- - result : Node - The result of conditional expression. - - Note - ---- - Unlike Select, if_then_else will not execute - the branch that does not satisfy the condition. - You can use it to guard against out of bound access. - Unlike Select, if_then_else cannot be vectorized - if some lanes in the vector have different conditions. - """ - return _make._OpIfThenElse(convert(cond), convert(t), convert(f)) - - -# Intrinsic rule related code -def register_intrin_rule(target, intrin, f=None, override=False): - """Register an intrinsic function generation rule. - - Intrinsic generation rules are callback functions for - code generator to get device specific calls. - This function simply translates to. - - :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)` - - TVM may already pre-register intrinsic rules in the backend. - However, user can use this function to change the intrinsic translation - behavior or add new intrinsic rules during runtime. - - Parameters - ---------- - target : str - The name of codegen target. - - intrin : str - The name of the intrinsic. - - f : function, optional - The function to be registered. - - override: boolean optional - Whether override existing entry. - - Returns - ------- - fregister : function - Register function if f is not specified. - - Examples - -------- - The following code registers exp expansion rule for opencl. - - .. code-block:: python - - register_intrin_rule("opencl", "exp", my_exp_rule, override=True) - """ - return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) - - -def _rule_float_suffix(op): - """Intrinsic rule: Add float suffix if it is float32. - - This is an example intrinsic generation rule. - - Parameters - ---------- - op : Expr - The call expression of original intrinsic. - - Returns - ------- - ret : Expr - The translated intrinsic rule. - Return same op if no translation is possible. - - See Also - -------- - register_intrin_rule : The registeration function for intrin rule. - """ - if op.dtype == "float32": - return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) - if op.dtype == "float64": - return call_pure_extern(op.dtype, op.name, *op.args) - return op - - -def _rule_float_direct(op): - """Intrinsic rule: Directly call pure extern function for floats. - - This is an example intrinsic generation rule. - - Parameters - ---------- - op : Expr - The call expression of original intrinsic. - - Returns - ------- - ret : Expr - The translated intrinsic rule. - Return same op if no translation is possible. - - See Also - -------- - register_intrin_rule : The registeration function for intrin rule. - """ - if str(op.dtype).startswith("float"): - return call_pure_extern(op.dtype, op.name, *op.args) - return None - -@tvm._ffi.register_func("tvm.default_trace_action") -def _tvm_default_trace_action(*args): - print(list(args)) - -def trace(args, trace_action="tvm.default_trace_action"): - """Trace tensor data at the runtime. - - The trace function allows to trace specific tensor at the - runtime. The tracing value should come as last argument. - The trace action should be specified, by default - tvm.default_trace_action is used. - - Parameters - ---------- - args : list of Expr or Buffers. - Positional arguments. - - trace_action : str. - The name of the trace action. - - Returns - ------- - call : Expr - The call expression. - - See Also - -------- - tvm.call_packed : Creates packed function. - """ - if not isinstance(args, list): - raise Exception("tvm.trace consumes the args as list type") - call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args] - call_args.insert(0, trace_action) - return _make.Call( - args[-1].dtype, "tvm_call_trace_packed", call_args, _Call.Intrinsic, None, 0) - -# opencl pattern for exp -register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) -# default pattern for exp -register_intrin_rule("default", "exp", _rule_float_suffix, override=True) +# pylint:disable=unused-wildcard-import, wildcard-import, redefined-builtin +"""Backwared compatible layer for intrin.""" +from .tir.op import * diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index d47e24054d02..f122956de9c0 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -24,7 +24,7 @@ from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range from .adt import Constructor, TypeData from .module import IRModule -from .attrs import Attrs +from .attrs import Attrs, make_node from .container import Array, Map from . import transform diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index a5967394e72e..f30a18f6aee2 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -18,6 +18,7 @@ import tvm._ffi from tvm.runtime import Object +import tvm.runtime._ffi_node_api from . import _ffi_api @@ -91,3 +92,40 @@ def get_str(self, key): def __getitem__(self, item): return self.__getattr__(item) + +def make_node(type_key, **kwargs): + """Make a new IR node by its type key and fields + + Parameters + ---------- + type_key : str + The type key of the node. + + **kwargs : dict + The fields of the node. + + Returns + ------- + node : Node + The corresponding IR Node + + Note + ---- + If the created node is instance of AttrsNode, then + the creator function will also run bound checks and + default value setup as supported by Attrs. + + Example + ------- + The following code constructs a IntImm object + + .. code-block:: python + + x = tvm.ir.make_node("IntImm", dtype="int32", value=10) + assert isinstance(x, tvm.tir.IntImm) + assert x.value == 10 + """ + args = [type_key] + for k, v in kwargs.items(): + args += [k, v] + return tvm.runtime._ffi_node_api.MakeNode(*args) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 3314ef130e25..07ed8e8f8de1 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -53,7 +53,7 @@ def astext(self, show_meta_data=True, annotate=None): return _ffi_api.AsText(self, show_meta_data, annotate) def __str__(self): - return self.astext(show_meta_data=False) + return _ffi_api.PrettyPrint(self) @tvm._ffi.register_object("relay.SourceName") diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 46acd16b8031..d29e73a9b10e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -99,3 +99,23 @@ class Range(Node): You do not need to create a Range explicitly. Python lists and tuples will be converted automatically to a Range in API functions. """ + @staticmethod + def make_by_min_extent(min_value, extent): + """Construct a Range by min and extent. + + This constructs a range in [min_value, min_value + extent) + + Parameters + ---------- + min_value : PrimExpr + The minimum value of the range. + + extent : PrimExpr + The extent of the range. + + Returns + ------- + rng : Range + The constructed range. + """ + return _ffi_api.range_by_min_extent(min_value, extent) diff --git a/python/tvm/make.py b/python/tvm/make.py index 7f94d1031d9a..089c3938723b 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-import """namespace of IR node builder make function This namespace is used for developers. While you do not see any declarations. @@ -23,19 +24,22 @@ You can use make function to build the IR node. """ import tvm._ffi +import tvm.ir +from tvm.ir import make_node as node +from tvm.tir import Call -def range_by_min_extent(min_value, extent): +def make_by_min_extent(min_value, extent): """Construct a Range by min and extent. This constructs a range in [min_value, min_value + extent) Parameters ---------- - min_value : Expr + min_value : PrimExpr The minimum value of the range. - extent : Expr + extent : PrimExpr The extent of the range. Returns @@ -43,45 +47,6 @@ def range_by_min_extent(min_value, extent): rng : Range The constructed range. """ - return _range_by_min_extent(min_value, extent) - - -def node(type_key, **kwargs): - """Make a new DSL node by its type key and fields - - Parameters - ---------- - type_key : str - The type key of the node. - - **kwargs : dict - The fields of the node. - - Returns - ------- - node : Node - The corresponding DSL Node - - Note - ---- - If the created node is instance of AttrsNode, then - the creator function will also run bound checks and - default value setup as supported by Attrs. - - Example - ------- - The following code constructs a IntImm object - - .. code-block:: python - - x = tvm.make.node("IntImm", dtype="int32", value=10) - assert isinstance(x, tvm.expr.IntImm) - assert x.value == 10 - """ - args = [type_key] - for k, v in kwargs.items(): - args += [k, v] - return _Node(*args) - + return tvm.ir.Range.make_by_min_extent(min_value, extent) tvm._ffi._init_api("tvm.make") diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index c1a413098e84..78ab8ff63ec3 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -509,7 +509,7 @@ def mk_func( _, type_params = zip(*type_params) self.exit_var_scope() - attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None + attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None return expr.Function(var_list, body, ret_type, type_params, attrs) @spanify diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 68b2b1c97c03..6929be0726f2 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -181,11 +181,11 @@ def _update_target(self, target): raise ValueError("Target is not set in env or passed as argument.") tgts = {} if isinstance(target, (str, tvm.target.Target)): - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type) + dev_type = tvm.tir.IntImm("int32", tvm.nd.context(str(target)).device_type) tgts[dev_type] = tvm.target.create(target) elif isinstance(target, dict): for dev, tgt in target.items(): - dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type) + dev_type = tvm.tir.IntImm("int32", tvm.nd.context(dev).device_type) tgts[dev_type] = tvm.target.create(tgt) else: raise TypeError("target is expected to be str, tvm.target.Target, " + diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ac2ea9d0b1bb..f920682de2c7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -932,7 +932,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm)): + if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True break diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8a4cb2f632c6..e6053b887d38 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -557,7 +557,7 @@ def split_shape_func(attrs, inputs, _): """ Shape function for split op. """ - if isinstance(attrs.indices_or_sections, (int, tvm.expr.IntImm)): + if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)): indices_or_sections = get_const_int(attrs.indices_or_sections) else: indices_or_sections = get_const_tuple(attrs.indices_or_sections) diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 650bf9d1aab1..cd48d495a74a 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -14,126 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-import """The computation schedule api of TVM.""" import tvm._ffi - from tvm._ffi.base import string_types + from tvm.runtime import Object, convert from tvm.ir import container as _container +from tvm.tir import expr as _expr, Buffer from . import _api_internal from . import tensor as _tensor -from . import expr as _expr - - -@tvm._ffi.register_object -class Buffer(Object): - """Symbolic data buffer in TVM. - - Buffer provide a way to represent data layout - specialization of data structure in TVM. - - Do not construct directly, use :any:`decl_buffer` instead. - See the documentation of :any:`decl_buffer` for more details. - - See Also - -------- - decl_buffer : Declare a buffer - """ - READ = 1 - WRITE = 2 - - def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): - """Get an access pointer to the head of buffer. - - This is the recommended method to get buffer data - ptress when interacting with external functions. - - Parameters - ---------- - access_mask : int - The access pattern MASK. Indicate whether the - access will read or write to the data content. - - ptr_type : str, optional - The data type of the result pointer. Do not specify - unless we want to cast pointer to specific type. - - content_lanes: int, optional - The number of lanes for the data type. This value - is greater than one for vector types. - - offset: Expr, optional - The offset of pointer. We can use it to offset by - the number of elements from the address of ptr. - - Examples - -------- - .. code-block:: python - - import tvm.schedule.Buffer - # Get access ptr for read - buffer.access_ptr("r") - # Get access ptr for read/write with bitmask - buffer.access_ptr(Buffer.READ | Buffer.WRITE) - # Get access ptr for read/write with str flag - buffer.access_ptr("rw") - # Get access ptr for read with offset - buffer.access_ptr("r", offset = 100) - """ - if isinstance(access_mask, string_types): - mask = 0 - for value in access_mask: - if value == "r": - mask = mask | Buffer.READ - elif value == "w": - mask = mask | Buffer.WRITE - else: - raise ValueError("Unknown access_mask %s" % access_mask) - access_mask = mask - offset = convert(offset) - return _api_internal._BufferAccessPtr(self, access_mask, ptr_type, - content_lanes, offset) - - def vload(self, begin, dtype=None): - """Generate an Expr that loads dtype from begin index. - - Parameters - ---------- - begin : Array of Expr - The beginning index in unit of Buffer.dtype - - dtype : str - The data type to be loaded, - can be vector type which have lanes that is multiple of Buffer.dtype - - Returns - ------- - load : Expr - The corresponding load expression. - """ - begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin - dtype = dtype if dtype else self.dtype - return _api_internal._BufferVLoad(self, begin, dtype) - - def vstore(self, begin, value): - """Generate a Stmt that store value into begin index. - - Parameters - ---------- - begin : Array of Expr - The beginning index in unit of Buffer.dtype - - value : Expr - The value to be stored. - - Returns - ------- - store : Stmt - The corresponding store stmt. - """ - begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin - return _api_internal._BufferVStore(self, begin, value) @tvm._ffi.register_object diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index abe8436a55ba..287649670fb0 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -60,3 +60,4 @@ from .generic_func import generic_func, get_native_generic_func, override_native_generic_func from . import datatype from . import codegen +from .intrin import register_intrin_rule diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index a9506b3339cb..328568a360bc 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -19,7 +19,7 @@ import tvm.runtime._ffi_api from tvm.runtime import convert, DataType -from tvm.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm def register(type_name, type_code): diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py new file mode 100644 index 000000000000..acb0efe0ea64 --- /dev/null +++ b/python/tvm/target/intrin.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Target dependent intrinsic registration.""" +import tvm._ffi +from tvm.tir import call_pure_extern + + +# Intrinsic rule related code +def register_intrin_rule(target, intrin, f=None, override=False): + """Register an intrinsic function generation rule. + + Intrinsic generation rules are callback functions for + code generator to get device specific calls. + This function simply translates to. + + :code:`register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override)` + + TVM may already pre-register intrinsic rules in the backend. + However, user can use this function to change the intrinsic translation + behavior or add new intrinsic rules during runtime. + + Parameters + ---------- + target : str + The name of codegen target. + + intrin : str + The name of the intrinsic. + + f : function, optional + The function to be registered. + + override: boolean optional + Whether override existing entry. + + Returns + ------- + fregister : function + Register function if f is not specified. + + Examples + -------- + The following code registers exp expansion rule for opencl. + + .. code-block:: python + + register_intrin_rule("opencl", "exp", my_exp_rule, override=True) + """ + return tvm._ffi.register_func("tvm.intrin.rule.%s.%s" % (target, intrin), f, override) + + +def _rule_float_suffix(op): + """Intrinsic rule: Add float suffix if it is float32. + + This is an example intrinsic generation rule. + + Parameters + ---------- + op : PrimExpr + The call expression of original intrinsic. + + Returns + ------- + ret : PrimExpr + The translated intrinsic rule. + Return same op if no translation is possible. + + See Also + -------- + register_intrin_rule : The registeration function for intrin rule. + """ + if op.dtype == "float32": + return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) + if op.dtype == "float64": + return call_pure_extern(op.dtype, op.name, *op.args) + return op + + +def _rule_float_direct(op): + """Intrinsic rule: Directly call pure extern function for floats. + + This is an example intrinsic generation rule. + + Parameters + ---------- + op : PrimExpr + The call expression of original intrinsic. + + Returns + ------- + ret : PrimExpr + The translated intrinsic rule. + Return same op if no translation is possible. + + See Also + -------- + register_intrin_rule : The registeration function for intrin rule. + """ + if str(op.dtype).startswith("float"): + return call_pure_extern(op.dtype, op.name, *op.args) + return None + +# opencl pattern for exp +register_intrin_rule("opencl", "exp", _rule_float_direct, override=True) +# default pattern for exp +register_intrin_rule("default", "exp", _rule_float_suffix, override=True) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index bd25f845abb1..00bd9d146b36 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -19,10 +19,9 @@ import tvm._ffi from tvm.runtime import Object, ObjectGeneric, convert_to_object +from tvm.tir import expr as _expr from . import _api_internal -from . import make as _make -from . import expr as _expr class TensorSlice(ObjectGeneric, _expr.ExprOp): @@ -74,7 +73,7 @@ def __call__(self, *indices): else: raise ValueError("The indices must be expression") - return _make.Call(self.dtype, self.op.name, + return _expr.Call(self.dtype, self.op.name, args, _expr.Call.Halide, self.op, self.value_index) @@ -207,136 +206,3 @@ class HybridOp(Operation): def axis(self): """Represent the IterVar axis, also defined when it is a HybridOp""" return self.__getattr__("axis") - - -@tvm._ffi.register_object -class Layout(Object): - """Layout is composed of upper cases, lower cases and numbers, - where upper case indicates a primal axis and - the corresponding lower case with factor size indicates the subordinate axis. - For example, NCHW16c can describe a 5-D tensor of - [batch_size, channel, height, width, channel_block]. - Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). - - Do not construct directly, use :any:`layout` instead. - See the documentation of :any:`layout` for more details. - - See Also - -------- - layout : Declare a layout - """ - def __len__(self): - return _api_internal._LayoutNdim(self) - - def __contains__(self, axis): - return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name - - def __getitem__(self, index): - if index >= len(self): - raise IndexError("Layout index out of range") - return _api_internal._LayoutGetItem(self, index) - - def index_of(self, axis): - """Get the index of an axis - - Parameters - ---------- - axis : str - The axis name, need to be [a-z,A-Z] - - Returns - ------- - index : int - The index of the axis, -1 if not found. - """ - return _api_internal._LayoutIndexOf(self, axis) - - def factor_of(self, axis): - """Get the factor size of the subordinate axis. - - Parameters - ---------- - axis : str - The axis name, need to be [a-z,A-Z] - - Returns - ------- - factor : int - the size of the subordinate-axis of axis (if axis is a primal-axis), - or the size of axis itself (if axis is a subordinate-axis). - Return -1 if axis is not in the layout. - """ - return _api_internal._LayoutFactorOf(self, axis) - - -@tvm._ffi.register_object -class BijectiveLayout(Object): - """Bijective mapping for two layouts (src-layout and dst-layout). - It provides shape and index conversion between each other. - - Do not construct directly, use :any:`bijective_layout` instead. - See the documentation of :any:`bijective_layout` for more details. - - See Also - -------- - bijective_layout : Declare a bijective layout converter - """ - def forward_index(self, index): - """Given the indices of the src-layout, infer the dst index. - - Parameters - ---------- - index: Array of Expr - The indices in src-layout. - - Returns - ------- - dst_index: Array of Expr - The inferred indices in dst-layout. - """ - return _api_internal._BijectiveLayoutForwardIndex(self, index) - - def backward_index(self, index): - """Given the indices of the dst-layout, infer the src index. - - Parameters - ---------- - index: Array of Expr - The indices in dst-layout. - - Returns - ------- - src_index: Array of Expr - The inferred indices in src-layout. - """ - return _api_internal._BijectiveLayoutBackwardIndex(self, index) - - def forward_shape(self, shape): - """Given the shape of the src-layout, infer the dst shape. - - Parameters - ---------- - shape: Array of Expr - The shape in src-layout. - - Returns - ------- - dst_shape: Array of Expr - The inferred shape in dst-layout. - """ - return _api_internal._BijectiveLayoutForwardShape(self, shape) - - def backward_shape(self, shape): - """Given the shape of the dst-layout, infer the src shape. - - Parameters - ---------- - shape: Array of Expr - The shape in dst-layout. - - Returns - ------- - src_shape: Array of Expr - The inferred shape in src-layout. - """ - return _api_internal._BijectiveLayoutBackwardShape(self, shape) diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 2a01fe773f94..1fd8bee720ba 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -18,11 +18,12 @@ import tvm._ffi from tvm.runtime import Object +from tvm.ir import Range +from tvm.tir import expr as _expr +from tvm.tir import stmt as _stmt + from . import _api_internal from . import api as _api -from . import expr as _expr -from . import stmt as _stmt -from . import make as _make from . import tensor as _tensor from . import schedule as _schedule from .build_module import current_build_config @@ -39,7 +40,7 @@ def _get_region(tslice): begin = idx.var else: begin = idx - region.append(_make.range_by_min_extent(begin, 1)) + region.append(Range.make_by_min_extent(begin, 1)) return region @tvm._ffi.register_object @@ -136,7 +137,7 @@ def decl_tensor_intrin(op, scalar_params = [] if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)): body = [body] - body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body] + body = [_stmt.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _api_internal._TensorIntrin( diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py new file mode 100644 index 000000000000..8621540f7223 --- /dev/null +++ b/python/tvm/tir/__init__.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Tensor-level IR""" +from tvm.ir import PrimExpr +from .buffer import Buffer, decl_buffer +from .data_layout import Layout, BijectiveLayout, bijective_layout, layout +from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast +from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod +from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not +from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let + +from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For +from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt +from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list + +from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern +from .op import call_llvm_intrin, min_value, max_value +from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil +from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else +from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod + +from . import ir_builder +from . import ir_pass diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py new file mode 100644 index 000000000000..1b60b8c81c6d --- /dev/null +++ b/python/tvm/tir/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir""" +import tvm._ffi + + +tvm._ffi._init_api("tir", __name__) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py new file mode 100644 index 000000000000..d0d01d7479be --- /dev/null +++ b/python/tvm/tir/buffer.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstraction for array data structures.""" +from numbers import Integral +import tvm._ffi + +from tvm._ffi.base import string_types +from tvm.runtime import Object, convert +from tvm.ir import PrimExpr +from . import _ffi_api + + +@tvm._ffi.register_object +class Buffer(Object): + """Symbolic data buffer in TVM. + + Buffer provide a way to represent data layout + specialization of data structure in TVM. + + Do not construct directly, use :py:func:`~decl_buffer` instead. + See the documentation of :py:func:`decl_buffer` for more details. + + See Also + -------- + decl_buffer : Declare a buffer + """ + READ = 1 + WRITE = 2 + + def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): + """Get an access pointer to the head of buffer. + + This is the recommended method to get buffer data + ptress when interacting with external functions. + + Parameters + ---------- + access_mask : int + The access pattern MASK. Indicate whether the + access will read or write to the data content. + + ptr_type : str, optional + The data type of the result pointer. Do not specify + unless we want to cast pointer to specific type. + + content_lanes: int, optional + The number of lanes for the data type. This value + is greater than one for vector types. + + offset: Expr, optional + The offset of pointer. We can use it to offset by + the number of elements from the address of ptr. + + Examples + -------- + .. code-block:: python + + # Get access ptr for read + buffer.access_ptr("r") + # Get access ptr for read/write with bitmask + buffer.access_ptr(Buffer.READ | Buffer.WRITE) + # Get access ptr for read/write with str flag + buffer.access_ptr("rw") + # Get access ptr for read with offset + buffer.access_ptr("r", offset = 100) + """ + if isinstance(access_mask, string_types): + mask = 0 + for value in access_mask: + if value == "r": + mask = mask | Buffer.READ + elif value == "w": + mask = mask | Buffer.WRITE + else: + raise ValueError("Unknown access_mask %s" % access_mask) + access_mask = mask + offset = convert(offset) + return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, + content_lanes, offset) + + def vload(self, begin, dtype=None): + """Generate an Expr that loads dtype from begin index. + + Parameters + ---------- + begin : Array of Expr + The beginning index in unit of Buffer.dtype + + dtype : str + The data type to be loaded, + can be vector type which have lanes that is multiple of Buffer.dtype + + Returns + ------- + load : Expr + The corresponding load expression. + """ + begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin + dtype = dtype if dtype else self.dtype + return _ffi_api.BufferVLoad(self, begin, dtype) + + def vstore(self, begin, value): + """Generate a Stmt that store value into begin index. + + Parameters + ---------- + begin : Array of Expr + The beginning index in unit of Buffer.dtype + + value : Expr + The value to be stored. + + Returns + ------- + store : Stmt + The corresponding store stmt. + """ + begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin + return _ffi_api.BufferVStore(self, begin, value) + + +def decl_buffer(shape, + dtype=None, + name="buffer", + data=None, + strides=None, + elem_offset=None, + scope="", + data_alignment=-1, + offset_factor=0, + buffer_type=""): + """Declare a new symbolic buffer. + + Normally buffer is created automatically during lower and build. + This is only needed if user want to specify their own buffer layout. + + See the note below for detailed discussion on usage of buffer. + + Parameters + ---------- + shape : tuple of Expr + The shape of the buffer. + + dtype : str, optional + The data type of the buffer. + + name : str, optional + The name of the buffer. + + data : Var, optional + The data pointer in the buffer. + + strides: array of Expr + The stride of the buffer. + + elem_offset: Expr, optional + The beginning offset of the array to data. + In terms of number of elements of dtype. + + scope: str, optional + The storage scope of the buffer, if not global. + If scope equals empty string, it means it is global memory. + + data_alignment: int, optional + The alignment of data pointer in bytes. + If -1 is passed, the alignment will be set to TVM's internal default. + + offset_factor: int, optional + The factor of elem_offset field, when set, + elem_offset is required to be multiple of offset_factor. + If 0 is pssed, the alignment will be set to 1. + if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. + + buffer_type: str, optional, {"", "auto_broadcast"} + auto_broadcast buffer allows one to implement broadcast computation + without considering whether dimension size equals to one. + TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. + + Returns + ------- + buffer : Buffer + The created buffer + + Example + ------- + Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, + + .. code-block:: python + + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") + Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") + s = tvm.create_schedule(C.op) + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + + Note + ---- + Buffer data structure reflects the DLTensor structure in dlpack. + While DLTensor data structure is very general, it is usually helpful + to create function that only handles specific case of data structure + and make compiled function benefit from it. + + If user pass strides and elem_offset is passed as None + when constructing the function, then the function will be specialized + for the DLTensor that is compact and aligned. + If user pass a fully generic symbolic array to the strides, + then the resulting function becomes fully generic. + """ + # pylint: disable=import-outside-toplevel + from .expr import Var + + shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape + dtype = "float32" if dtype is None else dtype + strides = () if strides is None else strides + if offset_factor != 0 and elem_offset is None: + shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + elem_offset = Var('%s_elem_offset' % name, shape_dtype) + if data is None: + data = Var(name, "handle") + return _ffi_api.Buffer( + data, dtype, shape, strides, elem_offset, name, scope, + data_alignment, offset_factor, buffer_type) diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py new file mode 100644 index 000000000000..fd8c7a942297 --- /dev/null +++ b/python/tvm/tir/data_layout.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Data layout.""" +import tvm._ffi + +from tvm.runtime import Object +from . import _ffi_api + +@tvm._ffi.register_object +class Layout(Object): + """Layout is composed of upper cases, lower cases and numbers, + where upper case indicates a primal axis and + the corresponding lower case with factor size indicates the subordinate axis. + For example, NCHW16c can describe a 5-D tensor of + [batch_size, channel, height, width, channel_block]. + Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). + + See Also + -------- + layout : Declare a layout + """ + def __len__(self): + return _ffi_api.LayoutNdim(self) + + def __contains__(self, axis): + return len(axis) == 1 and axis[0].isalpha() and axis[0] in self.name + + def __getitem__(self, index): + if index >= len(self): + raise IndexError("Layout index out of range") + return _ffi_api.LayoutGetItem(self, index) + + def index_of(self, axis): + """Get the index of an axis + + Parameters + ---------- + axis : str + The axis name, need to be [a-z,A-Z] + + Returns + ------- + index : int + The index of the axis, -1 if not found. + """ + return _ffi_api.LayoutIndexOf(self, axis) + + def factor_of(self, axis): + """Get the factor size of the subordinate axis. + + Parameters + ---------- + axis : str + The axis name, need to be [a-z,A-Z] + + Returns + ------- + factor : int + the size of the subordinate-axis of axis (if axis is a primal-axis), + or the size of axis itself (if axis is a subordinate-axis). + Return -1 if axis is not in the layout. + """ + return _ffi_api.LayoutFactorOf(self, axis) + + +@tvm._ffi.register_object +class BijectiveLayout(Object): + """Bijective mapping for two layouts (src-layout and dst-layout). + It provides shape and index conversion between each other. + + Do not construct directly, use :any:`bijective_layout` instead. + See the documentation of :any:`bijective_layout` for more details. + + Parameters + ---------- + src_layout : str or Layout + source layout. + + dst_layout : str or Layout + destination layout. + + See Also + -------- + bijective_layout : Declare a layout + """ + def forward_index(self, index): + """Given the indices of the src-layout, infer the dst index. + + Parameters + ---------- + index: Array of Expr + The indices in src-layout. + + Returns + ------- + dst_index: Array of Expr + The inferred indices in dst-layout. + """ + return _ffi_api.BijectiveLayoutForwardIndex(self, index) + + def backward_index(self, index): + """Given the indices of the dst-layout, infer the src index. + + Parameters + ---------- + index: Array of Expr + The indices in dst-layout. + + Returns + ------- + src_index: Array of Expr + The inferred indices in src-layout. + """ + return _ffi_api.BijectiveLayoutBackwardIndex(self, index) + + def forward_shape(self, shape): + """Given the shape of the src-layout, infer the dst shape. + + Parameters + ---------- + shape: Array of Expr + The shape in src-layout. + + Returns + ------- + dst_shape: Array of Expr + The inferred shape in dst-layout. + """ + return _ffi_api.BijectiveLayoutForwardShape(self, shape) + + def backward_shape(self, shape): + """Given the shape of the dst-layout, infer the src shape. + + Parameters + ---------- + shape: Array of Expr + The shape in dst-layout. + + Returns + ------- + src_shape: Array of Expr + The inferred shape in src-layout. + """ + return _ffi_api.BijectiveLayoutBackwardShape(self, shape) + + +def layout(layout_str): + """Create a layout node from a string. + + Parameters + ---------- + layout_str : str + A layout representation is composed of upper cases, lower cases and numbers, + where upper case indicates a primal axis and + the corresponding lower case with factor size indicates the subordinate axis. + For example, NCHW16c can describe a 5-D tensor of + [batch_size, channel, height, width, channel_block]. + Here subordinate axis channel_block=16 is the factor size of + the primal axis C (channel). + + Returns + ------- + layout : Layout + The created layout + """ + return _ffi_api.Layout(layout_str) + + +def bijective_layout(src_layout, dst_layout): + """Create a bijective layout mapping. + + Parameters + ---------- + src_layout : str or Layout + source layout. + + dst_layout : str or Layout + destination layout. + + Returns + ------- + bijective_layout : BijectiveLayout + The created bijective layout + """ + if isinstance(src_layout, str): + src_layout = layout(src_layout) + if isinstance(dst_layout, str): + dst_layout = layout(dst_layout) + return _ffi_api.BijectiveLayout(src_layout, dst_layout) diff --git a/python/tvm/expr.py b/python/tvm/tir/expr.py similarity index 77% rename from python/tvm/expr.py rename to python/tvm/tir/expr.py index 1ff069720ed7..92d6fbe42f17 100644 --- a/python/tvm/expr.py +++ b/python/tvm/tir/expr.py @@ -27,16 +27,16 @@ x = tvm.var("n") y = x + 2 - assert(isinstance(y, tvm.expr.Add)) + assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -# pylint: disable=missing-docstring import tvm._ffi -from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const -from . import make as _make +from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const +from tvm.ir import PrimExpr +import tvm.ir._ffi_api from . import generic as _generic -from . import _api_internal +from . import _ffi_api def div_ambiguity_error(): @@ -45,6 +45,7 @@ def div_ambiguity_error(): "please call div, indexdiv/indexmod, floordiv/floormod " + " or truncdiv/truncmod directly to avoid ambiguity in the code.") + def _dtype_is_int(value): if isinstance(value, int): return True @@ -53,6 +54,7 @@ def _dtype_is_int(value): class ExprOp(object): + """Operator overloading for Expr like expressions.""" def __add__(self, other): return _generic.add(self, other) @@ -98,44 +100,44 @@ def __rfloordiv__(self, other): return _generic.floordiv(other, self) def __mod__(self, other): - return _make._OpFloorMod(self, other) + return _ffi_api._OpFloorMod(self, other) def __neg__(self): neg_one = const(-1, self.dtype) return self.__mul__(neg_one) def __lshift__(self, other): - return _make.left_shift(self, other) + return _ffi_api.left_shift(self, other) def __rshift__(self, other): - return _make.right_shift(self, other) + return _ffi_api.right_shift(self, other) def __and__(self, other): - return _make.bitwise_and(self, other) + return _ffi_api.bitwise_and(self, other) def __rand__(self, other): - return _make.bitwise_and(other, self) + return _ffi_api.bitwise_and(other, self) def __or__(self, other): - return _make.bitwise_or(self, other) + return _ffi_api.bitwise_or(self, other) def __ror__(self, other): - return _make.bitwise_or(other, self) + return _ffi_api.bitwise_or(other, self) def __xor__(self, other): - return _make.bitwise_xor(self, other) + return _ffi_api.bitwise_xor(self, other) def __rxor__(self, other): - return _make.bitwise_xor(other, self) + return _ffi_api.bitwise_xor(other, self) def __invert__(self): - return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) + return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) def __lt__(self, other): - return _make._OpLT(self, other) + return _ffi_api._OpLT(self, other) def __le__(self, other): - return _make._OpLE(self, other) + return _ffi_api._OpLE(self, other) def __eq__(self, other): return EqualOp(self, other) @@ -144,10 +146,10 @@ def __ne__(self, other): return NotEqualOp(self, other) def __gt__(self, other): - return _make._OpGT(self, other) + return _ffi_api._OpGT(self, other) def __ge__(self, other): - return _make._OpGE(self, other) + return _ffi_api._OpGE(self, other) def __nonzero__(self): raise ValueError("Cannot use and / or / not operator to Expr, hint: " + @@ -161,15 +163,15 @@ def equal(self, other): Parameters ---------- - other : Expr + other : PrimExpr The other expression Returns ------- - ret : Expr + ret : PrimExpr The equality expression. """ - return _make._OpEQ(self, other) + return _ffi_api._OpEQ(self, other) def astype(self, dtype): """Cast the expression to other type. @@ -181,7 +183,7 @@ def astype(self, dtype): Returns ------- - expr : Expr + expr : PrimExpr Expression with new type """ return _generic.cast(self, dtype) @@ -195,10 +197,10 @@ class EqualOp(ObjectGeneric, ExprOp): Parameters ---------- - a : Expr + a : PrimExpr Left operand. - b : Expr + b : PrimExpr Right operand. """ # This class is not manipulated by C++. So use python's identity check function is sufficient @@ -216,7 +218,7 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _make._OpEQ(self.a, self.b) + return _ffi_api._OpEQ(self.a, self.b) class NotEqualOp(ObjectGeneric, ExprOp): @@ -227,10 +229,10 @@ class NotEqualOp(ObjectGeneric, ExprOp): Parameters ---------- - a : Expr + a : PrimExpr Left operand. - b : Expr + b : PrimExpr Right operand. """ # This class is not manipulated by C++. So use python's identity check function is sufficient @@ -248,30 +250,30 @@ def __bool__(self): def asobject(self): """Convert object.""" - return _make._OpNE(self.a, self.b) + return _ffi_api._OpNE(self.a, self.b) -class PrimExpr(ExprOp, Object): - """Base class of all tvm Expressions""" +class PrimExprWithOp(ExprOp, PrimExpr): + """Helper base class to inherit from PrimExpr.""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ - __hash__ = Object.__hash__ + __hash__ = PrimExpr.__hash__ -class ConstExpr(PrimExpr): +class ConstExpr(PrimExprWithOp): pass -class BinaryOpExpr(PrimExpr): +class BinaryOpExpr(PrimExprWithOp): pass -class CmpExpr(PrimExpr): +class CmpExpr(PrimExprWithOp): pass -class LogicalExpr(PrimExpr): +class LogicalExpr(PrimExprWithOp): pass @tvm._ffi.register_object("Variable") -class Var(PrimExpr): +class Var(PrimExprWithOp): """Symbolic variable. Parameters @@ -279,18 +281,18 @@ class Var(PrimExpr): name : str The name - dtype : int + dtype : str The data type """ def __init__(self, name, dtype): self.__init_handle_by_constructor__( - _api_internal._Var, name, dtype) + _ffi_api.Var, name, dtype) @tvm._ffi.register_object class SizeVar(Var): """Symbolic variable to represent a tensor index size - which is greater or equal to zero + which is greater or equal to zero. Parameters ---------- @@ -303,11 +305,34 @@ class SizeVar(Var): # pylint: disable=super-init-not-called def __init__(self, name, dtype): self.__init_handle_by_constructor__( - _api_internal._SizeVar, name, dtype) + _ffi_api.SizeVar, name, dtype) + + +@tvm._ffi.register_object +class CommReducer(Object): + """Communicative reduce operator + + Parameters + ---------- + lhs : List[Var] + The left arguments of the reducer. + + rhs : List[Var] + The right arguments of the reducer. + + result : List[PrimExpr] + The reduction results. + + identity_element : List[PrimExpr] + The identity elements. + """ + def __init__(self, lhs, rhs, result, identity_element): + self.__init_handle_by_constructor__( + _ffi_api.CommReducer, lhs, rhs, result, identity_element) @tvm._ffi.register_object -class Reduce(PrimExpr): +class Reduce(PrimExprWithOp): """Reduce node. Parameters @@ -321,7 +346,7 @@ class Reduce(PrimExpr): rdom : list of IterVar The iteration domain - condition : Expr + condition : PrimExpr The reduce condition. value_index : int @@ -329,7 +354,7 @@ class Reduce(PrimExpr): """ def __init__(self, combiner, src, rdom, condition, value_index): self.__init_handle_by_constructor__( - _make.Reduce, combiner, src, rdom, + _ffi_api.Reduce, combiner, src, rdom, condition, value_index) @@ -347,7 +372,7 @@ class FloatImm(ConstExpr): """ def __init__(self, dtype, value): self.__init_handle_by_constructor__( - _make.FloatImm, dtype, value) + tvm.ir._ffi_api.FloatImm, dtype, value) @tvm._ffi.register_object class IntImm(ConstExpr): @@ -363,7 +388,7 @@ class IntImm(ConstExpr): """ def __init__(self, dtype, value): self.__init_handle_by_constructor__( - _make.IntImm, dtype, value) + tvm.ir._ffi_api.IntImm, dtype, value) def __int__(self): return self.value @@ -380,7 +405,7 @@ class StringImm(ConstExpr): """ def __init__(self, value): self.__init_handle_by_constructor__( - _make.StringImm, value) + _ffi_api.StringImm, value) def __eq__(self, other): if isinstance(other, ConstExpr): @@ -394,7 +419,7 @@ def __ne__(self, other): @tvm._ffi.register_object -class Cast(PrimExpr): +class Cast(PrimExprWithOp): """Cast expression. Parameters @@ -402,12 +427,12 @@ class Cast(PrimExpr): dtype : str The data type - value : Expr + value : PrimExpr The value of the function. """ def __init__(self, dtype, value): self.__init_handle_by_constructor__( - _make.Cast, dtype, value) + _ffi_api.Cast, dtype, value) @tvm._ffi.register_object @@ -416,15 +441,15 @@ class Add(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Add, a, b) + _ffi_api.Add, a, b) @tvm._ffi.register_object @@ -433,15 +458,15 @@ class Sub(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Sub, a, b) + _ffi_api.Sub, a, b) @tvm._ffi.register_object @@ -450,15 +475,15 @@ class Mul(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Mul, a, b) + _ffi_api.Mul, a, b) @tvm._ffi.register_object @@ -467,15 +492,15 @@ class Div(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Div, a, b) + _ffi_api.Div, a, b) @tvm._ffi.register_object @@ -484,15 +509,15 @@ class Mod(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Mod, a, b) + _ffi_api.Mod, a, b) @tvm._ffi.register_object @@ -501,15 +526,15 @@ class FloorDiv(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.FloorDiv, a, b) + _ffi_api.FloorDiv, a, b) @tvm._ffi.register_object @@ -518,15 +543,15 @@ class FloorMod(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.FloorMod, a, b) + _ffi_api.FloorMod, a, b) @tvm._ffi.register_object @@ -535,15 +560,15 @@ class Min(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Min, a, b) + _ffi_api.Min, a, b) @tvm._ffi.register_object @@ -552,15 +577,15 @@ class Max(BinaryOpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Max, a, b) + _ffi_api.Max, a, b) @tvm._ffi.register_object @@ -569,15 +594,15 @@ class EQ(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.EQ, a, b) + _ffi_api.EQ, a, b) @tvm._ffi.register_object @@ -586,15 +611,15 @@ class NE(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.NE, a, b) + _ffi_api.NE, a, b) @tvm._ffi.register_object @@ -603,15 +628,15 @@ class LT(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.LT, a, b) + _ffi_api.LT, a, b) @tvm._ffi.register_object @@ -620,15 +645,15 @@ class LE(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.LE, a, b) + _ffi_api.LE, a, b) @tvm._ffi.register_object @@ -637,15 +662,15 @@ class GT(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.GT, a, b) + _ffi_api.GT, a, b) @tvm._ffi.register_object @@ -654,15 +679,15 @@ class GE(CmpExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.GE, a, b) + _ffi_api.GE, a, b) @tvm._ffi.register_object @@ -671,15 +696,15 @@ class And(LogicalExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.And, a, b) + _ffi_api.And, a, b) @tvm._ffi.register_object @@ -688,15 +713,15 @@ class Or(LogicalExpr): Parameters ---------- - a : Expr + a : PrimExpr The left hand operand. - b : Expr + b : PrimExpr The right hand operand. """ def __init__(self, a, b): self.__init_handle_by_constructor__( - _make.Or, a, b) + _ffi_api.Or, a, b) @tvm._ffi.register_object @@ -705,16 +730,16 @@ class Not(LogicalExpr): Parameters ---------- - a : Expr + a : PrimExpr The input value """ def __init__(self, a): self.__init_handle_by_constructor__( - _make.Not, a) + _ffi_api.Not, a) @tvm._ffi.register_object -class Select(PrimExpr): +class Select(PrimExprWithOp): """Select node. Note @@ -726,23 +751,23 @@ class Select(PrimExpr): Parameters ---------- - condition : Expr + condition : PrimExpr The condition expression. - true_value : Expr + true_value : PrimExpr The value to take when condition is true. - false_value : Expr + false_value : PrimExpr The value to take when condition is false. """ def __init__(self, condition, true_value, false_value): self.__init_handle_by_constructor__( - _make.Select, condition, true_value, false_value) + _ffi_api.Select, condition, true_value, false_value) @tvm._ffi.register_object -class Load(PrimExpr): +class Load(PrimExprWithOp): """Load node. Parameters @@ -753,24 +778,25 @@ class Load(PrimExpr): buffer_var : Var The buffer variable in the load expression. - index : Expr + index : PrimExpr The index in the load. - predicate : Expr + predicate : PrimExpr The load predicate. """ - def __init__(self, dtype, buffer_var, index, predicate): + def __init__(self, dtype, buffer_var, index, predicate=None): + args = [] if predicate is None else [predicate] self.__init_handle_by_constructor__( - _make.Load, dtype, buffer_var, index, predicate) + _ffi_api.Load, dtype, buffer_var, index, *args) @tvm._ffi.register_object -class Ramp(PrimExpr): +class Ramp(PrimExprWithOp): """Ramp node. Parameters ---------- - base : Expr + base : PrimExpr The base expression. stride : ramp stride @@ -781,16 +807,16 @@ class Ramp(PrimExpr): """ def __init__(self, base, stride, lanes): self.__init_handle_by_constructor__( - _make.Ramp, base, stride, lanes) + _ffi_api.Ramp, base, stride, lanes) @tvm._ffi.register_object -class Broadcast(PrimExpr): +class Broadcast(PrimExprWithOp): """Broadcast node. Parameters ---------- - value : Expr + value : PrimExpr The value of the expression. lanes : int @@ -798,11 +824,11 @@ class Broadcast(PrimExpr): """ def __init__(self, value, lanes): self.__init_handle_by_constructor__( - _make.Broadcast, value, lanes) + _ffi_api.Broadcast, value, lanes) @tvm._ffi.register_object -class Shuffle(PrimExpr): +class Shuffle(PrimExprWithOp): """Shuffle node. Parameters @@ -815,11 +841,11 @@ class Shuffle(PrimExpr): """ def __init__(self, vectors, indices): self.__init_handle_by_constructor__( - _make.Shuffle, vectors, indices) + _ffi_api.Shuffle, vectors, indices) @tvm._ffi.register_object -class Call(PrimExpr): +class Call(PrimExprWithOp): """Call node. Parameters @@ -850,11 +876,11 @@ class Call(PrimExpr): PureIntrinsic = 5 def __init__(self, dtype, name, args, call_type, func, value_index): self.__init_handle_by_constructor__( - _make.Call, dtype, name, args, call_type, func, value_index) + _ffi_api.Call, dtype, name, args, call_type, func, value_index) @tvm._ffi.register_object -class Let(PrimExpr): +class Let(PrimExprWithOp): """Let node. Parameters @@ -862,12 +888,12 @@ class Let(PrimExpr): var : Var The variable in the binding. - value : Expr + value : PrimExpr The value in to be binded. - body : Expr + body : PrimExpr The body expression. """ def __init__(self, var, value, body): self.__init_handle_by_constructor__( - _make.Let, var, value, body) + _ffi_api.Let, var, value, body) diff --git a/python/tvm/tir/generic.py b/python/tvm/tir/generic.py new file mode 100644 index 000000000000..8a9cf8eeb50d --- /dev/null +++ b/python/tvm/tir/generic.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Generic opertors in TVM. +We follow the numpy naming convention for this interface +(e.g., tvm.generic.multitply ~ numpy.multiply). +The default implementation is used by tvm.ExprOp. +""" +# pylint: disable=unused-argument +from . import _ffi_api + +# Operator precedence used when overloading. +__op_priority__ = 0 + + +def add(lhs, rhs): + """Generic add operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of add operaton. + """ + return _ffi_api._OpAdd(lhs, rhs) + + +def subtract(lhs, rhs): + """Generic subtract operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of subtract operaton. + """ + return _ffi_api._OpSub(lhs, rhs) + + +def multiply(lhs, rhs): + """Generic multiply operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of multiply operaton. + """ + return _ffi_api._OpMul(lhs, rhs) + +def divide(lhs, rhs): + """Generic divide operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of divide operaton. + """ + return _ffi_api._OpDiv(lhs, rhs) + +def floordiv(lhs, rhs): + """Generic floordiv operator. + + Parameters + ---------- + lhs : object + The left operand. + rhs : object + The right operand. + + Returns + ------- + op : tvm.Expr + The result Expr of divide operaton. + """ + return _ffi_api._OpFloorDiv(lhs, rhs) + + +def cast(src, dtype): + """Generic cast operator. + + Parameters + ---------- + src : object + The source operand. + + Returns + ------- + op : tvm.Expr + The result Expr of divide operaton. + """ + return _ffi_api._cast(dtype, src) diff --git a/python/tvm/ir_builder.py b/python/tvm/tir/ir_builder.py similarity index 89% rename from python/tvm/ir_builder.py rename to python/tvm/tir/ir_builder.py index 4cc7f4f8082d..b56e15377358 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -16,15 +16,13 @@ # under the License. """Developer API of IR node builder make function.""" from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, DataType +from tvm.runtime import ObjectGeneric, DataType, convert, const from tvm.ir import container as _container -from . import api as _api from . import stmt as _stmt from . import expr as _expr -from . import make as _make from . import ir_pass as _pass -from .expr import Call as _Call + class WithScope(object): """Auxiliary scope with""" @@ -53,7 +51,7 @@ class BufferVar(ObjectGeneric): .. code-block:: python # The following code generate IR for x[0] = x[ - ib = tvm.ir_builder.create() + ib = tvm.tir.ir_builder.create() x = ib.pointer("float32") x[0] = x[10] + 1 @@ -78,19 +76,19 @@ def dtype(self): def __getitem__(self, index): t = DataType(self._content_type) if t.lanes > 1: - index = _make.Ramp(index * t.lanes, 1, t.lanes) - return _make.Load(self._content_type, self._buffer_var, index) + index = _expr.Ramp(index * t.lanes, 1, t.lanes) + return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): - value = _api.convert(value) + value = convert(value) if value.dtype != self._content_type: raise ValueError( "data type does not match content type %s vs %s" % ( value.dtype, self._content_type)) t = DataType(self._content_type) if t.lanes > 1: - index = _make.Ramp(index * t.lanes, 1, t.lanes) - self._builder.emit(_make.Store(self._buffer_var, value, index)) + index = _expr.Ramp(index * t.lanes, 1, t.lanes) + self._builder.emit(_stmt.Store(self._buffer_var, value, index)) class IRBuilder(object): @@ -117,7 +115,7 @@ def _pop_seq(self): """Pop sequence from stack""" seq = self._seq_stack.pop() if not seq or callable(seq[-1]): - seq.append(_make.Evaluate(0)) + seq.append(_stmt.Evaluate(0)) seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) ret_seq = [seq[-1]] @@ -138,7 +136,7 @@ def emit(self, stmt): The statement to be emitted or callable that build stmt given body. """ if isinstance(stmt, _expr.Call): - stmt = _make.Evaluate(stmt) + stmt = _stmt.Evaluate(stmt) assert isinstance(stmt, _stmt.Stmt) or callable(stmt) self._seq_stack[-1].append(stmt) @@ -167,10 +165,10 @@ def scope_attr(self, node, attr_key, value): x[i] = x[i - 1] + 1 """ if isinstance(node, string_types): - node = _make.StringImm(node) + node = _expr.StringImm(node) if isinstance(value, string_types): - value = _make.StringImm(value) - self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x)) + value = _expr.StringImm(value) + self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): """Create a for iteration scope. @@ -211,7 +209,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3) self.nidx += 1 self._seq_stack.append([]) - loop_var = _api.var(name, dtype=dtype) + loop_var = _expr.Var(name, dtype=dtype) extent = end if begin == 0 else _pass.Simplify(end - begin) def _exit_cb(): if for_type == "serial": @@ -224,7 +222,7 @@ def _exit_cb(): for_type_id = 3 else: raise ValueError("Unknown for_type") - self.emit(_make.For( + self.emit(_stmt.For( loop_var, begin, extent, for_type_id, 0, self._pop_seq())) return WithScope(loop_var, _exit_cb) @@ -253,7 +251,7 @@ def if_scope(self, cond): """ self._seq_stack.append([]) def _exit_cb(): - self.emit(_make.IfThenElse(cond, self._pop_seq(), None)) + self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None)) return WithScope(None, _exit_cb) def else_scope(self): @@ -286,7 +284,7 @@ def else_scope(self): self._seq_stack[-1].pop() self._seq_stack.append([]) def _exit_cb(): - self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq())) + self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq())) return WithScope(None, _exit_cb) def new_scope(self): @@ -326,13 +324,13 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _api.var(name, dtype="handle") + buffer_var = _expr.Var(name, dtype="handle") if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: self.scope_attr(buffer_var, "storage_scope", scope) - self.emit(lambda x: _make.Allocate( - buffer_var, dtype, shape, _api.const(1, dtype="uint1"), x)) + self.emit(lambda x: _stmt.Allocate( + buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, dtype) def pointer(self, content_type, name="ptr"): @@ -351,7 +349,7 @@ def pointer(self, content_type, name="ptr"): ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _api.var(name, dtype="handle") + buffer_var = _expr.Var(name, dtype="handle") return BufferVar(self, buffer_var, content_type) def buffer_ptr(self, buf): @@ -380,7 +378,8 @@ def likely(self, expr): expr : Expr The expression will likely tag. """ - return _make.Call(expr.dtype, "likely", [expr], _Call.PureIntrinsic, None, 0) + return _expr.Call(expr.dtype, "likely", [expr], + _expr.Call.PureIntrinsic, None, 0) def get(self): """Return the builded IR. diff --git a/python/tvm/ir_pass.py b/python/tvm/tir/ir_pass.py similarity index 96% rename from python/tvm/ir_pass.py rename to python/tvm/tir/ir_pass.py index 9d7f340310e7..239b1fb98dd0 100644 --- a/python/tvm/ir_pass.py +++ b/python/tvm/tir/ir_pass.py @@ -25,4 +25,4 @@ """ import tvm._ffi -tvm._ffi._init_api("tvm.ir_pass") +tvm._ffi._init_api("tvm.ir_pass", __name__) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py new file mode 100644 index 000000000000..a10fe695c245 --- /dev/null +++ b/python/tvm/tir/op.py @@ -0,0 +1,782 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""Operators used in TIR expression.""" +import tvm._ffi +from tvm.runtime import convert, const +from tvm.schedule import Buffer + +from .expr import Call +from . import _ffi_api + + +def _pack_buffer(buf): + """Build intrinsics that packs the buffer. + """ + assert buf.shape + shape = Call("handle", "tvm_stack_make_shape", buf.shape, + Call.Intrinsic, None, 0) + strides = Call("handle", "tvm_stack_make_shape", buf.strides, + Call.Intrinsic, None, 0) if buf.strides else 0 + pack_args = [buf.data, + shape, + strides, + len(buf.shape), + const(0, dtype=buf.dtype), + buf.elem_offset] + return Call("handle", "tvm_stack_make_array", + pack_args, Call.Intrinsic, None, 0) + +def call_packed(*args): + """Build expression by call an external packed function. + + The argument to packed function can be Expr or Buffer. + The argument is the corresponding POD type when Expr is presented. + + When the argument is Buffer, the corresponding PackedFunc + will recieve an TVMArrayHandle whose content is valid during the callback period. + If the PackedFunc is a python callback, then the corresponding argument is NDArray. + + Parameters + ---------- + args : list of Expr or Buffer. + Positional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + tvm.extern : Create tensor with extern function call. + """ + call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] + return Call( + "int32", "tvm_call_packed", call_args, Call.Intrinsic, None, 0) + + +def call_pure_intrin(dtype, func_name, *args): + """Build expression by calling a pure intrinsic function. + + Intrinsics can be overloaded with multiple data types via + the intrinsic translation rule. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The intrinsic function name. + + args : list + Positional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + """ + args = convert(args) + return Call( + dtype, func_name, convert(args), Call.PureIntrinsic, None, 0) + + +def call_intrin(dtype, func_name, *args): + """Build expression by calling an intrinsic function. + + Intrinsics can be overloaded with multiple data types via + the intrinsic translation rule. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The intrinsic function name. + + args : list + Positional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + """ + args = convert(args) + return Call( + dtype, func_name, convert(args), Call.Intrinsic, None, 0) + + +def call_pure_extern(dtype, func_name, *args): + """Build expression by calling a pure extern function. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The extern function name. + + args : list + Positional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return Call( + dtype, func_name, convert(args), Call.PureExtern, None, 0) + + +def call_extern(dtype, func_name, *args): + """Build expression by calling a extern function. + + Parameters + ---------- + dtype : str + The data type of the result. + + func_name: str + The extern function name. + + args : list + Positional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return Call( + dtype, func_name, convert(args), Call.Extern, None, 0) + + +def call_llvm_intrin(dtype, name, *args): + """Build expression by calling an llvm intrinsic function + + Parameters + ---------- + dtype : str + The data type of the result. + + name : str + The name of the llvm intrinsic function. + + args : list + Poistional arguments. + + Returns + ------- + call : PrimExpr + The call expression. + """ + # pylint: disable=import-outside-toplevel + from tvm.target import codegen + llvm_id = codegen.llvm_lookup_intrinsic_id(name) + assert llvm_id != 0, "%s is not an LLVM intrinsic" % name + return call_pure_intrin(dtype, 'llvm_intrin', tvm.const(llvm_id, 'uint32'), *args) + + +@tvm._ffi.register_func("tvm.default_trace_action") +def _tvm_default_trace_action(*args): + print(list(args)) + +def trace(args, trace_action="tvm.default_trace_action"): + """Trace tensor data at the runtime. + + The trace function allows to trace specific tensor at the + runtime. The tracing value should come as last argument. + The trace action should be specified, by default + tvm.default_trace_action is used. + + Parameters + ---------- + args : list of Expr or Buffers. + Positional arguments. + + trace_action : str. + The name of the trace action. + + Returns + ------- + call : PrimExpr + The call expression. + + See Also + -------- + tvm.tir.call_packed : Creates packed function. + """ + if not isinstance(args, list): + raise Exception("tvm.trace consumes the args as list type") + call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] + call_args.insert(0, trace_action) + return tvm.tir.Call( + args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic, None, 0) + + + +def min_value(dtype): + """minimum value of dtype + + Parameters + ---------- + dtype : str + The data type. + + Returns + ------- + value : tvm.Expr + The minimum value of dtype. + """ + return _ffi_api.min_value(dtype) + + +def max_value(dtype): + """maximum value of dtype + + Parameters + ---------- + dtype : str + The data type. + + Returns + ------- + value : tvm.Expr + The maximum value of dtype. + """ + return _ffi_api.max_value(dtype) + + +def exp(x): + """Take exponetial of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "exp", x) + + +def erf(x): + """Take gauss error function of the input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "erf", x) + + +def tanh(x): + """Take hyperbolic tanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "tanh", x) + + +def sigmoid(x): + """Quick function to get sigmoid + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "sigmoid", x) + + +def log(x): + """Take log of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "log", x) + +def cos(x): + """Take cos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "cos", x) + +def sin(x): + """Take sin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "sin", x) + +def atan(x): + """Take atan of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "atan", x) + +def sqrt(x): + """Take square root of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "sqrt", x) + + +def rsqrt(x): + """Take reciprocal of square root of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "rsqrt", x) + + +def floor(x): + """Take floor of float input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.floor(x) + + +def ceil(x): + """Take ceil of float input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.ceil(x) + + +def trunc(x): + """Get truncated value of the input. + + The truncated value of the scalar x is the + nearest integer i which is closer to zero than x is. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.trunc(x) + + +def abs(x): + """Get absolute value of the input element-wise. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.abs(x) + + +def round(x): + """Round elements of the array to the nearest integer. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.round(x) + + +def nearbyint(x): + """Round elements of the array to the nearest integer. + This intrinsic uses llvm.nearbyint instead of llvm.round + which is faster but will results different from tvm.round. + Notably nearbyint rounds according to the rounding mode, + whereas tvm.round (llvm.round) ignores that. + For differences between the two see: + https://en.cppreference.com/w/cpp/numeric/math/round + https://en.cppreference.com/w/cpp/numeric/math/nearbyint + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.nearbyint(x) + + +def isnan(x): + """Check if input value is Nan. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return _ffi_api.isnan(x) + + +def power(x, y): + """x power y + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + The exponent + + Returns + ------- + z : PrimExpr + The result. + """ + return _ffi_api._OpPow(convert(x), convert(y)) + + +def popcount(x): + """Count the number of set bits in input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "popcount", x) + +def fmod(x, y): + """Return the remainder of x divided by y with the same sign as x. + + Parameters + ---------- + x : PrimExpr + Input argument. + y : PrimExpr + Input argument. + + Returns + ------- + z : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "fmod", x, y) + + +def if_then_else(cond, t, f): + """Conditional selection expression. + + Parameters + ---------- + cond : PrimExpr + The condition + + t : PrimExpr + The result expression if cond is true. + + f : PrimExpr + The result expression if cond is false. + + Returns + ------- + result : Node + The result of conditional expression. + + Note + ---- + Unlike Select, if_then_else will not execute + the branch that does not satisfy the condition. + You can use it to guard against out of bound access. + Unlike Select, if_then_else cannot be vectorized + if some lanes in the vector have different conditions. + """ + return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f)) + + +def div(a, b): + """Compute a / b as in C/C++ semantics. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + Returns + ------- + res : PrimExpr + The result expression. + Note + ---- + When operands are integers, returns truncdiv(a, b). + """ + return _ffi_api._OpDiv(a, b) + + +def indexdiv(a, b): + """Compute floor(a / b) where a and b are non-negative. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _ffi_api._OpIndexDiv(a, b) + + +def indexmod(a, b): + """Compute the remainder of indexdiv. a and b are non-negative. + + Parameters + ---------- + a : PrimExpr + The left hand operand, known to be non-negative. + + b : PrimExpr + The right hand operand, known to be non-negative. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + Use this function to split non-negative indices. + This function may take advantage of operands' + non-negativeness. + """ + return _ffi_api._OpIndexMod(a, b) + + +def truncdiv(a, b): + """Compute the truncdiv of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api._OpTruncDiv(a, b) + + +def truncmod(a, b): + """Compute the truncmod of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api._OpTruncMod(a, b) + + +def floordiv(a, b): + """Compute the floordiv of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api._OpFloorDiv(a, b) + + +def floormod(a, b): + """Compute the floormod of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api._OpFloorMod(a, b) diff --git a/python/tvm/stmt.py b/python/tvm/tir/stmt.py similarity index 85% rename from python/tvm/stmt.py rename to python/tvm/tir/stmt.py index e5feb50ddf6f..bc02b7d23ead 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/tir/stmt.py @@ -25,18 +25,19 @@ x = tvm.var("n") a = tvm.var("array", tvm.handle) - st = tvm.make.Store(a, x + 1, 1) - assert isinstance(st, tvm.stmt.Store) + st = tvm.tir.stmt.Store(a, x + 1, 1) + assert isinstance(st, tvm.tir.stmt.Store) assert(st.buffer_var == a) """ import tvm._ffi from tvm.runtime import Object -from . import make as _make +from . import _ffi_api class Stmt(Object): - pass + """Base class of all the statements.""" + @tvm._ffi.register_object class LetStmt(Stmt): @@ -47,7 +48,7 @@ class LetStmt(Stmt): var : Var The variable in the binding. - value : Expr + value : PrimExpr The value in to be binded. body : Stmt @@ -55,7 +56,7 @@ class LetStmt(Stmt): """ def __init__(self, var, value, body): self.__init_handle_by_constructor__( - _make.LetStmt, var, value, body) + _ffi_api.LetStmt, var, value, body) @tvm._ffi.register_object @@ -64,10 +65,10 @@ class AssertStmt(Stmt): Parameters ---------- - condition : Expr + condition : PrimExpr The assert condition. - message : Expr + message : PrimExpr The error message. body : Stmt @@ -75,7 +76,7 @@ class AssertStmt(Stmt): """ def __init__(self, condition, message, body): self.__init_handle_by_constructor__( - _make.AssertStmt, condition, message, body) + _ffi_api.AssertStmt, condition, message, body) @tvm._ffi.register_object @@ -95,7 +96,7 @@ class ProducerConsumer(Stmt): """ def __init__(self, func, is_producer, body): self.__init_handle_by_constructor__( - _make.ProducerConsumer, func, is_producer, body) + _ffi_api.ProducerConsumer, func, is_producer, body) @tvm._ffi.register_object @@ -107,10 +108,10 @@ class For(Stmt): loop_var : Var The loop variable. - min_val : Expr + min_val : PrimExpr The begining value. - extent : Expr + extent : PrimExpr The length of the loop. for_type : int @@ -134,7 +135,7 @@ def __init__(self, device_api, body): self.__init_handle_by_constructor__( - _make.For, loop_var, min_val, extent, + _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body) @@ -147,18 +148,19 @@ class Store(Stmt): buffer_var : Var The buffer Variable. - value : Expr + value : PrimExpr The value we want to store. - index : Expr + index : PrimExpr The index in the store expression. - predicate : Expr + predicate : PrimExpr The store predicate. """ - def __init__(self, buffer_var, value, index, predicate): + def __init__(self, buffer_var, value, index, predicate=None): + args = [] if predicate is None else [predicate] self.__init_handle_by_constructor__( - _make.Store, buffer_var, value, index, predicate) + _ffi_api.Store, buffer_var, value, index, *args) @tvm._ffi.register_object @@ -173,7 +175,7 @@ class Provide(Stmt): value_index : int The output value index - value : Expr + value : PrimExpr The value to be stored. args : list of Expr @@ -181,7 +183,7 @@ class Provide(Stmt): """ def __init__(self, func, value_index, value, args): self.__init_handle_by_constructor__( - _make.Provide, func, value_index, value, args) + _ffi_api.Provide, func, value_index, value, args) @tvm._ffi.register_object @@ -199,7 +201,7 @@ class Allocate(Stmt): extents : list of Expr The extents of the allocate - condition : Expr + condition : PrimExpr The condition. body : Stmt @@ -212,7 +214,7 @@ def __init__(self, condition, body): self.__init_handle_by_constructor__( - _make.Allocate, buffer_var, dtype, + _ffi_api.Allocate, buffer_var, dtype, extents, condition, body) @@ -228,7 +230,7 @@ class AttrStmt(Stmt): attr_key : str Attribute type key. - value : Expr + value : PrimExpr The value of the attribute body : Stmt @@ -236,7 +238,7 @@ class AttrStmt(Stmt): """ def __init__(self, node, attr_key, value, body): self.__init_handle_by_constructor__( - _make.AttrStmt, node, attr_key, value, body) + _ffi_api.AttrStmt, node, attr_key, value, body) @tvm._ffi.register_object @@ -250,7 +252,7 @@ class Free(Stmt): """ def __init__(self, buffer_var): self.__init_handle_by_constructor__( - _make.Free, buffer_var) + _ffi_api.Free, buffer_var) @tvm._ffi.register_object @@ -271,7 +273,7 @@ class Realize(Stmt): bounds : list of range The bound of realize - condition : Expr + condition : PrimExpr The realize condition. body : Stmt @@ -285,7 +287,7 @@ def __init__(self, condition, body): self.__init_handle_by_constructor__( - _make.Realize, func, value_index, dtype, + _ffi_api.Realize, func, value_index, dtype, bounds, condition, body) @@ -300,7 +302,7 @@ class SeqStmt(Stmt): """ def __init__(self, seq): self.__init_handle_by_constructor__( - _make.SeqStmt, seq) + _ffi_api.SeqStmt, seq) def __getitem__(self, i): return self.seq[i] @@ -315,7 +317,7 @@ class IfThenElse(Stmt): Parameters ---------- - condition : Expr + condition : PrimExpr The expression then_case : Stmt @@ -326,7 +328,7 @@ class IfThenElse(Stmt): """ def __init__(self, condition, then_case, else_case): self.__init_handle_by_constructor__( - _make.IfThenElse, condition, then_case, else_case) + _ffi_api.IfThenElse, condition, then_case, else_case) @tvm._ffi.register_object @@ -335,12 +337,12 @@ class Evaluate(Stmt): Parameters ---------- - value : Expr + value : PrimExpr The expression to be evalued. """ def __init__(self, value): self.__init_handle_by_constructor__( - _make.Evaluate, value) + _ffi_api.Evaluate, value) @tvm._ffi.register_object @@ -363,7 +365,7 @@ class Prefetch(Stmt): """ def __init__(self, func, value_index, dtype, bounds): self.__init_handle_by_constructor__( - _make.Prefetch, func, value_index, dtype, bounds) + _ffi_api.Prefetch, func, value_index, dtype, bounds) @tvm._ffi.register_object @@ -417,6 +419,3 @@ def stmt_list(stmt): if isinstance(stmt, ProducerConsumer): return stmt_list(stmt.body) return [stmt] - - -_make.stmt_list = stmt_list diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 35810cbc23cf..1e71baf305d4 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -30,50 +30,50 @@ namespace tvm { namespace tir { -TVM_REGISTER_GLOBAL("_Var") +TVM_REGISTER_GLOBAL("tir.Var") .set_body_typed([](std::string s, DataType t) { return Var(s, t); }); -TVM_REGISTER_GLOBAL("_SizeVar") +TVM_REGISTER_GLOBAL("tir.SizeVar") .set_body_typed([](std::string s, DataType t) { return SizeVar(s, t); }); -TVM_REGISTER_GLOBAL("make.abs") +TVM_REGISTER_GLOBAL("tir.abs") .set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("make.isnan") +TVM_REGISTER_GLOBAL("tir.isnan") .set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("make.floor") +TVM_REGISTER_GLOBAL("tir.floor") .set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("make.ceil") +TVM_REGISTER_GLOBAL("tir.ceil") .set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("make.round") +TVM_REGISTER_GLOBAL("tir.round") .set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("make.nearbyint") +TVM_REGISTER_GLOBAL("tir.nearbyint") .set_body_typed(tvm::nearbyint); -TVM_REGISTER_GLOBAL("make.trunc") +TVM_REGISTER_GLOBAL("tir.trunc") .set_body_typed(tvm::trunc); -TVM_REGISTER_GLOBAL("make._cast") +TVM_REGISTER_GLOBAL("tir._cast") .set_body_typed(tvm::cast); -TVM_REGISTER_GLOBAL("make._range_by_min_extent") +TVM_REGISTER_GLOBAL("ir.range_by_min_extent") .set_body_typed(Range::make_by_min_extent); -TVM_REGISTER_GLOBAL("make.SeqStmt") +TVM_REGISTER_GLOBAL("tir.SeqStmt") .set_body_typed([](Array seq) { return SeqStmt(std::move(seq)); }); -TVM_REGISTER_GLOBAL("make.For") +TVM_REGISTER_GLOBAL("tir.For") .set_body_typed([]( Var loop_var, PrimExpr min, PrimExpr extent, int for_type, int device_api, Stmt body) { @@ -85,7 +85,7 @@ TVM_REGISTER_GLOBAL("make.For") body); }); -TVM_REGISTER_GLOBAL("make.Load") +TVM_REGISTER_GLOBAL("tir.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { DataType t = args[0]; if (args.size() == 3) { @@ -95,7 +95,7 @@ TVM_REGISTER_GLOBAL("make.Load") } }); -TVM_REGISTER_GLOBAL("make.Store") +TVM_REGISTER_GLOBAL("tir.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { PrimExpr value = args[1]; if (args.size() == 3) { @@ -105,10 +105,10 @@ TVM_REGISTER_GLOBAL("make.Store") } }); -TVM_REGISTER_GLOBAL("make.Realize") +TVM_REGISTER_GLOBAL("tir.Realize") .set_body_typed(RealizeNode::make); -TVM_REGISTER_GLOBAL("make.Call") +TVM_REGISTER_GLOBAL("tir.Call") .set_body_typed([]( DataType type, std::string name, Array args, int call_type, @@ -122,12 +122,12 @@ TVM_REGISTER_GLOBAL("make.Call") value_index); }); -TVM_REGISTER_GLOBAL("make.CommReducer") +TVM_REGISTER_GLOBAL("tir.CommReducer") .set_body_typed(CommReducerNode::make); // make from two arguments #define REGISTER_MAKE(NodeName) \ - TVM_REGISTER_GLOBAL("make."#NodeName) \ + TVM_REGISTER_GLOBAL("tir."#NodeName) \ .set_body_typed(NodeName ## Node::make); \ @@ -172,7 +172,7 @@ REGISTER_MAKE(Evaluate); // overloaded, needs special handling // has default args -TVM_REGISTER_GLOBAL("make.Allocate") +TVM_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([]( Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body ){ @@ -180,14 +180,14 @@ TVM_REGISTER_GLOBAL("make.Allocate") }); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("make."#Node) \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir."#Node) \ .set_body_typed([](PrimExpr a, PrimExpr b) { \ return (Func(a, b)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("make."#Node) \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ bool lhs_is_int = args[0].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \ @@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); -TVM_REGISTER_GLOBAL("make._OpIfThenElse") +TVM_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 591869e49883..d2f2cb69b721 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -34,10 +34,10 @@ namespace tvm { -TVM_REGISTER_GLOBAL("_min_value") +TVM_REGISTER_GLOBAL("tir.min_value") .set_body_typed(min_value); -TVM_REGISTER_GLOBAL("_max_value") +TVM_REGISTER_GLOBAL("tir.max_value") .set_body_typed(max_value); TVM_REGISTER_GLOBAL("Range") @@ -49,66 +49,6 @@ TVM_REGISTER_GLOBAL("Range") } }); -namespace tir { - -TVM_REGISTER_GLOBAL("_Buffer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 10); - auto buffer_type = args[9].operator std::string(); - BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8], type); - }); - -TVM_REGISTER_GLOBAL("_BufferAccessPtr") -.set_body_method(&Buffer::access_ptr); - -TVM_REGISTER_GLOBAL("_BufferVLoad") -.set_body_method(&Buffer::vload); - -TVM_REGISTER_GLOBAL("_BufferVStore") -.set_body_method(&Buffer::vstore); - -TVM_REGISTER_GLOBAL("_Layout") -.set_body_typed(LayoutNode::make); - -TVM_REGISTER_GLOBAL("_LayoutIndexOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::make(axis)); -}); - -TVM_REGISTER_GLOBAL("_LayoutFactorOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); -}); - -TVM_REGISTER_GLOBAL("_LayoutNdim") -.set_body_typed([](Layout layout) -> int { - return layout.ndim(); -}); - -TVM_REGISTER_GLOBAL("_LayoutGetItem") -.set_body_typed([](Layout layout, int idx) -> std::string { - const LayoutAxis& axis = layout[idx]; - return axis.name(); -}); - -TVM_REGISTER_GLOBAL("_BijectiveLayout") -.set_body_typed(BijectiveLayoutNode::make); - -TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardIndex") -.set_body_method(&BijectiveLayout::ForwardIndex); - -TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardIndex") -.set_body_method(&BijectiveLayout::BackwardIndex); - -TVM_REGISTER_GLOBAL("_BijectiveLayoutForwardShape") -.set_body_method(&BijectiveLayout::ForwardShape); - -TVM_REGISTER_GLOBAL("_BijectiveLayoutBackwardShape") -.set_body_method(&BijectiveLayout::BackwardShape); -} // namespace tir - namespace te { TVM_REGISTER_GLOBAL("_Tensor") .set_body_typed(TensorNode::make); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 78c6879d8ced..4feabeb8e505 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -71,7 +71,7 @@ IntImm::IntImm(DataType dtype, int64_t value) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("make.IntImm") +TVM_REGISTER_GLOBAL("ir.IntImm") .set_body_typed([](DataType dtype, int64_t value) { return IntImm(dtype, value); }); @@ -97,7 +97,7 @@ FloatImm::FloatImm(DataType dtype, double value) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("make.FloatImm") +TVM_REGISTER_GLOBAL("ir.FloatImm") .set_body_typed([](DataType dtype, double value) { return FloatImm(dtype, value); }); diff --git a/src/node/reflection.cc b/src/node/reflection.cc index d61d72b82b2f..183079ffc82a 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -304,6 +304,6 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr") TVM_REGISTER_GLOBAL("node.NodeListAttrNames") .set_body(NodeListAttrNames); -TVM_REGISTER_GLOBAL("make._Node") +TVM_REGISTER_GLOBAL("node.MakeNode") .set_body(MakeNode); } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 00bf70b5d289..86362e0a0419 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -906,7 +906,9 @@ static const char* kSemVer = "v0.0.4"; // - relay_text_printer.cc (specific printing logics for relay) // - tir_text_printer.cc (specific printing logics for TIR) std::string PrettyPrint(const ObjectRef& node) { - return AsText(node, false, nullptr); + Doc doc; + doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); + return doc.str(); } std::string AsText(const ObjectRef& node, @@ -918,6 +920,10 @@ std::string AsText(const ObjectRef& node, return doc.str(); } + +TVM_REGISTER_GLOBAL("ir.PrettyPrint") +.set_body_typed(PrettyPrint); + TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index ff67e8d9cbc2..19e32d6681ae 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -20,6 +20,7 @@ /*! * \file buffer.cc */ +#include #include #include #include @@ -460,5 +461,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(BufferNode); + + +TVM_REGISTER_GLOBAL("tir.Buffer") +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], type); + }); + +TVM_REGISTER_GLOBAL("tir.BufferAccessPtr") +.set_body_method(&Buffer::access_ptr); + +TVM_REGISTER_GLOBAL("tir.BufferVLoad") +.set_body_method(&Buffer::vload); + +TVM_REGISTER_GLOBAL("tir.BufferVStore") +.set_body_method(&Buffer::vstore); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 8a5125bca193..9cc07a8a3b81 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -21,6 +21,7 @@ * \file src/lang/data_layout.cc * \brief Data Layout expression. */ +#include #include #include #include @@ -371,5 +372,44 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() << ")"; }); + +TVM_REGISTER_GLOBAL("tir.Layout") +.set_body_typed(LayoutNode::make); + +TVM_REGISTER_GLOBAL("tir.LayoutIndexOf") +.set_body_typed([](Layout layout, std::string axis) -> int { + return layout.IndexOf(LayoutAxis::make(axis)); +}); + +TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") +.set_body_typed([](Layout layout, std::string axis) -> int { + return layout.FactorOf(LayoutAxis::make(axis)); +}); + +TVM_REGISTER_GLOBAL("tir.LayoutNdim") +.set_body_typed([](Layout layout) -> int { + return layout.ndim(); +}); + +TVM_REGISTER_GLOBAL("tir.LayoutGetItem") +.set_body_typed([](Layout layout, int idx) -> std::string { + const LayoutAxis& axis = layout[idx]; + return axis.name(); +}); + +TVM_REGISTER_GLOBAL("tir.BijectiveLayout") +.set_body_typed(BijectiveLayoutNode::make); + +TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") +.set_body_method(&BijectiveLayout::ForwardIndex); + +TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") +.set_body_method(&BijectiveLayout::BackwardIndex); + +TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") +.set_body_method(&BijectiveLayout::ForwardShape); + +TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") +.set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 454354e6e68b..62c029043084 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -24,7 +24,7 @@ def test_prim(reducer, np_reducer): n = tvm.size_var('n') m = tvm.size_var('m') A = tvm.placeholder((n, m), name='A') - R = tvm.compute((n, ), lambda i: tvm.expr.Select((i > 1), 1, 0), name='R') + R = tvm.compute((n, ), lambda i: tvm.tir.Select((i > 1), 1, 0), name='R') k = tvm.reduce_axis((0, m)) B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B') # schedule @@ -232,8 +232,8 @@ def check_target(device, host="stackvm"): def test_argmax(): def fcombine(x, y): - lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): @@ -279,8 +279,8 @@ def check_target(): def test_rfactor_argmax(): def fcombine(x, y): - lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index 640da0bd2ebe..fd7ec188611f 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -82,10 +82,10 @@ def test_compile_tuple_dup(): def test_compile_full(): # Shape calculations can happen in int64. The test checks that full operator # can handle when shapes are not int32 - shape = (tvm.expr.IntImm('int32', 1), - tvm.expr.IntImm('int64', 16), - tvm.expr.IntImm('int64', 16), - tvm.expr.IntImm('int32', 64)) + shape = (tvm.tir.IntImm('int32', 1), + tvm.tir.IntImm('int64', 16), + tvm.tir.IntImm('int64', 16), + tvm.tir.IntImm('int32', 64)) output = relay.full(relay.const(0, 'int32'), shape=shape, dtype='int32') f = relay.Function([], output) mod = tvm.IRModule.from_expr(f) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 2af4a2030f4f..674e214df058 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -41,7 +41,7 @@ def test_basic_build(): } # build targets = { - tvm.expr.IntImm("int32", ctx.device_type): tgt + tvm.tir.IntImm("int32", ctx.device_type): tgt } g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 3735259280a4..608bc2a77bb0 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -77,9 +77,9 @@ def check_graph_runtime_result(): def set_external_func_attr(func, compiler, ext_symbol): - func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) - func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler)) - func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol)) + func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler)) + func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) return func diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 7e8c83217451..713aca918883 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -307,7 +307,7 @@ def get_synthetic_lib(): subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2, gcc_input3], relay.copy(gcc_input0)) subgraph0 = subgraph0.set_attribute( - "Primitive", tvm.expr.IntImm("int32", 1)) + "Primitive", tvm.tir.IntImm("int32", 1)) # Call subgraph0 subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2]) @@ -320,7 +320,7 @@ def get_synthetic_lib(): subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6, gcc_input7], relay.copy(gcc_input4)) subgraph1 = subgraph1.set_attribute( - "Primitive", tvm.expr.IntImm("int32", 1)) + "Primitive", tvm.tir.IntImm("int32", 1)) # Call subgraph1 subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5]) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 29e578b08b11..ad1525576d08 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -17,7 +17,7 @@ """ test ir""" import tvm from tvm import relay -from tvm.expr import * +from tvm.tir.expr import * from tvm.relay import op from tvm.relay.analysis import graph_equal import numpy as np @@ -110,7 +110,7 @@ def test_type_relation(): num_inputs = 2 func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") - attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) assert tr.args == args diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 261cbb97c4af..bcce9b4ba5dd 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -69,7 +69,7 @@ """ def roundtrip(expr): - x = relay.fromtext(str(expr)) + x = relay.fromtext(expr.astext()) assert_graph_equal(x, expr) @@ -343,7 +343,7 @@ def test_func(): # attributes assert parses_as( "fn (n=5) { () }", - relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) + relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5))) ) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 5876a7052a2d..0d3fd4b3f829 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -630,8 +630,8 @@ def test_upsampling_infer_type(): y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear") "method=\"BINLINEAR\"" in y.astext() yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(h*scale)), - tvm.expr.Cast("int32", tvm.round(w*scale))), + assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(h*scale)), + tvm.tir.Cast("int32", tvm.round(w*scale))), "float32") n, c = tvm.size_var("n"), tvm.size_var("c") x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32")) @@ -647,9 +647,9 @@ def test_upsampling3d_infer_type(): y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)), - tvm.expr.Cast("int32", tvm.round(h*scale)), - tvm.expr.Cast("int32", tvm.round(w*scale))), + assert yy.checked_type == relay.TensorType((n, c, tvm.tir.Cast("int32", tvm.round(d*scale)), + tvm.tir.Cast("int32", tvm.round(h*scale)), + tvm.tir.Cast("int32", tvm.round(w*scale))), "float32") n, c = tvm.size_var("n"), tvm.size_var("c") x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32")) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 9c5dfacf62a2..c5f340a843a3 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -517,7 +517,7 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): alpha_shape = (data[axis],) assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32") - if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or not alpha: + if all(isinstance(v, tvm.tir.Var) == 1 for v in data) or not alpha: return func = relay.Function([x, y], z) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 0243adc59319..c5cd70818795 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -154,7 +154,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype assert zz.checked_type == relay.ty.TensorType(output, out_type) - if all(isinstance(v, tvm.expr.Var) == 1 for v in data): + if all(isinstance(v, tvm.tir.Var) == 1 for v in data): return func = relay.Function([x], z) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index bdc032e8b65c..5985273ce6de 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -160,9 +160,9 @@ def test_type_relation_alpha_equal(): broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") - attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) + attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4)) tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) @@ -322,7 +322,7 @@ def test_multi_node_subgraph(): p00 = relay.subtract(z00, w01) q00 = relay.multiply(p00, w02) func0 = relay.Function([x0, w00, w01, w02], q00) - func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a")) + func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a")) x1 = relay.var('x1', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10)) @@ -332,7 +332,7 @@ def test_multi_node_subgraph(): p10 = relay.subtract(z10, w11) q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) - func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b")) + func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b")) assert not alpha_equal(func0, func1) @@ -413,9 +413,9 @@ def test_call_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") - attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) - attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4)) + attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4,4)) tt1 = relay.TensorType((1, 2, 3), "float32") tt2 = relay.TensorType((), "int8") diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 27a143bc455a..6f20278133d9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -303,11 +303,11 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) + func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func = func.set_attribute("Compiler", - tvm.expr.StringImm("ccompiler")) + tvm.tir.StringImm("ccompiler")) func = func.set_attribute("ExternalSymbol", - tvm.expr.StringImm("ccompiler_0")) + tvm.tir.StringImm("ccompiler_0")) add_call = relay.Call(func, [x, y]) # Function that uses default compiler. Ops are fused in this function. p0 = relay.var("p0", shape=(8, 8)) @@ -316,7 +316,7 @@ def expected(): concat = relay.concatenate([log, exp], axis=0) fused_func = relay.Function([p0], concat) fused_func = fused_func.set_attribute("Primitive", - tvm.expr.IntImm("int32", 1)) + tvm.tir.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) mod = tvm.IRModule() diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py index b84a49ae3fda..854301bf714a 100644 --- a/tests/python/relay/test_type_functor.py +++ b/tests/python/relay/test_type_functor.py @@ -65,7 +65,7 @@ def test_tuple_type(): def test_type_relation(): func = tvm.ir.EnvFunc.get('tvm.relay.type_relation.Broadcast') - attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4)) + attrs = tvm.ir.make_node('attrs.TestAttrs', name='attr', padding=(3,4)) tp = TypeVar('tp') tf = FuncType([], TupleType([]), [], []) tt = TensorType([1, 2, 3], 'float32') diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 28a43884affd..35822d240b04 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -151,9 +151,9 @@ def test_reduce_combiner_simplify(): prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) sum_or_prod = comm_reducer( - lambda x, y: tvm.expr.Select(dummy < 0, + lambda x, y: tvm.tir.Select(dummy < 0, x + y, x*y), - lambda t0: tvm.expr.Select(dummy < 0, + lambda t0: tvm.tir.Select(dummy < 0, tvm.const(0, t0), tvm.const(1, t0))) sum_and_prod = comm_reducer( lambda x, y: (x[0] + y[0], @@ -199,7 +199,7 @@ def test_reduce_combiner_simplify(): assert tvm.ir_pass.Equal(lhs, rhs) # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0) + side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], @@ -211,7 +211,7 @@ def test_reduce_simplify(): k = tvm.reduce_axis((0, 10), name="k") j = tvm.reduce_axis((-5, 3), name="j") A = tvm.placeholder((10,), name='A') - ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]), + ck.verify(tvm.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), tvm.sum(k + j, [k, j])) ck.verify(tvm.sum(A[3], []), A[3]) # The rule below is not typical, removed for now @@ -235,23 +235,23 @@ def test_simplify_if_then_else(): tmod(tmod(((x*4) + y) - 466036, 24528) -24512, 16), x), y) expected = tvm.if_then_else( - tvm.expr.LE(466036, (x * 4 + y)), - tvm.if_then_else(tvm.expr.LE(24512, tmod(((x*4) + y) - 4, 24528)), + tvm.tir.LE(466036, (x * 4 + y)), + tvm.if_then_else(tvm.tir.LE(24512, tmod(((x*4) + y) - 4, 24528)), tmod(((x*4) + y) - 4, 16), x), y) ck.verify(res, expected) ck.verify(res2, expected) # can only simplify if condition - res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) - expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) + res = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) + expected = tvm.tir.Select(tvm.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) ck.verify(res, ck.analyzer.canonical_simplify(expected)) - res = tvm.expr.Select(x >= 10, + res = tvm.tir.Select(x >= 10, tvm.if_then_else(tdiv(x, 3) > 2, x, 0), 0) - expected = tvm.expr.Select(x >= 10, x, 0) + expected = tvm.tir.Select(x >= 10, x, 0) ck.verify(res, ck.analyzer.canonical_simplify(expected)) - res = tvm.expr.Select(x >= 10, + res = tvm.tir.Select(x >= 10, tvm.if_then_else(tdiv(x, 3) < 2, x, 0), 0) ck.verify(res, 0) diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index ae2837d6446f..aba56ac6c0c5 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -228,7 +228,7 @@ def test_select_bound(): analyzer.update(y, tvm.arith.ConstIntBound(4, 10)) bd = analyzer.const_int_bound( - tvm.expr.Select(x > 1, (y < 0).astype("int32"), y + 1)) + tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1)) assert bd.min_value == 0 assert bd.max_value == 11 diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 33e31c766950..787dfe80d536 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -19,7 +19,7 @@ def assert_expr_equal(a, b): res = tvm.ir_pass.Simplify(a - b) - equal = isinstance(res, tvm.expr.IntImm) and res.value == 0 + equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 if not equal: raise ValueError("{} and {} are not equal".format(a, b)) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 9139cc2ce29f..3e45d4e5fd93 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -23,14 +23,14 @@ def test_domain_touched(): m = tvm.var('m') a = tvm.placeholder((n, m), name = 'a') b = tvm.placeholder((n, m), name = 'b') - ir = tvm.make.For( + ir = tvm.tir.For( i, 0, n, 0, 0, - tvm.make.For(j, 0, m, 0, 0, - tvm.make.Provide( + tvm.tir.For(j, 0, m, 0, 0, + tvm.tir.Provide( a.op, 0, - tvm.make.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + - tvm.make.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), + tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + + tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), [i, j] ) ) @@ -51,7 +51,7 @@ def test_domain_touched(): assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[1].min.value == -1 - assert isinstance(a_domain_rw[1].extent, tvm.expr.Add) + assert isinstance(a_domain_rw[1].extent, tvm.tir.Add) assert a_domain_rw[1].extent.a.name == 'm' assert a_domain_rw[1].extent.b.value == 1 diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 20e3f573776e..d83d33db5c1b 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -41,7 +41,7 @@ def test_vector(): base = 10 stride = 3 lanes = 2 - s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes)) + s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes)) assert s.min_value.value == base assert s.max_value.value == base + stride * lanes - 1 @@ -99,7 +99,7 @@ def test_max_min(): def test_select(): ck = IntSetChecker() x, y = tvm.var("x"), tvm.var("y") - ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1), + ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11)) diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 1ce7197706ca..6bb86e4c4717 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -84,7 +84,7 @@ def test_min_max_select(): assert m.coeff == 3 assert m.base == 1 - m = analyzer.modular_set(tvm.expr.Select(x > 0, x * 3 + 1, y * 9 + 2)) + m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2)) assert m.coeff == 1 assert m.base == 0 diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 99c2942cd470..84560e8c1f9d 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -29,31 +29,31 @@ def test_vector_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules - ck.verify(tvm.expr.Ramp(x, 1, 4) + tvm.expr.Ramp(y, 2, 4), - tvm.expr.Ramp(x + y, 3, 4)) - ck.verify(tvm.expr.Ramp(x, 1, 2) + y, - tvm.expr.Ramp(x + y, 1, 2)) - ck.verify(y + tvm.expr.Ramp(x, 1, 2) , - tvm.expr.Ramp(y + x, 1, 2)) + ck.verify(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), + tvm.tir.Ramp(x + y, 3, 4)) + ck.verify(tvm.tir.Ramp(x, 1, 2) + y, + tvm.tir.Ramp(x + y, 1, 2)) + ck.verify(y + tvm.tir.Ramp(x, 1, 2) , + tvm.tir.Ramp(y + x, 1, 2)) ck.verify(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")) # Sub rules - ck.verify(tvm.expr.Ramp(x, 4, 4) - tvm.expr.Ramp(y, 2, 4), - tvm.expr.Ramp(x - y, 2, 4)) - ck.verify(tvm.expr.Ramp(x, 1, 2) - y, - tvm.expr.Ramp(x - y, 1, 2)) - ck.verify(y - tvm.expr.Ramp(x, 1, 2) , - tvm.expr.Ramp(y - x, -1, 2)) + ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), + tvm.tir.Ramp(x - y, 2, 4)) + ck.verify(tvm.tir.Ramp(x, 1, 2) - y, + tvm.tir.Ramp(x - y, 1, 2)) + ck.verify(y - tvm.tir.Ramp(x, 1, 2) , + tvm.tir.Ramp(y - x, -1, 2)) ck.verify(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")) # Mul rules ck.verify(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")) - ck.verify(tvm.expr.Ramp(x, 4, 4) * 2, - tvm.expr.Ramp(x * 2, 8, 4)) - ck.verify(2 * tvm.expr.Ramp(x, 4, 4), - tvm.expr.Ramp(x * 2, 8, 4)) + ck.verify(tvm.tir.Ramp(x, 4, 4) * 2, + tvm.tir.Ramp(x * 2, 8, 4)) + ck.verify(2 * tvm.tir.Ramp(x, 4, 4), + tvm.tir.Ramp(x * 2, 8, 4)) ## DivMod rules tdiv = tvm.truncdiv @@ -61,21 +61,21 @@ def test_vector_simplify(): # truc div ck.verify(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2")) - ck.verify(tdiv(tvm.expr.Ramp(x, 4, 4), 2), - tvm.expr.Ramp(tdiv(x, 2), 2, 4)) + ck.verify(tdiv(tvm.tir.Ramp(x, 4, 4), 2), + tvm.tir.Ramp(tdiv(x, 2), 2, 4)) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000), override=True) - ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), + ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) - ck.verify(tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8), - tdiv(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8)) + ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), + tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")) - ck.verify(tmod(tvm.expr.Ramp(x, 4, 4), 2), - tvm.expr.Broadcast(tmod(x, 2), 4)) - ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), - tvm.expr.Ramp(1, 1, 4)) - ck.verify(tmod(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8), - tmod(tvm.expr.Ramp(1, 15, 4), 8)) + ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2), + tvm.tir.Broadcast(tmod(x, 2), 4)) + ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), + tvm.tir.Ramp(1, 1, 4)) + ck.verify(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), + tmod(tvm.tir.Ramp(1, 15, 4), 8)) # floor div fld = tvm.floordiv @@ -83,20 +83,20 @@ def test_vector_simplify(): ck.analyzer.update(x, tvm.arith.ConstIntBound(-10, 1000), override=True) ck.verify(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")) - ck.verify(fld(tvm.expr.Ramp(x, 4, 4), 2), - tvm.expr.Ramp(fld(x, 2), 2, 4)) - ck.verify(fld(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), + ck.verify(fld(tvm.tir.Ramp(x, 4, 4), 2), + tvm.tir.Ramp(fld(x, 2), 2, 4)) + ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) - ck.verify(fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8), - fld(tvm.expr.Ramp(x * 8 + 15, 1, 4), 8)) + ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), + fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) - ck.verify(flm(tvm.expr.Ramp(x, 4, 4), 2), - tvm.expr.Broadcast(flm(x, 2), 4)) - ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 1, 4), 8), - tvm.expr.Ramp(1, 1, 4)) - ck.verify(flm(tvm.expr.Ramp(x * 8 + 1, 15, 4), 8), - flm(tvm.expr.Ramp(1, 15, 4), 8)) + ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), + tvm.tir.Broadcast(flm(x, 2), 4)) + ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), + tvm.tir.Ramp(1, 1, 4)) + ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), + flm(tvm.tir.Ramp(1, 15, 4), 8)) # Min/Max rules vx = tvm.var("vx", dtype="int32x2") @@ -113,8 +113,8 @@ def test_vector_simplify(): ## Logical rules ck.verify(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("uint1x2")) - ck.verify(tvm.expr.NE(y.astype("int32x2"), (x.astype("int32x2"))), - (tvm.expr.NE(y, x)).astype("uint1x2")) + ck.verify(tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), + (tvm.tir.NE(y, x)).astype("uint1x2")) ck.verify(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("uint1x2")) ck.verify(y.astype("int32x2") >= x.astype("int32x2"), @@ -123,32 +123,32 @@ def test_vector_simplify(): (y < x).astype("uint1x2")) ck.verify(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("uint1x2")) - ck.verify(tvm.expr.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), - (tvm.expr.And(y <= x, vc)).astype("uint1x2")) - ck.verify(tvm.expr.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), - (tvm.expr.Or(y <= x, vc)).astype("uint1x2")) + ck.verify(tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), + (tvm.tir.And(y <= x, vc)).astype("uint1x2")) + ck.verify(tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("uint1x2")), + (tvm.tir.Or(y <= x, vc)).astype("uint1x2")) def test_select_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") # Add rules - ck.verify(tvm.expr.Select(x < 0, y, 0) + tvm.expr.Select(x < 0, 1, z), - tvm.expr.Select(x < 0, y + 1, z)) - ck.verify(tvm.expr.Select(x < 0, y, 1) - tvm.expr.Select(x < 0, 1, z), - tvm.expr.Select(x < 0, y + (-1), 1 - z)) - ck.verify(tvm.expr.Select(x < 0, y, z) - y, - tvm.expr.Select(x < 0, 0, z - y)) - ck.verify(tvm.expr.Select(x < 0, y, z) - z, - tvm.expr.Select(x < 0, y - z, 0)) - ck.verify(tvm.min(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), - tvm.expr.Select(x < 0, tvm.min(y, 1), tvm.min(0, z))) - ck.verify(tvm.max(tvm.expr.Select(x < 0, y, 0), tvm.expr.Select(x < 0, 1, z)), - tvm.expr.Select(x < 0, tvm.max(y, 1), tvm.max(0, z))) - - ck.verify(tvm.expr.Select(x * 3 + 1 != 0, y, z), y) - ck.verify(tvm.expr.Select(x * 3 + 1 == 0, y, z), z) - ck.verify(tvm.expr.Select(x > 0, y + 1, y + 1), y + 1) + ck.verify(tvm.tir.Select(x < 0, y, 0) + tvm.tir.Select(x < 0, 1, z), + tvm.tir.Select(x < 0, y + 1, z)) + ck.verify(tvm.tir.Select(x < 0, y, 1) - tvm.tir.Select(x < 0, 1, z), + tvm.tir.Select(x < 0, y + (-1), 1 - z)) + ck.verify(tvm.tir.Select(x < 0, y, z) - y, + tvm.tir.Select(x < 0, 0, z - y)) + ck.verify(tvm.tir.Select(x < 0, y, z) - z, + tvm.tir.Select(x < 0, y - z, 0)) + ck.verify(tvm.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)), + tvm.tir.Select(x < 0, tvm.min(y, 1), tvm.min(0, z))) + ck.verify(tvm.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)), + tvm.tir.Select(x < 0, tvm.max(y, 1), tvm.max(0, z))) + + ck.verify(tvm.tir.Select(x * 3 + 1 != 0, y, z), y) + ck.verify(tvm.tir.Select(x * 3 + 1 == 0, y, z), z) + ck.verify(tvm.tir.Select(x > 0, y + 1, y + 1), y + 1) def test_add_index_simplify(): @@ -633,7 +633,7 @@ def test_cmp_simplify(): tmod = tvm.truncmod # const int bound ck.verify((tmod(x, 2) + 10).equal(0), tvm.const(0, "bool")) - ck.verify(tvm.expr.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool")) + ck.verify(tvm.tir.NE(tmod(x, 2) + 10, 0), tvm.const(1, "bool")) ck.verify(tmod(x, 2) + 10 > 1, tvm.const(1, "bool")) ck.verify(tmod(x, 2) + 10 <= 1, tvm.const(0, "bool")) ck.verify(flm(x, 2) + 2 > 1, tvm.const(1, "bool")) @@ -645,7 +645,7 @@ def test_cmp_simplify(): # canonicalization ck.verify((x - 10).equal(0), x.equal(10)) ck.verify((10 - x).equal(0), x.equal(10)) - ck.verify((x * y).equal(0), tvm.expr.Or(x.equal(0), y.equal(0))) + ck.verify((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0))) # cmp bound ck.verify(x + y < x + z, y < z) @@ -655,104 +655,104 @@ def test_cmp_simplify(): ck.verify(y - x < z - x, y < z) ck.verify(x - y < x - z, z < y) - ck.verify(x < z + x, tvm.expr.LT(0, z)) - ck.verify(x < x + z, tvm.expr.LT(0, z)) + ck.verify(x < z + x, tvm.tir.LT(0, z)) + ck.verify(x < x + z, tvm.tir.LT(0, z)) - ck.verify(100 < x + 1, tvm.expr.LT(99, x)) - ck.verify(1 < 100 - x, tvm.expr.LT(x, 99)) + ck.verify(100 < x + 1, tvm.tir.LT(99, x)) + ck.verify(1 < 100 - x, tvm.tir.LT(x, 99)) ck.verify(x * 3 < y * 3, x < y) ck.verify(x * (-3) < y * (-3), y < x) ck.verify(x * 3 >= y * 3, y <= x) - ck.verify(x * 4 >= 2, tvm.expr.LE(1, x)) - ck.verify(x * 2 >= 50, tvm.expr.LE(25, x)) + ck.verify(x * 4 >= 2, tvm.tir.LE(1, x)) + ck.verify(x * 2 >= 50, tvm.tir.LE(25, x)) ck.verify(x * 4 <= 2, x <= 0) - ck.verify((0 - x * 3) <= 0, tvm.expr.LE(0, x)) - ck.verify((0 - x * 3) >= 0, tvm.expr.LE(x, 0)) + ck.verify((0 - x * 3) <= 0, tvm.tir.LE(0, x)) + ck.verify((0 - x * 3) >= 0, tvm.tir.LE(x, 0)) ck.verify(2 * x <= 0, x <= 0) - ck.verify(x * 2 >= 3, tvm.expr.LE(2, x)) - ck.verify(x * 2 >= 2, tvm.expr.LE(1, x)) - ck.verify(x * 2 >= 1, tvm.expr.LE(1, x)) - ck.verify(x * 2 >= 0, tvm.expr.LE(0, x)) - ck.verify(x * 2 >= -1, tvm.expr.LE(0, x)) - ck.verify(x * 2 >= -2, tvm.expr.LE(-1, x)) - ck.verify(x * 2 >= -3, tvm.expr.LE(-1, x)) - - ck.verify(x * 2 <= 3, tvm.expr.LE(x, 1)) - ck.verify(x * 2 <= 2, tvm.expr.LE(x, 1)) - ck.verify(x * 2 <= 1, tvm.expr.LE(x, 0)) - ck.verify(x * 2 <= 0, tvm.expr.LE(x, 0)) - ck.verify(x * 2 <= -1, tvm.expr.LE(x, -1)) - ck.verify(x * 2 <= -2, tvm.expr.LE(x, -1)) - ck.verify(x * 2 <= -3, tvm.expr.LE(x, -2)) - - ck.verify(x * (-2) >= 3, tvm.expr.LE(x, -2)) - ck.verify(x * (-2) >= 2, tvm.expr.LE(x, -1)) - ck.verify(x * (-2) >= 1, tvm.expr.LE(x, -1)) - ck.verify(x * (-2) >= 0, tvm.expr.LE(x, 0)) - ck.verify(x * (-2) >= -1, tvm.expr.LE(x, 0)) - ck.verify(x * (-2) >= -2, tvm.expr.LE(x, 1)) - ck.verify(x * (-2) >= -3, tvm.expr.LE(x, 1)) - - ck.verify(x * (-2) <= 3, tvm.expr.LE(-1, x)) - ck.verify(x * (-2) <= 2, tvm.expr.LE(-1, x)) - ck.verify(x * (-2) <= 1, tvm.expr.LE(0, x)) - ck.verify(x * (-2) <= 0, tvm.expr.LE(0, x)) - ck.verify(x * (-2) <= -1, tvm.expr.LE(1, x)) - ck.verify(x * (-2) <= -2, tvm.expr.LE(1, x)) - ck.verify(x * (-2) <= -3, tvm.expr.LE(2, x)) + ck.verify(x * 2 >= 3, tvm.tir.LE(2, x)) + ck.verify(x * 2 >= 2, tvm.tir.LE(1, x)) + ck.verify(x * 2 >= 1, tvm.tir.LE(1, x)) + ck.verify(x * 2 >= 0, tvm.tir.LE(0, x)) + ck.verify(x * 2 >= -1, tvm.tir.LE(0, x)) + ck.verify(x * 2 >= -2, tvm.tir.LE(-1, x)) + ck.verify(x * 2 >= -3, tvm.tir.LE(-1, x)) + + ck.verify(x * 2 <= 3, tvm.tir.LE(x, 1)) + ck.verify(x * 2 <= 2, tvm.tir.LE(x, 1)) + ck.verify(x * 2 <= 1, tvm.tir.LE(x, 0)) + ck.verify(x * 2 <= 0, tvm.tir.LE(x, 0)) + ck.verify(x * 2 <= -1, tvm.tir.LE(x, -1)) + ck.verify(x * 2 <= -2, tvm.tir.LE(x, -1)) + ck.verify(x * 2 <= -3, tvm.tir.LE(x, -2)) + + ck.verify(x * (-2) >= 3, tvm.tir.LE(x, -2)) + ck.verify(x * (-2) >= 2, tvm.tir.LE(x, -1)) + ck.verify(x * (-2) >= 1, tvm.tir.LE(x, -1)) + ck.verify(x * (-2) >= 0, tvm.tir.LE(x, 0)) + ck.verify(x * (-2) >= -1, tvm.tir.LE(x, 0)) + ck.verify(x * (-2) >= -2, tvm.tir.LE(x, 1)) + ck.verify(x * (-2) >= -3, tvm.tir.LE(x, 1)) + + ck.verify(x * (-2) <= 3, tvm.tir.LE(-1, x)) + ck.verify(x * (-2) <= 2, tvm.tir.LE(-1, x)) + ck.verify(x * (-2) <= 1, tvm.tir.LE(0, x)) + ck.verify(x * (-2) <= 0, tvm.tir.LE(0, x)) + ck.verify(x * (-2) <= -1, tvm.tir.LE(1, x)) + ck.verify(x * (-2) <= -2, tvm.tir.LE(1, x)) + ck.verify(x * (-2) <= -3, tvm.tir.LE(2, x)) # DivMod rules # truc div ck.verify(tdiv(x, 2) < 3, x < 6) - ck.verify(3 < tdiv(x, 2), tvm.expr.LT(7, x)) - ck.verify(tdiv(x, 3) >= 0, tvm.expr.LE(-2, x)) - ck.verify(tdiv(x, 2) >= 1, tvm.expr.LE(2, x)) - ck.verify(tdiv(x, 2) >= 0, tvm.expr.LE(-1, x)) - ck.verify(tdiv(x, 2) >= -1, tvm.expr.LE(-3, x)) + ck.verify(3 < tdiv(x, 2), tvm.tir.LT(7, x)) + ck.verify(tdiv(x, 3) >= 0, tvm.tir.LE(-2, x)) + ck.verify(tdiv(x, 2) >= 1, tvm.tir.LE(2, x)) + ck.verify(tdiv(x, 2) >= 0, tvm.tir.LE(-1, x)) + ck.verify(tdiv(x, 2) >= -1, tvm.tir.LE(-3, x)) - ck.verify(tdiv(x, 2) <= 1, tvm.expr.LE(x, 3)) - ck.verify(tdiv(x, 2) <= 0, tvm.expr.LE(x, 1)) - ck.verify(tdiv(x, 2) <= -1, tvm.expr.LE(x, -2)) + ck.verify(tdiv(x, 2) <= 1, tvm.tir.LE(x, 3)) + ck.verify(tdiv(x, 2) <= 0, tvm.tir.LE(x, 1)) + ck.verify(tdiv(x, 2) <= -1, tvm.tir.LE(x, -2)) - ck.verify(tdiv(x, 4) * 4 < x, tvm.expr.LT(0, tmod(x, 4))) - ck.verify(tdiv(x, 4) * 4 >= x, tvm.expr.LE(tmod(x, 4), 0)) + ck.verify(tdiv(x, 4) * 4 < x, tvm.tir.LT(0, tmod(x, 4))) + ck.verify(tdiv(x, 4) * 4 >= x, tvm.tir.LE(tmod(x, 4), 0)) - ck.verify(tdiv(x, 4) * 4 < x + y, tvm.expr.LT(0, tmod(x, 4) + y)) - ck.verify(tdiv(x, 4) * 4 < x - y, tvm.expr.LT(y, tmod(x, 4))) + ck.verify(tdiv(x, 4) * 4 < x + y, tvm.tir.LT(0, tmod(x, 4) + y)) + ck.verify(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4))) - ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.expr.LE(tmod(x + 2, 4), 2)) - ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.expr.LE(tmod(x + 2, 4) + y, 2)) - ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.expr.LE(tmod(x + 2, 4) + (-2), y)) + ck.verify(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2)) + ck.verify(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2)) + ck.verify(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y)) # floor div ck.verify(fld(x, 2) < 3, x < 6) - ck.verify(3 < fld(x, 2), tvm.expr.LT(7, x)) - ck.verify(-3 < fld(x, 2), tvm.expr.LT(-5, x)) - ck.verify(fld(x, 3) >= 0, tvm.expr.LE(0, x)) - ck.verify(fld(x, 2) >= 1, tvm.expr.LE(2, x)) - ck.verify(fld(x, 2) >= 0, tvm.expr.LE(0, x)) - ck.verify(fld(x, 2) >= -1, tvm.expr.LE(-2, x)) - - ck.verify(fld(x, 2) <= 1, tvm.expr.LE(x, 3)) - ck.verify(fld(x, 2) <= 0, tvm.expr.LE(x, 1)) - ck.verify(fld(x, 2) <= -1, tvm.expr.LE(x, -1)) - - ck.verify(fld(x, 4) * 4 < x, tvm.expr.LT(0, flm(x, 4))) - ck.verify(fld(x, 4) * 4 >= x, tvm.expr.LE(flm(x, 4), 0)) - - ck.verify(fld(x, 4) * 4 < x + y, tvm.expr.LT(0, flm(x, 4) + y)) - ck.verify(fld(x, 4) * 4 < x - y, tvm.expr.LT(y, flm(x, 4))) - - ck.verify(fld(x + 2, 4) * 4 >= x, tvm.expr.LE(flm(x + 2, 4), 2)) - ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.expr.LE(flm(x + 2, 4) + y, 2)) - ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.expr.LE(flm(x + 2, 4) + (-2), y)) + ck.verify(3 < fld(x, 2), tvm.tir.LT(7, x)) + ck.verify(-3 < fld(x, 2), tvm.tir.LT(-5, x)) + ck.verify(fld(x, 3) >= 0, tvm.tir.LE(0, x)) + ck.verify(fld(x, 2) >= 1, tvm.tir.LE(2, x)) + ck.verify(fld(x, 2) >= 0, tvm.tir.LE(0, x)) + ck.verify(fld(x, 2) >= -1, tvm.tir.LE(-2, x)) + + ck.verify(fld(x, 2) <= 1, tvm.tir.LE(x, 3)) + ck.verify(fld(x, 2) <= 0, tvm.tir.LE(x, 1)) + ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1)) + + ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4))) + ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0)) + + ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y)) + ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))) + + ck.verify(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2)) + ck.verify(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2)) + ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y)) # End DivMod Rules ck.verify(tvm.min(x, 11) < 10, x < 10) ck.verify(tvm.min(x, 8) < 10, tvm.const(1, "bool")) - ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) + ck.verify(tvm.max(8, x) > 10, tvm.tir.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) @@ -777,48 +777,48 @@ def test_logical_simplify(): ck = RewriteChecker() x, y, z = tvm.var("x"), tvm.var("y"), tvm.var("z") - ck.verify(tvm.expr.And(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)), + ck.verify(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)), + ck.verify(tvm.tir.And(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x > 1, tvm.expr.Not(x > 1)), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x <= y, y < x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(y < x, x <= y), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x < 1, 0 < x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x < 0, 1 < x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x < 1, 1 <= x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x <= 1, 1 < x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(1 <= x, x < 1), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(1 < x, x <= 1), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x <= 1, 2 <= x), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(2 <= x, x <= 1), tvm.const(False, "bool")) - ck.verify(tvm.expr.And(x == 1, x != 2), x == 1) - - - ck.verify(tvm.expr.Or(tvm.expr.EQ(x, y), tvm.expr.NE(x, y)), + ck.verify(tvm.tir.And(x > 1, tvm.tir.Not(x > 1)), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x <= y, y < x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(y < x, x <= y), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x < 1, 0 < x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x < 0, 1 < x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x < 1, 1 <= x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x <= 1, 1 < x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(1 <= x, x < 1), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(1 < x, x <= 1), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x <= 1, 2 <= x), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(2 <= x, x <= 1), tvm.const(False, "bool")) + ck.verify(tvm.tir.And(x == 1, x != 2), x == 1) + + + ck.verify(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(tvm.expr.NE(x, y), tvm.expr.EQ(x, y)), + ck.verify(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x > y, tvm.expr.Not(x > y)), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x <= y, y < x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(y < x, y >= x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x <= y, y < x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(y < x, y >= x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x < 1, 0 < x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(0 < x, x < 1), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x < 1, 0 < x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(0 < x, x < 1), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x < 1, 1 <= x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x <= 1, 1 < x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(1 <= x, x < 1), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(1 < x, x <= 1), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x <= 1, 2 <= x), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(2 <= x, x <= 1), tvm.const(True, "bool")) - ck.verify(tvm.expr.Or(x != 1, x == 2), x != 1) + ck.verify(tvm.tir.Or(x < 1, 1 <= x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x <= 1, 1 < x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(1 <= x, x < 1), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(1 < x, x <= 1), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x <= 1, 2 <= x), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(2 <= x, x <= 1), tvm.const(True, "bool")) + ck.verify(tvm.tir.Or(x != 1, x == 2), x != 1) def test_let_simplify(): ck = RewriteChecker() x, y = tvm.var("x"), tvm.var("y") - z = tvm.expr.Let(x, 1, x + 1) + z = tvm.tir.Let(x, 1, x + 1) ck.verify(z + z, 4) def test_cast_simplify(): @@ -827,11 +827,11 @@ def test_cast_simplify(): dtypes = ["float32", "float16", "int32", "int8", "bool"] for dtype1 in dtypes: - ck.verify(tvm.expr.Cast(dtype1, x - x), tvm.const(0, dtype1)) - ck.verify(tvm.expr.Cast(dtype1, x == x), tvm.const(1, dtype1)) + ck.verify(tvm.tir.Cast(dtype1, x - x), tvm.const(0, dtype1)) + ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.const(1, dtype1)) for dtype2 in dtypes: for i in [0, 1, 2, 3]: - ck.verify(tvm.expr.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1)) + ck.verify(tvm.tir.Cast(dtype1, tvm.const(i, dtype2)), tvm.const(i, dtype1)) if __name__ == "__main__": test_floordiv_index_simplify() diff --git a/tests/python/unittest/test_arith_stmt_simplify.py b/tests/python/unittest/test_arith_stmt_simplify.py index 9e0b47749fee..58b60836539f 100644 --- a/tests/python/unittest/test_arith_stmt_simplify.py +++ b/tests/python/unittest/test_arith_stmt_simplify.py @@ -25,9 +25,9 @@ def test_stmt_simplify(): with ib.if_scope(i < 12): A[i] = C[i] - body = tvm.stmt.LetStmt(n, 10, ib.get()) + body = tvm.tir.LetStmt(n, 10, ib.get()) body = tvm.ir_pass.CanonicalSimplify(body) - assert isinstance(body.body, tvm.stmt.Store) + assert isinstance(body.body, tvm.tir.Store) def test_thread_extent_simplify(): @@ -42,9 +42,9 @@ def test_thread_extent_simplify(): ib.scope_attr(ty, "thread_extent", 1) with ib.if_scope(tx + ty < 12): A[tx] = C[tx + ty] - body = tvm.stmt.LetStmt(n, 10, ib.get()) + body = tvm.tir.LetStmt(n, 10, ib.get()) body = tvm.ir_pass.CanonicalSimplify(body) - assert isinstance(body.body.body.body, tvm.stmt.Store) + assert isinstance(body.body.body.body, tvm.tir.Store) def test_basic_likely_elimination(): diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index bfeb6520e3ee..79b3544f46eb 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -185,19 +185,19 @@ def test_cuda_shuffle(): def my_vectorize(stmt): def vectorizer(op): - if op.for_type == tvm.stmt.For.Vectorized: + if op.for_type == tvm.tir.For.Vectorized: four = tvm.const(4, 'int32') - idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4) + idx = tvm.tir.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4) all_ones = tvm.const(1, 'int32x4') store = op.body value = store.value - new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones) + new_a = tvm.tir.Load('int32x4', value.a.buffer_var, idx, all_ones) bs, ids = [], [] for i in range(4): - bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32'))) + bs.append(tvm.tir.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32'))) ids.append(tvm.const(3 - i, 'int32')) - new_b = tvm.make.Shuffle(bs, ids) - return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) + new_b = tvm.tir.Shuffle(bs, ids) + return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return None return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index c60f3816722c..ca3229389c27 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -29,9 +29,9 @@ def test_llvm_intrin(): tvm.call_pure_intrin("handle", "tvm_address_of", A[0]), 0, 3, 1 ] - ib.emit(tvm.make.Evaluate( - tvm.make.Call( - "int32", "prefetch", args, tvm.expr.Call.Intrinsic, None, 0))) + ib.emit(tvm.tir.Evaluate( + tvm.tir.Call( + "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) body = ib.get() func = tvm.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) fcode = tvm.build(func, None, "llvm") @@ -643,14 +643,14 @@ def my_vectorize(stmt): def vectorizer(op): store = op.body - idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8) + idx = tvm.tir.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8) all_ones = tvm.const(1, 'int32x8') value = store.value - b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)]) - new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones) - new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones) + b_idx = tvm.tir.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)]) + new_a = tvm.tir.Load('int32x8', value.a.buffer_var, idx, all_ones) + new_b = tvm.tir.Load('int32x8', value.b.buffer_var, b_idx, all_ones) value = new_a + new_b - return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones) + return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) diff --git a/tests/python/unittest/test_codegen_opencl.py b/tests/python/unittest/test_codegen_opencl.py index cf89608c1302..3b9b4a73c52d 100644 --- a/tests/python/unittest/test_codegen_opencl.py +++ b/tests/python/unittest/test_codegen_opencl.py @@ -40,7 +40,7 @@ def check_select(ctx, n, dtype): true_value = tvm.const(1, dtype=dtype) false_value = tvm.const(3, dtype=dtype) max_lhs = tvm.const(2, dtype=dtype) - max_rhs = tvm.expr.Select(A[0] > 0, true_value, false_value) + max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value) C = tvm.compute((n,), lambda i: tvm.max(max_lhs, max_rhs), name='C') s = tvm.create_schedule(C.op) s[C].bind(s[C].op.axis[0], tvm.thread_axis("threadIdx.x")) diff --git a/tests/python/unittest/test_codegen_static_init.py b/tests/python/unittest/test_codegen_static_init.py index 3bfe01319a3a..4d71cb3929a7 100644 --- a/tests/python/unittest/test_codegen_static_init.py +++ b/tests/python/unittest/test_codegen_static_init.py @@ -26,7 +26,7 @@ def test_static_callback(): ib = tvm.ir_builder.create() A = ib.buffer_ptr(Ab) cp = tvm.thread_axis((0, 1), "cop") - finit = tvm.make.StringImm("TVMBackendRunOnce") + finit = tvm.tir.StringImm("TVMBackendRunOnce") ib.scope_attr(cp, "coproc_uop_scope", finit) with ib.for_range(0, n, "i", for_type="parallel") as i: A[i] = A[i] + 1 diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index d477983b0979..7f08c75366e6 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -34,7 +34,7 @@ def tvm_call_back_get_shape(shape0): n = tvm.size_var('n') Ab = tvm.decl_buffer((n, ), tvm.float32) - stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) + stmt = tvm.tir.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm") @@ -75,7 +75,7 @@ def test_stack_vm_cond(): ib = tvm.ir_builder.create() A = ib.buffer_ptr(Ab) with ib.for_range(0, n - 1, "i") as i: - with ib.if_scope(tvm.make.EQ(i, 4)): + with ib.if_scope(tvm.tir.EQ(i, 4)): A[i + 1] = A[i] + 1 with ib.else_scope(): A[i + 1] = A[i] + 2 diff --git a/tests/python/unittest/test_codegen_vulkan.py b/tests/python/unittest/test_codegen_vulkan.py index d9e3c4399675..d480a0f6ead8 100644 --- a/tests/python/unittest/test_codegen_vulkan.py +++ b/tests/python/unittest/test_codegen_vulkan.py @@ -31,7 +31,7 @@ def check_correct_assembly(dtype): A = tvm.placeholder(n, dtype=dtype, name='A') B = tvm.compute( A.shape, - lambda i: tvm.expr.Select( + lambda i: tvm.tir.Select( A[i] >= 0, A[i] + tvm.const(1, dtype), tvm.const(0, dtype)), name='B') s = tvm.create_schedule(B.op) diff --git a/tests/python/unittest/test_custom_datatypes_mybfloat16.py b/tests/python/unittest/test_custom_datatypes_mybfloat16.py index 00f9b3329835..cae481353d6b 100644 --- a/tests/python/unittest/test_custom_datatypes_mybfloat16.py +++ b/tests/python/unittest/test_custom_datatypes_mybfloat16.py @@ -18,7 +18,7 @@ import tvm from ctypes import * import topi -import tvm.ir_pass as ir_pass +import tvm.tir.ir_pass as ir_pass import numpy as np tgt = "llvm" @@ -126,7 +126,7 @@ def test_bfloat_add_and_cast_FloatImm(): Z = topi.cast( topi.add( topi.cast(X, dtype="custom[bfloat]16"), - tvm.expr.FloatImm("custom[bfloat]16", 1.5)), + tvm.tir.FloatImm("custom[bfloat]16", 1.5)), dtype="float") s = tvm.create_schedule([Z.op]) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 87e5a26cd443..311dae803dba 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm,)) + assert isinstance(val, (tvm.tir.IntImm,)) return val.value ctx = tvm.context(target, 0) @@ -46,14 +46,14 @@ def tvm_val_2_py_val(val): shape = [tvm_val_2_py_val(j) for j in i.shape] emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) nd_args.append(tvm.nd.array(emu_args[-1], ctx)) - elif isinstance(i, tvm.expr.Var): + elif isinstance(i, tvm.tir.Var): emu_args.append(tvm_val_2_py_val(i)) nd_args.append(emu_args[-1]) else: assert isinstance(i, list) emu_args.append(numpy.array(i)) - compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ + compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))] + \ (outs if isinstance(outs, list) else [outs]) module = tvm.build(sch, compile_args, @@ -76,7 +76,7 @@ def tvm_val_2_py_val(val): for nd, np in zip(out_tensors, ref_data): tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) - module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + module_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.tir.Var))] module_outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs h_module = tvm.hybrid.build(sch, module_args, module_outs) @@ -111,32 +111,32 @@ def test_outer_product(): return #Check for i in (0, n) - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 assert ir.extent.name == 'n' ibody = ir.body - assert isinstance(ibody, tvm.stmt.For) + assert isinstance(ibody, tvm.tir.For) #Check for j in (0, m) assert ibody.loop_var.name == 'j' assert ibody.min.value == 0 assert ibody.extent.name == 'm' #Check loop body jblock = ibody.body - assert isinstance(jblock, tvm.stmt.SeqStmt) + assert isinstance(jblock, tvm.tir.SeqStmt) jbody = jblock[0] - assert isinstance(jbody, tvm.stmt.AssertStmt) - assert isinstance(jbody.message, tvm.expr.StringImm) + assert isinstance(jbody, tvm.tir.AssertStmt) + assert isinstance(jbody.message, tvm.tir.StringImm) assert jbody.message.value == "index out of range!" jbody = jblock[1] - assert isinstance(jbody, tvm.stmt.Provide) + assert isinstance(jbody, tvm.tir.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 assert jbody.args[0].name == 'i' assert jbody.args[1].name == 'j' - assert isinstance(jbody.value, tvm.expr.Mul) + assert isinstance(jbody.value, tvm.tir.Mul) mul = jbody.value - assert isinstance(mul.a, tvm.expr.Call) + assert isinstance(mul.a, tvm.tir.Call) assert mul.a.name == 'a' assert mul.b.name == 'b' @@ -177,21 +177,21 @@ def fanout(n, a): return #Check for i in (0, n-3) - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i' assert ir.min.value == 0 assert tvm.ir_pass.Equal(ir.extent, n - 3) #Check loopbody ibody = ir.body - assert isinstance(ibody, tvm.stmt.AttrStmt) + assert isinstance(ibody, tvm.tir.AttrStmt) abody = ibody.body - assert isinstance(abody, tvm.stmt.Realize) + assert isinstance(abody, tvm.tir.Realize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 assert abody.func.name == 'sigma' #Check i loop body rbody = abody.body - assert isinstance(rbody[0], tvm.stmt.Provide) + assert isinstance(rbody[0], tvm.tir.Provide) assert rbody[0].func.name == 'sigma' assert len(rbody[0].args) == 1 assert rbody[0].args[0].value == 0 @@ -201,13 +201,13 @@ def fanout(n, a): assert jloop.min.value == 0 assert jloop.extent.value == 3 jbody = jloop.body - assert isinstance(jbody, tvm.stmt.Provide) + assert isinstance(jbody, tvm.tir.Provide) assert len(jbody.args) == 1 assert jbody.args[0].value == 0 assert jbody.func.name == 'sigma' - assert isinstance(jbody.value, tvm.expr.Add) + assert isinstance(jbody.value, tvm.tir.Add) value = jbody.value - assert isinstance(value.a, tvm.expr.Call) + assert isinstance(value.a, tvm.tir.Call) assert value.a.name == 'sigma' assert len(value.a.args) == 1 assert value.a.args[0].value == 0 @@ -215,17 +215,17 @@ def fanout(n, a): assert len(value.b.args) == 1 assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) divide= rbody[2] - assert isinstance(divide, tvm.stmt.Provide) + assert isinstance(divide, tvm.tir.Provide) assert len(divide.args) == 1 assert divide.args[0].value == 0 value = divide.value - assert isinstance(value, tvm.expr.Mul) + assert isinstance(value, tvm.tir.Mul) assert value.a.name == 'sigma' assert len(value.a.args) == 1 assert value.a.args[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 write = rbody[3] - assert isinstance(write, tvm.stmt.Provide) + assert isinstance(write, tvm.tir.Provide) assert write.func.name == 'b' assert write.value.name == 'sigma' assert len(write.value.args) == 1 @@ -260,9 +260,9 @@ def looptype(a, b, c): iloop = ir[0] jloop = ir[1] kloop = ir[2] - assert iloop.for_type == tvm.stmt.For.Parallel - assert jloop.for_type == tvm.stmt.For.Vectorized - assert kloop.for_type == tvm.stmt.For.Unrolled + assert iloop.for_type == tvm.tir.For.Parallel + assert jloop.for_type == tvm.tir.For.Vectorized + assert kloop.for_type == tvm.tir.For.Unrolled func, ins, outs = run_and_check(looptype, [a, b, c]) run_and_check(func, ins, outs=outs) @@ -364,7 +364,7 @@ def foo(a): c = foo(a) s = tvm.create_schedule(c.op) ir = tvm.lower(s, [a, c], simple_mode=True) - assert not isinstance(ir, tvm.stmt.AttrStmt) + assert not isinstance(ir, tvm.tir.AttrStmt) func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') @@ -729,20 +729,20 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c], simple_mode=True) - assert isinstance(ir, tvm.stmt.ProducerConsumer) + assert isinstance(ir, tvm.tir.ProducerConsumer) ir = ir.body - assert isinstance(ir, tvm.stmt.AttrStmt) + assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i.inner' ir = ir.body - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i.outer' ir = ir.body - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'j.outer.outer' ir = ir.body - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'j.outer.inner' ir = ir.body func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) @@ -752,11 +752,11 @@ def outer_product(a, b): sch = tvm.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c], simple_mode=True) - assert isinstance(ir, tvm.stmt.ProducerConsumer) + assert isinstance(ir, tvm.tir.ProducerConsumer) ir = ir.body - assert isinstance(ir, tvm.stmt.AttrStmt) + assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body - assert isinstance(ir, tvm.stmt.For) + assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == 'i.j.fused' func, ins, outs = run_and_check(outer_product, [a, b], sch=sch, outs=[c]) run_and_check(func, ins, outs=outs) diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index 748662918edb..5679625e7799 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -28,14 +28,14 @@ def test_for(): body = ib.get() print(body) - assert isinstance(body, tvm.stmt.AttrStmt) + assert isinstance(body, tvm.tir.AttrStmt) body = body.body - assert isinstance(body, tvm.stmt.Allocate) + assert isinstance(body, tvm.tir.Allocate) body = body.body - assert isinstance(body, tvm.stmt.For) + assert isinstance(body, tvm.tir.For) body = body.body - assert isinstance(body, tvm.stmt.SeqStmt) - assert isinstance(body[1], tvm.stmt.For) + assert isinstance(body, tvm.tir.SeqStmt) + assert isinstance(body[1], tvm.tir.For) def test_if(): ib = tvm.ir_builder.create() @@ -50,11 +50,11 @@ def test_if(): body = ib.get() assert A == A - assert isinstance(body, tvm.stmt.For) + assert isinstance(body, tvm.tir.For) body = body.body - assert isinstance(body, tvm.stmt.IfThenElse) - assert isinstance(body.condition, tvm.expr.EQ) - assert isinstance(body.then_case.index, tvm.expr.Var) + assert isinstance(body, tvm.tir.IfThenElse) + assert isinstance(body.condition, tvm.tir.EQ) + assert isinstance(body.then_case.index, tvm.tir.Var) assert body.else_case.index.value == 0 def test_prefetch(): @@ -64,10 +64,10 @@ def test_prefetch(): with ib.for_range(0, n, name="i") as i: ib.emit( - tvm.make.Prefetch( + tvm.tir.Prefetch( A.op, A.value_index, A.dtype, - [tvm.make.range_by_min_extent(i+1, 2), - tvm.make.range_by_min_extent(0, 20)])) + [tvm.ir.Range.make_by_min_extent(i+1, 2), + tvm.ir.Range.make_by_min_extent(0, 20)])) body = ib.get() assert body.body.bounds[0].extent.value == 2 diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 0015d6d2cd8d..733992595562 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -22,7 +22,7 @@ def test_const(): x = tvm.const(1, "int32") print(x.dtype) assert x.dtype == tvm.int32 - assert isinstance(x, tvm.expr.IntImm) + assert isinstance(x, tvm.tir.IntImm) def test_scalar_dtype_inference(): @@ -45,47 +45,47 @@ def test_make(): x = tvm.const(1, "int32") y = tvm.var("x") z = x + y - assert isinstance(tvm.max(x, y), tvm.expr.Max) - assert isinstance(tvm.min(x, y), tvm.expr.Min) + assert isinstance(tvm.max(x, y), tvm.tir.Max) + assert isinstance(tvm.min(x, y), tvm.tir.Min) def test_ir(): x = tvm.const(1, "int32") - y = tvm.make.IntImm('int32', 1) + y = tvm.tir.IntImm('int32', 1) z = x + y - stmt = tvm.make.Evaluate(z) - assert isinstance(stmt, tvm.stmt.Evaluate) + stmt = tvm.tir.Evaluate(z) + assert isinstance(stmt, tvm.tir.Evaluate) def test_ir2(): x = tvm.var("n") a = tvm.var("array", tvm.handle) - st = tvm.make.Store(a, x + 1, 1) - assert isinstance(st, tvm.stmt.Store) + st = tvm.tir.Store(a, x + 1, 1) + assert isinstance(st, tvm.tir.Store) assert(st.buffer_var == a) def test_let(): x = tvm.var('x') y = tvm.var('y') - stmt = tvm.make.LetStmt( - x, 10, tvm.make.Evaluate(x + 1)); + stmt = tvm.tir.LetStmt( + x, 10, tvm.tir.Evaluate(x + 1)); def test_cast(): x = tvm.var('x', dtype="float32") y = x.astype("int32") z = x.astype("float32x4") - assert isinstance(y, tvm.expr.Cast) - assert isinstance(z, tvm.expr.Broadcast) + assert isinstance(y, tvm.tir.Cast) + assert isinstance(z, tvm.tir.Broadcast) assert z.lanes == 4 def test_attr(): x = tvm.var('x') y = tvm.var('y') - stmt = tvm.make.AttrStmt( - y, "stride", 10, tvm.make.Evaluate(x + 1)); + stmt = tvm.tir.AttrStmt( + y, "stride", 10, tvm.tir.Evaluate(x + 1)); assert stmt.node == y a = tvm.convert(1) @@ -105,9 +105,9 @@ def test_basic(): def test_stmt(): - x = tvm.make.Evaluate(0) - tvm.make.For(tvm.var('i'), 0, 1, - tvm.stmt.For.Serial, 0, + x = tvm.tir.Evaluate(0) + tvm.tir.For(tvm.var('i'), 0, 1, + tvm.tir.For.Serial, 0, x) @@ -207,7 +207,7 @@ def test_equality(): def test_equality_string_imm(): x = 'a' - y = tvm.make.StringImm(x) + y = tvm.tir.StringImm(x) x == y.value x == y diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index c4187858a8a8..4ce7e872dc36 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -17,50 +17,50 @@ import tvm def test_expr_constructor(): - x = tvm.expr.Var("xx", "float32") - assert isinstance(x, tvm.expr.Var) + x = tvm.tir.Var("xx", "float32") + assert isinstance(x, tvm.tir.Var) assert x.name == "xx" - x = tvm.expr.Reduce(None, [1], + x = tvm.tir.Reduce(None, [1], [tvm.api._IterVar((0, 1), "x", 2)], None, 0) - assert isinstance(x, tvm.expr.Reduce) + assert isinstance(x, tvm.tir.Reduce) assert x.combiner == None assert x.value_index == 0 - x = tvm.expr.FloatImm("float32", 1.0) - assert isinstance(x, tvm.expr.FloatImm) + x = tvm.tir.FloatImm("float32", 1.0) + assert isinstance(x, tvm.tir.FloatImm) assert x.value == 1.0 assert x.dtype == "float32" - x = tvm.expr.IntImm("int64", 2) - assert isinstance(x, tvm.expr.IntImm) + x = tvm.tir.IntImm("int64", 2) + assert isinstance(x, tvm.tir.IntImm) assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.StringImm("xyza") - assert isinstance(x, tvm.expr.StringImm) + x = tvm.tir.StringImm("xyza") + assert isinstance(x, tvm.tir.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) - assert isinstance(x, tvm.expr.Cast) + x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1)) + assert isinstance(x, tvm.tir.Cast) assert x.dtype == "float32" assert x.value.value == 1 a = tvm.const(1.0, dtype="float32") b = tvm.var("x", dtype="float32") - for cls in [tvm.expr.Add, - tvm.expr.Sub, - tvm.expr.Mul, - tvm.expr.Div, - tvm.expr.Mod, - tvm.expr.Min, - tvm.expr.Max, - tvm.expr.LT, - tvm.expr.LE, - tvm.expr.GT, - tvm.expr.GE]: + for cls in [tvm.tir.Add, + tvm.tir.Sub, + tvm.tir.Mul, + tvm.tir.Div, + tvm.tir.Mod, + tvm.tir.Min, + tvm.tir.Max, + tvm.tir.LT, + tvm.tir.LE, + tvm.tir.GT, + tvm.tir.GE]: x = cls(a, b) assert isinstance(x, cls) assert x.a == a @@ -70,58 +70,58 @@ def test_expr_constructor(): a = tvm.convert(tvm.var("x") > 1) b = tvm.convert(tvm.var("x") == 1) - for cls in [tvm.expr.And, - tvm.expr.Or]: + for cls in [tvm.tir.And, + tvm.tir.Or]: x = cls(a, b) assert isinstance(x, cls) assert x.a == a assert x.b.same_as(b) - x = tvm.expr.Not(a) - assert isinstance(x, tvm.expr.Not) + x = tvm.tir.Not(a) + assert isinstance(x, tvm.tir.Not) assert x.a == a - x = tvm.expr.Select(a, a, b) - assert isinstance(x, tvm.expr.Select) + x = tvm.tir.Select(a, a, b) + assert isinstance(x, tvm.tir.Select) assert x.true_value == a assert x.false_value == b assert x.condition == a buffer_var = tvm.var("x", dtype="handle") - x = tvm.expr.Load("float32", buffer_var, 1, a) - assert isinstance(x, tvm.expr.Load) + x = tvm.tir.Load("float32", buffer_var, 1, a) + assert isinstance(x, tvm.tir.Load) assert x.dtype == "float32" assert x.buffer_var == buffer_var assert x.index.value == 1 assert x.predicate == a - x = tvm.expr.Ramp(1, 2, 10) - assert isinstance(x, tvm.expr.Ramp) + x = tvm.tir.Ramp(1, 2, 10) + assert isinstance(x, tvm.tir.Ramp) assert x.base.value == 1 assert x.stride.value == 2 assert x.lanes == 10 - x = tvm.expr.Broadcast(a, 10) - assert isinstance(x, tvm.expr.Broadcast) + x = tvm.tir.Broadcast(a, 10) + assert isinstance(x, tvm.tir.Broadcast) assert x.value == a assert x.lanes == 10 - x = tvm.expr.Shuffle([a], [0]) - assert isinstance(x, tvm.expr.Shuffle) + x = tvm.tir.Shuffle([a], [0]) + assert isinstance(x, tvm.tir.Shuffle) assert x.vectors[0] == a assert x.indices[0].value == 0 - x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0) - assert isinstance(x, tvm.expr.Call) + x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0) + assert isinstance(x, tvm.tir.Call) assert x.dtype == "float32" assert x.name == "xyz" assert x.args[0] == a - assert x.call_type == tvm.expr.Call.Extern + assert x.call_type == tvm.tir.Call.Extern assert x.func == None assert x.value_index == 0 v = tvm.var("aa") - x = tvm.expr.Let(v, 1, v) + x = tvm.tir.Let(v, 1, v) assert x.var == v assert x.value.value == 1 assert x.body == v @@ -130,75 +130,75 @@ def test_expr_constructor(): def test_stmt_constructor(): v = tvm.var("aa") buffer_var = tvm.var("buf", dtype="handle") - nop = tvm.stmt.Evaluate(1) - x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1)) - assert isinstance(x, tvm.stmt.LetStmt) + nop = tvm.tir.Evaluate(1) + x = tvm.tir.LetStmt(v, 1, tvm.tir.Evaluate(1)) + assert isinstance(x, tvm.tir.LetStmt) assert x.var == v assert x.value.value == 1 - assert isinstance(x.body, tvm.stmt.Evaluate) + assert isinstance(x.body, tvm.tir.Evaluate) - x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1)) - assert isinstance(x, tvm.stmt.AttrStmt) + x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1)) + assert isinstance(x, tvm.tir.AttrStmt) assert x.value.value == 1 - x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"), + x = tvm.tir.AssertStmt(tvm.const(1, "uint1"), tvm.convert("hellow"), nop) - assert isinstance(x, tvm.stmt.AssertStmt) + assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop - x = tvm.stmt.ProducerConsumer(None, True, nop) - assert isinstance(x, tvm.stmt.ProducerConsumer) + x = tvm.tir.ProducerConsumer(None, True, nop) + assert isinstance(x, tvm.tir.ProducerConsumer) assert x.body == nop - x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop) - assert isinstance(x, tvm.stmt.For) + x = tvm.tir.For(tvm.var("x"), 0, 10, 0, 0, nop) + assert isinstance(x, tvm.tir.For) assert x.min.value == 0 assert x.extent.value == 10 assert x.body == nop - x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1")) - assert isinstance(x, tvm.stmt.Store) + x = tvm.tir.Store(buffer_var, 1, 10, tvm.const(1, "uint1")) + assert isinstance(x, tvm.tir.Store) assert x.buffer_var == buffer_var assert x.index.value == 10 assert x.value.value == 1 tensor = tvm.placeholder((), dtype="float32") - x = tvm.stmt.Provide(tensor.op, 0, 10, []) - assert isinstance(x, tvm.stmt.Provide) + x = tvm.tir.Provide(tensor.op, 0, 10, []) + assert isinstance(x, tvm.tir.Provide) assert x.value_index == 0 assert x.value.value == 10 - x = tvm.stmt.Allocate(buffer_var, "float32", [10], + x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.const(1, "uint1"), nop) - assert isinstance(x, tvm.stmt.Allocate) + assert isinstance(x, tvm.tir.Allocate) assert x.dtype == "float32" assert x.buffer_var == buffer_var assert x.body == nop - x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop) - assert isinstance(x, tvm.stmt.AttrStmt) + x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop) + assert isinstance(x, tvm.tir.AttrStmt) assert x.node == buffer_var assert x.attr_key == "xyz" assert x.body == nop - x = tvm.stmt.Free(buffer_var) - assert isinstance(x, tvm.stmt.Free) + x = tvm.tir.Free(buffer_var) + assert isinstance(x, tvm.tir.Free) assert x.buffer_var == buffer_var - x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop) - assert isinstance(x, tvm.stmt.Realize) + x = tvm.tir.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop) + assert isinstance(x, tvm.tir.Realize) assert x.body == nop - x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"), - tvm.stmt.Evaluate(11), + x = tvm.tir.IfThenElse(tvm.const(1, "uint1"), + tvm.tir.Evaluate(11), nop) - assert isinstance(x, tvm.stmt.IfThenElse) + assert isinstance(x, tvm.tir.IfThenElse) assert x.then_case.value.value == 11 assert x.else_case == nop - x = tvm.stmt.Prefetch(None, 1, "float32", []) - assert isinstance(x, tvm.stmt.Prefetch) + x = tvm.tir.Prefetch(None, 1, "float32", []) + assert isinstance(x, tvm.tir.Prefetch) assert x.value_index == 1 diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index 4f8a93b8fbd3..0b9fad9a2d20 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -69,7 +69,7 @@ def test_map_save_load_json(): def test_in_container(): arr = tvm.convert(['a', 'b', 'c']) assert 'a' in arr - assert tvm.make.StringImm('a') in arr + assert tvm.tir.StringImm('a') in arr assert 'd' not in arr def test_ndarray_container(): diff --git a/tests/python/unittest/test_lang_data_layout.py b/tests/python/unittest/test_lang_data_layout.py index cde4a813d89a..4c1cafcf3d67 100644 --- a/tests/python/unittest/test_lang_data_layout.py +++ b/tests/python/unittest/test_lang_data_layout.py @@ -20,9 +20,9 @@ from topi.util import get_const_tuple def test_layout(): - layout = tvm.layout("NCHW16c") + layout = tvm.tir.layout("NCHW16c") assert layout is not None - assert isinstance(layout, tvm.tensor.Layout) + assert isinstance(layout, tvm.tir.Layout) assert layout.factor_of("c") == 16 assert layout.factor_of("C") == 16 @@ -63,7 +63,7 @@ def test_bilayout_convertible(): def test_bilayout_shape(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") - assert isinstance(bilayout, tvm.tensor.BijectiveLayout) + assert isinstance(bilayout, tvm.tir.BijectiveLayout) dst_shape = bilayout.forward_shape((1, 32, 7, 7)) assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16) diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index 26783e62db13..d32b4c51ef69 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): + if not isinstance(x, (tvm.tir.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod @@ -56,7 +56,7 @@ def test_const_fold2(): assert tmod(x, 1).value == 0 assert (x * 1).same_as(x) assert (1 * x).same_as(x) - assert isinstance(tdiv(1, x), tvm.expr.Div) + assert isinstance(tdiv(1, x), tvm.tir.Div) def test_const_fold3(): # Test that using ints with logic operations is forbidden @@ -92,17 +92,17 @@ def test_const_fold4(): x1 = tvm.const(4, "int32") x2 = x1 + 5 tdiv = tvm.truncdiv - assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 + assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9 x3 = tdiv(x2, 3) - assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 + assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3 x4 = x3 + 0.55 - assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6 + assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6 x5 = tvm.ceil(x4) - assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 + assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4 x6 = x5.astype('int') - assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6) + assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, "x6={}".format(x6) y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') - assert isinstance(y, tvm.expr.IntImm) and y.value == 6 + assert isinstance(y, tvm.tir.IntImm) and y.value == 6 def test_binary_dtype_match(): diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_lang_reflection.py index b971e386cfc7..e97e73a1d1cc 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_lang_reflection.py @@ -31,7 +31,7 @@ def test_make_smap(): # save load json x = tvm.const(1, "int32") y = tvm.const(10, "int32") - z = tvm.expr.Add(x, y) + z = tvm.tir.Add(x, y) smap = tvm.convert({"z": z, "x": x}) json_str = tvm.ir.save_json(tvm.convert([smap])) arr = tvm.ir.load_json(json_str) @@ -40,11 +40,11 @@ def test_make_smap(): def test_make_node(): - x = tvm.make.node("IntImm", dtype="int32", value=10) - assert isinstance(x, tvm.expr.IntImm) + x = tvm.ir.make_node("IntImm", dtype="int32", value=10) + assert isinstance(x, tvm.tir.IntImm) assert x.value == 10 A = tvm.placeholder((10, ), name='A') - AA = tvm.make.node("Tensor", + AA = tvm.ir.make_node("Tensor", shape=A.shape, dtype=A.dtype, op=A.op, @@ -55,25 +55,25 @@ def test_make_node(): def test_make_attrs(): try: - x = tvm.make.node("attrs.TestAttrs", unknown_key=1, name="xx") + x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx") assert False except tvm.error.TVMError as e: assert str(e).find("unknown_key") != -1 try: - x = tvm.make.node("attrs.TestAttrs", axis=100, name="xx") + x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx") assert False except tvm.error.TVMError as e: assert str(e).find("upper bound") != -1 - x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4)) + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4)) assert x.name == "xx" assert x.padding[0].value == 3 assert x.padding[1].value == 4 assert x.axis == 10 - dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) + dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert dattr.x.value == 1 datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) assert dattr.name.value == "xyz" @@ -104,7 +104,7 @@ def test(x): assert y(1) == 2 assert y.func(1) == 2 - x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y) + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4), func=y) assert x.name == "xx" assert x.padding[0].value == 3 assert x.padding[1].value == 4 diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 6b5b7fa2be67..10843f993d06 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -240,7 +240,7 @@ def intrin_func(ins, outs, sp): C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, C], simple_mode=True) - assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.body.value.args) == 5 assert str(stmt.body.body.body.value.args[3]) == "(i*i)" assert str(stmt.body.body.body.value.args[4]) == "(i + j)" diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index a8b5fc094cca..2de5e19c9e36 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -128,7 +128,7 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body, tvm.tir.Evaluate) def test_tensor_compute2(): M = 2048 @@ -171,8 +171,8 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate) - assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate) + assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate) def test_tensor_scan(): m = tvm.size_var("m") @@ -259,7 +259,7 @@ def test_tuple_with_different_deps(): stmt = tvm.schedule.ScheduleOps(sch, bounds) def get_B1_realize(x): - if isinstance(x, tvm.stmt.Realize) and \ + if isinstance(x, tvm.tir.Realize) and \ x.func == B1.op and x.value_index == 1: ret.append(x) ret = [] diff --git a/tests/python/unittest/test_lang_tensor_overload_op.py b/tests/python/unittest/test_lang_tensor_overload_op.py index d205b55a5156..e1a17c9203ef 100644 --- a/tests/python/unittest/test_lang_tensor_overload_op.py +++ b/tests/python/unittest/test_lang_tensor_overload_op.py @@ -29,8 +29,8 @@ def test_operator_type_and_tags(): B1 = B[0] B2 = B[0,0] - assert isinstance(k + n, tvm.expr.PrimExpr) - assert isinstance(n + n, tvm.expr.PrimExpr) + assert isinstance(k + n, tvm.tir.PrimExpr) + assert isinstance(n + n, tvm.tir.PrimExpr) assert isinstance(k + A, tvm.tensor.Tensor) assert isinstance(A + k, tvm.tensor.Tensor) assert isinstance(n + A, tvm.tensor.Tensor) @@ -53,11 +53,11 @@ def test_operator_type_and_tags(): assert (B + A).op.tag == topi.tag.BROADCAST assert (B + B).op.tag == topi.tag.BROADCAST - assert isinstance(k + B2, tvm.expr.PrimExpr) - assert isinstance(B2 + k, tvm.expr.PrimExpr) - assert isinstance(n + B2, tvm.expr.PrimExpr) - assert isinstance(B2 + n, tvm.expr.PrimExpr) - assert isinstance(B2 + B2, tvm.expr.PrimExpr) + assert isinstance(k + B2, tvm.tir.PrimExpr) + assert isinstance(B2 + k, tvm.tir.PrimExpr) + assert isinstance(n + B2, tvm.tir.PrimExpr) + assert isinstance(B2 + n, tvm.tir.PrimExpr) + assert isinstance(B2 + B2, tvm.tir.PrimExpr) assert isinstance(B2 + A, tvm.tensor.Tensor) assert isinstance(A + B2, tvm.tensor.Tensor) assert isinstance(B2 + B, tvm.tensor.Tensor) diff --git a/tests/python/unittest/test_pass_attrs_hash_equal.py b/tests/python/unittest/test_pass_attrs_hash_equal.py index bb4c196ddc71..2bd94e0d5cab 100644 --- a/tests/python/unittest/test_pass_attrs_hash_equal.py +++ b/tests/python/unittest/test_pass_attrs_hash_equal.py @@ -17,15 +17,15 @@ import tvm def test_attrs_equal(): - x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) - z = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4,1)) + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) + y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) + z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1)) assert tvm.ir_pass.AttrsEqual(x, y) assert not tvm.ir_pass.AttrsEqual(x, z) - dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) + dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert not tvm.ir_pass.AttrsEqual(dattr, x) - dattr2 = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) + dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert tvm.ir_pass.AttrsEqual(dattr, dattr2) assert tvm.ir_pass.AttrsEqual({"x": x}, {"x": y}) @@ -42,8 +42,8 @@ def test_attrs_equal(): def test_attrs_hash(): fhash = tvm.ir_pass.AttrsHash - x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3, 4)) + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) + y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) assert fhash({"x": x}) == fhash({"x": y}) assert fhash({"x": x}) != fhash({"x": [y, 1]}) assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]}) diff --git a/tests/python/unittest/test_pass_basic.py b/tests/python/unittest/test_pass_basic.py index 8f35611c8c8b..93c815a4a21b 100644 --- a/tests/python/unittest/test_pass_basic.py +++ b/tests/python/unittest/test_pass_basic.py @@ -31,16 +31,16 @@ def test_simplify(): def test_verify_ssa(): x = tvm.var('x') y = tvm.var() - z = tvm.make.Evaluate(x + y) + z = tvm.tir.Evaluate(x + y) assert(tvm.ir_pass.VerifySSA(z)) def test_convert_ssa(): x = tvm.var('x') y = tvm.var() - let1 = tvm.make.Let(x, 1, x + 1) - let2 = tvm.make.Let(x, 1, x + y) - z = tvm.make.Evaluate(let1 + let2) + let1 = tvm.tir.Let(x, 1, x + 1) + let2 = tvm.tir.Let(x, 1, x + y) + z = tvm.tir.Evaluate(let1 + let2) assert(not tvm.ir_pass.VerifySSA(z)) z_ssa = tvm.ir_pass.ConvertSSA(z) assert(tvm.ir_pass.VerifySSA(z_ssa)) diff --git a/tests/python/unittest/test_pass_bound_checkers.py b/tests/python/unittest/test_pass_bound_checkers.py index e62e539178eb..6b959e0d8da7 100644 --- a/tests/python/unittest/test_pass_bound_checkers.py +++ b/tests/python/unittest/test_pass_bound_checkers.py @@ -166,12 +166,12 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): def test_in_bounds_const_loop_partition_ir(): def check_attr_stmt (x): - if isinstance(x, tvm.stmt.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n): + if isinstance(x, tvm.tir.AttrStmt) and x.attr_key == "buffer_bound" and str(x.value) == str(n): return True return False def check_branch_stmt (x): - if isinstance(x, tvm.stmt.IfThenElse): + if isinstance(x, tvm.tir.IfThenElse): return True return False @@ -183,7 +183,7 @@ def assert_bound_instrumentation(stmt, f, nums): assert (count == nums) def collect_branch_stmt (x): - if isinstance(x, tvm.stmt.IfThenElse): + if isinstance(x, tvm.tir.IfThenElse): branch_collector.append(x) n = 21 diff --git a/tests/python/unittest/test_pass_combine_context_call.py b/tests/python/unittest/test_pass_combine_context_call.py index a25568f719d2..ef741a4bff7b 100644 --- a/tests/python/unittest/test_pass_combine_context_call.py +++ b/tests/python/unittest/test_pass_combine_context_call.py @@ -20,8 +20,8 @@ def test_for(): dev_type = tvm.var("dev_type") def device_context(dev_id): ctx = tvm.call_extern("handle", "device_context", dev_type, dev_id) - return tvm.make.Call( - "handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0) + return tvm.tir.Call( + "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0) ib = tvm.ir_builder.create() n = tvm.var("n") diff --git a/tests/python/unittest/test_pass_decorate_device_scope.py b/tests/python/unittest/test_pass_decorate_device_scope.py index 9ffd56544ebc..b464354e008a 100644 --- a/tests/python/unittest/test_pass_decorate_device_scope.py +++ b/tests/python/unittest/test_pass_decorate_device_scope.py @@ -33,7 +33,7 @@ def test_decorate_device(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt1 = tvm.ir_pass.Simplify(stmt) stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1) - assert isinstance(stmt2, tvm.stmt.AttrStmt) + assert isinstance(stmt2, tvm.tir.AttrStmt) assert stmt2.attr_key == "device_scope" assert stmt1 == stmt2.body diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py index 4a28cf6b318a..2eb641b0cd90 100644 --- a/tests/python/unittest/test_pass_hoist_if.py +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -24,19 +24,19 @@ def verify_structure(stmt, expected_struct): struct = {} def _extract_vars(op): global var_list - if isinstance(op, tvm.expr.Var): + if isinstance(op, tvm.tir.Var): var_list.append(op.name) def _visit(op): key = op - if isinstance(op, tvm.stmt.IfThenElse): + if isinstance(op, tvm.tir.IfThenElse): global var_list tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] var_list.clear() - elif isinstance(op, tvm.stmt.For): + elif isinstance(op, tvm.tir.For): val = [(op.body,), ("For", op.loop_var.name)] - elif isinstance(op, tvm.stmt.AttrStmt): + elif isinstance(op, tvm.tir.AttrStmt): val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] else: return @@ -61,9 +61,9 @@ def test_basic(): with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.make.Evaluate(m)) + ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): - ib.emit(tvm.make.Evaluate(n)) + ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) @@ -82,7 +82,7 @@ def test_no_else(): with ib.for_range(0, m, "j") as j: with ib.for_range(0, n, "k") as k: with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.make.Evaluate(m)) + ib.emit(tvm.tir.Evaluate(m)) stmt = ib.get() new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 858b1e8a9153..f49388db3eb2 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -33,7 +33,7 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert dst.strides[1].value == 1 assert src.strides[0] == l assert tuple(src.shape) == (m, l) - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) def test_copy_pad(): @@ -57,7 +57,7 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert pad_after[0].value == 1 assert pad_after[1].value == 0 assert pad_value.value == 1.0 - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) def test_single_point_test(): @@ -76,7 +76,7 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0 assert tvm.ir_pass.Simplify(src.strides[0]).value == 1 assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1 - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) def assert_expr_equal(a, b): @@ -109,7 +109,7 @@ def cb(src, dst, pad_before, pad_after, pad_value): assert_expr_equal(pad_before[0], rpad_before) assert_expr_equal(pad_after[0], rpad_after) assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) diff --git a/tests/python/unittest/test_pass_inject_double_buffer.py b/tests/python/unittest/test_pass_inject_double_buffer.py index aa569cea8665..cf8f78c8090d 100644 --- a/tests/python/unittest/test_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_pass_inject_double_buffer.py @@ -37,13 +37,13 @@ def test_double_buffer(): stmt = ib.get() stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.ir_pass.Simplify(stmt) - assert isinstance(stmt.body.body, tvm.stmt.Allocate) + assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op): - if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 tvm.ir_pass.PostOrderVisit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_pass_inline.py b/tests/python/unittest/test_pass_inline.py index e87353ed98a1..521a6f99e026 100644 --- a/tests/python/unittest/test_pass_inline.py +++ b/tests/python/unittest/test_pass_inline.py @@ -20,7 +20,7 @@ def test_inline(): m = tvm.size_var('m') A = tvm.placeholder((m,), name='A') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) + stmt = tvm.tir.Evaluate(T[10] + 11 * T[100]) stmt = tvm.ir_pass.Inline( stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) print(stmt) @@ -39,11 +39,11 @@ def test_inline2(): m = tvm.size_var('m') A = tvm.placeholder((m,), name='A') T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) + stmt = tvm.tir.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.ir_pass.Inline( stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) def check(op): - if isinstance(op, tvm.expr.Call): + if isinstance(op, tvm.tir.Call): assert op.func != T.op tvm.ir_pass.PostOrderVisit(stmt, check) diff --git a/tests/python/unittest/test_pass_ir_transform.py b/tests/python/unittest/test_pass_ir_transform.py index 098e0d7700cb..b024a3c8d5b9 100644 --- a/tests/python/unittest/test_pass_ir_transform.py +++ b/tests/python/unittest/test_pass_ir_transform.py @@ -32,12 +32,12 @@ def preorder(op): return None def postorder(op): - assert isinstance(op, tvm.expr.Call) + assert isinstance(op, tvm.tir.Call) if op.name == "TestA": return tvm.call_extern("int32", "TestB", op.args[0] + 1) return op body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) - stmt_list = tvm.make.stmt_list(body.body.body) + stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0 diff --git a/tests/python/unittest/test_pass_lift_attr_scope.py b/tests/python/unittest/test_pass_lift_attr_scope.py index b281e17bc633..181f4ef57a4f 100644 --- a/tests/python/unittest/test_pass_lift_attr_scope.py +++ b/tests/python/unittest/test_pass_lift_attr_scope.py @@ -20,7 +20,7 @@ def test_coproc_lift(): ib = tvm.ir_builder.create() n = tvm.var("n") cp = tvm.thread_axis((0, 1), "cop") - value = tvm.make.StringImm("xxx") + value = tvm.tir.StringImm("xxx") A = ib.allocate("float32", n, name="A", scope="global") with ib.for_range(0, n, name="i") as i: diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 9812660d2ad1..e9df98e43d79 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -24,7 +24,7 @@ def collect_visit(stmt, f): def find_top_produce(stmt): def f(x, ret): - if isinstance(x, tvm.stmt.ProducerConsumer): + if isinstance(x, tvm.tir.ProducerConsumer): ret.append(x) ret = [] tvm.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret)) @@ -90,13 +90,13 @@ def test_multi_loop(): with ib.for_range(0, n, "j") as j: with ib.for_range(0, m, "k") as k: with ib.if_scope(ib.likely(i*m+j+k < n)): - ib.emit(tvm.make.Evaluate(m)) + ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): - ib.emit(tvm.make.Evaluate(n)) + ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) def test_multi_if(): ib = tvm.ir_builder.create() @@ -106,13 +106,13 @@ def test_multi_if(): with ib.for_range(0, n, 'j') as j: with ib.for_range(0, m, 'k') as k: with ib.if_scope(ib.likely(i*m+j+k < n)): - ib.emit(tvm.make.Evaluate(m)) + ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): - ib.emit(tvm.make.Evaluate(n)) + ib.emit(tvm.tir.Evaluate(n)) with ib.if_scope(ib.likely(i*m+j-k < n)): - ib.emit(tvm.make.Evaluate(m)) + ib.emit(tvm.tir.Evaluate(m)) with ib.else_scope(): - ib.emit(tvm.make.Evaluate(n)) + ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) @@ -157,7 +157,7 @@ def test_vectorize(): stmt = lower(s, [A, B]) body = stmt.body.body.body.body.body assert(x.var.name not in str(body.condition)) - assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp)))) + assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))) def test_condition(): ib = tvm.ir_builder.create() @@ -165,24 +165,24 @@ def test_condition(): n = tvm.size_var('n') with ib.for_range(0, tvm.truncdiv(n+3,4), 'i') as i: with ib.for_range(0, 4, 'j') as j: - ib.emit(tvm.make.Evaluate( - tvm.make.Select(ib.likely(i*4+j= 0 - res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floordiv(x, y), zero)) + res = lower_intrin(tvm.tir.Select(y >= 0, tvm.floordiv(x, y), zero)) check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0) # involves max - res = lower_intrin(tvm.expr.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero)) + res = lower_intrin(tvm.tir.Select(y >= 0, tvm.max(tvm.floordiv(x, y), zero), zero)) check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0) # lhs >= 0 - res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero)) + res = lower_intrin(tvm.tir.Select(tvm.all(y >= 0, x >= 0), tvm.floordiv(x, y), zero)) check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0) # const power of two res = lower_intrin(tvm.floordiv(x, tvm.const(8, dtype=dtype))) @@ -95,10 +95,10 @@ def test_lower_floormod(): res = lower_intrin(tvm.floormod(x, y)) check_value(res, x, y, data, lambda a, b: a % b) # rhs >= 0 - res = lower_intrin(tvm.expr.Select(y >= 0, tvm.floormod(x, y), zero)) + res = lower_intrin(tvm.tir.Select(y >= 0, tvm.floormod(x, y), zero)) check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0) # lhs >= 0 - res = lower_intrin(tvm.expr.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero)) + res = lower_intrin(tvm.tir.Select(tvm.all(y >= 0, x >= 0), tvm.floormod(x, y), zero)) check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0) # const power of two res = lower_intrin(tvm.floormod(x, tvm.const(8, dtype=dtype))) diff --git a/tests/python/unittest/test_pass_remove_no_op.py b/tests/python/unittest/test_pass_remove_no_op.py index d287b8591fb3..a3927f7db49d 100644 --- a/tests/python/unittest/test_pass_remove_no_op.py +++ b/tests/python/unittest/test_pass_remove_no_op.py @@ -17,7 +17,7 @@ import tvm def nop(): - return tvm.stmt.Evaluate(0) + return tvm.tir.Evaluate(0) def test_remove_no_op(): i = tvm.var('i') @@ -27,25 +27,25 @@ def test_remove_no_op(): n = tvm.var('n') dtype = 'int64' Ab = tvm.decl_buffer((n, ), dtype) - stmt = tvm.make.For( + stmt = tvm.tir.For( i, 0, 4, 0, 0, - tvm.make.For( + tvm.tir.For( j, 0, n, 0, 0, - tvm.make.For( + tvm.tir.For( k, 0, m, 0, 0, - tvm.make.IfThenElse( - (i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) + tvm.tir.IfThenElse( + (i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n))))) ret = tvm.ir_pass.RemoveNoOp(stmt) - assert(isinstance(ret, tvm.stmt.Evaluate)) - store = tvm.make.Store(Ab.data, - tvm.make.Load(dtype, Ab.data, i) + 1, + assert(isinstance(ret, tvm.tir.Evaluate)) + store = tvm.tir.Store(Ab.data, + tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1) - stmt2 = tvm.stmt.SeqStmt([nop(), tvm.stmt.SeqStmt([store, nop()])]) + stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])]) assert(tvm.ir_pass.RemoveNoOp(stmt2) == store) # remove zero extent loop - stmt3 = tvm.make.For(i, 0, 0, 0, 0, store) + stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store) ret = tvm.ir_pass.RemoveNoOp(stmt3) - assert(isinstance(ret, tvm.stmt.Evaluate)) + assert(isinstance(ret, tvm.tir.Evaluate)) if __name__ == "__main__": diff --git a/tests/python/unittest/test_pass_rewrite_unsafe_select.py b/tests/python/unittest/test_pass_rewrite_unsafe_select.py index 4c42899be62a..dc6ae8286213 100644 --- a/tests/python/unittest/test_pass_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_pass_rewrite_unsafe_select.py @@ -21,18 +21,18 @@ def test_rewrite_Select(): ib = tvm.ir_builder.create() A = ib.allocate("float32", 100, name="A", scope="global") i = tvm.var("i") - y = tvm.expr.Select(i > 1, A[i-1], 1.0) - yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(y)).value + y = tvm.tir.Select(i > 1, A[i-1], 1.0) + yy = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value - z = tvm.expr.Select( - tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) - zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value + z = tvm.tir.Select( + tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) + zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value - a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z) - aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value + a = tvm.tir.Select(tvm.floordiv(i, 4) > 10, y, z) + aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" - assert isinstance(aa, tvm.expr.Select) + assert isinstance(aa, tvm.tir.Select) if __name__ == "__main__": diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 2bee66c0a42e..47a43c7ac2a0 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -40,12 +40,12 @@ def test_flatten_prefetch(): _A= tvm.decl_buffer(A.shape, A.dtype, name = 'A'); i = tvm.size_var('i') j = tvm.size_var('j') - region = [tvm.make.range_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] - stmt = tvm.make.Prefetch(A.op, 0, A.dtype, region) + region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] + stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: _A}, 64) stmt = tvm.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 - assert isinstance(stmt.body, tvm.stmt.For) + assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -89,13 +89,13 @@ def test_flatten_double_buffer(): stmt = tvm.ir_pass.StorageFlatten(stmt, {}, 64) stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2) stmt = tvm.ir_pass.Simplify(stmt) - assert isinstance(stmt.body.body, tvm.stmt.Allocate) + assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.ir_pass.ThreadSync(f, "shared") count = [0] def count_sync(op): - if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 tvm.ir_pass.PostOrderVisit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 6fd6f8b1ce52..d4125d093198 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -39,7 +39,7 @@ def test_storage_share(): # verify inplace folding works num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 1 @@ -48,7 +48,7 @@ def register_mem(scope_tb, max_bits): #Register mem @tvm.register_func("tvm.info.mem.%s" % scope_tb) def mem_info_inp_buffer(): - return tvm.make.node("MemoryInfo", + return tvm.ir.make_node("MemoryInfo", unit_bits= 16, max_simd_bits=32, max_num_bits=max_bits, @@ -74,7 +74,7 @@ def test_alloc_seq(): body = tvm.ir_pass.StorageRewrite(body) num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 tvm.ir_pass.PostOrderVisit(body, verify) @@ -123,7 +123,7 @@ def offset_generater(dtype_list, length): def dtype_test(dtype_list, length): def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == offset body = stmt_generater(dtype_list, length) @@ -166,7 +166,7 @@ def test_inplace_rule(): # verify inplace folding works num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 2 @@ -196,7 +196,7 @@ def test_storage_combine(): stmt = tvm.ir_pass.StorageRewrite(stmt) num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert (n.extents[0].value == 16) tvm.ir_pass.PostOrderVisit(stmt, verify) @@ -231,7 +231,7 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.stmt.AttrStmt): + if isinstance(n, tvm.tir.AttrStmt): if n.attr_key == "storage_scope": alloc_stats[n.value.value] += 1 tvm.ir_pass.PostOrderVisit(stmt, verify) @@ -248,14 +248,14 @@ def test_parallel_alloc(): body = ib.get() body = tvm.ir_pass.StorageRewrite(body) - assert (isinstance(body.body.body, tvm.stmt.Allocate)) + assert (isinstance(body.body.body, tvm.tir.Allocate)) ib = tvm.ir_builder.create() n = tvm.var("n") with ib.for_range(0, n, name="t") as i: ib.scope_attr( tvm.const(1, "int32") , "pragma_scope", - tvm.make.StringImm("parallel_launch_point")) + tvm.tir.StringImm("parallel_launch_point")) with ib.for_range(0, n, name="i", for_type="parallel") as i: with ib.for_range(0, 10, name="j") as j: A = ib.allocate("float32", n, name="A", scope="global") @@ -263,7 +263,7 @@ def test_parallel_alloc(): body = ib.get() body = tvm.ir_pass.StorageRewrite(body) - assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) + assert(isinstance(body.body.body.body.body, tvm.tir.Allocate)) def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): #Test Buffer @@ -295,7 +295,7 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): # verify inplace folding works num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 2 @@ -387,7 +387,7 @@ def test_inplace_rule3(): # verify only have one allocations. # verify inplace folding works def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 70 tvm.ir_pass.PostOrderVisit(stmt, verify) @@ -413,7 +413,7 @@ def test_alloc_seq_type(): body = tvm.ir_pass.StorageRewrite(body) num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 500 tvm.ir_pass.PostOrderVisit(body, verify) @@ -442,7 +442,7 @@ def test_alloc_seq_type2(): body = tvm.ir_pass.StorageRewrite(body) num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 tvm.ir_pass.PostOrderVisit(body, verify) @@ -473,7 +473,7 @@ def test_reuse_small_buffer(): num_alloc = [0] def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 800 tvm.ir_pass.PostOrderVisit(body, verify) @@ -512,7 +512,7 @@ def compute(a, b): s = tvm.create_schedule(c.op) stmt = tvm.lower(s, [a, b, c], simple_mode=True) def verify(n): - if isinstance(n, tvm.stmt.Allocate): + if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 268435456 tvm.ir_pass.PostOrderVisit(stmt, verify) diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py index 55596eea4579..0ed0c993ac55 100644 --- a/tests/python/unittest/test_pass_storage_sync.py +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -40,14 +40,14 @@ def test_storage_sync(): flist = tvm.ir_pass.SplitHostDevice(f) f = flist[1] f = tvm.ir_pass.ThreadSync(f, "shared") - body_list = tvm.make.stmt_list(f.body.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body.body) assert(body_list[1].value.name == "tvm_storage_sync") def test_coproc_sync(): @tvm.register_func("tvm.info.mem.global.cache") def meminfo_cache(): - return tvm.make.node( + return tvm.ir.make_node( "MemoryInfo", unit_bits=8, max_simd_bits=32, @@ -66,7 +66,7 @@ def meminfo_cache(): stmt = ib.get() stmt = tvm.ir_pass.CoProcSync(stmt) body = stmt.body.body.body - blist = tvm.make.stmt_list(body) + blist = tvm.tir.stmt_list(body) assert(blist[1].value.name == "cop.coproc_read_barrier") assert(blist[1].value.args[3].value == 80) assert(blist[-2].value.name == "cop.coproc_sync") @@ -119,9 +119,9 @@ def __check_list(tvm_array, py_list): stmt = ib.get() stmt = tvm.ir_pass.CoProcSync(stmt) - slist = tvm.make.stmt_list(stmt[0].body.body) + slist = tvm.tir.stmt_list(stmt[0].body.body) push_st = slist[2] - slist = tvm.make.stmt_list(slist[-1]) + slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] assert(push_st.value.name == "cop.coproc_dep_push") diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index e5ef9d0aa2f4..c6b536bf970e 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -30,26 +30,26 @@ def test_unroll_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() - assert isinstance(stmt, tvm.stmt.For) + assert isinstance(stmt, tvm.tir.For) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, True) - assert not isinstance(ret, tvm.stmt.For) + assert not isinstance(ret, tvm.tir.For) ret = tvm.ir_pass.UnrollLoop(stmt, 15, 8, 0, True) - assert isinstance(ret, tvm.stmt.For) + assert isinstance(ret, tvm.tir.For) ret = tvm.ir_pass.UnrollLoop(stmt, 16, 8, 0, False) - assert isinstance(ret, tvm.stmt.For) - assert ret.for_type == tvm.stmt.For.Unrolled + assert isinstance(ret, tvm.tir.For) + assert ret.for_type == tvm.tir.For.Unrolled ib = tvm.ir_builder.create() ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16) ib.emit(stmt) wrapped = ib.get() - wrapped = tvm.stmt.SeqStmt([wrapped, stmt]) - assert isinstance(ret, tvm.stmt.For) + wrapped = tvm.tir.SeqStmt([wrapped, stmt]) + assert isinstance(ret, tvm.tir.For) ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) - assert isinstance(ret[0], tvm.stmt.For) - assert ret[0].for_type == tvm.stmt.For.Unrolled - assert isinstance(ret[1], tvm.stmt.For) - assert ret[1].for_type != tvm.stmt.For.Unrolled + assert isinstance(ret[0], tvm.tir.For) + assert ret[0].for_type == tvm.tir.For.Unrolled + assert isinstance(ret[1], tvm.tir.For) + assert ret[1].for_type != tvm.tir.For.Unrolled def test_unroll_fake_loop(): ib = tvm.ir_builder.create() @@ -65,7 +65,7 @@ def test_unroll_fake_loop(): stmt = ib.get() ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) - assert isinstance(ret[0], tvm.stmt.Store) + assert isinstance(ret[0], tvm.tir.Store) def test_unroll_single_count_loops(): n = tvm.size_var('n') diff --git a/tests/python/unittest/test_pass_vectorize.py b/tests/python/unittest/test_pass_vectorize.py index fca22a1eca30..d1cd2d46074a 100644 --- a/tests/python/unittest/test_pass_vectorize.py +++ b/tests/python/unittest/test_pass_vectorize.py @@ -26,12 +26,12 @@ def test_vectorize_loop(): A[j] = tvm.const(1, A.dtype) stmt = ib.get() - assert isinstance(stmt.body, tvm.stmt.For) + assert isinstance(stmt.body, tvm.tir.For) stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.For) - assert not isinstance(stmt.body, tvm.stmt.For) - assert isinstance(stmt.body.index, tvm.expr.Ramp) - assert isinstance(stmt.body.value, tvm.expr.Broadcast) + assert isinstance(stmt, tvm.tir.For) + assert not isinstance(stmt.body, tvm.tir.For) + assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert isinstance(stmt.body.value, tvm.tir.Broadcast) def test_vectorize_vector(): dtype = 'int64' @@ -42,12 +42,12 @@ def test_vectorize_vector(): with ib.for_range(0, 4, for_type="vectorize") as j: A[j] = tvm.const(1, A.dtype) stmt = ib.get() - assert isinstance(stmt.body, tvm.stmt.For) + assert isinstance(stmt.body, tvm.tir.For) stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.For) - assert not isinstance(stmt.body, tvm.stmt.For) - assert isinstance(stmt.body.index, tvm.expr.Ramp) - assert isinstance(stmt.body.value, tvm.expr.Broadcast) + assert isinstance(stmt, tvm.tir.For) + assert not isinstance(stmt.body, tvm.tir.For) + assert isinstance(stmt.body.index, tvm.tir.Ramp) + assert isinstance(stmt.body.value, tvm.tir.Broadcast) def test_vectorize_with_if(): @@ -63,11 +63,11 @@ def test_vectorize_with_if(): A[i] = 2.0 stmt = ib.get() stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.IfThenElse) - assert isinstance(stmt.then_case.index, tvm.expr.Ramp) - assert isinstance(stmt.then_case.value, tvm.expr.Add) + assert isinstance(stmt, tvm.tir.IfThenElse) + assert isinstance(stmt.then_case.index, tvm.tir.Ramp) + assert isinstance(stmt.then_case.value, tvm.tir.Add) assert stmt.then_case.value.dtype == "float32x4" - assert isinstance(stmt.else_case, tvm.stmt.For) + assert isinstance(stmt.else_case, tvm.tir.For) def test_vectorize_with_le_cond(): n = tvm.var('n') @@ -78,7 +78,7 @@ def test_vectorize_with_le_cond(): A[i] = A[i] + 1 stmt = ib.get() stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.For) + assert isinstance(stmt, tvm.tir.For) def test_vectorize_with_ge_cond(): n = tvm.var('n') @@ -89,7 +89,7 @@ def test_vectorize_with_ge_cond(): A[i] = A[i] + 1 stmt = ib.get() stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.For) + assert isinstance(stmt, tvm.tir.For) def test_vectorize_if_then_else(): n = tvm.var('n') @@ -102,7 +102,7 @@ def test_vectorize_if_then_else(): A[i] + 1, A[i]) stmt = ib.get() stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert isinstance(stmt, tvm.stmt.For) + assert isinstance(stmt, tvm.tir.For) ib = tvm.ir_builder.create() @@ -113,10 +113,10 @@ def test_vectorize_if_then_else(): k > 0, A[k * 4 + i], 0) stmt = ib.get() - assert isinstance(stmt.body, tvm.stmt.For) + assert isinstance(stmt.body, tvm.tir.For) stmt = tvm.ir_pass.VectorizeLoop(stmt) - assert not isinstance(stmt.body, tvm.stmt.For) - assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast) + assert not isinstance(stmt.body, tvm.tir.For) + assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index b1a784bc48b6..1cbc157a154c 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -50,10 +50,10 @@ def save_object(names): Ab = tvm.decl_buffer((n, ), dtype) i = tvm.var('i') # for i in 0 to n-1: - stmt = tvm.make.For( + stmt = tvm.tir.For( i, 0, n - 1, 0, 0, - tvm.make.Store(Ab.data, - tvm.make.Load(dtype, Ab.data, i) + 1, + tvm.tir.Store(Ab.data, + tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 1030a9995aec..2fc84bb43b16 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -79,8 +79,8 @@ def test_schedule_scan(): def test_inline_multi_reduce(): def argmax_comp(x, y): - idx = tvm.expr.Select((x[1] >= y[1]), x[0], y[0]) - val = tvm.expr.Select((x[1] >= y[1]), x[1], y[1]) + idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + val = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return idx, val def argmax_init(idx_typ, val_typ): return tvm.const(-1, idx_typ), tvm.min_value(val_typ) @@ -142,7 +142,7 @@ def test_inline_mixed(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) def check(x): - if isinstance(x, tvm.expr.Call): + if isinstance(x, tvm.tir.Call): assert x.func != A2 tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check) @@ -426,7 +426,7 @@ def test_loop_dep_reduce_cache_write(): X = tvm.placeholder(shape=(10,), name="x") def f(n): rv = tvm.reduce_axis((0, n)) - init = lambda dtype: tvm.expr.Select(n > 1, tvm.const(0, dtype), n.astype(dtype)) + init = lambda dtype: tvm.tir.Select(n > 1, tvm.const(0, dtype), n.astype(dtype)) sum = tvm.comm_reducer(lambda x, y: tvm.max(x + y, n.astype('float32')), init, name='sum') return sum(X[rv], axis=rv) Y = tvm.compute(X.shape, f, name="y") diff --git a/tests/python/unittest/test_schedule_tensorize.py b/tests/python/unittest/test_schedule_tensorize.py index 59adf0cc7e99..ac60c2d34ebd 100644 --- a/tests/python/unittest/test_schedule_tensorize.py +++ b/tests/python/unittest/test_schedule_tensorize.py @@ -321,7 +321,7 @@ def intrin_func(ins, outs): stmt = tvm.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body.body, tvm.stmt.For) + assert isinstance(stmt.body.body.body, tvm.tir.For) assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 84c9da37ff7a..c3e1a102471e 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -429,7 +429,9 @@ def cast(x, dtype): if isinstance(x, tvm.tensor.Tensor): return tvm.compute( x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) - return tvm.make._cast(dtype, x) + # pylint: disable=import-outside-toplevel + from tvm.tir import _ffi_api + return _ffi_api._cast(dtype, x) def reinterpret(x, dtype): diff --git a/topi/python/topi/vision/rcnn/roi_pool.py b/topi/python/topi/vision/rcnn/roi_pool.py index 702f551e35eb..53ffe35e7e1b 100644 --- a/topi/python/topi/vision/rcnn/roi_pool.py +++ b/topi/python/topi/vision/rcnn/roi_pool.py @@ -85,7 +85,7 @@ def _pool(i, c, ph, pw): min_value = lambda dtype: tvm.if_then_else(non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) # pylint: disable=unnecessary-lambda - _max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') + _max = tvm.comm_reducer(lambda x, y: tvm.max(x, y), min_value, name='max') rh = tvm.reduce_axis((0, hend - hstart), 'rh') rw = tvm.reduce_axis((0, wend - wstart), 'rw') return _max(data[batch_index, c, hstart+rh, wstart+rw], axis=[rh, rw]) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index eb672bbe151a..97c4a1f49a9a 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -86,8 +86,8 @@ loops = [] def find_width8(op): """ Find all the 'For' nodes whose extent can be divided by 8. """ - if isinstance(op, tvm.stmt.For): - if isinstance(op.extent, tvm.expr.IntImm): + if isinstance(op, tvm.tir.For): + if isinstance(op.extent, tvm.tir.IntImm): if op.extent.value % 8 == 0: loops.append(op) @@ -113,8 +113,8 @@ def vectorize8(op): name = op.loop_var.name lo, li = tvm.var(name + '.outer'), tvm.var(name + '.inner') body = tvm.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li}) - body = tvm.make.For(li, 0, 8, tvm.stmt.For.Vectorized, 0, body) - body = tvm.make.For(lo, 0, extent // 8, tvm.stmt.For.Serial, 0, body) + body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) + body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) return body return None diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index f338e7004761..c1af984a09a1 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -96,7 +96,7 @@ # def my_cuda_math_rule(op): """Customized CUDA intrinsic lowering rule""" - assert isinstance(op, tvm.expr.Call) + assert isinstance(op, tvm.tir.Call) if op.dtype == "float32": # call float function return tvm.call_pure_extern("float32", "%sf" % op.name, op.args[0]) @@ -106,7 +106,7 @@ def my_cuda_math_rule(op): else: # cannot do translation, return self. return op -tvm.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True) +tvm.target.register_intrin_rule("cuda", "exp", my_cuda_math_rule, override=True) ###################################################################### # Register the rule to TVM with override option to override existing rule. # Notice the difference between the printed code from previous one: @@ -135,7 +135,7 @@ def my_cuda_mylog_rule(op): return tvm.call_pure_extern("float64", "log", op.args[0]) else: return op -tvm.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) +tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) n = tvm.var("n") A = tvm.placeholder((n,), name='A') diff --git a/tutorials/language/tuple_inputs.py b/tutorials/language/tuple_inputs.py index 8fb8083c3480..0c5c85ca585a 100644 --- a/tutorials/language/tuple_inputs.py +++ b/tutorials/language/tuple_inputs.py @@ -61,8 +61,8 @@ # x and y are the operands of reduction, both of them is a tuple of index # and value. def fcombine(x, y): - lhs = tvm.expr.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.expr.Select((x[1] >= y[1]), x[1], y[1]) + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs # our identity element also need to be a tuple, so `fidentity` accepts diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index df67faaac2bf..f3683626384f 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -68,7 +68,7 @@ def add_debug(stmt): env.dev.command_handle, debug_flag) - return tvm.make.stmt_seq(debug, stmt) + return tvm.tir.stmt_seq(debug, stmt) pass_list = [(0, ir_pass.inject_conv2d_transpose_skip), (1, ir_pass.inject_dma_intrin), (1, ir_pass.inject_skip_copy), diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 83db6121ed55..8d58958410e5 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -62,11 +62,11 @@ class DevContext(object): def __init__(self, env): self.vta_axis = tvm.thread_axis("vta") - self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp") + self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp") ctx = tvm.call_extern("handle", "VTATLSCommandHandle") - self.command_handle = tvm.make.Call( + self.command_handle = tvm.tir.Call( "handle", "tvm_thread_context", [ctx], - tvm.expr.Call.Intrinsic, None, 0) + tvm.tir.Call.Intrinsic, None, 0) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode) @@ -256,29 +256,29 @@ def get_env(): @tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope) def mem_info_inp_buffer(): spec = get_env() - return tvm.make.node("MemoryInfo", - unit_bits=spec.INP_ELEM_BITS, - max_simd_bits=spec.INP_ELEM_BITS, - max_num_bits=spec.INP_BUFF_SIZE * 8, - head_address=None) + return tvm.ir.make_node("MemoryInfo", + unit_bits=spec.INP_ELEM_BITS, + max_simd_bits=spec.INP_ELEM_BITS, + max_num_bits=spec.INP_BUFF_SIZE * 8, + head_address=None) @tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope) def mem_info_wgt_buffer(): spec = get_env() - return tvm.make.node("MemoryInfo", - unit_bits=spec.WGT_ELEM_BITS, - max_simd_bits=spec.WGT_ELEM_BITS, - max_num_bits=spec.WGT_BUFF_SIZE * 8, - head_address=None) + return tvm.ir.make_node("MemoryInfo", + unit_bits=spec.WGT_ELEM_BITS, + max_simd_bits=spec.WGT_ELEM_BITS, + max_num_bits=spec.WGT_BUFF_SIZE * 8, + head_address=None) @tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope) def mem_info_acc_buffer(): spec = get_env() - return tvm.make.node("MemoryInfo", - unit_bits=spec.ACC_ELEM_BITS, - max_simd_bits=spec.ACC_ELEM_BITS, - max_num_bits=spec.ACC_BUFF_SIZE * 8, - head_address=None) + return tvm.ir.make_node("MemoryInfo", + unit_bits=spec.ACC_ELEM_BITS, + max_simd_bits=spec.ACC_ELEM_BITS, + max_num_bits=spec.ACC_BUFF_SIZE * 8, + head_address=None) # TVM related registration @tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync") diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py index 77c7ff2e8315..a43fc75a92d0 100644 --- a/vta/python/vta/intrin.py +++ b/vta/python/vta/intrin.py @@ -98,7 +98,7 @@ def instr(index): 0, 0, 0)) return irb.get() # return a triple of normal-set, reset, update - nop = tvm.make.Evaluate(0) + nop = tvm.tir.Evaluate(0) if mock: return (nop, nop, nop) return (instr(0), instr(1), instr(2)) diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index e42e3a0751dd..8b8a2f06b498 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -59,8 +59,8 @@ def fold_uop_loop(stmt_in): def _fold_outermost_loop(body): stmt = body - while not isinstance(stmt, tvm.stmt.For): - if isinstance(stmt, (tvm.stmt.ProducerConsumer,)): + while not isinstance(stmt, tvm.tir.For): + if isinstance(stmt, (tvm.tir.ProducerConsumer,)): stmt = stmt.body else: return None, body, None @@ -70,7 +70,7 @@ def _fold_outermost_loop(body): fail = [False] def _post_order(op): - assert isinstance(op, tvm.expr.Call) + assert isinstance(op, tvm.tir.Call) base_args = 2 if op.name == "VTAUopPush": args = [] @@ -112,7 +112,7 @@ def _visit(op): def _do_fold(stmt): if (stmt.attr_key == "coproc_uop_scope" and - isinstance(stmt.value, tvm.expr.StringImm) and + isinstance(stmt.value, tvm.tir.StringImm) and stmt.value.value == env.dev.vta_push_uop.value): body = stmt.body begins = [] @@ -133,8 +133,8 @@ def _do_fold(stmt): if body == stmt.body: return stmt ends = list(reversed(ends)) - body = tvm.stmt.stmt_seq(*(begins + [body] + ends)) - return tvm.make.AttrStmt( + body = tvm.tir.stmt_seq(*(begins + [body] + ends)) + return tvm.tir.AttrStmt( stmt.node, stmt.attr_key, stmt.value, body) return None out = tvm.ir_pass.IRTransform( @@ -163,40 +163,40 @@ def cpu_access_rewrite(stmt_in): env = get_env() rw_info = {} def _post_order(op): - if isinstance(op, tvm.stmt.Allocate): + if isinstance(op, tvm.tir.Allocate): buffer_var = op.buffer_var if not buffer_var in rw_info: return None new_var = rw_info[buffer_var] - let_stmt = tvm.make.LetStmt( + let_stmt = tvm.tir.LetStmt( new_var, tvm.call_extern( "handle", "VTABufferCPUPtr", env.dev.command_handle, buffer_var), op.body) - alloc = tvm.make.Allocate( + alloc = tvm.tir.Allocate( buffer_var, op.dtype, op.extents, op.condition, let_stmt) del rw_info[buffer_var] return alloc - if isinstance(op, tvm.expr.Load): + if isinstance(op, tvm.tir.Load): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] - return tvm.make.Load(op.dtype, new_var, op.index) - if isinstance(op, tvm.stmt.Store): + return tvm.tir.Load(op.dtype, new_var, op.index) + if isinstance(op, tvm.tir.Store): buffer_var = op.buffer_var if not buffer_var in rw_info: rw_info[buffer_var] = tvm.var( buffer_var.name + "_ptr", "handle") new_var = rw_info[buffer_var] - return tvm.make.Store(new_var, op.value, op.index) + return tvm.tir.Store(new_var, op.value, op.index) raise RuntimeError("not reached") stmt = tvm.ir_pass.IRTransform( stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) for buffer_var, new_var in rw_info.items(): - stmt = tvm.make.LetStmt( + stmt = tvm.tir.LetStmt( new_var, tvm.call_extern( "handle", "VTABufferCPUPtr", env.dev.command_handle, @@ -222,15 +222,15 @@ def _merge_block(slist, body): for op in slist: if op.body == body: body = op - elif isinstance(op, tvm.stmt.Allocate): - body = tvm.make.Allocate( + elif isinstance(op, tvm.tir.Allocate): + body = tvm.tir.Allocate( op.buffer_var, op.dtype, op.extents, op.condition, body) - elif isinstance(op, tvm.stmt.AttrStmt): - body = tvm.make.AttrStmt( + elif isinstance(op, tvm.tir.AttrStmt): + body = tvm.tir.AttrStmt( op.node, op.attr_key, op.value, body) - elif isinstance(op, tvm.stmt.For): - body = tvm.make.For( + elif isinstance(op, tvm.tir.For): + body = tvm.tir.For( op.loop_var, op.min, op.extent, op.for_type, op.device_api, body) else: @@ -239,24 +239,24 @@ def _merge_block(slist, body): return body def _pre_order(op): - if isinstance(op, tvm.stmt.For): + if isinstance(op, tvm.tir.For): lift_stmt.append([]) - elif isinstance(op, tvm.stmt.AttrStmt): + elif isinstance(op, tvm.tir.AttrStmt): if op.attr_key == "virtual_thread": lift_stmt.append([]) def _post_order(op): - if isinstance(op, tvm.stmt.Allocate): + if isinstance(op, tvm.tir.Allocate): lift_stmt[-1].append(op) return op.body - if isinstance(op, tvm.stmt.AttrStmt): + if isinstance(op, tvm.tir.AttrStmt): if op.attr_key == "storage_scope": lift_stmt[-1].append(op) return op.body if op.attr_key == "virtual_thread": return _merge_block(lift_stmt.pop() + [op], op.body) return op - if isinstance(op, tvm.stmt.For): + if isinstance(op, tvm.tir.For): return _merge_block(lift_stmt.pop() + [op], op.body) raise RuntimeError("not reached") stmt = tvm.ir_pass.IRTransform( @@ -280,7 +280,7 @@ def inject_skip_copy(stmt_in): """ def _do_fold(stmt): if _match_pragma(stmt, "skip_dma_copy"): - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) return None return tvm.ir_pass.IRTransform( stmt_in, _do_fold, None, ["AttrStmt"]) @@ -303,13 +303,13 @@ def inject_coproc_sync(stmt_in): def _do_fold(stmt): if _match_pragma(stmt, "coproc_sync"): success[0] = True - sync = tvm.make.Call( - "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) - return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)]) + sync = tvm.tir.Call( + "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) + return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) if _match_pragma(stmt, "trim_loop"): op = stmt.body - assert isinstance(op, tvm.stmt.For) - return tvm.make.For( + assert isinstance(op, tvm.tir.For) + return tvm.tir.For( op.loop_var, op.min, 2, op.for_type, op.device_api, op.body) return None @@ -640,9 +640,9 @@ def inject_conv2d_transpose_skip(stmt_in): selects = [] def _find_basics(op): - if isinstance(op, tvm.expr.Call): + if isinstance(op, tvm.tir.Call): calls.append(op) - elif isinstance(op, tvm.expr.Select): + elif isinstance(op, tvm.tir.Select): selects.append(op) def _do_fold(op): @@ -665,7 +665,7 @@ def _do_fold(op): args = op.body.body.args res_tensor = op.body.body.func.output(0) tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.make.AttrStmt( + inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) return inner @@ -697,19 +697,19 @@ def _do_fold(op): args = conv_call.args tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.make.AttrStmt( + inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) args = kernel_call.args tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) - inner = tvm.make.AttrStmt( + inner = tvm.tir.AttrStmt( [dwgt, kernel_tensor], 'buffer_bind_scope', tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) args = data_call.args tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN) - inner = tvm.make.AttrStmt( + inner = tvm.tir.AttrStmt( [dinp, pad_data_tensor], 'buffer_bind_scope', tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner) return inner @@ -739,11 +739,11 @@ def _do_fold(stmt): irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(env.dev.QID_COMPUTE)) irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", - tvm.make.StringImm("VTAPushALUOp")) + tvm.tir.StringImm("VTAPushALUOp")) irb.emit(stmt) return irb.get() if _match_pragma(stmt, "skip_alu"): - return tvm.make.Evaluate(0) + return tvm.tir.Evaluate(0) return stmt stmt_out = tvm.ir_pass.IRTransform( @@ -810,7 +810,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): # Get to the innermost loop body loop_body = stmt.body nest_size = 0 - while isinstance(loop_body, tvm.stmt.For): + while isinstance(loop_body, tvm.tir.For): loop_body = loop_body.body nest_size += 1 # Get the src/dst arguments @@ -825,27 +825,27 @@ def _flatten_loop(src_coeff, dst_coeff, extents): extents.append(tmp_body.extent) tmp_body = tmp_body.body # Derive opcode - if isinstance(loop_body.value, tvm.expr.Add): + if isinstance(loop_body.value, tvm.tir.Add): alu_opcode = env.dev.ALU_OPCODE_ADD lhs = loop_body.value.a rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.expr.Sub): + elif isinstance(loop_body.value, tvm.tir.Sub): alu_opcode = env.dev.ALU_OPCODE_SUB lhs = loop_body.value.a rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.expr.Mul): + elif isinstance(loop_body.value, tvm.tir.Mul): alu_opcode = env.dev.ALU_OPCODE_MUL lhs = loop_body.value.a rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.expr.Min): + elif isinstance(loop_body.value, tvm.tir.Min): alu_opcode = env.dev.ALU_OPCODE_MIN lhs = loop_body.value.a rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.expr.Max): + elif isinstance(loop_body.value, tvm.tir.Max): alu_opcode = env.dev.ALU_OPCODE_MAX lhs = loop_body.value.a rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.expr.Call): + elif isinstance(loop_body.value, tvm.tir.Call): if loop_body.value.name == 'shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] @@ -857,7 +857,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): else: raise RuntimeError( "Function call not recognized %s" % (loop_body.value.name)) - elif isinstance(loop_body.value, tvm.expr.Load): + elif isinstance(loop_body.value, tvm.tir.Load): alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value rhs = tvm.const(0, "int32") @@ -871,12 +871,12 @@ def _flatten_loop(src_coeff, dst_coeff, extents): # Check if lhs/rhs is immediate use_imm = False imm_val = None - if isinstance(rhs, tvm.expr.IntImm): + if isinstance(rhs, tvm.tir.IntImm): assert lhs.buffer_var.same_as(dst_var) src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) use_imm = True imm_val = rhs - if isinstance(lhs, tvm.expr.IntImm): + if isinstance(lhs, tvm.tir.IntImm): assert rhs.buffer_var.same_as(dst_var) src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) use_imm = True