From f6fb0ad4e0f9bdbc5d4c2acac18b55dd228fb14d Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 19 Dec 2019 16:26:33 -0800 Subject: [PATCH] Cythonize the shape property --- python/tvm/_ffi/_ctypes/ndarray.py | 5 +++++ python/tvm/_ffi/_cython/ndarray.pxi | 5 +++++ python/tvm/_ffi/ndarray.py | 5 +---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index 97ac8a823a48c..af59de6eee1d7 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -90,6 +90,11 @@ def _copyto(self, target_nd): check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None)) return target_nd + @property + def shape(self): + """Shape of this array""" + return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) + def to_dlpack(self): """Produce an array from a DLPack Tensor without copying memory diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi index fdb66535c13c5..25d4538074e29 100644 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -68,6 +68,11 @@ cdef class NDArrayBase: def __set__(self, value): self._set_handle(value) + property shape: + """Shape of this array""" + def __get__(self): + return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim)) + def __init__(self, handle, is_view): self._set_handle(handle) self.c_is_view = is_view diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 8520fc3ff035d..56bf4a00080cb 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -157,10 +157,6 @@ def from_dlpack(dltensor): class NDArrayBase(_NDArrayBase): """A simple Device/CPU Array object in runtime.""" - @property - def shape(self): - """Shape of this array""" - return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim)) @property def dtype(self): @@ -240,6 +236,7 @@ def copyfrom(self, source_array): except: raise TypeError('array must be an array_like data,' + 'type %s is not supported' % str(type(source_array))) + t = TVMType(self.dtype) shape, dtype = self.shape, self.dtype if t.lanes > 1: