Skip to content

Commit

Permalink
[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass…
Browse files Browse the repository at this point in the history
… str (apache#5426)

To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint.

We introduce a PyNativeObject class to handle this kind of object sub-classing
and updated the FFI to handle PyNativeObject classes.
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent 9d883b8 commit ffa31d2
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 89 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ else(MSVC)
endif(USE_TF_COMPILE_FLAGS)

if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand All @@ -64,6 +68,33 @@ def _return_object(x):
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj



class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object

PackedFuncHandle = ctypes.c_void_p
Expand Down Expand Up @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))

if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)

(<ObjectBase>obj).chandle = chandle
return obj


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj


cdef class ObjectBase:
cdef void* chandle

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
Expand Down
83 changes: 23 additions & 60 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
from .object import Object, PyNativeObject
from .object_generic import ObjectTypes
from . import _ffi_api


def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -112,64 +113,26 @@ def tuple_object(fields=None):


@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.
Parameters
----------
string : str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
content : str
The content string used to construct the object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
__slots__ = ["__tvm_object__"]

def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val

# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
14 changes: 6 additions & 8 deletions python/tvm/runtime/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
if _FFI_MODE == "ctypes":
raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject


def _new_object(cls):
Expand All @@ -41,6 +41,7 @@ def _new_object(cls):

class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self):
return _ffi_node_api.AsRepr(self)

Expand Down Expand Up @@ -78,13 +79,10 @@ def __getstate__(self):
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
self.handle = None
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)

def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm._ffi.runtime_ctypes import ObjectRValueRef

from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module
Expand All @@ -34,7 +34,7 @@ def asobject(self):
raise NotImplementedError()


ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)


def convert_to_object(value):
Expand Down
19 changes: 2 additions & 17 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers.
* \brief Implementations of common containers.
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
Expand Down Expand Up @@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
return String(std::move(str));
});

TVM_REGISTER_GLOBAL("runtime.GetStringSize")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});

TVM_REGISTER_GLOBAL("runtime.GetStdString")
TVM_REGISTER_GLOBAL("runtime.GetFFIString")
.set_body_typed([](String str) {
return std::string(str);
});

TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});

TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});

TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
Expand Down
5 changes: 5 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});

TVM_REGISTER_GLOBAL("testing.echo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0];
});

TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import tvm
import pickle
from tvm import te
from tvm import nd, relay
from tvm.runtime import container as _container
Expand Down Expand Up @@ -56,6 +57,29 @@ def test_tuple_object():
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))


def test_string():
s = tvm.runtime.String("xyz")

assert isinstance(s, tvm.runtime.String)
assert isinstance(s, str)
assert s.startswith("xy")
assert s + "1" == "xyz1"
y = tvm.testing.echo(s)
assert isinstance(y, tvm.runtime.String)
assert s.__tvm_object__.same_as(y.__tvm_object__)
assert s == y

x = tvm.ir.load_json(tvm.ir.save_json(y))
assert isinstance(x, tvm.runtime.String)
assert x == y

# test pickle
z = pickle.loads(pickle.dumps(s))
assert isinstance(z, tvm.runtime.String)
assert s == z


if __name__ == "__main__":
test_string()
test_adt_constructor()
test_tuple_object()

0 comments on commit ffa31d2

Please sign in to comment.