Skip to content

Commit

Permalink
[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass…
Browse files Browse the repository at this point in the history
… str (apache#5426)

To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint.

We introduce a PyNativeObject class to handle this kind of object sub-classing
and updated the FFI to handle PyNativeObject classes.

[Relay][Frontend][TFLite] Add parser support for shape and range

Signed-off-by: Dhruva Ray <[email protected]>
  • Loading branch information
tqchen authored and dhruvaray committed Apr 24, 2020
1 parent 6c77195 commit 6de8c1f
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 112 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ if(MSVC)
endif()
else(MSVC)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand All @@ -64,6 +68,33 @@ def _return_object(x):
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj



class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object

PackedFuncHandle = ctypes.c_void_p
Expand Down Expand Up @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))

if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)

(<ObjectBase>obj).chandle = chandle
return obj


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj


cdef class ObjectBase:
cdef void* chandle

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .common import ExprTable
from .common import infer_shape as _infer_shape


__all__ = ['from_tflite']

class TensorWrapper(object):
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'RANGE': self.convert_range,
'REDUCE_ANY': self._convert_reduce_any,
'REDUCE_MAX': self._convert_reduce_max,
'REDUCE_MIN': self._convert_reduce_min,
Expand All @@ -115,6 +117,7 @@ def __init__(self, model, subgraph, exp_tab):
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand Down Expand Up @@ -552,6 +555,63 @@ def convert_tanh(self, op):

return out

def convert_range(self, op):
"""Convert TFLite Range"""
try:
from tflite.Operator import Operator
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized RANGE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
expressions = []

for t in [start, limit, delta]:
if self.has_expr(t.tensor_idx):
expressions.append(self.get_expr(t.tensor_idx))
else:
tensor_type = self.get_tensor_type_str(t.tensor.Type())
tensor_value = self.get_tensor_value(t)
expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))

#out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

#put type here form op
out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)

return out

def convert_shape(self, op):
"""Convert TFLite Shape"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SHAPE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx))

return out

def convert_relu(self, op):
"""Convert TFLite ReLU"""
input_tensors = self.get_input_tensors(op)
Expand Down
83 changes: 23 additions & 60 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
from .object import Object, PyNativeObject
from .object_generic import ObjectTypes
from . import _ffi_api


def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -112,64 +113,26 @@ def tuple_object(fields=None):


@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.
Parameters
----------
string : str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
content : str
The content string used to construct the object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
__slots__ = ["__tvm_object__"]

def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val

# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
14 changes: 6 additions & 8 deletions python/tvm/runtime/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
if _FFI_MODE == "ctypes":
raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject


def _new_object(cls):
Expand All @@ -41,6 +41,7 @@ def _new_object(cls):

class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self):
return _ffi_node_api.AsRepr(self)

Expand Down Expand Up @@ -78,13 +79,10 @@ def __getstate__(self):
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
self.handle = None
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)

def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm._ffi.runtime_ctypes import ObjectRValueRef

from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module
Expand All @@ -34,7 +34,7 @@ def asobject(self):
raise NotImplementedError()


ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)


def convert_to_object(value):
Expand Down
Loading

0 comments on commit 6de8c1f

Please sign in to comment.