Skip to content

Commit

Permalink
[REFACTOR][PY] Establish tvm.runtime (#4818)
Browse files Browse the repository at this point in the history
* [REFACTOR][PY] Establish tvm.runtime

This PR establishes the tvm.runtime namespace that contains the core runtime data structures.
The top-level API are kept inact for now via re-exporting.

We will followup later to cleanup some of the top-level APIs.

* Fix ndarray name
  • Loading branch information
tqchen authored Feb 5, 2020
1 parent 7d263c3 commit fc7dd6d
Show file tree
Hide file tree
Showing 45 changed files with 722 additions and 726 deletions.
27 changes: 17 additions & 10 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,29 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
"""TVM: Open Deep Learning Compiler Stack."""
import multiprocessing
import sys
import traceback

# import ffi related features
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object

# top-level alias
# tvm.runtime
from .runtime.object import Object
from .runtime.packed_func import PackedFunc as Function
from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .runtime import module
from .runtime import ndarray
# pylint: disable=reimported
from .runtime import ndarray as nd

# others
from . import tensor
from . import arith
from . import expr
Expand All @@ -37,7 +47,7 @@
from . import codegen
from . import container
from . import schedule
from . import module

from . import attrs
from . import ir_builder
from . import target
Expand All @@ -47,9 +57,6 @@
from . import error
from . import datatype

from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev

from .api import *
from .intrin import *
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from ..runtime_ctypes import DataType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
Expand Down Expand Up @@ -132,7 +132,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, TVMContext):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types, py2cerror
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray


cdef void tvm_callback_finalize(void* fhandle):
Expand Down Expand Up @@ -129,7 +129,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, TVMType):
elif isinstance(arg, DataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kTVMStr
Expand Down
98 changes: 0 additions & 98 deletions python/tvm/_ffi/module.py

This file was deleted.

10 changes: 5 additions & 5 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TVMByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]

class TVMType(ctypes.Structure):
class DataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
Expand All @@ -60,7 +60,7 @@ class TVMType(ctypes.Structure):
4 : 'handle'
}
def __init__(self, type_str):
super(TVMType, self).__init__()
super(DataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)

Expand Down Expand Up @@ -104,8 +104,8 @@ def __init__(self, type_str):
def __repr__(self):
if self.bits == 1 and self.lanes == 1:
return "bool"
if self.type_code in TVMType.CODE2STR:
type_name = TVMType.CODE2STR[self.type_code]
if self.type_code in DataType.CODE2STR:
type_name = DataType.CODE2STR[self.type_code]
else:
type_name = "custom[%s]" % \
_api_internal._datatype_get_type_name(self.type_code)
Expand Down Expand Up @@ -263,7 +263,7 @@ class TVMArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("dtype", DataType),
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)]
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import tvm._ffi

from tvm.runtime import convert, const, DataType
from ._ffi.base import string_types, TVMError
from ._ffi.object_generic import convert, const
from ._ffi.registry import register_func, get_global_func, extract_ext_funcs
from ._ffi.runtime_ctypes import TVMType

from . import _api_internal
from . import make as _make
from . import expr as _expr
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object

from ._ffi.object import Object
from . import _api_internal

class IntSet(Object):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
import tvm._ffi

from ._ffi.object import Object
from tvm.runtime import Object
from . import _api_internal


Expand Down
3 changes: 1 addition & 2 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings
import tvm._ffi

from ._ffi.object import Object
from tvm.runtime import Object, ndarray
from . import api
from . import _api_internal
from . import tensor
Expand All @@ -33,7 +33,6 @@
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
from . import make

Expand Down
28 changes: 7 additions & 21 deletions python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"""Container data structures used in TVM DSL."""
import tvm._ffi

from tvm import ndarray as _nd
from tvm.runtime import Object, ObjectTypes
from tvm.runtime.container import getitem_helper
from . import _api_internal
from ._ffi.object import Object, getitem_helper


@tvm._ffi.register_object
Expand All @@ -31,23 +31,9 @@ class Array(Object):
to Array during tvm function call.
You may get Array in return values of TVM function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]

if i < -len(self) or i >= len(self):
raise IndexError("Array index out of range. Array size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ArrayGetItem(self, i)
def __getitem__(self, idx):
return getitem_helper(
self, _api_internal._ArrayGetItem, len(self), idx)

def __len__(self):
return _api_internal._ArraySize(self)
Expand Down Expand Up @@ -133,7 +119,7 @@ class ADT(Object):
"""
def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or " \
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
self.__init_handle_by_constructor__(_ADT, tag, *fields)

Expand Down Expand Up @@ -164,7 +150,7 @@ def tuple_object(fields=None):
"""
fields = fields if fields else []
for f in fields:
assert isinstance(f, (Object, _nd.NDArray)), "Expect object or tvm " \
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
return _Tuple(*fields)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime
from tvm.ndarray import array
from tvm.runtime.ndarray import array
from . import debug_result

_DUMP_ROOT_PREFIX = "tvmdbg_"
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import subprocess
import os
import warnings
from tvm.runtime import ndarray as nd

from . import util
from .. import ndarray as nd
from ..api import register_func
from .._ffi.base import py_str

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType
from ._ffi.runtime_ctypes import DataType
from . import _api_internal


Expand Down Expand Up @@ -131,7 +131,7 @@ def lower(op):
width as the custom type is returned. Otherwise, the type is
unchanged."""
dtype = op.dtype
t = _TVMType(dtype)
t = DataType(dtype)
if get_type_registered(t.type_code):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@
assert(y.a == x)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
import tvm._ffi
from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode

from ._ffi.object import Object
from ._ffi.object_generic import ObjectGeneric
from ._ffi.runtime_ctypes import TVMType, TypeCode
from . import make as _make
from . import generic as _generic
from . import _api_internal
Expand All @@ -52,7 +49,7 @@ def _dtype_is_int(value):
if isinstance(value, int):
return True
return (isinstance(value, ExprOp) and
TVMType(value.dtype).type_code == TypeCode.INT)
DataType(value.dtype).type_code == TypeCode.INT)


class ExprOp(object):
Expand Down
Loading

0 comments on commit fc7dd6d

Please sign in to comment.