diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ebf7bf4bd9f..fc7c67c83a48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ if(MSVC) endif() else(MSVC) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("Build in Debug mode") + message(STATUS "Build in Debug mode") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index b5dc65fd5e79..3dbb60715703 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -50,6 +50,10 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.handle = handle + return cls.__from_tvm_object__(cls, obj) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -64,6 +68,33 @@ def _return_object(x): _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + # pylint: disable=assigning-non-slot + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj + + + class ObjectBase(object): """Base object for all object types""" __slots__ = ["handle"] diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 11bb65504c61..dc2dc1944f30 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -29,7 +29,7 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .object import ObjectBase, _set_class_object +from .object import ObjectBase, PyNativeObject, _set_class_object from . import object as _object PackedFuncHandle = ctypes.c_void_p @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) type_codes[i] = (TypeCode.NDARRAY_HANDLE if not arg.is_view else TypeCode.DLTENSOR_HANDLE) + elif isinstance(arg, PyNativeObject): + values[i].v_handle = arg.__tvm_object__.handle + type_codes[i] = TypeCode.OBJECT_HANDLE elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index f2b5cc172d45..371cbbb0a4a2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle): object_type = OBJECT_TYPE handle = ctypes_handle(chandle) CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(OBJECT_TYPE): cls = OBJECT_TYPE[tindex] if cls is not None: + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle + return cls.__from_tvm_object__(cls, obj) obj = cls.__new__(cls) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle return obj +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj + + cdef class ObjectBase: cdef void* chandle diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 6977e108bf88..1f68df1885db 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = (kTVMNDArrayHandle if not (arg).c_is_view else kTVMDLTensorHandle) + elif isinstance(arg, PyNativeObject): + value[0].v_handle = ((arg.__tvm_object__)).chandle + tcode[0] = kTVMObjectHandle elif isinstance(arg, _TVM_COMPATS): ptr = arg._tvm_handle value[0].v_handle = (ptr) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 8c335c9eb017..a955ba2858ca 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -32,6 +32,7 @@ from .common import ExprTable from .common import infer_shape as _infer_shape + __all__ = ['from_tflite'] class TensorWrapper(object): @@ -105,6 +106,7 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, + 'RANGE': self.convert_range, 'REDUCE_ANY': self._convert_reduce_any, 'REDUCE_MAX': self._convert_reduce_max, 'REDUCE_MIN': self._convert_reduce_min, @@ -115,6 +117,7 @@ def __init__(self, model, subgraph, exp_tab): 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SHAPE': self.convert_shape, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -552,6 +555,63 @@ def convert_tanh(self, op): return out + def convert_range(self, op): + """Convert TFLite Range""" + try: + from tflite.Operator import Operator + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized RANGE operator is not supported yet.') + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + + start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] + expressions = [] + + for t in [start, limit, delta]: + if self.has_expr(t.tensor_idx): + expressions.append(self.get_expr(t.tensor_idx)) + else: + tensor_type = self.get_tensor_type_str(t.tensor.Type()) + tensor_value = self.get_tensor_value(t) + expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type)) + + #out type inference + if delta.tensor.Type() == TensorType.FLOAT32: + out_type = self.get_tensor_type_str(delta.tensor.Type()) + else: + out_type = self.get_tensor_type_str(start.tensor.Type()) + + #put type here form op + out = _op.arange(expressions[0], expressions[1], expressions[2], out_type) + + return out + + def convert_shape(self, op): + """Convert TFLite Shape""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized SHAPE operator is not supported yet.') + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx)) + + return out + def convert_relu(self, op): """Convert TFLite ReLU""" input_tensors = self.get_input_tensors(op) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index a719dcd4eaf0..392365ca55c7 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,9 +16,10 @@ # under the License. """Runtime container structures.""" import tvm._ffi -from tvm._ffi.base import string_types -from tvm.runtime import Object, ObjectTypes -from tvm.runtime import _ffi_api +from .object import Object, PyNativeObject +from .object_generic import ObjectTypes +from . import _ffi_api + def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. @@ -112,64 +113,26 @@ def tuple_object(fields=None): @tvm._ffi.register_object("runtime.String") -class String(Object): - """The string object. +class String(str, PyNativeObject): + """TVM runtime.String object, represented as a python str. Parameters ---------- - string : str - The string used to construct a runtime String object - - Returns - ------- - ret : String - The created object. + content : str + The content string used to construct the object. """ - def __init__(self, string): - self.__init_handle_by_constructor__(_ffi_api.String, string) - - def __str__(self): - return _ffi_api.GetStdString(self) - - def __len__(self): - return _ffi_api.GetStringSize(self) - - def __hash__(self): - return _ffi_api.StringHash(self) - - def __eq__(self, other): - if isinstance(other, string_types): - return self.__str__() == other - - if not isinstance(other, String): - return False - - return _ffi_api.CompareString(self, other) == 0 - - def __ne__(self, other): - return not self.__eq__(other) - - def __gt__(self, other): - return _ffi_api.CompareString(self, other) > 0 - - def __lt__(self, other): - return _ffi_api.CompareString(self, other) < 0 - - def __getitem__(self, key): - return self.__str__()[key] - - def startswith(self, string): - """Check if the runtime string starts with a given string - - Parameters - ---------- - string : str - The provided string - - Returns - ------- - ret : boolean - Return true if the runtime string starts with the given string, - otherwise, false. - """ - return self.__str__().startswith(string) + __slots__ = ["__tvm_object__"] + + def __new__(cls, content): + """Construct from string content.""" + val = str.__new__(cls, content) + val.__init_tvm_object_by_constructor__(_ffi_api.String, content) + return val + + # pylint: disable=no-self-argument + def __from_tvm_object__(cls, obj): + """Construct from a given tvm object.""" + content = _ffi_api.GetFFIString(obj) + val = str.__new__(cls, content) + val.__tvm_object__ = obj + return val diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index a55eeb0cb3ee..dd1bcde38c95 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -27,11 +27,11 @@ if _FFI_MODE == "ctypes": raise ImportError() from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic - from tvm._ffi._cy3.core import ObjectBase + from tvm._ffi._cy3.core import ObjectBase, PyNativeObject except (RuntimeError, ImportError): # pylint: disable=wrong-import-position,unused-import from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic - from tvm._ffi._ctypes.object import ObjectBase + from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject def _new_object(cls): @@ -41,6 +41,7 @@ def _new_object(cls): class Object(ObjectBase): """Base class for all tvm's runtime objects.""" + __slots__ = [] def __repr__(self): return _ffi_node_api.AsRepr(self) @@ -78,13 +79,10 @@ def __getstate__(self): def __setstate__(self, state): # pylint: disable=assigning-non-slot, assignment-from-no-return handle = state['handle'] + self.handle = None if handle is not None: - json_str = handle - other = _ffi_node_api.LoadJSON(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None + self.__init_handle_by_constructor__( + _ffi_node_api.LoadJSON, handle) def _move(self): """Create an RValue reference to the object and mark the object as moved. diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index ac20b67e8299..cc21450e25c1 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -21,7 +21,7 @@ from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_node_api, _ffi_api -from .object import ObjectBase, _set_class_object_generic +from .object import ObjectBase, PyNativeObject, _set_class_object_generic from .ndarray import NDArrayBase from .packed_func import PackedFuncBase, convert_to_tvm_func from .module import Module @@ -34,7 +34,7 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject) def convert_to_object(value): diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 81dfd3d4e252..614592610584 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -19,7 +19,7 @@ /*! * \file src/runtime/container.cc - * \brief Implementations of common plain old data (POD) containers. + * \brief Implementations of common containers. */ #include #include @@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String") return String(std::move(str)); }); -TVM_REGISTER_GLOBAL("runtime.GetStringSize") -.set_body_typed([](String str) { - return static_cast(str.size()); -}); - -TVM_REGISTER_GLOBAL("runtime.GetStdString") +TVM_REGISTER_GLOBAL("runtime.GetFFIString") .set_body_typed([](String str) { return std::string(str); }); -TVM_REGISTER_GLOBAL("runtime.CompareString") -.set_body_typed([](String lhs, String rhs) { - return lhs.compare(rhs); -}); - -TVM_REGISTER_GLOBAL("runtime.StringHash") -.set_body_typed([](String str) { - return static_cast(std::hash()(str)); -}); - TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 90fcfff0eef3..622e28ef170a 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop") .set_body([](TVMArgs args, TVMRetValue *ret) { }); +TVM_REGISTER_GLOBAL("testing.echo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = args[0]; + }); + TVM_REGISTER_GLOBAL("testing.test_wrap_callback") .set_body([](TVMArgs args, TVMRetValue *ret) { PackedFunc pf = args[0]; diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b7ba9c2908a7..f45dc7c45777 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,8 +73,34 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def vmobj_to_list(o): + if isinstance(o, tvm.nd.NDArray): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % + o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', - out_names=None): + out_names=None, mode='graph_runtime'): """ Generic function to compile on relay and execute on tvm """ try: import tflite.Model @@ -96,27 +122,44 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) - with relay.build_config(opt_level=3): - graph, lib, params = relay.build(mod, target, params=params) - ctx = tvm.context(target, 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - # set inputs - for i, e in enumerate(input_node): - m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) - - m.set_input(**params) - # execute - m.run() - # get outputs - assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( - out_names, num_output) - tvm_output_list = [] - for i in range(0, num_output): - tvm_output = m.get_output(i) - tvm_output_list.append(tvm_output.asnumpy()) - return tvm_output_list + if mode in ['debug', 'vm']: + ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm") + inputs = [] + for param in mod['main'].params: + found = False + for i, n in enumerate(input_node): + if n == param.name_hint: + found = True + inputs.append(tvm.nd.array(input_data[i])) + break + # Interpreter doesn't bind constants, so still need to find in params + if not found: + inputs.append(tvm.nd.array(params[param.name_hint])) + result = ex.evaluate()(*inputs) + return vmobj_to_list(result) + else: + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, target, params=params) + + ctx = tvm.context(target, 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( + out_names, num_output) + tvm_output_list = [] + for i in range(0, num_output): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list def run_tflite_graph(tflite_model_buf, input_data): @@ -147,7 +190,7 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm(in_data, in_name, input_tensors, output_tensors, init_global_variables=False, - out_names=None, quantized=False, input_range=None): + out_names=None, quantized=False, input_range=None, mode='graph_runtime'): """Generic function to generate and compare TFLite and TVM output""" in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) @@ -189,7 +232,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, continue tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device, - num_output=len(out_names), out_names=out_names) + num_output=len(out_names), out_names=out_names,mode=mode) # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values @@ -651,6 +694,82 @@ def test_all_resize(): _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False) +####################################################################### +# Range +# ----- +def _test_range(start, limit, delta): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + start_scalar, limit_scalar, delta_scalar = \ + tf.placeholder(dtype=start.dtype, shape=(), name="start"), \ + tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \ + tf.placeholder(dtype=delta.dtype, shape=(), name="delta") + + out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range") + + compare_tflite_with_tvm( + [start, limit, delta], + ["start", "limit", "delta"], + [start_scalar, limit_scalar, delta_scalar], + [out], + mode="vm", + quantized=False + ) + +def _test_range_default(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + + inputs = [ + tf.placeholder(dtype=tf.int32, shape=(), name="p1"), + tf.placeholder(dtype=tf.int32, shape=(), name="p2") + ] + leaves = [ + tf.range(start = inputs[0], limit = inputs[1]), #use default delta + tf.range(start = inputs[1]) #use start as limit with 0 as the first item in the range + ] + + compare_tflite_with_tvm( + [np.int32(1), np.int32(18)], + ["p1", "p2"], + inputs, + leaves, + mode="vm", + quantized=False + ) + +def test_forward_range(): + _test_range(np.int32(1), np.int32(18), np.int32(3)) + _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float + _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float + _test_range_default() + +####################################################################### +# Shape +# ----- +def test_forward_shape(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + data = np.array([1, 18, 3], dtype=np.int32) + start = tf.placeholder(dtype=tf.int32, shape=[], name="start") + limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit") + delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta") + r = tf.range(start, limit, delta, tf.int32, name="range") + out = tf.shape(r, out_type=tf.dtypes.int32) + compare_tflite_with_tvm( + [x for x in np.nditer(data)], + ["start", "limit", "delta"], + [start, limit, delta], + [out], + mode="vm", + quantized=False + ) ####################################################################### # Concatenation # ------------- @@ -1845,6 +1964,9 @@ def test_forward_mediapipe_hand_landmark(): # Tile test_forward_tile() + # Query + test_forward_shape() + # Transforms test_forward_concatenation() test_forward_pad() @@ -1852,6 +1974,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_unpack() test_forward_reshape() test_all_resize() + test_forward_range() test_forward_squeeze() test_forward_slice() test_forward_topk() diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 84b26be6cbc1..5ecc21e520af 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -17,6 +17,7 @@ import numpy as np import tvm +import pickle from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -56,6 +57,29 @@ def test_tuple_object(): tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) +def test_string(): + s = tvm.runtime.String("xyz") + + assert isinstance(s, tvm.runtime.String) + assert isinstance(s, str) + assert s.startswith("xy") + assert s + "1" == "xyz1" + y = tvm.testing.echo(s) + assert isinstance(y, tvm.runtime.String) + assert s.__tvm_object__.same_as(y.__tvm_object__) + assert s == y + + x = tvm.ir.load_json(tvm.ir.save_json(y)) + assert isinstance(x, tvm.runtime.String) + assert x == y + + # test pickle + z = pickle.loads(pickle.dumps(s)) + assert isinstance(z, tvm.runtime.String) + assert s == z + + if __name__ == "__main__": + test_string() test_adt_constructor() test_tuple_object()