Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PY][FFI] Refactor runtime.String to subclass str #5426

Merged
merged 1 commit into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ if(MSVC)
endif()
else(MSVC)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand All @@ -64,6 +68,33 @@ def _return_object(x):
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.

Parameters
----------
fconstructor : Function
Constructor function.

args: list of objects
The arguments to the constructor

Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj



class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object

PackedFuncHandle = ctypes.c_void_p
Expand Down Expand Up @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))

if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)

(<ObjectBase>obj).chandle = chandle
return obj


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.

Parameters
----------
fconstructor : Function
Constructor function.

args: list of objects
The arguments to the constructor

Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj


cdef class ObjectBase:
cdef void* chandle

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
Expand Down
83 changes: 23 additions & 60 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
from .object import Object, PyNativeObject
from .object_generic import ObjectTypes
from . import _ffi_api


def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -112,64 +113,26 @@ def tuple_object(fields=None):


@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.

Parameters
----------
string : str
The string used to construct a runtime String object

Returns
-------
ret : String
The created object.
content : str
The content string used to construct the object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string

Parameters
----------
string : str
The provided string

Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
__slots__ = ["__tvm_object__"]

def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val

# pylint: disable=no-self-argument
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
14 changes: 6 additions & 8 deletions python/tvm/runtime/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
if _FFI_MODE == "ctypes":
raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject


def _new_object(cls):
Expand All @@ -41,6 +41,7 @@ def _new_object(cls):

class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self):
return _ffi_node_api.AsRepr(self)

Expand Down Expand Up @@ -78,13 +79,10 @@ def __getstate__(self):
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
self.handle = None
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)

def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm._ffi.runtime_ctypes import ObjectRValueRef

from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module
Expand All @@ -34,7 +34,7 @@ def asobject(self):
raise NotImplementedError()


ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)


def convert_to_object(value):
Expand Down
19 changes: 2 additions & 17 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers.
* \brief Implementations of common containers.
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
Expand Down Expand Up @@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
return String(std::move(str));
});

TVM_REGISTER_GLOBAL("runtime.GetStringSize")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});

TVM_REGISTER_GLOBAL("runtime.GetStdString")
TVM_REGISTER_GLOBAL("runtime.GetFFIString")
.set_body_typed([](String str) {
return std::string(str);
});

TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});

TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});

TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
Expand Down
5 changes: 5 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});

TVM_REGISTER_GLOBAL("testing.echo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0];
});

TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import tvm
import pickle
from tvm import te
from tvm import nd, relay
from tvm.runtime import container as _container
Expand Down Expand Up @@ -56,6 +57,29 @@ def test_tuple_object():
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))


def test_string():
s = tvm.runtime.String("xyz")

assert isinstance(s, tvm.runtime.String)
assert isinstance(s, str)
assert s.startswith("xy")
assert s + "1" == "xyz1"
y = tvm.testing.echo(s)
assert isinstance(y, tvm.runtime.String)
assert s.__tvm_object__.same_as(y.__tvm_object__)
assert s == y

x = tvm.ir.load_json(tvm.ir.save_json(y))
assert isinstance(x, tvm.runtime.String)
assert x == y

# test pickle
z = pickle.loads(pickle.dumps(s))
assert isinstance(z, tvm.runtime.String)
assert s == z


if __name__ == "__main__":
test_string()
test_adt_constructor()
test_tuple_object()