-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
__array_api_version__ = '2022.12' | ||
|
||
from jax.experimental.array_api._constants import ( | ||
e as e, | ||
inf as inf, | ||
nan as nan, | ||
newaxis as newaxis, | ||
pi as pi, | ||
) | ||
|
||
from ._creation_functions import ( | ||
arange as arange, | ||
asarray as asarray, | ||
empty as empty, | ||
empty_like as empty_like, | ||
eye as eye, | ||
from_dlpack as from_dlpack, | ||
full as full, | ||
full_like as full_like, | ||
linspace as linspace, | ||
meshgrid as meshgrid, | ||
ones as ones, | ||
ones_like as ones_like, | ||
tril as tril, | ||
triu as triu, | ||
zeros as zeros, | ||
zeros_like as zeros_like, | ||
) | ||
|
||
from ._data_type_functions import ( | ||
astype as astype, | ||
can_cast as can_cast, | ||
finfo as finfo, | ||
iinfo as iinfo, | ||
isdtype as isdtype, | ||
result_type as result_type, | ||
) | ||
|
||
from ._dtypes import ( | ||
bool as bool, | ||
int8 as int8, | ||
int16 as int16, | ||
int32 as int32, | ||
int64 as int64, | ||
uint8 as uint8, | ||
uint16 as uint16, | ||
uint32 as uint32, | ||
uint64 as uint64, | ||
float32 as float32, | ||
float64 as float64, | ||
complex64 as complex64, | ||
complex128 as complex128, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import numpy as np | ||
|
||
e = np.e | ||
inf = np.inf | ||
nan = np.nan | ||
newaxis = np.newaxis | ||
pi = np.pi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def arange(start, /, stop=None, step=1, *, dtype=None, device=None): | ||
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) | ||
|
||
def asarray(obj, /, *, dtype=None, device=None, copy=None): | ||
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) | ||
|
||
def empty(shape, *, dtype=None, device=None): | ||
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device) | ||
|
||
def empty_like(x, /, *, dtype=None, device=None): | ||
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device) | ||
|
||
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): | ||
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) | ||
|
||
def from_dlpack(x, /): | ||
return jnp.from_dlpack(x) | ||
|
||
def full(shape, fill_value, *, dtype=None, device=None): | ||
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) | ||
|
||
def full_like(x, /, fill_value, *, dtype=None, device=None): | ||
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device) | ||
|
||
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): | ||
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) | ||
|
||
def meshgrid(*arrays, indexing='xy'): | ||
return jnp.meshgrid(*arrays, indexing=indexing) | ||
|
||
def ones(shape, *, dtype=None, device=None): | ||
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device) | ||
|
||
def ones_like(x, /, *, dtype=None, device=None): | ||
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device) | ||
|
||
def tril(x, /, *, k=0): | ||
return jnp.tril(x, k=k) | ||
|
||
def triu(x, /, *, k=0): | ||
return jnp.triu(x, k=k) | ||
|
||
def zeros(shape, *, dtype=None, device=None): | ||
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device) | ||
|
||
def zeros_like(x, /, *, dtype=None, device=None): | ||
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import functools | ||
from typing import NamedTuple | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
from jax.experimental.array_api._dtypes import ( | ||
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, | ||
float32, float64, complex64, complex128 | ||
) | ||
|
||
_valid_dtypes = { | ||
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, | ||
float32, float64, complex64, complex128 | ||
} | ||
|
||
_promotion_table = { | ||
(bool, bool): bool, | ||
(int8, int8): int8, | ||
(int8, int16): int16, | ||
(int8, int32): int32, | ||
(int8, int64): int64, | ||
(int8, uint8): int16, | ||
(int8, uint16): int32, | ||
(int8, uint32): int64, | ||
(int16, int8): int16, | ||
(int16, int16): int16, | ||
(int16, int32): int32, | ||
(int16, int64): int64, | ||
(int16, uint8): int32, | ||
(int16, uint16): int32, | ||
(int16, uint32): int64, | ||
(int32, int8): int32, | ||
(int32, int16): int32, | ||
(int32, int32): int32, | ||
(int32, int64): int64, | ||
(int32, uint8): int64, | ||
(int32, uint16): int64, | ||
(int32, uint32): int64, | ||
(int64, int8): int64, | ||
(int64, int16): int64, | ||
(int64, int32): int64, | ||
(int64, int64): int64, | ||
(int64, uint8): int64, | ||
(int64, uint16): int64, | ||
(int64, uint32): int64, | ||
(uint8, int8): int16, | ||
(uint8, int16): int32, | ||
(uint8, int32): int64, | ||
(uint8, uint8): uint8, | ||
(uint8, uint16): uint16, | ||
(uint8, uint32): uint32, | ||
(uint8, uint64): uint64, | ||
(uint16, int8): int32, | ||
(uint16, int16): int32, | ||
(uint16, int32): int64, | ||
(uint16, uint8): uint16, | ||
(uint16, uint16): uint16, | ||
(uint16, uint32): uint32, | ||
(uint16, uint64): uint64, | ||
(uint32, int8): int64, | ||
(uint32, int16): int64, | ||
(uint32, int32): int64, | ||
(uint32, uint8): uint32, | ||
(uint32, uint16): uint32, | ||
(uint32, uint32): uint32, | ||
(uint32, uint64): uint64, | ||
(uint64, uint8): uint64, | ||
(uint64, uint16): uint64, | ||
(uint64, uint32): uint64, | ||
(uint64, uint64): uint64, | ||
(float32, float32): float32, | ||
(float32, float64): float64, | ||
(float32, complex64): complex64, | ||
(float32, complex128): complex128, | ||
(float64, float32): float64, | ||
(float64, float64): float64, | ||
(float64, complex64): complex128, | ||
(float64, complex128): complex128, | ||
(complex64, float32): complex64, | ||
(complex64, float64): complex128, | ||
(complex64, complex64): complex64, | ||
(complex64, complex128): complex128, | ||
(complex128, float32): complex128, | ||
(complex128, float64): complex128, | ||
(complex128, complex64): complex128, | ||
(complex128, complex128): complex128, | ||
} | ||
|
||
|
||
def _is_valid_dtype(t): | ||
try: | ||
return t in _valid_dtypes | ||
except TypeError: | ||
return False | ||
|
||
|
||
def _promote_types(t1, t2): | ||
if not _is_valid_dtype(t1): | ||
raise ValueError(f"{t1} is not a valid dtype") | ||
if not _is_valid_dtype(t2): | ||
raise ValueError(f"{t2} is not a valid dtype") | ||
if result := _promotion_table.get((t1, t2), None): | ||
return result | ||
else: | ||
raise ValueError("No promotion path for {t1} & {t2}") | ||
|
||
|
||
def astype(x, dtype, /, *, copy=True): | ||
return jnp.asarray(x, dtype=dtype, copy=copy) | ||
|
||
|
||
def can_cast(from_, to, /): | ||
if not _is_valid_dtype(from_): | ||
raise ValueError(f"{from_} is not a valid dtype") | ||
if not _is_valid_dtype(to): | ||
raise ValueError(f"{to} is not a valid dtype") | ||
try: | ||
result = _promote_types(from_, to) | ||
except ValueError: | ||
return False | ||
else: | ||
return result == to | ||
|
||
|
||
class FInfo(NamedTuple): | ||
bits: int | ||
eps: float | ||
max: float | ||
min: float | ||
smallest_normal: float | ||
|
||
|
||
class IInfo(NamedTuple): | ||
bits: int | ||
max: int | ||
min: int | ||
|
||
|
||
def finfo(type, /) -> FInfo: | ||
info = jnp.finfo(type) | ||
return FInfo( | ||
bits=info.bits, | ||
eps=float(info.eps), | ||
max=float(info.max), | ||
min=float(info.min), | ||
smallest_normal=float(info.smallest_normal), | ||
) | ||
|
||
|
||
def iinfo(type, /) -> IInfo: | ||
info = jnp.iinfo(type) | ||
return IInfo(bits=info.bits, max=info.max, min=info.min) | ||
|
||
|
||
_dtype_kinds = { | ||
'bool': {bool}, | ||
'signed integer': {int8, int16, int32, int64}, | ||
'unsigned integer': {uint8, uint16, uint32, uint64}, | ||
'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64}, | ||
'real floating': {float32, float64}, | ||
'complex floating': {complex64, complex128}, | ||
'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64, | ||
float32, float64, complex64, complex128}, | ||
} | ||
|
||
def isdtype(dtype, kind): | ||
if not _is_valid_dtype(dtype): | ||
raise ValueError(f"{dtype} is not a valid dtype.") | ||
if isinstance(kind, tuple): | ||
return any(_isdtype(dtype, k) for k in kind) | ||
return _isdtype(dtype, kind) | ||
|
||
def _isdtype(dtype, kind): | ||
if isinstance(kind, jnp.dtype): | ||
return dtype == kind | ||
elif isinstance(kind, str): | ||
if kind not in _dtype_kinds: | ||
raise ValueError(f"Unrecognized {kind=!r}") | ||
return dtype in _dtype_kinds[kind] | ||
else: | ||
raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.") | ||
|
||
|
||
def result_type(*arrays_and_dtypes): | ||
dtypes = [] | ||
for val in arrays_and_dtypes: | ||
if isinstance(val, jax.Array): | ||
val = val.dtype | ||
if _is_valid_dtype(val): | ||
dtypes.append(val) | ||
else: | ||
raise ValueError(f"{val} is not a valid dtype") | ||
if len(dtypes) == 0: | ||
raise ValueError("result_type requires at least one argument") | ||
if len(dtypes) == 1: | ||
return dtypes[0] | ||
return functools.reduce(_promote_types, dtypes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
|
||
bool = np.dtype('bool') | ||
int8 = np.dtype('int8') | ||
int16 = np.dtype('int16') | ||
int32 = np.dtype('int32') | ||
int64 = np.dtype('int64') | ||
uint8 = np.dtype('uint8') | ||
uint16 = np.dtype('uint16') | ||
uint32 = np.dtype('uint32') | ||
uint64 = np.dtype('uint64') | ||
float32 = np.dtype('float32') | ||
float64 = np.dtype('float64') | ||
complex64 = np.dtype('complex64') | ||
complex128 = np.dtype('complex128') |