From ebb7eecef53fafe8cc425b39ccea6e8f9131b2db Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 23 Apr 2020 22:06:05 -0700 Subject: [PATCH] [PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (#5426) To make runtime.String to work as naturally as possible in the python side, we make it sub-class the python's str object. Note that however, we cannot sub-class Object at the same time due to python's type layout constraint. We introduce a PyNativeObject class to handle this kind of object sub-classing and updated the FFI to handle PyNativeObject classes. --- CMakeLists.txt | 2 +- python/tvm/_ffi/_ctypes/object.py | 31 +++++++ python/tvm/_ffi/_ctypes/packed_func.py | 5 +- python/tvm/_ffi/_cython/object.pxi | 31 +++++++ python/tvm/_ffi/_cython/packed_func.pxi | 3 + python/tvm/runtime/container.py | 83 +++++-------------- python/tvm/runtime/object.py | 14 ++-- python/tvm/runtime/object_generic.py | 4 +- src/runtime/container.cc | 19 +---- src/support/ffi_testing.cc | 5 ++ .../python/unittest/test_runtime_container.py | 24 ++++++ 11 files changed, 132 insertions(+), 89 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce646fed2d82..87818ef26d11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,7 +123,7 @@ else(MSVC) endif(USE_TF_COMPILE_FLAGS) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("Build in Debug mode") + message(STATUS "Build in Debug mode") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index b5dc65fd5e79..3dbb60715703 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -50,6 +50,10 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.handle = handle + return cls.__from_tvm_object__(cls, obj) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -64,6 +68,33 @@ def _return_object(x): _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + # pylint: disable=assigning-non-slot + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj + + + class ObjectBase(object): """Base object for all object types""" __slots__ = ["handle"] diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 11bb65504c61..dc2dc1944f30 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -29,7 +29,7 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .object import ObjectBase, _set_class_object +from .object import ObjectBase, PyNativeObject, _set_class_object from . import object as _object PackedFuncHandle = ctypes.c_void_p @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) type_codes[i] = (TypeCode.NDARRAY_HANDLE if not arg.is_view else TypeCode.DLTENSOR_HANDLE) + elif isinstance(arg, PyNativeObject): + values[i].v_handle = arg.__tvm_object__.handle + type_codes[i] = TypeCode.OBJECT_HANDLE elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index f2b5cc172d45..371cbbb0a4a2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle): object_type = OBJECT_TYPE handle = ctypes_handle(chandle) CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(OBJECT_TYPE): cls = OBJECT_TYPE[tindex] if cls is not None: + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle + return cls.__from_tvm_object__(cls, obj) obj = cls.__new__(cls) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle return obj +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj + + cdef class ObjectBase: cdef void* chandle diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 6977e108bf88..1f68df1885db 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = (kTVMNDArrayHandle if not (arg).c_is_view else kTVMDLTensorHandle) + elif isinstance(arg, PyNativeObject): + value[0].v_handle = ((arg.__tvm_object__)).chandle + tcode[0] = kTVMObjectHandle elif isinstance(arg, _TVM_COMPATS): ptr = arg._tvm_handle value[0].v_handle = (ptr) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index a719dcd4eaf0..392365ca55c7 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,9 +16,10 @@ # under the License. """Runtime container structures.""" import tvm._ffi -from tvm._ffi.base import string_types -from tvm.runtime import Object, ObjectTypes -from tvm.runtime import _ffi_api +from .object import Object, PyNativeObject +from .object_generic import ObjectTypes +from . import _ffi_api + def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. @@ -112,64 +113,26 @@ def tuple_object(fields=None): @tvm._ffi.register_object("runtime.String") -class String(Object): - """The string object. +class String(str, PyNativeObject): + """TVM runtime.String object, represented as a python str. Parameters ---------- - string : str - The string used to construct a runtime String object - - Returns - ------- - ret : String - The created object. + content : str + The content string used to construct the object. """ - def __init__(self, string): - self.__init_handle_by_constructor__(_ffi_api.String, string) - - def __str__(self): - return _ffi_api.GetStdString(self) - - def __len__(self): - return _ffi_api.GetStringSize(self) - - def __hash__(self): - return _ffi_api.StringHash(self) - - def __eq__(self, other): - if isinstance(other, string_types): - return self.__str__() == other - - if not isinstance(other, String): - return False - - return _ffi_api.CompareString(self, other) == 0 - - def __ne__(self, other): - return not self.__eq__(other) - - def __gt__(self, other): - return _ffi_api.CompareString(self, other) > 0 - - def __lt__(self, other): - return _ffi_api.CompareString(self, other) < 0 - - def __getitem__(self, key): - return self.__str__()[key] - - def startswith(self, string): - """Check if the runtime string starts with a given string - - Parameters - ---------- - string : str - The provided string - - Returns - ------- - ret : boolean - Return true if the runtime string starts with the given string, - otherwise, false. - """ - return self.__str__().startswith(string) + __slots__ = ["__tvm_object__"] + + def __new__(cls, content): + """Construct from string content.""" + val = str.__new__(cls, content) + val.__init_tvm_object_by_constructor__(_ffi_api.String, content) + return val + + # pylint: disable=no-self-argument + def __from_tvm_object__(cls, obj): + """Construct from a given tvm object.""" + content = _ffi_api.GetFFIString(obj) + val = str.__new__(cls, content) + val.__tvm_object__ = obj + return val diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index a55eeb0cb3ee..dd1bcde38c95 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -27,11 +27,11 @@ if _FFI_MODE == "ctypes": raise ImportError() from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic - from tvm._ffi._cy3.core import ObjectBase + from tvm._ffi._cy3.core import ObjectBase, PyNativeObject except (RuntimeError, ImportError): # pylint: disable=wrong-import-position,unused-import from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic - from tvm._ffi._ctypes.object import ObjectBase + from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject def _new_object(cls): @@ -41,6 +41,7 @@ def _new_object(cls): class Object(ObjectBase): """Base class for all tvm's runtime objects.""" + __slots__ = [] def __repr__(self): return _ffi_node_api.AsRepr(self) @@ -78,13 +79,10 @@ def __getstate__(self): def __setstate__(self, state): # pylint: disable=assigning-non-slot, assignment-from-no-return handle = state['handle'] + self.handle = None if handle is not None: - json_str = handle - other = _ffi_node_api.LoadJSON(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None + self.__init_handle_by_constructor__( + _ffi_node_api.LoadJSON, handle) def _move(self): """Create an RValue reference to the object and mark the object as moved. diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index ac20b67e8299..cc21450e25c1 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -21,7 +21,7 @@ from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_node_api, _ffi_api -from .object import ObjectBase, _set_class_object_generic +from .object import ObjectBase, PyNativeObject, _set_class_object_generic from .ndarray import NDArrayBase from .packed_func import PackedFuncBase, convert_to_tvm_func from .module import Module @@ -34,7 +34,7 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject) def convert_to_object(value): diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 81dfd3d4e252..614592610584 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -19,7 +19,7 @@ /*! * \file src/runtime/container.cc - * \brief Implementations of common plain old data (POD) containers. + * \brief Implementations of common containers. */ #include #include @@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String") return String(std::move(str)); }); -TVM_REGISTER_GLOBAL("runtime.GetStringSize") -.set_body_typed([](String str) { - return static_cast(str.size()); -}); - -TVM_REGISTER_GLOBAL("runtime.GetStdString") +TVM_REGISTER_GLOBAL("runtime.GetFFIString") .set_body_typed([](String str) { return std::string(str); }); -TVM_REGISTER_GLOBAL("runtime.CompareString") -.set_body_typed([](String lhs, String rhs) { - return lhs.compare(rhs); -}); - -TVM_REGISTER_GLOBAL("runtime.StringHash") -.set_body_typed([](String str) { - return static_cast(std::hash()(str)); -}); - TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 90fcfff0eef3..622e28ef170a 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop") .set_body([](TVMArgs args, TVMRetValue *ret) { }); +TVM_REGISTER_GLOBAL("testing.echo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = args[0]; + }); + TVM_REGISTER_GLOBAL("testing.test_wrap_callback") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc pf = args[0]; diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 84b26be6cbc1..5ecc21e520af 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -17,6 +17,7 @@ import numpy as np import tvm +import pickle from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -56,6 +57,29 @@ def test_tuple_object(): tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) +def test_string(): + s = tvm.runtime.String("xyz") + + assert isinstance(s, tvm.runtime.String) + assert isinstance(s, str) + assert s.startswith("xy") + assert s + "1" == "xyz1" + y = tvm.testing.echo(s) + assert isinstance(y, tvm.runtime.String) + assert s.__tvm_object__.same_as(y.__tvm_object__) + assert s == y + + x = tvm.ir.load_json(tvm.ir.save_json(y)) + assert isinstance(x, tvm.runtime.String) + assert x == y + + # test pickle + z = pickle.loads(pickle.dumps(s)) + assert isinstance(z, tvm.runtime.String) + assert s == z + + if __name__ == "__main__": + test_string() test_adt_constructor() test_tuple_object()