Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][PY] Establish tvm.runtime #4818

Merged
merged 2 commits into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,29 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
"""TVM: Open Deep Learning Compiler Stack."""
import multiprocessing
import sys
import traceback

# import ffi related features
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object

# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd

# others
from . import tensor
from . import arith
from . import expr
Expand All @@ -37,7 +47,7 @@
from . import codegen
from . import container
from . import schedule
from . import module

from . import attrs
from . import ir_builder
from . import target
Expand All @@ -47,9 +57,6 @@
from . import error
from . import datatype

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev

from .api import *
from .intrin import *
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
Expand Down Expand Up @@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray


cdef void tvm_callback_finalize(void* fhandle):
Expand Down Expand Up @@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kTVMStr
Expand Down
98 changes: 0 additions & 98 deletions python/tvm/_ffi/module.py

This file was deleted.

10 changes: 5 additions & 5 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]

class TVMType(ctypes.Structure):
class DataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
Expand All @@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
super(DataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)

Expand Down Expand Up @@ -104,8 +104,8 @@ def __init__(self, type_str):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
if self.type_code in TVMType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code]
if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code]
else:
type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code)
Expand Down Expand Up @@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("dtype", DataType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import tvm._ffi

from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType

from . import _api_internal
from . import make as _make
from . import expr as _expr
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object

from ._ffi.object import Object
from . import _api_internal

class IntSet(Object):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
import tvm._ffi

from ._ffi.object import Object
from tvm.runtime import Object
from . import _api_internal


Expand Down
3 changes: 1 addition & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings
import tvm._ffi

from ._ffi.object import Object
from tvm.runtime import Object, ndarray
from . import api
from . import _api_internal
from . import tensor
Expand All @@ -33,7 +33,6 @@
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
from . import make

Expand Down
28 changes: 7 additions & 21 deletions python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi

from tvm import ndarray as _nd
from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from . import _api_internal
from ._ffi.object import Object, getitem_helper


@tvm._ffi.register_object
Expand All @@ -31,23 +31,9 @@ class Array(Object):
to Array during tvm function call.
You may get Array in return values of TVM function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]

if i < -len(self) or i >= len(self):
raise IndexError("Array index out of range. Array size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ArrayGetItem(self, i)
def __getitem__(self, idx):
return getitem_helper(
self, _api_internal._ArrayGetItem, len(self), idx)

def __len__(self):
return _api_internal._ArraySize(self)
Expand Down Expand Up @@ -133,7 +119,7 @@ class ADT(Object):
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

Expand Down Expand Up @@ -164,7 +150,7 @@ def tuple_object(fields=None):
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from tvm.runtime.ndarray import array
from . import debug_result

_DUMP_ROOT_PREFIX = "tvmdbg_"
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import subprocess
import os
import warnings
from tvm.runtime import ndarray as nd

from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType
from ._ffi.runtime_ctypes import DataType
from . import _api_internal


Expand Down Expand Up @@ -131,7 +131,7 @@ def lower(op):
width as the custom type is returned. Otherwise, the type is
unchanged."""
dtype = op.dtype
t = _TVMType(dtype)
t = DataType(dtype)
if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@
assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode

from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
from . import _api_internal
Expand All @@ -52,7 +49,7 @@ def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
DataType(value.dtype).type_code == TypeCode.INT)


class ExprOp(object):
Expand Down
Loading