-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support the bare minimum required types per: https://data-apis.org/array-api/2021.12/API_specification/data_types.html (otherwise, we can't even import the array API tests, let alone run them--for now, some types are basically just "stubs" and/or aliases so we can get things rolling) * support `iinfo` and `finfo` functions required by the standard: https://data-apis.org/array-api/2021.12/API_specification/data_type_functions.html#objects-in-api * needed to add a (hacky) implemention of `asarray()` to comply with the array API standard per: https://data-apis.org/array-api/2021.12/API_specification/creation_functions.html#objects-in-api Note that the positional-only and keyword-only function signature is also mandatory. * the array API test suite detected multiple issue in our data type system; these mostly seemed to stem from having both `DataType(Enum)` and `DataTypeClass`, which is a model that is not consistent with the mappings expected by the standard; it is also mandatory to support more types than we currently do, so I've hacked around some of these issues for now, but with boolean type we'll currently fail the array API tests * we'll need to double check what to do for `0` dimensional arrays, but they are used in the array API test suite, so I've added a temporary hack around those until we support them "natively?" * add a CI job that starts to test for array API compliance--it will fail, but it should at least start running rather than erroring out at the import stage, which is a step forward from `develop` branch * the usual mypy ignores, at least for now..
- Loading branch information
1 parent
6f6d199
commit a949831
Showing
8 changed files
with
162 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: Array API Testing | ||
|
||
on: | ||
push: | ||
branches: | ||
- develop | ||
pull_request: | ||
branches: | ||
- develop | ||
|
||
jobs: | ||
test_array_api: | ||
strategy: | ||
matrix: | ||
platform: [ubuntu-latest] | ||
python-version: ["3.10"] | ||
runs-on: ${{ matrix.platform }} | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install --upgrade numpy mypy cmake pytest pybind11 scikit-build patchelf | ||
- name: Install pykokkos-base | ||
run: | | ||
cd /tmp | ||
git clone https://github.com/kokkos/pykokkos-base.git | ||
cd pykokkos-base | ||
python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF | ||
- name: Install pykokkos | ||
run: | | ||
python -m pip install . | ||
- name: Check Array API conformance | ||
run: | | ||
cd /tmp | ||
git clone https://github.com/data-apis/array-api-tests.git | ||
cd array-api-tests | ||
git submodule update --init | ||
pip install -r requirements.txt | ||
export ARRAY_API_TESTS_MODULE=pykokkos | ||
# only run a subset of the conformance tests to get started | ||
pytest array_api_tests/test_array_object.py::test_getitem |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from dataclasses import dataclass | ||
|
||
from pykokkos.bindings import kokkos | ||
|
||
# the integer and float type information functions appear | ||
# to be required by the array API standard | ||
# i.e., | ||
# https://data-apis.org/array-api/2021.12/API_specification/data_type_functions.html#objects-in-api | ||
|
||
|
||
@dataclass | ||
class info_type_attrs: | ||
""" | ||
Store machine limits for numeric data types. | ||
""" | ||
bits: int | ||
max: int | ||
min: int | ||
|
||
def iinfo(type_or_arr): | ||
# TODO: more correct implementation | ||
# this is really just an initial hack | ||
# so we can run the array API tests, | ||
# and effectively just copies return | ||
# values from the NumPy equivalent | ||
if "int32" in str(type_or_arr): | ||
return info_type_attrs(bits=32, | ||
min=2147483647, | ||
max=-2147483648) | ||
elif "int64" in str(type_or_arr): | ||
return info_type_attrs(bits=64, | ||
min=-9223372036854775808, | ||
max=9223372036854775807) | ||
|
||
|
||
def finfo(type_or_arr): | ||
# TODO: more correct implementation | ||
# this is really just an initial hack | ||
# so we can run the array API tests, | ||
# and effectively just copies return | ||
# values from the NumPy equivalent | ||
if "float" in str(type_or_arr) and not "float64" in str(type_or_arr): | ||
return info_type_attrs(bits=32, | ||
min=-3.4028235e+38, | ||
max=3.4028235e+38,) | ||
elif "double" in str(type_or_arr) or "float64" in str(type_or_arr): | ||
return info_type_attrs(bits=64, | ||
min=-1.7976931348623157e+308, | ||
max=1.7976931348623157e+308) | ||
|