Skip to content

Commit

Permalink
[PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiang2713 authored Sep 24, 2021
1 parent 7c6a334 commit d3d7e8e
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,18 @@ def copyfrom(self, source_array):
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
numpy_str_map = DataType.NUMPY2STR
np_dtype_str = (
numpy_str_map[source_array.dtype]
if source_array.dtype in numpy_str_map
else str(source_array.dtype)
)
if (not source_array.flags["C_CONTIGUOUS"]) or (
dtype == "bfloat16" or dtype != np_dtype_str
):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down

0 comments on commit d3d7e8e

Please sign in to comment.