diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 27811a963993..2b9f7f9446ba 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -165,9 +165,18 @@ def copyfrom(self, source_array): source_array.shape, shape ) ) - source_array = np.ascontiguousarray( - source_array, dtype="uint16" if dtype == "bfloat16" else dtype + numpy_str_map = DataType.NUMPY2STR + np_dtype_str = ( + numpy_str_map[source_array.dtype] + if source_array.dtype in numpy_str_map + else str(source_array.dtype) ) + if (not source_array.flags["C_CONTIGUOUS"]) or ( + dtype == "bfloat16" or dtype != np_dtype_str + ): + source_array = np.ascontiguousarray( + source_array, dtype="uint16" if dtype == "bfloat16" else dtype + ) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)