diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml new file mode 100644 index 000000000000..9597af64b9af --- /dev/null +++ b/.github/workflows/jax-array-api.yml @@ -0,0 +1,49 @@ +name: JAX Array API + +on: + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting particular files + branches: + - main + paths: + - '**workflows/jax-array-api.yml' + - '**experimental/array_api/**' + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.11] + + steps: + - name: Checkout jax + uses: actions/checkout@v3 + - name: Checkout array-api-tests + uses: actions/checkout@v3 + with: + repository: data-apis/array-api-tests + ref: '83f0bcdcc5286250dbb26be5d37511702970b4dc' # Latest commit as of 2023-11-15 + submodules: 'true' + path: 'array-api-tests' + - name: Fix array-apis bug + # Temporary workaround for https://github.com/data-apis/array-api/issues/631 + run: | + sed -i -e 's/\\/\\\\/g' array-api-tests/array-api/spec/API_specification/signatures/*.py + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install .[cpu] + python -m pip install "hypothesis<6.88.4" # 6.88.4 breaks with a Return-type annotation warning + python -m pip install -r array-api-tests/requirements.txt + - name: Run the test suite + env: + ARRAY_API_TESTS_MODULE: jax.experimental.array_api + JAX_ENABLE_X64: 'true' + run: | + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest --ci array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt diff --git a/docs/jax.experimental.array_api.rst b/docs/jax.experimental.array_api.rst new file mode 100644 index 000000000000..e5fa25f90b18 --- /dev/null +++ b/docs/jax.experimental.array_api.rst @@ -0,0 +1,4 @@ +``jax.experimental.array_api`` module +===================================== + +.. automodule:: jax.experimental.array_api diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index e6d4ed0e5911..adb6ea993eb6 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -14,6 +14,7 @@ Experimental Modules .. toctree:: :maxdepth: 1 + jax.experimental.array_api jax.experimental.checkify jax.experimental.host_callback jax.experimental.maps diff --git a/jax/BUILD b/jax/BUILD index 3239e36c24ce..55abbaa3a5fe 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -841,6 +841,17 @@ pytype_library( deps = [":jax"], ) +pytype_library( + name = "experimental_array_api", + srcs = glob( + [ + "experimental/array_api/*.py", + ], + ), + visibility = [":internal"], + deps = [":jax"], +) + pytype_library( name = "experimental_sparse", srcs = glob( @@ -873,7 +884,7 @@ pytype_library( "example_libraries/optimizers.py", ], visibility = ["//visibility:public"], - deps = [":jax"], + deps = [":jax"] + py_deps("numpy"), ) pytype_library( diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py new file mode 100644 index 000000000000..3169f9667256 --- /dev/null +++ b/jax/experimental/array_api/__init__.py @@ -0,0 +1,220 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module includes experimental JAX support for the `Python array API standard`_. +Support for this is currently experimental and not fully complete. + +Example Usage:: + + >>> from jax.experimental import array_api as xp + + >>> xp.__array_api_version__ + '2022.12' + + >>> arr = xp.arange(1000) + + >>> arr.sum() + Array(499500, dtype=int32) + +The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`, +and implements most of the API listed in the standard. + +.. _Python array API standard: https://data-apis.org/array-api/latest/ +""" + +from __future__ import annotations + +from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ + +from jax.experimental.array_api import ( + fft as fft, + linalg as linalg, +) + +from jax.experimental.array_api._constants import ( + e as e, + inf as inf, + nan as nan, + newaxis as newaxis, + pi as pi, +) + +from jax.experimental.array_api._creation_functions import ( + arange as arange, + asarray as asarray, + empty as empty, + empty_like as empty_like, + eye as eye, + from_dlpack as from_dlpack, + full as full, + full_like as full_like, + linspace as linspace, + meshgrid as meshgrid, + ones as ones, + ones_like as ones_like, + tril as tril, + triu as triu, + zeros as zeros, + zeros_like as zeros_like, +) + +from jax.experimental.array_api._data_type_functions import ( + astype as astype, + can_cast as can_cast, + finfo as finfo, + iinfo as iinfo, + isdtype as isdtype, + result_type as result_type, +) + +from jax.experimental.array_api._dtypes import ( + bool as bool, + int8 as int8, + int16 as int16, + int32 as int32, + int64 as int64, + uint8 as uint8, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, + float32 as float32, + float64 as float64, + complex64 as complex64, + complex128 as complex128, +) + +from jax.experimental.array_api._elementwise_functions import ( + abs as abs, + acos as acos, + acosh as acosh, + add as add, + asin as asin, + asinh as asinh, + atan as atan, + atan2 as atan2, + atanh as atanh, + bitwise_and as bitwise_and, + bitwise_invert as bitwise_invert, + bitwise_left_shift as bitwise_left_shift, + bitwise_or as bitwise_or, + bitwise_right_shift as bitwise_right_shift, + bitwise_xor as bitwise_xor, + ceil as ceil, + conj as conj, + cos as cos, + cosh as cosh, + divide as divide, + equal as equal, + exp as exp, + expm1 as expm1, + floor as floor, + floor_divide as floor_divide, + greater as greater, + greater_equal as greater_equal, + imag as imag, + isfinite as isfinite, + isinf as isinf, + isnan as isnan, + less as less, + less_equal as less_equal, + log as log, + log10 as log10, + log1p as log1p, + log2 as log2, + logaddexp as logaddexp, + logical_and as logical_and, + logical_not as logical_not, + logical_or as logical_or, + logical_xor as logical_xor, + multiply as multiply, + negative as negative, + not_equal as not_equal, + positive as positive, + pow as pow, + real as real, + remainder as remainder, + round as round, + sign as sign, + sin as sin, + sinh as sinh, + sqrt as sqrt, + square as square, + subtract as subtract, + tan as tan, + tanh as tanh, + trunc as trunc, +) + +from jax.experimental.array_api._indexing_functions import ( + take as take, +) + +from jax.experimental.array_api._manipulation_functions import ( + broadcast_arrays as broadcast_arrays, + broadcast_to as broadcast_to, + concat as concat, + expand_dims as expand_dims, + flip as flip, + permute_dims as permute_dims, + reshape as reshape, + roll as roll, + squeeze as squeeze, + stack as stack, +) + +from jax.experimental.array_api._searching_functions import ( + argmax as argmax, + argmin as argmin, + nonzero as nonzero, + where as where, +) + +from jax.experimental.array_api._set_functions import ( + unique_all as unique_all, + unique_counts as unique_counts, + unique_inverse as unique_inverse, + unique_values as unique_values, +) + +from jax.experimental.array_api._sorting_functions import ( + argsort as argsort, + sort as sort, +) + +from jax.experimental.array_api._statistical_functions import ( + max as max, + mean as mean, + min as min, + prod as prod, + std as std, + sum as sum, + var as var +) + +from jax.experimental.array_api._utility_functions import ( + all as all, + any as any, +) + +from jax.experimental.array_api._linear_algebra_functions import ( + matmul as matmul, + matrix_transpose as matrix_transpose, + tensordot as tensordot, + vecdot as vecdot, +) + +from jax.experimental.array_api import _array_methods +_array_methods.add_array_object_methods() +del _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py new file mode 100644 index 000000000000..5eedd6a151f2 --- /dev/null +++ b/jax/experimental/array_api/_array_methods.py @@ -0,0 +1,45 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Callable, Optional, Union + +import jax +from jax._src.array import ArrayImpl +from jax.experimental.array_api._version import __array_api_version__ + +from jax._src.lib import xla_extension as xe + + +def _array_namespace(self, /, *, api_version: None | str = None): + if api_version is not None and api_version != __array_api_version__: + raise ValueError(f"{api_version=!r} is not available; " + f"available versions are: {[__array_api_version__]}") + return jax.experimental.array_api + + +def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *, + stream: Optional[Union[int, Any]] = None): + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + # The type of device is defined by Array.device. In JAX, this is a callable that + # returns a device, so we must handle this case to satisfy the API spec. + return jax.device_put(self, device() if callable(device) else device) + + +def add_array_object_methods(): + # TODO(jakevdp): set on tracers as well? + setattr(ArrayImpl, "__array_namespace__", _array_namespace) + setattr(ArrayImpl, "to_device", _to_device) diff --git a/jax/experimental/array_api/_constants.py b/jax/experimental/array_api/_constants.py new file mode 100644 index 000000000000..e6f0d542ae79 --- /dev/null +++ b/jax/experimental/array_api/_constants.py @@ -0,0 +1,21 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +e = np.e +inf = np.inf +nan = np.nan +newaxis = np.newaxis +pi = np.pi diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py new file mode 100644 index 000000000000..0fcde42d58bb --- /dev/null +++ b/jax/experimental/array_api/_creation_functions.py @@ -0,0 +1,65 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp + + +def arange(start, /, stop=None, step=1, *, dtype=None, device=None): + return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) + +def asarray(obj, /, *, dtype=None, device=None, copy=None): + return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) + +def empty(shape, *, dtype=None, device=None): + return jax.device_put(jnp.empty(shape, dtype=dtype), device=device) + +def empty_like(x, /, *, dtype=None, device=None): + return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device) + +def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): + return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) + +def from_dlpack(x, /): + return jnp.from_dlpack(x) + +def full(shape, fill_value, *, dtype=None, device=None): + return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) + +def full_like(x, /, fill_value, *, dtype=None, device=None): + return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device) + +def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): + return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) + +def meshgrid(*arrays, indexing='xy'): + return jnp.meshgrid(*arrays, indexing=indexing) + +def ones(shape, *, dtype=None, device=None): + return jax.device_put(jnp.ones(shape, dtype=dtype), device=device) + +def ones_like(x, /, *, dtype=None, device=None): + return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device) + +def tril(x, /, *, k=0): + return jnp.tril(x, k=k) + +def triu(x, /, *, k=0): + return jnp.triu(x, k=k) + +def zeros(shape, *, dtype=None, device=None): + return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device) + +def zeros_like(x, /, *, dtype=None, device=None): + return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py new file mode 100644 index 000000000000..7403136cfff1 --- /dev/null +++ b/jax/experimental/array_api/_data_type_functions.py @@ -0,0 +1,236 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from typing import NamedTuple +import jax +import jax.numpy as jnp + + +from jax.experimental.array_api._dtypes import ( + bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float32, float64, complex64, complex128 +) + +_valid_dtypes = { + bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float32, float64, complex64, complex128 +} + +_promotion_table = { + (bool, bool): bool, + (int8, int8): int8, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int8, uint8): int16, + (int8, uint16): int32, + (int8, uint32): int64, + (int16, int8): int16, + (int16, int16): int16, + (int16, int32): int32, + (int16, int64): int64, + (int16, uint8): int16, + (int16, uint16): int32, + (int16, uint32): int64, + (int32, int8): int32, + (int32, int16): int32, + (int32, int32): int32, + (int32, int64): int64, + (int32, uint8): int32, + (int32, uint16): int32, + (int32, uint32): int64, + (int64, int8): int64, + (int64, int16): int64, + (int64, int32): int64, + (int64, int64): int64, + (int64, uint8): int64, + (int64, uint16): int64, + (int64, uint32): int64, + (uint8, int8): int16, + (uint8, int16): int16, + (uint8, int32): int32, + (uint8, int64): int64, + (uint8, uint8): uint8, + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, int8): int32, + (uint16, int16): int32, + (uint16, int32): int32, + (uint16, int64): int64, + (uint16, uint8): uint16, + (uint16, uint16): uint16, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, int8): int64, + (uint32, int16): int64, + (uint32, int32): int64, + (uint32, int64): int64, + (uint32, uint8): uint32, + (uint32, uint16): uint32, + (uint32, uint32): uint32, + (uint32, uint64): uint64, + (uint64, uint8): uint64, + (uint64, uint16): uint64, + (uint64, uint32): uint64, + (uint64, uint64): uint64, + (float32, float32): float32, + (float32, float64): float64, + (float32, complex64): complex64, + (float32, complex128): complex128, + (float64, float32): float64, + (float64, float64): float64, + (float64, complex64): complex128, + (float64, complex128): complex128, + (complex64, float32): complex64, + (complex64, float64): complex128, + (complex64, complex64): complex64, + (complex64, complex128): complex128, + (complex128, float32): complex128, + (complex128, float64): complex128, + (complex128, complex64): complex128, + (complex128, complex128): complex128, +} + + +def _is_valid_dtype(t): + try: + return t in _valid_dtypes + except TypeError: + return False + + +def _promote_types(t1, t2): + if not _is_valid_dtype(t1): + raise ValueError(f"{t1} is not a valid dtype") + if not _is_valid_dtype(t2): + raise ValueError(f"{t2} is not a valid dtype") + if result := _promotion_table.get((t1, t2), None): + return result + else: + raise ValueError("No promotion path for {t1} & {t2}") + + +def astype(x, dtype, /, *, copy=True): + return jnp.array(x, dtype=dtype, copy=copy) + + +def can_cast(from_, to, /): + if isinstance(from_, jax.Array): + from_ = from_.dtype + if not _is_valid_dtype(from_): + raise ValueError(f"{from_} is not a valid dtype") + if not _is_valid_dtype(to): + raise ValueError(f"{to} is not a valid dtype") + try: + result = _promote_types(from_, to) + except ValueError: + return False + else: + return result == to + + +class FInfo(NamedTuple): + bits: int + eps: float + max: float + min: float + smallest_normal: float + dtype: jnp.dtype + + +class IInfo(NamedTuple): + bits: int + max: int + min: int + dtype: jnp.dtype + + +def finfo(type, /) -> FInfo: + info = jnp.finfo(type) + return FInfo( + bits=info.bits, + eps=float(info.eps), + max=float(info.max), + min=float(info.min), + smallest_normal=float(info.smallest_normal), + dtype=jnp.dtype(type) + ) + + +def iinfo(type, /) -> IInfo: + info = jnp.iinfo(type) + return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type)) + + +_dtype_kinds = { + 'bool': {bool}, + 'signed integer': {int8, int16, int32, int64}, + 'unsigned integer': {uint8, uint16, uint32, uint64}, + 'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64}, + 'real floating': {float32, float64}, + 'complex floating': {complex64, complex128}, + 'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64, + float32, float64, complex64, complex128}, +} + +def isdtype(dtype, kind): + if not _is_valid_dtype(dtype): + raise ValueError(f"{dtype} is not a valid dtype.") + if isinstance(kind, tuple): + return any(_isdtype(dtype, k) for k in kind) + return _isdtype(dtype, kind) + +def _isdtype(dtype, kind): + if isinstance(kind, jnp.dtype): + return dtype == kind + elif isinstance(kind, str): + if kind not in _dtype_kinds: + raise ValueError(f"Unrecognized {kind=!r}") + return dtype in _dtype_kinds[kind] + else: + raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.") + + +def result_type(*arrays_and_dtypes): + dtypes = [] + for val in arrays_and_dtypes: + if isinstance(val, jax.Array): + val = val.dtype + if _is_valid_dtype(val): + dtypes.append(val) + else: + raise ValueError(f"{val} is not a valid dtype") + if len(dtypes) == 0: + raise ValueError("result_type requires at least one argument") + if len(dtypes) == 1: + return dtypes[0] + return functools.reduce(_promote_types, dtypes) + + +def _promote_to_default_dtype(x): + if x.dtype.kind == 'b': + return x + elif x.dtype.kind == 'i': + return x.astype(jnp.int_) + elif x.dtype.kind == 'u': + return x.astype(jnp.uint) + elif x.dtype.kind == 'f': + return x.astype(jnp.float_) + elif x.dtype.kind == 'c': + return x.astype(jnp.complex_) + else: + raise ValueError(f"Unrecognized {x.dtype=}") diff --git a/jax/experimental/array_api/_dtypes.py b/jax/experimental/array_api/_dtypes.py new file mode 100644 index 000000000000..72229bfc28af --- /dev/null +++ b/jax/experimental/array_api/_dtypes.py @@ -0,0 +1,29 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +bool = np.dtype('bool') +int8 = np.dtype('int8') +int16 = np.dtype('int16') +int32 = np.dtype('int32') +int64 = np.dtype('int64') +uint8 = np.dtype('uint8') +uint16 = np.dtype('uint16') +uint32 = np.dtype('uint32') +uint64 = np.dtype('uint64') +float32 = np.dtype('float32') +float64 = np.dtype('float64') +complex64 = np.dtype('complex64') +complex128 = np.dtype('complex128') diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py new file mode 100644 index 000000000000..373d29098a16 --- /dev/null +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -0,0 +1,388 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax.experimental.array_api._data_type_functions import ( + result_type as _result_type, + isdtype as _isdtype, +) +import numpy as np + + +def _promote_dtypes(name, *args): + assert isinstance(name, str) + if not all(isinstance(arg, jax.Array) for arg in args): + raise ValueError(f"{name}: inputs must be arrays; got types {[type(arg) for arg in args]}") + dtype = _result_type(*args) + return [arg.astype(dtype) for arg in args] + + +def abs(x, /): + """Calculates the absolute value for each element x_i of the input array x.""" + x, = _promote_dtypes("abs", x) + if _isdtype(x.dtype, "unsigned integer"): + return x + return jax.lax.abs(x) + + +def acos(x, /): + """Calculates an implementation-dependent approximation of the principal value of the inverse cosine for each element x_i of the input array x.""" + x, = _promote_dtypes("acos", x) + return jax.lax.acos(x) + +def acosh(x, /): + """Calculates an implementation-dependent approximation to the inverse hyperbolic cosine for each element x_i of the input array x.""" + x, = _promote_dtypes("acos", x) + return jax.lax.acosh(x) + + +def add(x1, x2, /): + """Calculates the sum for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("add", x1, x2) + return jax.numpy.add(x1, x2) + + +def asin(x, /): + """Calculates an implementation-dependent approximation of the principal value of the inverse sine for each element x_i of the input array x.""" + x, = _promote_dtypes("asin", x) + return jax.lax.asin(x) + + +def asinh(x, /): + """Calculates an implementation-dependent approximation to the inverse hyperbolic sine for each element x_i in the input array x.""" + x, = _promote_dtypes("asinh", x) + return jax.lax.asinh(x) + + +def atan(x, /): + """Calculates an implementation-dependent approximation of the principal value of the inverse tangent for each element x_i of the input array x.""" + x, = _promote_dtypes("atan", x) + return jax.lax.atan(x) + + +def atan2(x1, x2, /): + """Calculates an implementation-dependent approximation of the inverse tangent of the quotient x1/x2, having domain [-infinity, +infinity] x [-infinity, +infinity] (where the x notation denotes the set of ordered pairs of elements (x1_i, x2_i)) and codomain [-π, +π], for each pair of elements (x1_i, x2_i) of the input arrays x1 and x2, respectively.""" + x1, x2 = _promote_dtypes("atan2", x1, x2) + return jax.numpy.arctan2(x1, x2) + + +def atanh(x, /): + """Calculates an implementation-dependent approximation to the inverse hyperbolic tangent for each element x_i of the input array x.""" + x, = _promote_dtypes("atanh", x) + return jax.lax.atanh(x) + + +def bitwise_and(x1, x2, /): + """Computes the bitwise AND of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("bitwise_and", x1, x2) + return jax.numpy.bitwise_and(x1, x2) + + +def bitwise_left_shift(x1, x2, /): + """Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i.""" + x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2) + return jax.numpy.left_shift(x1, x2) + + +def bitwise_invert(x, /): + """Inverts (flips) each bit for each element x_i of the input array x.""" + x, = _promote_dtypes("bitwise_invert", x) + return jax.numpy.bitwise_not(x) + + +def bitwise_or(x1, x2, /): + """Computes the bitwise OR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("bitwise_or", x1, x2) + return jax.numpy.bitwise_or(x1, x2) + + +def bitwise_right_shift(x1, x2, /): + """Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2) + return jax.numpy.right_shift(x1, x2) + + +def bitwise_xor(x1, x2, /): + """Computes the bitwise XOR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("bitwise_xor", x1, x2) + return jax.numpy.bitwise_xor(x1, x2) + + +def ceil(x, /): + """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" + x, = _promote_dtypes("ceil", x) + if _isdtype(x.dtype, "integral"): + return x + return jax.lax.ceil(x) + + +def conj(x, /): + """Returns the complex conjugate for each element x_i of the input array x.""" + x, = _promote_dtypes("conj", x) + return jax.lax.conj(x) + + +def cos(x, /): + """Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x.""" + x, = _promote_dtypes("cos", x) + return jax.lax.cos(x) + + +def cosh(x, /): + """Calculates an implementation-dependent approximation to the hyperbolic cosine for each element x_i in the input array x.""" + x, = _promote_dtypes("cosh", x) + return jax.lax.cosh(x) + + +def divide(x1, x2, /): + """Calculates the division of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("divide", x1, x2) + return jax.numpy.divide(x1, x2) + + +def equal(x1, x2, /): + """Computes the truth value of x1_i == x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("equal", x1, x2) + return jax.numpy.equal(x1, x2) + + +def exp(x, /): + """Calculates an implementation-dependent approximation to the exponential function for each element x_i of the input array x (e raised to the power of x_i, where e is the base of the natural logarithm).""" + x, = _promote_dtypes("exp", x) + return jax.lax.exp(x) + + +def expm1(x, /): + """Calculates an implementation-dependent approximation to exp(x)-1 for each element x_i of the input array x.""" + x, = _promote_dtypes("expm1", x) + return jax.lax.expm1(x) + + +def floor(x, /): + """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" + x, = _promote_dtypes("floor", x) + if _isdtype(x.dtype, "integral"): + return x + return jax.lax.floor(x) + + +def floor_divide(x1, x2, /): + """Rounds the result of dividing each element x1_i of the input array x1 by the respective element x2_i of the input array x2 to the greatest (i.e., closest to +infinity) integer-value number that is not greater than the division result.""" + x1, x2 = _promote_dtypes("floor_divide", x1, x2) + return jax.numpy.floor_divide(x1, x2) + + +def greater(x1, x2, /): + """Computes the truth value of x1_i > x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("greater", x1, x2) + return jax.numpy.greater(x1, x2) + + +def greater_equal(x1, x2, /): + """Computes the truth value of x1_i >= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("greater_equal", x1, x2) + return jax.numpy.greater_equal(x1, x2) + + +def imag(x, /): + """Returns the imaginary component of a complex number for each element x_i of the input array x.""" + x, = _promote_dtypes("imag", x) + return jax.lax.imag(x) + + +def isfinite(x, /): + """Tests each element x_i of the input array x to determine if finite.""" + x, = _promote_dtypes("isfinite", x) + return jax.numpy.isfinite(x) + + +def isinf(x, /): + """Tests each element x_i of the input array x to determine if equal to positive or negative infinity.""" + x, = _promote_dtypes("isinf", x) + return jax.numpy.isinf(x) + + +def isnan(x, /): + """Tests each element x_i of the input array x to determine whether the element is NaN.""" + x, = _promote_dtypes("isnan", x) + return jax.numpy.isnan(x) + + +def less(x1, x2, /): + """Computes the truth value of x1_i < x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("less", x1, x2) + return jax.numpy.less(x1, x2) + + +def less_equal(x1, x2, /): + """Computes the truth value of x1_i <= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("less_equal", x1, x2) + return jax.numpy.less_equal(x1, x2) + + +def log(x, /): + """Calculates an implementation-dependent approximation to the natural (base e) logarithm for each element x_i of the input array x.""" + x, = _promote_dtypes("log", x) + return jax.lax.log(x) + +def log1p(x, /): + """Calculates an implementation-dependent approximation to log(1+x), where log refers to the natural (base e) logarithm, for each element x_i of the input array x.""" + x, = _promote_dtypes("log", x) + return jax.lax.log1p(x) + + +def log2(x, /): + """Calculates an implementation-dependent approximation to the base 2 logarithm for each element x_i of the input array x.""" + x, = _promote_dtypes("log2", x) + return jax.lax.div(jax.lax.log(x), jax.lax.log(np.array(2, dtype=x.dtype))) + + +def log10(x, /): + """Calculates an implementation-dependent approximation to the base 10 logarithm for each element x_i of the input array x.""" + x, = _promote_dtypes("log2", x) + return jax.lax.div(jax.lax.log(x), jax.lax.log(np.array(10, dtype=x.dtype))) + + +def logaddexp(x1, x2, /): + """Calculates the logarithm of the sum of exponentiations log(exp(x1) + exp(x2)) for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("logaddexp", x1, x2) + return jax.numpy.logaddexp(x1, x2) + + +def logical_and(x1, x2, /): + """Computes the logical AND for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("logical_and", x1, x2) + return jax.numpy.logical_and(x1, x2) + + +def logical_not(x, /): + """Computes the logical NOT for each element x_i of the input array x.""" + x, = _promote_dtypes("logical_not", x) + return jax.numpy.logical_not(x) + + +def logical_or(x1, x2, /): + """Computes the logical OR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("logical_or", x1, x2) + return jax.numpy.logical_or(x1, x2) + + +def logical_xor(x1, x2, /): + """Computes the logical XOR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("logical_xor", x1, x2) + return jax.numpy.logical_xor(x1, x2) + + +def multiply(x1, x2, /): + """Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("multiply", x1, x2) + return jax.numpy.multiply(x1, x2) + + +def negative(x, /): + """Computes the numerical negative of each element x_i (i.e., y_i = -x_i) of the input array x.""" + x, = _promote_dtypes("negative", x) + return jax.lax.neg(x) + + +def not_equal(x1, x2, /): + """Computes the truth value of x1_i != x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("not_equal", x1, x2) + return jax.numpy.not_equal(x1, x2) + + +def positive(x, /): + """Computes the numerical positive of each element x_i (i.e., y_i = +x_i) of the input array x.""" + x, = _promote_dtypes("positive", x) + return x + + +def pow(x1, x2, /): + """Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2.""" + x1, x2 = _promote_dtypes("pow", x1, x2) + return jax.numpy.power(x1, x2) + + +def real(x, /): + """Returns the real component of a complex number for each element x_i of the input array x.""" + x, = _promote_dtypes("real", x) + return jax.lax.real(x) + + +def remainder(x1, x2, /): + """Returns the remainder of division for each element x1_i of the input array x1 and the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("remainder", x1, x2) + return jax.numpy.remainder(x1, x2) + + +def round(x, /): + """Rounds each element x_i of the input array x to the nearest integer-valued number.""" + x, = _promote_dtypes("round", x) + return jax.numpy.round(x) + + +def sign(x, /): + """Returns an indication of the sign of a number for each element x_i of the input array x.""" + x, = _promote_dtypes("sign", x) + return jax.lax.sign(x) + + +def sin(x, /): + """Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x.""" + x, = _promote_dtypes("sin", x) + return jax.lax.sin(x) + + +def sinh(x, /): + """Calculates an implementation-dependent approximation to the hyperbolic sine for each element x_i of the input array x.""" + x, = _promote_dtypes("sin", x) + return jax.lax.sinh(x) + + +def square(x, /): + """Squares each element x_i of the input array x.""" + x, = _promote_dtypes("square", x) + return jax.lax.integer_pow(x, 2) + + +def sqrt(x, /): + """Calculates the principal square root for each element x_i of the input array x.""" + x, = _promote_dtypes("sqrt", x) + return jax.lax.sqrt(x) + + +def subtract(x1, x2, /): + """Calculates the difference for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("subtract", x1, x2) + return jax.numpy.subtract(x1, x2) + + +def tan(x, /): + """Calculates an implementation-dependent approximation to the tangent for each element x_i of the input array x.""" + x, = _promote_dtypes("tan", x) + return jax.lax.tan(x) + + +def tanh(x, /): + """Calculates an implementation-dependent approximation to the hyperbolic tangent for each element x_i of the input array x.""" + x, = _promote_dtypes("tanh", x) + return jax.lax.tanh(x) + + +def trunc(x, /): + """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" + x, = _promote_dtypes("trunc", x) + if _isdtype(x.dtype, "integral"): + return x + return jax.numpy.trunc(x) diff --git a/jax/experimental/array_api/_fft_functions.py b/jax/experimental/array_api/_fft_functions.py new file mode 100644 index 000000000000..d1e737a424ac --- /dev/null +++ b/jax/experimental/array_api/_fft_functions.py @@ -0,0 +1,72 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + + +def fft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional discrete Fourier transform.""" + return jnp.fft.fft(x, n=n, axis=axis, norm=norm) + +def ifft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional inverse discrete Fourier transform.""" + return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) + +def fftn(x, /, *, s=None, axes=None, norm='backward'): + """Computes the n-dimensional discrete Fourier transform.""" + return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) + +def ifftn(x, /, *, s=None, axes=None, norm='backward'): + """Computes the n-dimensional inverse discrete Fourier transform.""" + return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) + +def rfft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional discrete Fourier transform for real-valued input.""" + return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) + +def irfft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional inverse of rfft for complex-valued input.""" + return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) + +def rfftn(x, /, *, s=None, axes=None, norm='backward'): + """Computes the n-dimensional discrete Fourier transform for real-valued input.""" + return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) + +def irfftn(x, /, *, s=None, axes=None, norm='backward'): + """Computes the n-dimensional inverse of rfftn for complex-valued input.""" + return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) + +def hfft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry.""" + return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) + +def ihfft(x, /, *, n=None, axis=-1, norm='backward'): + """Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry.""" + return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) + +def fftfreq(n, /, *, d=1.0, device=None): + """Returns the discrete Fourier transform sample frequencies.""" + return jnp.fft.fftfreq(n, d=d).to_device(device) + +def rfftfreq(n, /, *, d=1.0, device=None): + """Returns the discrete Fourier transform sample frequencies (for rfft and irfft).""" + return jnp.fft.rfftfreq(n, d=d).to_device(device) + +def fftshift(x, /, *, axes=None): + """Shift the zero-frequency component to the center of the spectrum.""" + return jnp.fft.fftshift(x, axes=axes) + +def ifftshift(x, /, *, axes=None): + """Inverse of fftshift.""" + return jnp.fft.ifftshift(x, axes=axes) diff --git a/jax/experimental/array_api/_indexing_functions.py b/jax/experimental/array_api/_indexing_functions.py new file mode 100644 index 000000000000..261c81b20351 --- /dev/null +++ b/jax/experimental/array_api/_indexing_functions.py @@ -0,0 +1,18 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax + +def take(x, indices, /, *, axis): + return jax.numpy.take(x, indices, axis=axis) diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py new file mode 100644 index 000000000000..2ce616afd3b7 --- /dev/null +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -0,0 +1,192 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import NamedTuple + +import jax +from jax.experimental.array_api._data_type_functions import ( + _promote_to_default_dtype, +) + +class EighResult(NamedTuple): + eigenvalues: jax.Array + eigenvectors: jax.Array + +class QRResult(NamedTuple): + Q: jax.Array + R: jax.Array + +class SlogdetResult(NamedTuple): + sign: jax.Array + logabsdet: jax.Array + +class SVDResult(NamedTuple): + U: jax.Array + S: jax.Array + Vh: jax.Array + +def cholesky(x, /, *, upper=False): + """ + Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x. + """ + return jax.numpy.linalg.cholesky(jax.numpy.matrix_transpose(x) if upper else x) + +def cross(x1, x2, /, *, axis=-1): + """ + Returns the cross product of 3-element vectors. + """ + return jax.numpy.cross(x1, x2, axis=axis) + +def det(x, /): + """ + Returns the determinant of a square matrix (or a stack of square matrices) x. + """ + return jax.numpy.linalg.det(x) + +def diagonal(x, /, *, offset=0): + """ + Returns the specified diagonals of a matrix (or a stack of matrices) x. + """ + f = partial(jax.numpy.diagonal, offset=offset) + for _ in range(x.ndim - 2): + f = jax.vmap(f) + return f(x) + +def eigh(x, /): + """ + Returns an eigenvalue decomposition of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. + """ + eigenvalues, eigenvectors = jax.numpy.linalg.eigh(x) + return EighResult(eigenvalues=eigenvalues, eigenvectors=eigenvectors) + +def eigvalsh(x, /): + """ + Returns the eigenvalues of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. + """ + return jax.numpy.linalg.eigvalsh(x) + +def inv(x, /): + """ + Returns the multiplicative inverse of a square matrix (or a stack of square matrices) x. + """ + return jax.numpy.linalg.inv(x) + +def matmul(x1, x2, /): + """Computes the matrix product.""" + return jax.numpy.matmul(x1, x2) + +def matrix_norm(x, /, *, keepdims=False, ord='fro'): + """ + Computes the matrix norm of a matrix (or a stack of matrices) x. + """ + return jax.numpy.linalg.norm(x, ord=ord, keepdims=keepdims, axis=(-1, -2)) + +def matrix_power(x, n, /): + """ + Raises a square matrix (or a stack of square matrices) x to an integer power n. + """ + return jax.numpy.linalg.matrix_power(x, n) + +def matrix_rank(x, /, *, rtol=None): + """ + Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices). + """ + return jax.numpy.linalg.matrix_rank(x, tol=rtol) + +def matrix_transpose(x, /): + """Transposes a matrix (or a stack of matrices) x.""" + if x.ndim < 2: + raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {x.ndim=}") + return jax.lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + +def outer(x1, x2, /): + """ + Returns the outer product of two vectors x1 and x2. + """ + return jax.numpy.outer(x1, x2) + +def pinv(x, /, *, rtol=None): + """ + Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x. + """ + return jax.numpy.linalg.pinv(x, rcond=rtol) + +def qr(x, /, *, mode='reduced'): + """ + Returns the QR decomposition of a full column rank matrix (or a stack of matrices). + """ + Q, R = jax.numpy.linalg.qr(x, mode=mode) + return QRResult(Q=Q, R=R) + +def slogdet(x, /): + """ + Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) x. + """ + sign, logabsdet = jax.numpy.linalg.slogdet(x) + return SlogdetResult(sign, logabsdet) + +def solve(x1, x2, /): + """ + Returns the solution of a square system of linear equations with a unique solution. + """ + if x2.ndim == 1: + x2 = x2.reshape(*x1.shape[:-2], *x2.shape, 1) + return jax.numpy.linalg.solve(x1, x2)[..., 0] + if x2.ndim > x1.ndim: + x1 = x1.reshape(*x2.shape[:-2], *x1.shape) + elif x1.ndim > x2.ndim: + x2 = x2.reshape(*x1.shape[:-2], *x2.shape) + return jax.numpy.linalg.solve(x1, x2) + + +def svd(x, /, *, full_matrices=True): + """ + Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) x. + """ + U, S, Vh = jax.numpy.linalg.svd(x, full_matrices=full_matrices) + return SVDResult(U=U, S=S, Vh=Vh) + +def svdvals(x, /): + """ + Returns the singular values of a matrix (or a stack of matrices) x. + """ + return jax.numpy.linalg.svd(x, compute_uv=False) + +def tensordot(x1, x2, /, *, axes=2): + """Returns a tensor contraction of x1 and x2 over specific axes.""" + return jax.numpy.tensordot(x1, x2, axes=axes) + +def trace(x, /, *, offset=0, dtype=None): + """ + Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x. + """ + x = _promote_to_default_dtype(x) + return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1) + +def vecdot(x1, x2, /, *, axis=-1): + """Computes the (vector) dot product of two arrays.""" + rank = max(x1.ndim, x2.ndim) + x1 = jax.lax.broadcast_to_rank(x1, rank) + x2 = jax.lax.broadcast_to_rank(x2, rank) + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("x1 and x2 must have the same size along specified axis.") + x1, x2 = jax.numpy.broadcast_arrays(x1, x2) + x1 = jax.numpy.moveaxis(x1, axis, -1) + x2 = jax.numpy.moveaxis(x2, axis, -1) + return jax.numpy.matmul(x1[..., None, :], x2[..., None])[..., 0, 0] + +def vector_norm(x, /, *, axis=None, keepdims=False, ord=2): + """Computes the vector norm of a vector (or batch of vectors) x.""" + return jax.numpy.linalg.norm(x, axis=axis, keepdims=keepdims, ord=ord) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py new file mode 100644 index 000000000000..411476f229d7 --- /dev/null +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -0,0 +1,79 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +from typing import List, Optional, Tuple, Union + +import jax +from jax import Array +from jax.experimental.array_api._data_type_functions import result_type as _result_type + + +def broadcast_arrays(*arrays: Array) -> List[Array]: + """Broadcasts one or more arrays against one another.""" + return jax.numpy.broadcast_arrays(*arrays) + + +def broadcast_to(x: Array, /, shape: Tuple[int]) -> Array: + """Broadcasts an array to a specified shape.""" + return jax.numpy.broadcast_to(x, shape=shape) + + +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0) -> Array: + """Joins a sequence of arrays along an existing axis.""" + dtype = _result_type(*arrays) + if axis is None: + arrays = [reshape(arr, (arr.size,)) for arr in arrays] + axis = 0 + return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype) + + +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: + """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis.""" + if axis < -x.ndim - 1 or axis > x.ndim: + raise IndexError(f"{axis=} is out of bounds for array of dimension {x.ndim}") + return jax.numpy.expand_dims(x, axis=axis) + + +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """Reverses the order of elements in an array along the given axis.""" + return jax.numpy.flip(x, axis=axis) + + +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: + """Permutes the axes (dimensions) of an array x.""" + return jax.lax.transpose(x, axes) + + +def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: + """Reshapes an array without changing its data.""" + del copy # unused + return jax.numpy.reshape(x, shape) + + +def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """Rolls array elements along a specified axis.""" + return jax.numpy.roll(x, shift=shift, axis=axis) + + +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: + """Removes singleton dimensions (axes) from x.""" + dimensions = axis if isinstance(axis, tuple) else (axis,) + return jax.lax.squeeze(x, dimensions=dimensions) + + +def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: + """Joins a sequence of arrays along a new axis.""" + dtype = _result_type(*arrays) + return jax.numpy.stack(arrays, axis=axis, dtype=dtype) diff --git a/jax/experimental/array_api/_searching_functions.py b/jax/experimental/array_api/_searching_functions.py new file mode 100644 index 000000000000..698a9c2df50d --- /dev/null +++ b/jax/experimental/array_api/_searching_functions.py @@ -0,0 +1,37 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax.experimental.array_api._data_type_functions import result_type as _result_type + + +def argmax(x, /, *, axis=None, keepdims=False): + """Returns the indices of the maximum values along a specified axis.""" + return jax.numpy.argmax(x, axis=axis, keepdims=keepdims) + + +def argmin(x, /, *, axis=None, keepdims=False): + """Returns the indices of the minimum values along a specified axis.""" + return jax.numpy.argmin(x, axis=axis, keepdims=keepdims) + + +def nonzero(x, /): + """Returns the indices of the array elements which are non-zero.""" + return jax.numpy.nonzero(x) + + +def where(condition, x1, x2, /): + """Returns elements chosen from x1 or x2 depending on condition.""" + dtype = _result_type(x1, x2) + return jax.numpy.where(condition, x1.astype(dtype), x2.astype(dtype)) diff --git a/jax/experimental/array_api/_set_functions.py b/jax/experimental/array_api/_set_functions.py new file mode 100644 index 000000000000..95043790c37c --- /dev/null +++ b/jax/experimental/array_api/_set_functions.py @@ -0,0 +1,60 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import NamedTuple +import jax + + +class UniqueAllResult(NamedTuple): + values: jax.Array + indices: jax.Array + inverse_indices: jax.Array + counts: jax.Array + + +class UniqueCountsResult(NamedTuple): + values: jax.Array + counts: jax.Array + + +class UniqueInverseResult(NamedTuple): + values: jax.Array + inverse_indices: jax.Array + + +def unique_all(x, /): + """Returns the unique elements of an input array x, the first occurring indices for each unique element in x, the indices from the set of unique elements that reconstruct x, and the corresponding counts for each unique element in x.""" + values, indices, inverse_indices, counts = jax.numpy.unique( + x, return_index=True, return_inverse=True, return_counts=True) + # jnp.unique() flattens inverse indices + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts) + + +def unique_counts(x, /): + """Returns the unique elements of an input array x and the corresponding counts for each unique element in x.""" + values, counts = jax.numpy.unique(x, return_counts=True) + return UniqueCountsResult(values=values, counts=counts) + + +def unique_inverse(x, /): + """Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.""" + values, inverse_indices = jax.numpy.unique(x, return_inverse=True) + inverse_indices = inverse_indices.reshape(x.shape) + return UniqueInverseResult(values=values, inverse_indices=inverse_indices) + + +def unique_values(x, /): + """Returns the unique elements of an input array x.""" + return jax.numpy.unique(x) diff --git a/jax/experimental/array_api/_sorting_functions.py b/jax/experimental/array_api/_sorting_functions.py new file mode 100644 index 000000000000..139593f203cf --- /dev/null +++ b/jax/experimental/array_api/_sorting_functions.py @@ -0,0 +1,36 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax import Array + + +def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, + stable: bool = True) -> Array: + """Returns the indices that sort an array x along a specified axis.""" + del stable # unused + if descending: + return jax.numpy.argsort(-x, axis=axis) + else: + return jax.numpy.argsort(x, axis=axis) + + +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, + stable: bool = True) -> Array: + """Returns a sorted copy of an input array x.""" + del stable # unused + result = jax.numpy.sort(x, axis=axis) + if descending: + return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis]) + return result diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py new file mode 100644 index 000000000000..2e1333317605 --- /dev/null +++ b/jax/experimental/array_api/_statistical_functions.py @@ -0,0 +1,55 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax.experimental.array_api._data_type_functions import ( + _promote_to_default_dtype, +) + + +def max(x, /, *, axis=None, keepdims=False): + """Calculates the maximum value of the input array x.""" + return jax.numpy.max(x, axis=axis, keepdims=keepdims) + + +def mean(x, /, *, axis=None, keepdims=False): + """Calculates the arithmetic mean of the input array x.""" + return jax.numpy.mean(x, axis=axis, keepdims=keepdims) + + +def min(x, /, *, axis=None, keepdims=False): + """Calculates the minimum value of the input array x.""" + return jax.numpy.min(x, axis=axis, keepdims=keepdims) + + +def prod(x, /, *, axis=None, dtype=None, keepdims=False): + """Calculates the product of input array x elements.""" + x = _promote_to_default_dtype(x) + return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +def std(x, /, *, axis=None, correction=0.0, keepdims=False): + """Calculates the standard deviation of the input array x.""" + return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims) + + +def sum(x, /, *, axis=None, dtype=None, keepdims=False): + """Calculates the sum of the input array x.""" + x = _promote_to_default_dtype(x) + return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) + + +def var(x, /, *, axis=None, correction=0.0, keepdims=False): + """Calculates the variance of the input array x.""" + return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py new file mode 100644 index 000000000000..60d739277627 --- /dev/null +++ b/jax/experimental/array_api/_utility_functions.py @@ -0,0 +1,25 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax + + +def all(x, /, *, axis=None, keepdims=False): + """Tests whether all input array elements evaluate to True along a specified axis.""" + return jax.numpy.all(x, axis=axis, keepdims=keepdims) + + +def any(x, /, *, axis=None, keepdims=False): + """Tests whether any input array element evaluates to True along a specified axis.""" + return jax.numpy.any(x, axis=axis, keepdims=keepdims) diff --git a/jax/experimental/array_api/_version.py b/jax/experimental/array_api/_version.py new file mode 100644 index 000000000000..4936af86da4c --- /dev/null +++ b/jax/experimental/array_api/_version.py @@ -0,0 +1,15 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__array_api_version__ = '2022.12' diff --git a/jax/experimental/array_api/fft.py b/jax/experimental/array_api/fft.py new file mode 100644 index 000000000000..f83d45401d20 --- /dev/null +++ b/jax/experimental/array_api/fft.py @@ -0,0 +1,30 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.experimental.array_api._fft_functions import ( + fft as fft, + fftfreq as fftfreq, + fftn as fftn, + fftshift as fftshift, + hfft as hfft, + ifft as ifft, + ifftn as ifftn, + ifftshift as ifftshift, + ihfft as ihfft, + irfft as irfft, + irfftn as irfftn, + rfft as rfft, + rfftfreq as rfftfreq, + rfftn as rfftn, +) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py new file mode 100644 index 000000000000..30b531f502bc --- /dev/null +++ b/jax/experimental/array_api/linalg.py @@ -0,0 +1,40 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.experimental.array_api._linear_algebra_functions import ( + cholesky as cholesky, + cross as cross, + det as det, + diagonal as diagonal, + eigh as eigh, + eigvalsh as eigvalsh, + inv as inv, + jax as jax, + matmul as matmul, + matrix_norm as matrix_norm, + matrix_power as matrix_power, + matrix_rank as matrix_rank, + matrix_transpose as matrix_transpose, + outer as outer, + pinv as pinv, + qr as qr, + slogdet as slogdet, + solve as solve, + svd as svd, + svdvals as svdvals, + tensordot as tensordot, + trace as trace, + vecdot as vecdot, + vector_norm as vector_norm, +) diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt new file mode 100644 index 000000000000..dfc3d383cd10 --- /dev/null +++ b/jax/experimental/array_api/skips.txt @@ -0,0 +1,46 @@ +# Known failures for the array api tests. + +# JAX doesn't yet support scalar boolean indexing +array_api_tests/test_array_object.py::test_getitem_masking + +# Hypothesis warning +array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices + +# Test suite attempts in-place mutation: +array_api_tests/test_special_cases.py::test_binary +array_api_tests/test_special_cases.py::test_iop +array_api_tests/test_special_cases.py::test_nan_propagation +array_api_tests/test_special_cases.py::test_unary +array_api_tests/test_array_object.py::test_setitem +array_api_tests/test_creation_functions.py::test_asarray_arrays +array_api_tests/test_linalg.py::test_matrix_power +array_api_tests/test_linalg.py::test_solve + +# Overflow errors due to hypothesis generating integers that overflow int64 +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_square +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__iand__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__ior__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__ixor__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)] + +# JAX's NaN sorting doesn't match specification +array_api_tests/test_set_functions.py::test_unique_all +array_api_tests/test_set_functions.py::test_unique_counts +array_api_tests/test_set_functions.py::test_unique_inverse +array_api_tests/test_set_functions.py::test_unique_values +array_api_tests/test_sorting_functions.py::test_argsort + +# fft test suite is buggy as of 83f0bcdc +array_api_tests/test_fft.py diff --git a/pyproject.toml b/pyproject.toml index 0c17d2eb3e36..00719b262c57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,8 @@ markers = [ ] filterwarnings = [ "error", + "ignore:The numpy.array_api submodule is still experimental.:UserWarning", + "ignore:The hookimpl.*:DeprecationWarning", "ignore:No GPU/TPU found, falling back to CPU.:UserWarning", "ignore:xmap is an experimental feature and probably has bugs!", "ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",