Skip to content

Commit

Permalink
[array API] clean up some superseded definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 25, 2024
1 parent 9ea79c6 commit 8c36f90
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 43 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
expand_dims as expand_dims,
expm1 as expm1,
eye as eye,
finfo as finfo,
flip as flip,
float32 as float32,
float64 as float64,
Expand Down Expand Up @@ -193,7 +194,6 @@

from jax.experimental.array_api._data_type_functions import (
astype as astype,
finfo as finfo,
)

from jax.experimental.array_api._elementwise_functions import (
Expand Down
43 changes: 1 addition & 42 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,15 @@

from __future__ import annotations

import builtins
from typing import NamedTuple
import numpy as np

import jax.numpy as jnp

from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src import dtypes as _dtypes

# TODO(micky774): Update jax.numpy dtypes to dtype *objects*
bool = np.dtype('bool')
int8 = np.dtype('int8')
int16 = np.dtype('int16')
int32 = np.dtype('int32')
int64 = np.dtype('int64')
uint8 = np.dtype('uint8')
uint16 = np.dtype('uint16')
uint32 = np.dtype('uint32')
uint64 = np.dtype('uint64')
float32 = np.dtype('float32')
float64 = np.dtype('float64')
complex64 = np.dtype('complex64')
complex128 = np.dtype('complex128')


# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
def astype(x, dtype, /, *, copy: bool = True, device: xc.Device | Sharding | None = None):
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
if (
src_dtype is not None
Expand All @@ -54,25 +35,3 @@ def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Shard
"your input."
)
return jnp.astype(x, dtype, copy=copy, device=device)


class FInfo(NamedTuple):
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: jnp.dtype

# TODO(micky774): Update jax.numpy.finfo so that its attributes are python
# floats
def finfo(type, /) -> FInfo:
info = jnp.finfo(type)
return FInfo(
bits=info.bits,
eps=float(info.eps),
max=float(info.max),
min=float(info.min),
smallest_normal=float(info.smallest_normal),
dtype=jnp.dtype(type)
)
3 changes: 3 additions & 0 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Known failures for the array api tests.

# finfo return type misalignment (https://github.com/data-apis/array-api/issues/405)
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_iop
array_api_tests/test_special_cases.py::test_nan_propagation
Expand Down

0 comments on commit 8c36f90

Please sign in to comment.