Skip to content

Commit

Permalink
Merge pull request #223 from NaderAlAwar/array_api_fixes
Browse files Browse the repository at this point in the history
Fix Array API CI failures
  • Loading branch information
tylerjereddy authored Dec 7, 2023
2 parents 247f3ec + ffec823 commit 2db7c16
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 92 deletions.
30 changes: 20 additions & 10 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
int16, int32, int64,
uint8,
uint16, uint32, uint64,
double, float64,
double, float32, float64,
)
from .data_types import float as pk_float
from .layout import get_default_layout, Layout
from .memory_space import get_default_memory_space, MemorySpace
from .hierarchical import TeamMember
Expand Down Expand Up @@ -320,17 +321,24 @@ def _init_view(
is_cpu: bool = self.space is MemorySpace.HostSpace
kokkos_lib: ModuleType = km.get_kokkos_module(is_cpu)

if self.dtype == pk.float:
self.dtype = DataType.float
elif self.dtype == pk.double:
self.dtype = DataType.double
if self.dtype in {DataType.float, pk_float}:
self.dtype = float32
elif self.dtype in {DataType.double, double}:
self.dtype = float64
if trait is trait.Unmanaged:
if array is not None and array.ndim == 0:
# TODO: we don't really support 0-D under the hood--use
# NumPy for now...
self.array = array
else:
if array.dtype == np.bool_:
array = array.astype(np.uint8)
self.array = kokkos_lib.unmanaged_array(array, dtype=self.dtype.value, space=self.space.value, layout=self.layout.value)
# Store a reference here in case the array goes out of
# scope and gets garbage collected, which would
# invalidate the data. Currently, this happens when
# calling asarray()
self.orig_array = array
else:
if len(self.shape) == 0:
shape = [1]
Expand Down Expand Up @@ -361,9 +369,9 @@ def _get_type(self, dtype: Union[DataType, type]) -> Optional[DataType]:
return dtype

if dtype is int:
return DataType["int32"]
return int32
if dtype is float:
return DataType["double"]
return double

return None

Expand Down Expand Up @@ -652,7 +660,6 @@ def from_numpy(array: np.ndarray, space: Optional[MemorySpace] = None, layout: O
else:
ret_list = list((array.shape))


return View(ret_list, dtype, space=space, trait=Trait.Unmanaged, array=array, layout=layout)

def from_array(array) -> ViewType:
Expand Down Expand Up @@ -684,6 +691,8 @@ def from_array(array) -> ViewType:
ctype = ctypes.c_float
elif np_dtype is np.float64:
ctype = ctypes.c_double
elif np_dtype is np.bool_:
ctype = ctypes.c_uint8
else:
raise RuntimeError(f"ERROR: unsupported numpy datatype {np_dtype}")

Expand Down Expand Up @@ -742,7 +751,7 @@ def array(array, space: Optional[MemorySpace] = None, layout: Optional[Layout] =
"""

# if numpy array, use from_numpy()
if isinstance(array, np.ndarray):
if isinstance(array, np.ndarray) or np.isscalar(array):
return from_numpy(array, space, layout)
# test if the input array can duck-type to a numpy-like array
# and run from_array to preprocess the array to numpy
Expand All @@ -758,7 +767,8 @@ def asarray(obj, /, *, dtype=None, device=None, copy=None):
# TODO: proper implementation/design
# for now, let's cheat and use NumPy asarray() followed
# by pykokkos from_numpy()
if obj in {pk.e, pk.pi, pk.inf, pk.nan}:

if not isinstance(obj, list) and obj in {pk.e, pk.pi, pk.inf, pk.nan}:
if dtype is None:
dtype = pk.float64
view = pk.View([1], dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions pykokkos/lib/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ravel_C_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.View1D[

def ravel(view, order="C"):
if view.rank() == 2:
if str(view.dtype) == "DataType.double":
if view.dtype.__name__ == "float64":
out = pk.View([view.shape[0] * view.shape[1]], pk.double)
if order == "F":
pk.parallel_for(view.shape[1], ravel_F_impl_2d_double, view=view, out=out)
Expand Down Expand Up @@ -59,7 +59,7 @@ def expand_dims_1_impl_2d_double(tid: int, view: pk.View2D[pk.double], out: pk.V


def expand_dims(view, axis=0):
if str(view.dtype) != "DataType.double":
if view.dtype.__name__ == "float64":
raise RuntimeError("expand_dims supports views of type double only")

if view.rank() == 1:
Expand Down
1 change: 1 addition & 0 deletions pykokkos/lib/ufunc_workunits.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def tanh_impl_2d_float(tid: int, view: pk.View2D[pk.float], out: pk.View2D[pk.fl
out[tid][i] = tanh(view[tid][i])


@pk.workunit
def equal_impl_5d_int8(tid: int,
view1: pk.View5D[pk.int8],
view2: pk.View5D[pk.int8],
Expand Down
Loading

0 comments on commit 2db7c16

Please sign in to comment.