diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index 9367160b811b..af59de6eee1d 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -85,6 +85,16 @@ def __del__(self): def _tvm_handle(self): return ctypes.cast(self.handle, ctypes.c_void_p).value + def _copyto(self, target_nd): + """Internal function that implements copy to target ndarray.""" + check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None)) + return target_nd + + @property + def shape(self): + """Shape of this array""" + return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) + def to_dlpack(self): """Produce an array from a DLPack Tensor without copying memory diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi index 402c9de24ebc..5682ae619a46 100644 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -68,6 +68,11 @@ cdef class NDArrayBase: def __set__(self, value): self._set_handle(value) + @property + def shape(self): + """Shape of this array""" + return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim)) + def __init__(self, handle, is_view): self._set_handle(handle) self.c_is_view = is_view @@ -76,6 +81,11 @@ cdef class NDArrayBase: if self.c_is_view == 0: CALL(TVMArrayFree(self.chandle)) + def _copyto(self, target_nd): + """Internal function that implements copy to target ndarray.""" + CALL(TVMArrayCopyFromTo(self.chandle, (target_nd).chandle, NULL)) + return target_nd + def to_dlpack(self): """Produce an array from a DLPack Tensor without copying memory diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index da0783e10410..56bf4a00080c 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -157,10 +157,6 @@ def from_dlpack(dltensor): class NDArrayBase(_NDArrayBase): """A simple Device/CPU Array object in runtime.""" - @property - def shape(self): - """Shape of this array""" - return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) @property def dtype(self): @@ -240,6 +236,7 @@ def copyfrom(self, source_array): except: raise TypeError('array must be an array_like data,' + 'type %s is not supported' % str(type(source_array))) + t = TVMType(self.dtype) shape, dtype = self.shape, self.dtype if t.lanes > 1: @@ -294,14 +291,12 @@ def copyto(self, target): target : NDArray The target array to be copied, must have same shape as this array. """ - if isinstance(target, TVMContext): - target = empty(self.shape, self.dtype, target) if isinstance(target, NDArrayBase): - check_call(_LIB.TVMArrayCopyFromTo( - self.handle, target.handle, None)) - else: - raise ValueError("Unsupported target type %s" % str(type(target))) - return target + return self._copyto(target) + elif isinstance(target, TVMContext): + res = empty(self.shape, self.dtype, target) + return self._copyto(res) + raise ValueError("Unsupported target type %s" % str(type(target))) def free_extension_handle(handle, type_code):