From 5beec9cb14a560a04fd4af7ae69583f4cfa3545a Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Thu, 1 Sep 2022 09:42:14 -0600 Subject: [PATCH] ENH: support int8 (#43) * ENH: support int8 * add support for 8-bit signed integers with values that should exist on the interval `[-128, 127]` for conformance with the array API: https://data-apis.org/array-api/latest/API_specification/data_types.html https://github.com/kokkos/pykokkos/pull/57 * add a regression test for kokkos<->NumPy type equivalencies that includes the new `int8` type; I was hoping to do some simple arithmetic testing as well, but not sure how practical that is within `pykokkos-base` proper (maybe defer that kind of thing to `pykokkos`?) * MAINT: PR 43 revisions * format with `black` because of a failing CI check * use `signed_char` as the shorthand for `int8`, for similarity with: https://numpy.org/doc/stable/user/basics.types.html --- include/fwd.hpp | 3 ++- include/traits.hpp | 1 + kokkos/__init__.py.in | 3 ++- kokkos/test/test_types.py | 34 ++++++++++++++++++++++++++++++++++ kokkos/utility.py | 2 ++ src/variants/CMakeLists.txt | 2 +- 6 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 kokkos/test/test_types.py diff --git a/include/fwd.hpp b/include/fwd.hpp index 71199e7..5e91f5c 100644 --- a/include/fwd.hpp +++ b/include/fwd.hpp @@ -123,7 +123,8 @@ struct Device; //----------------------------------------------------------------------------// enum KokkosViewDataType { - Int16 = 0, + Int8 = 0, + Int16, Int32, Int64, Uint16, diff --git a/include/traits.hpp b/include/traits.hpp index f7365e6..887884b 100644 --- a/include/traits.hpp +++ b/include/traits.hpp @@ -75,6 +75,7 @@ VIEW_DATA_DIMS(8, T ********) // the first string identifier is the "canonical name" (i.e. what gets encoded) // and the remaining string entries are used to generate aliases // +VIEW_DATA_TYPE(int8_t, Int8, "int8", "signed_char") VIEW_DATA_TYPE(int16_t, Int16, "int16", "short") VIEW_DATA_TYPE(int32_t, Int32, "int32", "int") VIEW_DATA_TYPE(int64_t, Int64, "int64", "long") diff --git a/kokkos/__init__.py.in b/kokkos/__init__.py.in index 2ed710e..aa8675b 100644 --- a/kokkos/__init__.py.in +++ b/kokkos/__init__.py.in @@ -147,7 +147,8 @@ try: "read_dtype", "initialize", # bindings "finalize", - "int16", # data types + "int8", # data types + "int16", "int32", "int64", "uint16", diff --git a/kokkos/test/test_types.py b/kokkos/test/test_types.py new file mode 100644 index 0000000..1ba6571 --- /dev/null +++ b/kokkos/test/test_types.py @@ -0,0 +1,34 @@ +import kokkos + +import numpy as np +import pytest + + +@pytest.mark.parametrize( + "type_val, expected_np_type", + [ + (kokkos.int8, np.int8), + (kokkos.int16, np.int16), + (kokkos.int32, np.int32), + (kokkos.int64, np.int64), + (kokkos.uint16, np.uint16), + (kokkos.uint32, np.uint32), + (kokkos.uint64, np.uint64), + (kokkos.float32, np.float32), + (kokkos.float64, np.float64), + (kokkos.float, np.float32), + (kokkos.double, np.float64), + (kokkos.short, np.int16), + (kokkos.int, np.int32), + (kokkos.long, np.int64), + ], +) +def test_basic_type_equiv(type_val, expected_np_type): + # test some view to NumPy array type equivalencies + kokkos.initialize() + view = kokkos.array([2], dtype=type_val, space=kokkos.DefaultHostMemorySpace) + + # NOTE: copy can still happen (attempt no copy, + # not guarantee) + arr = np.array(view, copy=False) + assert arr.dtype == expected_np_type diff --git a/kokkos/utility.py b/kokkos/utility.py index 7faa2f9..851914d 100644 --- a/kokkos/utility.py +++ b/kokkos/utility.py @@ -81,6 +81,8 @@ def read_dtype(_dtype): try: import numpy as np + if _dtype == np.int8: + return lib.int8 if _dtype == np.int16: return lib.int16 elif _dtype == np.int32: diff --git a/src/variants/CMakeLists.txt b/src/variants/CMakeLists.txt index 14cbf38..eb30eac 100644 --- a/src/variants/CMakeLists.txt +++ b/src/variants/CMakeLists.txt @@ -26,7 +26,7 @@ TARGET_LINK_LIBRARIES(libpykokkos-variants PUBLIC SET(_types concrete dynamic) SET(_variants layout memory_trait) -SET(_data_types Int16 Int32 Int64 Uint16 Uint32 Uint64 Float32 Float64) +SET(_data_types Int8 Int16 Int32 Int64 Uint16 Uint32 Uint64 Float32 Float64) SET(layout_enums Right) SET(memory_trait_enums Managed)